Ejemplo n.º 1
0
def run(soltab, axesToNorm, normVal=1., log=False):
    """
    Normalize the solutions to a given value
    WEIGHT: Weights compliant

    Parameters
    ----------
    axesToNorm : array of str
        Axes along which compute the normalization.

    normVal : float, optional
        Number to normalize to vals = vals * (normVal/valsMean), by default 1.

    log : bool, optional
        clip is done in log10 space, by default False.
    """
    import numpy as np

    logging.info("Normalizing soltab: " + soltab.name)

    # input check
    axesNames = soltab.getAxesNames()
    for normAxis in axesToNorm:
        if normAxis not in axesNames:
            logging.error('Normalization axis ' + normAxis + ' not found.')
            return 1

    if soltab.getType() == 'amplitude' and not log:
        logging.warning(
            'Amplitude solution tab detected and log=False. Amplitude solution tables should be treated in log space.'
        )

    for vals, weights, coord, selection in soltab.getValuesIter(
            returnAxes=axesToNorm, weight=True):

        if log: vals = np.log10(vals)

        # rescale solutions
        if np.all(weights == 0): continue  # skip flagged selections
        valsMean = np.nanmean(vals[weights != 0])

        if log: vals[weights != 0] += np.log10(normVal) - valsMean
        else: vals[weights != 0] *= normVal / valsMean

        logging.debug("Rescaling by: " + str(normVal / valsMean))

        # writing back the solutions
        if log: vals = 10**vals
        soltab.setValues(vals, selection)

    soltab.flush()
    soltab.addHistory('NORM (on axis %s)' % (axesToNorm))

    return 0
Ejemplo n.º 2
0
def getStepSoltabs(parser, step, H):
    """
    Return a list of soltabs object for a step and apply selection creteria

    Parameters
    ----------
    parser : parser obj
        configuration file

    step : str
        current step

    H : h5parm obj
        the h5parm object

    Returns
    -------
    list
        list of soltab obj with applied selection
    """

    # selection on soltabs
    if parser.has_option(step, 'soltab'):
        stsel = parser.getarraystr(step, 'soltab')
    elif parser.has_option('_global', 'soltab'):
        stsel = parser.getarraystr('_global', 'soltab')
    else:
        stsel = ['.*/.*']  # select all
    #if not type(stsel) is list: stsel = [stsel]

    soltabs = []
    for solset in H.getSolsets():
        for soltabName in solset.getSoltabNames():
            if any(
                    re.compile(this_stsel).match(solset.name + '/' +
                                                 soltabName)
                    for this_stsel in stsel):
                if parser.getstr(step, 'operation').lower() in cacheSteps:
                    soltabs.append(solset.getSoltab(soltabName, useCache=True))
                else:
                    soltabs.append(solset.getSoltab(soltabName,
                                                    useCache=False))

    if soltabs == []:
        logging.warning('No soltabs selected for step %s.' % step)

    # axes selection
    for soltab in soltabs:
        userSel = {}
        for axisName in soltab.getAxesNames():
            userSel[axisName] = getParAxis(parser, step, axisName)
        soltab.setSelection(**userSel)

    return soltabs
Ejemplo n.º 3
0
 def checkSpelling(self, s, soltab, availValues=[]):
     """
     check if any value in the step is missing from a value list and return a warning
     """
     entries = [x.lower() for x in list(dict(self.items(s)).keys())]
     availValues = ['soltab','operation'] + availValues + \
                 soltab.getAxesNames() + [a+'.minmaxstep' for a in soltab.getAxesNames()] + [a+'.regexp' for a in soltab.getAxesNames()]
     availValues = [x.lower() for x in availValues]
     for e in entries:
         if e not in availValues:
             logging.warning('Mispelled option: %s - Ignoring!' % e)
Ejemplo n.º 4
0
def _run_timestep(t,coord_rr,coord_ll,weights,vals,solType,coord,maxResidual):
    c = 2.99792458e8
    if solType == 'phase':
        idx       = ((weights[coord_rr,:] != 0.) & (weights[coord_ll,:] != 0.))
        freq      = np.copy(coord['freq'])[idx]
        phase_rr  = vals[coord_rr,:][idx]
        phase_ll  = vals[coord_ll,:][idx]
        # RR-LL to be consistent with BBS/NDPPP
        phase_diff  = (phase_rr - phase_ll)      # not divide by 2 otherwise jump problem, then later fix this
    else: # rotation table
        idx        = ((weights[:] != 0.) & (weights[:] != 0.))
        freq       = np.copy(coord['freq'])[idx]
        phase_diff = 2.*vals[:][idx] # a rotation is between -pi and +pi

    if len(freq) < 20:
        fitresultrm_wav = [0]
        weight = 0
        logging.warning('No valid data found for Faraday fitting for antenna: '+coord['ant']+' at timestamp '+str(t))
    else:
        # if more than 1/4 of chans are flagged
        if (len(idx) - len(freq))/float(len(idx)) > 1/4.:
            logging.debug('High number of filtered out data points for the timeslot %i: %i/%i' % (t, len(idx) - len(freq), len(idx)) )

        wav = c/freq

        #fitresultrm_wav, success = scipy.optimize.leastsq(rmwavcomplex, [fitrmguess], args=(wav, phase_diff))
        ranges = slice(-0.1, 0.1, 2e-4)
        fitresultrm_wav = scipy.optimize.brute(costfunctionRM, (ranges,), finish=scipy.optimize.leastsq, args=(wav, phase_diff))        

        # fractional residual
        residual = np.nanmean(np.abs(np.mod((2.*fitresultrm_wav*wav*wav)-phase_diff + np.pi, 2.*np.pi) - np.pi))
        if maxResidual == 0 or residual < maxResidual:
            fitrmguess = fitresultrm_wav[0] # Don't know what this is for...
            weight = 1
        else:       
            # high residual, flag
            logging.warning('Bad solution for ant: '+coord['ant']+' (time: '+str(t)+', resdiaul: '+str(residual)+').')
            weight = 0

    return fitresultrm_wav[0],weight
Ejemplo n.º 5
0
def run(soltab, soltabsToSub, ratio=False):
    """
    Subtract/divide two tables or a clock/tec/tec3rd/rm from a phase.

    Parameters
    ----------
    soltabsToSub : list of str
        List of soltabs to subtract

    ratio : bool, optional
        Return the ratio instead of subtracting, by default False.
    """
    import numpy as np

    logging.info("Subtract soltab: " + soltab.name)

    solset = soltab.getSolset()
    for soltabToSub in soltabsToSub:
        soltabsub = solset.getSoltab(soltabToSub)

        if soltab.getType() != 'phase' and (
                soltabsub.getType() == 'tec' or soltabsub.getType() == 'clock'
                or soltabsub.getType() == 'rotationmeasure'
                or soltabsub.getType() == 'tec3rd'):
            logging.warning(
                soltabToSub +
                ' is of type clock/tec/rm and should be subtracted from a phase. Skipping it.'
            )
            return 1
        logging.info('Subtracting table: ' + soltabToSub)

        # a major speed up if tables are assumed with same axes, check that (should be the case in almost any case)
        for i, axisName in enumerate(soltabsub.getAxesNames()):
            # also armonise selection by copying only the axes present in the outtable and in the right order
            soltabsub.selection[i] = soltab.selection[
                soltab.getAxesNames().index(axisName)]
            assert (soltabsub.getAxisValues(axisName) == soltab.getAxisValues(
                axisName)).all()  # table not conform

        if soltab.getValues(retAxesVals=False,
                            weight=False).shape != soltabsub.getValues(
                                retAxesVals=False, weight=False).shape:
            hasMissingAxes = True
        else:
            hasMissingAxes = False

        if soltabsub.getType() == 'clock' or soltabsub.getType(
        ) == 'tec' or soltabsub.getType() == 'tec3rd' or soltabsub.getType(
        ) == 'rotationmeasure' or hasMissingAxes:

            freq = soltab.getAxisValues('freq')
            vals = soltab.getValues(retAxesVals=False, weight=False)
            weights = soltab.getValues(retAxesVals=False, weight=True)
            #print 'vals', vals.shape

            # valsSub doesn't have freq
            valsSub = soltabsub.getValues(retAxesVals=False, weight=False)
            weightsSub = soltabsub.getValues(retAxesVals=False, weight=True)
            #print 'valsSub', valsSub.shape

            # add missing axes and move it to the last position
            expand = [
                soltab.getAxisLen(ax) for ax in soltab.getAxesNames()
                if ax not in soltabsub.getAxesNames()
            ]
            #print "expand:", expand
            valsSub = np.resize(valsSub, expand + list(valsSub.shape))
            weightsSub = np.resize(weightsSub, expand + list(weightsSub.shape))
            #print 'valsSub missing axes', valsSub.shape

            # reorder axes to match soltab
            names = [
                ax for ax in soltab.getAxesNames()
                if ax not in soltabsub.getAxesNames()
            ] + soltabsub.getAxesNames()
            #print names, soltab.getAxesNames()
            valsSub = reorderAxes(valsSub, names, soltab.getAxesNames())
            weightsSub = reorderAxes(weightsSub, names, soltab.getAxesNames())
            weights[weightsSub == 0] = 0  # propagate flags
            #print 'valsSub reorder', valsSub.shape

            # put freq axis at the end
            idxFreq = soltab.getAxesNames().index('freq')
            vals = np.swapaxes(vals, idxFreq, len(vals.shape) - 1)
            valsSub = np.swapaxes(valsSub, idxFreq, len(valsSub.shape) - 1)
            #print 'vals reshaped', valsSub.shape

            # a multiplication will go along the last axis of the array
            if soltabsub.getType() == 'clock':
                vals -= 2. * np.pi * valsSub * freq

            elif soltabsub.getType() == 'tec':
                vals -= -8.44797245e9 * valsSub / freq

            elif soltabsub.getType() == 'tec3rd':
                vals -= -1.e21 * valsSub / np.power(freq, 3)

            elif soltabsub.getType() == 'rotationmeasure':
                # put pol axis at the beginning
                idxPol = soltab.getAxesNames().index('pol')
                if idxPol == len(vals.shape) - 1:
                    idxPol = idxFreq  # maybe freq swapped with pol
                vals = np.swapaxes(vals, idxPol, 0)
                valsSub = np.swapaxes(valsSub, idxPol, 0)
                #print 'vals reshaped 2', valsSub.shape

                wav = 2.99792458e8 / freq
                ph = wav * wav * valsSub
                #if coord['pol'] == 'XX' or coord['pol'] == 'RR':
                pols = soltab.getAxisValues('pol')
                assert len(pols) == 2  # full jons not supported
                if (pols[0] == 'XX' and pols[1] == 'YY') or \
                   (pols[0] == 'RR' and pols[1] == 'LL'):
                    vals[0] -= ph[0]
                    vals[1] += ph[1]
                else:
                    vals[0] += ph[0]
                    vals[1] -= ph[1]

                vals = np.swapaxes(vals, 0, idxPol)

            else:
                if ratio:
                    vals = (vals - valsSub) / valsSub
                else:
                    vals -= valsSub

            # move freq axis back
            vals = np.swapaxes(vals, len(vals.shape) - 1, idxFreq)

            soltab.setValues(vals)
            soltab.setValues(weights, weight=True)
        else:
            if ratio:
                soltab.setValues((soltab.getValues(retAxesVals=False) -
                                  soltabsub.getValues(retAxesVals=False)) /
                                 soltabsub.getValues(retAxesVals=False))
            else:
                soltab.setValues(
                    soltab.getValues(retAxesVals=False) -
                    soltabsub.getValues(retAxesVals=False))
            weight = soltab.getValues(retAxesVals=False, weight=True)
            weight[soltabsub.getValues(retAxesVals=False, weight=True) ==
                   0] = 0
            soltab.setValues(weight, weight=True)

    soltab.addHistory('RESIDUALS by subtracting tables ' +
                      ' '.join(soltabsToSub))

    return 0
Ejemplo n.º 6
0
def run(soltab, refAnt='', soltabError='', ncpu=0):
    """
    Remove jumps from TEC solutions.
    WEIGHT: uses the errors.

    Parameters
    ----------
    soltabError : str, optional
        The table name with solution errors. By default it has the same name of soltab with "error" in place of "tec".

    refAnt : str, optional
        Reference antenna for phases. By default None.

    """

    import scipy.ndimage.filters
    import numpy as np
    from scipy.optimize import minimize
    from scipy.interpolate import griddata
    import scipy.cluster.vq as vq

    def getPhaseWrapBase(freqs):
        """
        freqs: frequency grid of the data
        return the step size from a local minima (2pi phase wrap) to the others [0]: TEC, [1]: clock
        """
        freqs = np.array(freqs)
        nF = freqs.shape[0]
        A = np.zeros((nF, 2), dtype=np.float)
        A[:, 1] = freqs * 2 * np.pi * 1e-9
        A[:, 0] = -8.44797245e9 / freqs
        steps = np.dot(np.dot(np.linalg.inv(np.dot(A.T, A)), A.T),
                       2 * np.pi * np.ones((nF, ), dtype=np.float))
        return steps

    if soltab.getType() != 'tec':
        logging.error('TECJUMP works only on tec solutions.')
        return 1

    if soltabError == '': soltabError = soltab.name.replace('tec', 'error')
    solset = soltab.getSolset()
    soltab_e = solset.getSoltab(soltabError)
    try:
        seltab_e = solset.getSoltab(soltabError)
    except:
        logging.error('Cannot fine error solution table %s.' % soltabError)
        return 1
    vals_e_all = soltab_e.getValues(retAxesVals=False, weight=False)

    logging.info("Removing TEC jumps from soltab: " + soltab.name)

    ants = soltab.getAxisValues('ant')
    if refAnt != '' and refAnt != 'closest' and not refAnt in soltab.getAxisValues(
            'ant', ignoreSelection=True):
        logging.warning('Reference antenna ' + refAnt + ' not found. Using: ' +
                        ants[1])
        refAnt = ants[0]
    if refAnt == '': refAnt = None

    # Get the theoretical tec jump
    tec_jump_theory = abs(getPhaseWrapBase([42.308e6, 42.308e6 + 23828e6])[0])

    # Find the average jump on all antennas by averaging all jumps found to be withing 0.5 and 2 times the tec_jump_theory
    vals = soltab.getValues(retAxesVals=False, refAnt=refAnt)
    timeAxis = soltab.getAxesNames().index('time')
    vals = np.swapaxes(vals, 0, timeAxis)
    vals = vals[1, ...] - vals[:-1, ...]
    vals = vals[(vals > tec_jump_theory * 1) & (vals < tec_jump_theory * 1.5)]
    if len(vals) == 0:
        logging.info('TEC jump - theoretical: %.5f TECU - NO JUMP FOUND' %
                     (tec_jump_theory))
        return 0
    tec_jump = np.nanmedian(vals)

    logging.info('TEC jump - theoretical: %.5f TECU - estimated: %.5f TECU' %
                 (tec_jump_theory, tec_jump))

    mpm = multiprocManager(ncpu, _run_antenna)

    for vals, weights, coord, selection in soltab.getValuesIter(
            returnAxes='time', weight=True, refAnt=refAnt):

        # skip all flagged
        if (weights == 0).all(): continue
        # skip reference
        if (vals[(weights != 0)] == 0).all(): continue

        logging.info('Working on ant: %s' % (coord['ant']))

        # 1d linear interp to fill flagged data
        vals[np.where(weights == 0)] = np.interp(
            np.where(weights == 0)[0],
            np.where(weights != 0)[0], vals[np.where(weights != 0)])

        vals_init = np.copy(vals)  # backup for final check
        vals_e = np.squeeze(vals_e_all[selection])
        vals_e[np.where(weights == 0)] = 1.

        mpm.put([
            vals, vals_e, vals_init, weights, selection, tec_jump, coord['ant']
        ])

    mpm.wait()
    for (vals, weights, selection) in mpm.get():
        # set back to 0 the values for flagged data
        vals[weights == 0] = 0
        soltab.setValues(vals, selection)
    soltab.addHistory('TECJUMP')

    return 0
Ejemplo n.º 7
0
def run(soltab,
        chanWidth='',
        outSoltabName='bandpass',
        BadSBList='',
        interpolate=True,
        removeTimeAxis=True,
        autoFlag=False,
        nSigma=5.0,
        maxFlaggedFraction=0.5,
        maxStddev=0.01,
        ncpu=0):
    """
    This operation for LoSoTo implements the Prefactor bandpass operation
    WEIGHT: flag-only compliant, no need for weight

    Parameters
    ----------
    chanWidth : str or float, optional
        the width of each channel in the data from which solutions were obtained. Can be
        either a string like "48kHz" or a float in Hz. If interpolate = True, chanWidth
        must be specified
    BadSBList : str, optional
        a list of bad subbands that will be flagged
    outSoltabName : str, optional
        Name of the output bandpass soltab. An existing soltab with this name will be
        overwritten
    interpolate : bool, optional
        If True, interpolate to a regular frequency grid and then smooth, ignoring bad
        subbands. If False, neither interpolation nor smoothing is done and the output
        frequency grid is the same as the input one. If interpolate = True, chanWidth
        must be specified
    removeTimeAxis : bool, optional
        If True, the time axis of the output bandpass soltab is removed by doing a median
        over time. If False, the output time grid is the same as the input one
    autoFlag : bool, optional
        If True, automatically flag bad frequencies and stations
    nSigma : float, optional
        Number of sigma for autoFlagging. Amplitudes outside of nSigma*stddev are flagged
    maxFlaggedFraction : float, optional
        Maximum allowable fraction of flagged frequencies for autoFlagging. Stations with
        higher fractions will be completely flagged
    maxStddev : float, optional
        Maximum allowable standard deviation for autoFlagging
    ncpu : int, optional
        Number of CPUs to use during autoFlagging (0 = all)
    """
    import numpy as np
    import scipy
    import scipy.ndimage

    logging.info("Running prefactor_bandpass on: " + soltab.name)
    solset = soltab.getSolset()

    solType = soltab.getType()
    if solType != 'amplitude':
        logging.warning("Soltab type of " + soltab.name + " is: " + solType +
                        " should be amplitude. Ignoring.")
        return 1
    if soltab.name == outSoltabName and (removeTimeAxis or interpolate):
        logging.error(
            "If removeTimeAxis = True or interpolate = True, outSoltabName must specify a new soltab."
        )
        raise ValueError(
            "If removeTimeAxis = True or interpolate = True, outSoltabName must specify a new soltab."
        )

    if BadSBList == '':
        bad_sblist = []
    else:
        bad_sblist = [int(SB) for SB in BadSBList.strip('\"\'').split(';')]

    for vals, weights, coord, selection in soltab.getValuesIter(
            returnAxes=['pol', 'ant', 'freq', 'time'], weight=True):
        vals = reorderAxes(vals, soltab.getAxesNames(),
                           ['time', 'ant', 'freq', 'pol'])
        weights = reorderAxes(weights, soltab.getAxesNames(),
                              ['time', 'ant', 'freq', 'pol'])
        pass

    amplitude_arraytmp = vals  # axes are [time, ant, freq, pol]
    weights_arraytmp = weights  # axes are [time, ant, freq, pol]
    flagged = np.where(amplitude_arraytmp == 1.0)
    weights_arraytmp[flagged] = 0.0
    nfreqs = len(soltab.freq[:])
    ntimes = len(soltab.time[:])
    nants = len(soltab.ant[:])

    subbandHz = 195.3125e3
    if interpolate:
        if chanWidth == '':
            logging.error(
                "If interpolate = True, chanWidth must be specified.")
            raise ValueError(
                "If interpolate = True, chanWidth must be specified.")
        if type(chanWidth) is str:
            letters = [1 for s in chanWidth[::-1] if s.isalpha()]
            indx = len(chanWidth) - sum(letters)
            unit = chanWidth[indx:]
            if unit.strip().lower() == 'hz':
                conversion = 1.0
            elif unit.strip().lower() == 'khz':
                conversion = 1e3
            elif unit.strip().lower() == 'mhz':
                conversion = 1e6
            else:
                logging.error("The unit on chanWidth was not understood.")
                raise ValueError("The unit on chanWidth was not understood.")
            chanWidthHz = float(chanWidth[:indx]) * conversion
        else:
            chanWidthHz = chanWidth
        offsetHz = subbandHz / 2.0 - 0.5 * chanWidthHz
        freqmin = np.min(
            soltab.freq[:]) + offsetHz  # central frequency of first subband
        freqmax = np.max(
            soltab.freq[:]) + offsetHz  # central frequency of last subband
        SBgrid = np.floor(
            (soltab.freq[:] - np.min(soltab.freq[:])) / subbandHz)
        freqs_new = np.arange(freqmin, freqmax + 100e3, subbandHz)
        amps_array_flagged = np.zeros((nants, ntimes, len(freqs_new), 2),
                                      dtype='float')
        amps_array = np.zeros((nants, ntimes, len(freqs_new), 2),
                              dtype='float')
        weights_array = np.ones((nants, ntimes, len(freqs_new), 2),
                                dtype='float')

        logging.info("Have " + str(max(SBgrid)) + " subbands.")
        if len(freqs_new) < 20:
            logging.error(
                "Frequency span is less than 20 subbands! The filtering will not work!"
            )
            logging.error(
                "Please run the calibrator pipeline on the full calibrator bandwidth."
            )
            raise ValueError(
                "Frequency span is less than 20 subbands! Amplitude filtering will not work!"
            )

        # make a mapping of new frequencies to old ones
        freq_mapping = {}
        for fn in freqs_new:
            ind = np.where(
                np.logical_and(soltab.freq < fn + subbandHz / 2.0,
                               soltab.freq >= fn - subbandHz / 2.0))
            freq_mapping['{}'.format(fn)] = ind

    # remove bad subbands specified by user
    for bad_sb in bad_sblist:
        logging.info('Removing user-specified subband: ' + str(bad_sb))
        weights_arraytmp[:, :, bad_sb, :] = 0.0

    # remove bad solutions relative to the model bandpass
    if autoFlag:
        if ncpu == 0:
            import multiprocessing
            ncpu = multiprocessing.cpu_count()
        mpm = multiprocManager(ncpu, _flag_amplitudes)
        for s in range(nants):
            mpm.put([
                soltab.freq[:], amplitude_arraytmp[:, s, :, :],
                weights_arraytmp[:, s, :, :], nSigma, maxFlaggedFraction,
                maxStddev, False, s
            ])
        mpm.wait()
        for (s, w) in mpm.get():
            weights_arraytmp[:, s, :, :] = w

    # Now interpolate over flagged values and smooth over frequency and time axes
    if interpolate:
        for antenna_id in range(len(soltab.ant[:])):
            for time in range(len(soltab.time[:])):
                amp_xx_tmp = np.copy(amplitude_arraytmp[time, antenna_id, :,
                                                        0])
                amp_yy_tmp = np.copy(amplitude_arraytmp[time, antenna_id, :,
                                                        1])
                freq_tmp = soltab.freq[:]
                assert len(amp_xx_tmp[:]) == len(freq_tmp[:])
                mask_xx = np.not_equal(
                    weights_arraytmp[time, antenna_id, :, 0], 0.0)
                if np.sum(mask_xx) > 2:
                    amps_xx_tointer = amp_xx_tmp[mask_xx]
                    freq_xx_tointer = freq_tmp[mask_xx]
                    amps_array_flagged[antenna_id, time, :,
                                       0] = np.interp(freqs_new,
                                                      freq_xx_tointer,
                                                      amps_xx_tointer)
                elif time > 0:
                    amps_array_flagged[antenna_id, time, :,
                                       0] = amps_array_flagged[antenna_id,
                                                               (time - 1), :,
                                                               0]
                mask_yy = np.not_equal(
                    weights_arraytmp[time, antenna_id, :, 1], 0.0)
                if np.sum(mask_yy) > 2:
                    amps_yy_tointer = amp_yy_tmp[mask_yy]
                    freq_yy_tointer = freq_tmp[mask_yy]
                    amps_array_flagged[antenna_id, time, :,
                                       1] = np.interp(freqs_new,
                                                      freq_yy_tointer,
                                                      amps_yy_tointer)
                elif time > 0:
                    amps_array_flagged[antenna_id, time, :,
                                       1] = amps_array_flagged[antenna_id,
                                                               (time - 1), :,
                                                               1]

        ampsoutfile = open('calibrator_amplitude_array.txt', 'w')
        ampsoutfile.write(
            '# Antenna name, Antenna ID, subband, XXamp, YYamp, frequency\n')
        for antenna_id in range(len(soltab.ant[:])):
            if np.all(weights_arraytmp[:, antenna_id, :, :] == 0.0):
                weights_array[antenna_id, :, :, :] = 0.0
            else:
                amp_xx = np.copy(amps_array_flagged[antenna_id, :, :, 0])
                amp_yy = np.copy(amps_array_flagged[antenna_id, :, :, 1])

                amp_xx = scipy.ndimage.filters.median_filter(amp_xx, (3, 3))
                amp_xx = scipy.ndimage.filters.median_filter(amp_xx, (7, 1))
                amp_yy = scipy.ndimage.filters.median_filter(amp_yy, (3, 3))
                amp_yy = scipy.ndimage.filters.median_filter(amp_yy, (7, 1))

                for i in range(len(freqs_new)):
                    ampsoutfile.write(
                        '%s %s %s %s %s %s\n' %
                        (soltab.ant[antenna_id], antenna_id, i,
                         np.median(amp_xx[:, i], axis=0),
                         np.median(amp_yy[:, i], axis=0), freqs_new[i]))

                for time in range(len(soltab.time[:])):
                    amps_array[antenna_id, time, :, 0] = np.copy(
                        _savitzky_golay(amp_xx[time, :], 17, 2))
                    amps_array[antenna_id, time, :, 1] = np.copy(
                        _savitzky_golay(amp_yy[time, :], 17, 2))

                for i in range(len(freqs_new)):
                    amps_array[antenna_id, :, i,
                               0] = np.median(amps_array[antenna_id, :, i, 0])
                    amps_array[antenna_id, :, i,
                               1] = np.median(amps_array[antenna_id, :, i, 1])
                    ind = freq_mapping['{}'.format(freqs_new[i])]
                    for p in range(2):
                        # If half or more of original frequencies are flagged, flag the
                        # output frequency as well
                        nflagged = len(
                            np.where(weights_arraytmp[:, antenna_id, ind,
                                                      p] == 0.0)[0])
                        ntot = weights_arraytmp.shape[0] * len(ind[0])
                        if ntot > 0:
                            if float(nflagged) / float(ntot) >= 0.5:
                                weights_array[antenna_id, :, i, p] = 0.0
        amps_array = amps_array.swapaxes(0, 1)
        weights_array = weights_array.swapaxes(0, 1)
    else:
        amps_array = amplitude_arraytmp
        weights_array = weights_arraytmp
        freqs_new = soltab.freq[:]

    # delete existing bandpass soltab if needed and write solutions
    if soltab.name != outSoltabName:
        try:
            new_soltab = solset.getSoltab(outSoltabName)
            new_soltab.delete()
        except:
            pass

    if removeTimeAxis:
        # Write bandpass, taking median over the time axis
        new_soltab = solset.makeSoltab(
            soltype='amplitude',
            soltabName=outSoltabName,
            axesNames=['ant', 'freq', 'pol'],
            axesVals=[soltab.ant, freqs_new, ['XX', 'YY']],
            vals=np.median(amps_array, axis=0),
            weights=np.median(weights_array, axis=0))
        new_soltab.addHistory(
            'CREATE (by PREFACTOR_BANDPASS operation) with BadSBList = {0}, '
            'interpolate={1}, removeTimeAxis={2}, autoFlag={3}, nSigma={4}, '
            'maxFlaggedFraction={5}, maxStddev={6}'.format(
                BadSBList, interpolate, removeTimeAxis, autoFlag, nSigma,
                maxFlaggedFraction, maxStddev))
    else:
        # Write bandpass, preserving the time axis
        if soltab.name == outSoltabName:
            soltab.setValues(amps_array)
            soltab.setValues(weights_array, weight=True)
            soltab.addHistory(
                'BANDPASS processed with BadSBList = {0}, interpolate={1}, '
                'removeTimeAxis={2}, autoFlag={3}, nSigma={4}, '
                'maxFlaggedFraction={5}, maxStddev={6}'.format(
                    BadSBList, interpolate, removeTimeAxis, autoFlag, nSigma,
                    maxFlaggedFraction, maxStddev))
        else:
            new_soltab = solset.makeSoltab(
                soltype='amplitude',
                soltabName=outSoltabName,
                axesNames=['time', 'ant', 'freq', 'pol'],
                axesVals=[soltab.time, soltab.ant, freqs_new, ['XX', 'YY']],
                vals=amps_array,
                weights=weights_array)
            new_soltab.addHistory(
                'CREATE (by PREFACTOR_BANDPASS operation) with BadSBList = {0}, '
                'interpolate={1}, removeTimeAxis={2}, autoFlag={3}, nSigma={4}, '
                'maxFlaggedFraction={5}, maxStddev={6}'.format(
                    BadSBList, interpolate, removeTimeAxis, autoFlag, nSigma,
                    maxFlaggedFraction, maxStddev))

    return 0
Ejemplo n.º 8
0
def run(soltab,
        tecsoltabOut='tec000',
        clocksoltabOut='clock000',
        offsetsoltabOut='phase_offset000',
        tec3rdsoltabOut='tec3rd000',
        flagBadChannels=True,
        flagCut=5.,
        chi2cut=3000.,
        combinePol=False,
        removePhaseWraps=True,
        fit3rdorder=False,
        circular=False,
        reverse=False,
        invertOffset=False,
        nproc=10):
    """
    Separate phase solutions into Clock and TEC.
    The Clock and TEC values are stored in the specified output soltab with type 'clock', 'tec', 'tec3rd'.

    Parameters
    ----------
    flagBadChannels : bool, optional
        Detect and remove bad channel before fitting, by default True.

    flagCut : float, optional


    chi2cut : float, optional


    combinePol : bool, optional
        Find a combined polarization solution, by default False.

    removePhaseWraps : bool, optional
        Detect and remove phase wraps, by default True.

    fit3rdorder : bool, optional
        Fit a 3rd order ionospheric ocmponent (usefult <40 MHz). By default False.

    circular : bool, optional
        Assume circular polarization with FR not removed. By default False.

    reverse : bool, optional
        Reverse the time axis. By default False.

    invertOffset : bool, optional
        Invert (reverse the sign of) the phase offsets. By default False. Set to True
        if you want to use them with the residuals operation.
    """
    import numpy as np
    from ._fitClockTEC import doFit

    logging.info("Clock/TEC separation on soltab: " + soltab.name)

    # some checks
    solType = soltab.getType()
    if solType != 'phase':
        logging.warning("Soltab type of " + soltab.name + " is: " + solType +
                        " should be phase. Ignoring.")
        return 1

    # Collect station properties
    solset = soltab.getSolset()
    station_dict = solset.getAnt()
    stations = soltab.getAxisValues('ant')
    station_positions = np.zeros((len(stations), 3), dtype=np.float)
    for i, station_name in enumerate(stations):
        station_positions[i, 0] = station_dict[station_name][0]
        station_positions[i, 1] = station_dict[station_name][1]
        station_positions[i, 2] = station_dict[station_name][2]

    returnAxes = ['ant', 'freq', 'pol', 'time']
    for vals, flags, coord, selection in soltab.getValuesIter(
            returnAxes=returnAxes, weight=True):

        if len(coord['ant']) < 2:
            logging.error(
                'Clock/TEC separation needs at least 2 antennas selected.')
            return 1
        if len(coord['freq']) < 10:
            logging.error(
                'Clock/TEC separation needs at least 10 frequency channels, preferably distributed over a wide range'
            )
            return 1

        freqs = coord['freq']
        stations = coord['ant']
        times = coord['time']

        # get axes index
        axes = [i for i in soltab.getAxesNames() if i in returnAxes]

        # reverse time axes
        if reverse:
            vals = np.swapaxes(
                np.swapaxes(vals, 0, axes.index('time'))[::-1], 0,
                axes.index('time'))
            flags = np.swapaxes(
                np.swapaxes(flags, 0, axes.index('time'))[::-1], 0,
                axes.index('time'))

        result=doFit(vals,flags==0,freqs,stations,station_positions,axes,\
                         flagBadChannels=flagBadChannels,flagcut=flagCut,chi2cut=chi2cut,combine_pol=combinePol,removePhaseWraps=removePhaseWraps,fit3rdorder=fit3rdorder,circular=circular,n_proc=nproc)
        if fit3rdorder:
            clock, tec, offset, tec3rd = result
            if reverse:
                clock = clock[::-1, :]
                tec = tec[::-1, :]
                tec3rd = tec3rd[::-1, :]
        else:
            clock, tec, offset = result
            if reverse:
                clock = clock[::-1, :]
                tec = tec[::-1, :]
        if invertOffset:
            offset *= -1.0

        weights = tec > -5
        tec[np.logical_not(weights)] = 0
        clock[np.logical_not(weights)] = 0
        weights = np.float16(weights)

        if combinePol or not 'pol' in soltab.getAxesNames():
            tf_st = solset.makeSoltab('tec',
                                      soltabName=tecsoltabOut,
                                      axesNames=['time', 'ant'],
                                      axesVals=[times, stations],
                                      vals=tec[:, :, 0],
                                      weights=weights[:, :, 0])
            tf_st.addHistory('CREATE (by CLOCKTECFIT operation)')
            tf_st = solset.makeSoltab('clock',
                                      soltabName=clocksoltabOut,
                                      axesNames=['time', 'ant'],
                                      axesVals=[times, stations],
                                      vals=clock[:, :, 0] * 1e-9,
                                      weights=weights[:, :, 0])
            tf_st.addHistory('CREATE (by CLOCKTECFIT operation)')
            tf_st = solset.makeSoltab('phase',
                                      soltabName=offsetsoltabOut,
                                      axesNames=['ant'],
                                      axesVals=[stations],
                                      vals=offset[:, 0],
                                      weights=np.ones_like(offset[:, 0],
                                                           dtype=np.float16))
            tf_st.addHistory('CREATE (by CLOCKTECFIT operation)')
            if fit3rdorder:
                tf_st = solset.makeSoltab('tec3rd',
                                          soltabName=tec3rdsoltabOut,
                                          axesNames=['time', 'ant'],
                                          axesVals=[times, stations],
                                          vals=tec3rd[:, :, 0],
                                          weights=weights[:, :, 0])
        else:
            tf_st = solset.makeSoltab('tec',
                                      soltabName=tecsoltabOut,
                                      axesNames=['time', 'ant', 'pol'],
                                      axesVals=[times, stations, ['XX', 'YY']],
                                      vals=tec,
                                      weights=weights)
            tf_st.addHistory('CREATE (by CLOCKTECFIT operation)')
            tf_st = solset.makeSoltab('clock',
                                      soltabName=clocksoltabOut,
                                      axesNames=['time', 'ant', 'pol'],
                                      axesVals=[times, stations, ['XX', 'YY']],
                                      vals=clock * 1e-9,
                                      weights=weights)
            tf_st.addHistory('CREATE (by CLOCKTECFIT operation)')
            tf_st = solset.makeSoltab('phase',
                                      soltabName=offsetsoltabOut,
                                      axesNames=['ant', 'pol'],
                                      axesVals=[stations, ['XX', 'YY']],
                                      vals=offset,
                                      weights=np.ones_like(offset,
                                                           dtype=np.float16))
            tf_st.addHistory('CREATE (by CLOCKTECFIT operation)')
            if fit3rdorder:
                tf_st = solset.makeSoltab(
                    'tec3rd',
                    soltabName=tec3rdsoltabOut,
                    axesNames=['time', 'ant', 'pol'],
                    axesVals=[times, stations, ['XX', 'YY']],
                    vals=tec3rd,
                    weights=weights)
    return 0
Ejemplo n.º 9
0
def run(soltab,
        mode,
        maxFlaggedFraction=0.5,
        nSigma=5.0,
        maxStddev=None,
        ampRange=None,
        telescope='lofar',
        skipInternational=False,
        refAnt='',
        soltabExport='',
        ncpu=0):
    """
    This operation for LoSoTo implements a station-flagging procedure. Flags are time-independent.
    WEIGHT: compliant

    Parameters
    ----------
    mode: str
        Fitting algorithm: bandpass or resid. Bandpass mode clips amplitudes relative to a model bandpass (only LOFAR is currently supported). Resid mode clips residual phases or log(amplitudes).

    maxFlaggedFraction : float, optional
        This sets the maximum allowable fraction of flagged solutions above which the entire station is flagged.

    nSigma : float, optional
        This sets the number of standard deviations considered when outlier clipping is done.

    maxStddev : float, optional
        Maximum allowable standard deviation when outlier clipping is done. For phases, this should value
        should be in radians, for amplitudes in log(amp). If None (or negative), a value of 0.1 rad is
        used for phases and 0.01 for amplitudes.

    ampRange : array, optional
        2-element array of the median amplitude level to be acceptable, ampRange[0]: lower limit, ampRange[1]: upper limit.
        If None or [0, 0], a reasonable range for typical observations is used.

    telescope : str, optional
        Specifies the telescope if mode = 'bandpass'.

    skipInternational : str, optional
        If True, skip flagging of international LOFAR stations (only used if telescope = 'lofar')

    refAnt : str, optional
        If mode = resid, this sets the reference antenna for phase solutions, by default None.

    soltabExport : str, optional
        Soltab to export station flags to. Note: exported flags are not time- or frequency-dependent.

    ncpu : int, optional
        Number of cpu to use, by default all available.
    """

    logging.info("Flagging on soltab: " + soltab.name)

    # input check
    if refAnt == '':
        refAnt = None
    if soltabExport == '':
        soltabExport = None
    if mode is None or mode.lower() not in ['bandpass', 'resid']:
        logging.error('Mode must be one of bandpass or resid')
        return 1
    solType = soltab.getType()
    if maxStddev is None or maxStddev <= 0.0:
        if solType == 'phase':
            maxStddev = 0.1  # in radians
        else:
            maxStddev = 0.01  # in log10(amp)

    # Axis order must be [time, ant, freq, pol], so reorder if necessary
    axis_names = soltab.getAxesNames()
    if ('freq' not in axis_names or 'pol' not in axis_names
            or 'time' not in axis_names or 'ant' not in axis_names):
        logging.error("Currently, flagstation requires the following axes: "
                      "freq, pol, time, and ant.")
        return 1
    freq_ind = axis_names.index('freq')
    pol_ind = axis_names.index('pol')
    time_ind = axis_names.index('time')
    ant_ind = axis_names.index('ant')
    if 'dir' in axis_names:
        dir_ind = axis_names.index('dir')
        vals_arraytmp = soltab.val[:].transpose(
            [time_ind, ant_ind, freq_ind, pol_ind, dir_ind])
        weights_arraytmp = soltab.weight[:].transpose(
            [time_ind, ant_ind, freq_ind, pol_ind, dir_ind])
    else:
        vals_arraytmp = soltab.val[:].transpose(
            [time_ind, ant_ind, freq_ind, pol_ind])
        weights_arraytmp = soltab.weight[:].transpose(
            [time_ind, ant_ind, freq_ind, pol_ind])

    # Check for NaN solutions and flag
    flagged = np.where(np.isnan(vals_arraytmp))
    weights_arraytmp[flagged] = 0.0

    if mode == 'bandpass':
        if solType != 'amplitude':
            logging.error(
                "Soltab must be of type amplitude for bandpass mode.")
            return 1

        # Fill the queue
        if 'dir' in axis_names:
            for d, dirname in enumerate(soltab.dir):
                mpm = multiprocManager(ncpu, _flag_bandpass)
                for s in range(len(soltab.ant)):
                    if ('CS' not in soltab.ant[s] and 'RS' not in soltab.ant[s]
                            and skipInternational
                            and telescope.lower() == 'lofar'):
                        continue
                    mpm.put([
                        soltab.freq[:], vals_arraytmp[:, s, :, :, d],
                        weights_arraytmp[:, s, :, :,
                                         d], telescope, nSigma, ampRange,
                        maxFlaggedFraction, maxStddev, False, soltab.ant[:], s
                    ])
                mpm.wait()
                for (s, w) in mpm.get():
                    weights_arraytmp[:, s, :, :, d] = w
        else:
            mpm = multiprocManager(ncpu, _flag_bandpass)
            for s in range(len(soltab.ant)):
                if ('CS' not in soltab.ant[s] and 'RS' not in soltab.ant[s]
                        and skipInternational
                        and telescope.lower() == 'lofar'):
                    continue
                mpm.put([
                    soltab.freq[:], vals_arraytmp[:, s, :, :],
                    weights_arraytmp[:, s, :, :], telescope, nSigma, ampRange,
                    maxFlaggedFraction, maxStddev, False, soltab.ant[:], s
                ])
            mpm.wait()
            for (s, w) in mpm.get():
                weights_arraytmp[:, s, :, :] = w

        # Make sure that fully flagged stations have all pols flagged
        for s in range(len(soltab.ant)):
            for p in range(len(soltab.pol)):
                if np.all(weights_arraytmp[:, s, :, p] == 0.0):
                    weights_arraytmp[:, s, :, :] = 0.0
                    break

        # Write new weights
        if 'dir' in axis_names:
            weights_array = weights_arraytmp.transpose(
                [time_ind, ant_ind, freq_ind, pol_ind, dir_ind])
        else:
            weights_array = weights_arraytmp.transpose(
                [time_ind, ant_ind, freq_ind, pol_ind])
        soltab.setValues(weights_array, weight=True)
        soltab.addHistory(
            'FLAGSTATION (mode=bandpass, telescope={0}, maxFlaggedFraction={1}, '
            'nSigma={2})'.format(telescope, maxFlaggedFraction, nSigma))
    else:
        if solType not in ['phase', 'amplitude']:
            logging.error(
                "Soltab must be of type phase or amplitude for resid mode.")
            return 1

        # Subtract reference phases
        if refAnt is not None:
            if solType != 'phase':
                logging.error(
                    'Reference possible only for phase solution tables. Ignoring referencing.'
                )
            else:
                if refAnt == 'nearest':
                    for i, antToRef in enumerate(soltab.getAxisValues('ant')):
                        # get the closest antenna
                        antDists = soltab.getSolset().getAntDist(
                            antToRef)  # this is a dict
                        for badAnt in soltab._getFullyFlaggedAnts():
                            del antDists[badAnt]  # remove bad ants
                        reference = list(antDists.keys(
                        ))[list(antDists.values()).index(
                            sorted(antDists.values())[1]
                        )]  # get the second closest antenna (the first is itself)
                        refInd = soltab.getAxisValues(
                            'ant',
                            ignoreSelection=True).tolist().index(reference)
                        if 'dir' in axis_names:
                            vals_arrayref = vals_arraytmp[:,
                                                          refInd, :, :, :].copy(
                                                          )
                        else:
                            vals_arrayref = vals_arraytmp[:,
                                                          refInd, :, :].copy()
                        if 'dir' in axis_names:
                            vals_arraytmp[:, i, :, :, :] -= vals_arrayref
                        else:
                            vals_arraytmp[:, i, :, :] -= vals_arrayref
                else:
                    ants = soltab.getAxisValues('ant')
                    if refAnt not in ants:
                        logging.warning('Reference antenna ' + refAnt +
                                        ' not found. Using: ' + ants[0])
                        refAnt = ants[0]
                    refInd = ants.tolist().index(refAnt)
                    if 'dir' in axis_names:
                        vals_arrayref = vals_arraytmp[:,
                                                      refInd, :, :, :].copy()
                    else:
                        vals_arrayref = vals_arraytmp[:, refInd, :, :].copy()
                    for i in range(len(soltab.ant)):
                        if 'dir' in axis_names:
                            vals_arraytmp[:, i, :, :, :] -= vals_arrayref
                        else:
                            vals_arraytmp[:, i, :, :] -= vals_arrayref

        # Fill the queue
        if 'dir' in axis_names:
            for d, dirname in enumerate(soltab.dir):
                mpm = multiprocManager(ncpu, _flag_resid)
                for s in range(len(soltab.ant)):
                    mpm.put([
                        vals_arraytmp[:, s, :, :, d],
                        weights_arraytmp[:, s, :, :, d], solType, nSigma,
                        maxFlaggedFraction, maxStddev, soltab.ant[:], s
                    ])
                mpm.wait()
                for (s, w) in mpm.get():
                    weights_arraytmp[:, s, :, :, d] = w
        else:
            mpm = multiprocManager(ncpu, _flag_resid)
            for s in range(len(soltab.ant)):
                mpm.put([
                    vals_arraytmp[:, s, :, :], weights_arraytmp[:, s, :, :],
                    solType, nSigma, maxFlaggedFraction, maxStddev,
                    soltab.ant[:], s
                ])
            mpm.wait()
            for (s, w) in mpm.get():
                weights_arraytmp[:, s, :, :] = w

        # Make sure that fully flagged stations have all pols flagged
        for s in range(len(soltab.ant)):
            for p in range(len(soltab.pol)):
                if np.all(weights_arraytmp[:, s, :, p] == 0.0):
                    weights_arraytmp[:, s, :, :] = 0.0
                    break

        # Write new weights
        if 'dir' in axis_names:
            weights_array = weights_arraytmp.transpose(
                [time_ind, ant_ind, freq_ind, pol_ind, dir_ind])
        else:
            weights_array = weights_arraytmp.transpose(
                [time_ind, ant_ind, freq_ind, pol_ind])
        soltab.setValues(weights_array, weight=True)
        soltab.addHistory('FLAGSTATION (mode=resid, maxFlaggedFraction={0}, '
                          'nSigma={1})'.format(maxFlaggedFraction, nSigma))

    if soltabExport is not None:
        # Transfer station flags to soltabExport
        solset = soltab.getSolset()
        soltabexp = solset.getSoltab(soltabExport)
        axis_namesexp = soltabexp.getAxesNames()

        for stat in soltabexp.ant:
            if stat in soltab.ant:
                s = soltab.ant[:].tolist().index(stat)
                if 'pol' in axis_namesexp:
                    for pol in soltabexp.pol:
                        if pol in soltab.pol:
                            soltabexp.setSelection(ant=stat, pol=pol)
                            p = soltab.pol[:].tolist().index(pol)
                            if np.all(weights_arraytmp[:, s, :, p] == 0):
                                soltabexp.setValues(np.zeros(
                                    soltabexp.weight.shape),
                                                    weight=True)
                else:
                    soltabexp.setSelection(ant=stat)
                    if np.all(weights_arraytmp[:, s, :, :] == 0):
                        soltabexp.setValues(np.zeros(soltabexp.weight.shape),
                                            weight=True)
        soltabexp.addHistory('WEIGHT imported by FLAGSTATION from ' +
                             soltab.name + '.')

    return 0
Ejemplo n.º 10
0
def run( soltab, soltabOut='phasediff', maxResidual=1., fitOffset=False, average=False, replace=False, minFreq=0, refAnt='' ):
    """
    Estimate polarization misalignment as delay.

    Parameters
    ----------
    soltabOut : str, optional
        output table name (same solset), by deault "phasediff".

    maxResidual : float, optional
        Maximum acceptable rms of the residuals in radians before flagging, by default 1. If 0: No check.

    fitOffset : bool, optional
        Assume that together with a delay each station has also a differential phase offset (important for old LBA observations). By default False.    

    average : bool, optional
        Mean-average in time the resulting delays/offset. By default False.    
    
    replace : bool, optional
        replace using smoothed value instead of flag bad data? Smooth must be active. By default, False.

    minFreq : float, optional
        minimum frequency [Hz] to use in estimating the PA. By default, 0 (all freqs).

    refAnt : str, optional
        Reference antenna, by default the first.
    """
    import numpy as np
    import scipy.optimize
    from scipy import stats
    from scipy.ndimage import generic_filter

    logging.info("Finding polarization align for soltab: "+soltab.name)

    delaycomplex = lambda d, freq, y: abs(np.cos(d[0]*freq)  - np.cos(y)) + abs(np.sin(d[0]*freq)  - np.sin(y))
    #delaycomplex = lambda d, freq, y: abs(d[0]*freq  - y)

    solType = soltab.getType()
    if solType != 'phase':
        logging.warning("Soltab type of "+soltab.name+" is of type "+solType+", should be phase. Ignoring.")
        return 1
    
    if refAnt != '' and refAnt != 'closest' and not refAnt in soltab.getAxisValues('ant', ignoreSelection = True):
        logging.warning('Reference antenna '+refAnt+' not found. Using: '+soltab.getAxisValues('ant')[1])
        refAnt = soltab.getAxisValues('ant')[1]
    if refAnt == '': refAnt = soltab.getAxisValues('ant')[1]

    # times and ants needs to be complete or selection is much slower
    times = soltab.getAxisValues('time')

    # create new table
    solset = soltab.getSolset()
    soltabout = solset.makeSoltab(soltype = soltab.getType(), soltabName = soltabOut, axesNames=soltab.getAxesNames(),
                      axesVals=[soltab.getAxisValues(axisName) for axisName in soltab.getAxesNames()],
                      vals=soltab.getValues(retAxesVals = False), weights=soltab.getValues(weight = True, retAxesVals = False))
    soltabout.addHistory('Created by POLALIGN operation from %s.' % soltab.name)

    if 'XX' in soltab.getAxisValues('pol'): pol = 'XX'
    elif 'RR' in soltab.getAxisValues('pol'): pol = 'RR'
    else:
        logging.error('Cannot reference to known polarisation.')
        return 1

    for vals, weights, coord, selection in soltab.getValuesIter(returnAxes=['freq','pol','time'], weight=True, refAnt=refAnt):

        # reorder axes
        vals = reorderAxes( vals, soltab.getAxesNames(), ['pol','freq','time'] )
        weights = reorderAxes( weights, soltab.getAxesNames(), ['pol','freq','time'] )

        if 'RR' in coord['pol'] and 'LL' in coord['pol']:
            coord1 = np.where(coord['pol'] == 'RR')[0][0]
            coord2 = np.where(coord['pol'] == 'LL')[0][0]
        elif 'XX' in coord['pol'] and 'YY' in coord['pol']:
            coord1 = np.where(coord['pol'] == 'XX')[0][0]
            coord2 = np.where(coord['pol'] == 'YY')[0][0]

        if (weights == 0.).all() == True:
            logging.warning('Skipping flagged antenna: '+coord['ant'])
            weights[:] = 0
        else:

            fit_delays=[]; fit_offset=[]; fit_weights=[]
            for t, time in enumerate(times):

                # apply flags
                idx       = ( (weights[coord1,:,t] != 0.) & (weights[coord2,:,t] != 0.) & (coord['freq'] > minFreq) )
                freq      = np.copy(coord['freq'])[idx]
                phase1    = vals[coord1,:,t][idx]
                phase2    = vals[coord2,:,t][idx]

                if len(freq) < 30:
                    fit_weights.append(0.)
                    fit_delays.append(0.)
                    fit_offset.append(0.)
                    logging.debug('Not enough unflagged point for the timeslot '+str(t))
                    continue
    
                # if more than 1/2 of chans are flagged
                if (len(idx) - len(freq))/float(len(idx)) > 1/2.:
                    logging.debug('High number of filtered out data points for the timeslot %i: %i/%i' % (t, len(idx) - len(freq), len(idx)) )
    
                phase_diff = phase1 - phase2
                phase_diff = np.mod(phase_diff + np.pi, 2.*np.pi) - np.pi
                phase_diff = np.unwrap(phase_diff)

                A = np.vstack([freq, np.ones(len(freq))]).T
                fitresultdelay = np.linalg.lstsq(A, phase_diff.T)[0]
                # get the closest n*(2pi) to the intercept and refit with only 1 parameter
                if not fitOffset:
                    numjumps = np.around(fitresultdelay[1]/(2*np.pi))
                    A = np.reshape(freq, (-1,1)) # no b
                    phase_diff -= numjumps * 2 * np.pi
                    fitresultdelay = np.linalg.lstsq(A, phase_diff.T)[0]
                    fitresultdelay = [fitresultdelay[0],0.] # set offset to 0 to keep the rest of the script equal

                # fractional residual
                residual = np.mean(np.abs( fitresultdelay[0]*freq + fitresultdelay[1] - phase_diff ))

                fit_delays.append(fitresultdelay[0])
                fit_offset.append(fitresultdelay[1])
                if maxResidual == 0 or residual < maxResidual:
                    fit_weights.append(1.)
                else:       
                    # high residual, flag
                    logging.debug('Bad solution for ant: '+coord['ant']+' (time: '+str(t)+', residual: '+str(residual)+') -> ignoring.')
                    fit_weights.append(0.)

                # Debug plot
                doplot = False
                #if doplot and t%100==0 and (coord['ant'] == 'RS310LBA' or coord['ant'] == 'CS301LBA'):
                if doplot and t%10==0 and (coord['ant'] == 'W04'):
                    if not 'matplotlib' in sys.modules:
                        import matplotlib as mpl
                        mpl.rc('figure.subplot',left=0.05, bottom=0.05, right=0.95, top=0.95,wspace=0.22, hspace=0.22 )
                        mpl.use("Agg")
                    import matplotlib.pyplot as plt

                    fig = plt.figure()
                    fig.subplots_adjust(wspace=0)
                    ax = fig.add_subplot(111)

                    # plot rm fit
                    plotdelay = lambda delay, offset, freq: np.mod( delay*freq + offset + np.pi, 2.*np.pi) - np.pi
                    ax.plot(freq, fitresultdelay[0]*freq + fitresultdelay[1], "-", color='purple', label=r'delay:%f$\nu$ (ns) + %f ' % (fitresultdelay[0]*1e9,fitresultdelay[1]) )

                    ax.plot(freq, np.mod(phase1 + np.pi, 2.*np.pi) - np.pi, 'ob' )
                    ax.plot(freq, np.mod(phase2 + np.pi, 2.*np.pi) - np.pi, 'og' )
                    #ax.plot(freq, np.mod(phase_diff + np.pi, 2.*np.pi) - np.pi, '.', color='purple' )                           
                    ax.plot(freq, phase_diff, '.', color='purple' )                           
 
                    residual = np.mod(plotdelay(fitresultdelay[0], fitresultdelay[1], freq)-phase_diff + np.pi,2.*np.pi)-np.pi
                    ax.plot(freq, residual, '.', color='yellow')
    
                    ax.set_xlabel('freq')
                    ax.set_ylabel('phase')
                    #ax.set_ylim(ymin=-np.pi, ymax=np.pi)

                    logging.warning('Save pic: '+str(t)+'_'+coord['ant']+'.png')
                    fig.legend(loc='upper left')
                    plt.savefig(coord['ant']+'_'+str(t)+'.png', bbox_inches='tight')
                    del fig
            # end cycle in time

            fit_weights = np.array(fit_weights)
            fit_delays = np.array(fit_delays)
            fit_offset = np.array(fit_offset)

            # avg in time
            if average:
                fit_delays_bkp = fit_delays[ fit_weights == 0 ]
                fit_offset_bkp = fit_offset[ fit_weights == 0 ]
                np.putmask(fit_delays, fit_weights == 0, np.nan)
                np.putmask(fit_offset, fit_weights == 0, np.nan)
                fit_delays[:] = np.nanmean(fit_delays)
                # angle mean
                fit_offset[:] = np.angle( np.nansum( np.exp(1j*fit_offset) ) / np.count_nonzero(~np.isnan(fit_offset)) )

                if replace:
                    fit_weights[ fit_weights == 0 ] = 1.
                    fit_weights[ np.isnan(fit_delays) ] = 0. # all the size was flagged cannot estrapolate value
                else:
                    fit_delays[ fit_weights == 0 ] = fit_delays_bkp
                    fit_offset[ fit_weights == 0 ] = fit_offset_bkp

            logging.info('%s: average delay: %f ns (offset: %f)' % ( coord['ant'], np.mean(fit_delays)*1e9, np.mean(fit_offset)))
            for t, time in enumerate(times):
                #vals[:,:,t] = 0.
                #vals[coord1,:,t] = fit_delays[t]*np.array(coord['freq'])/2.
                vals[coord1,:,t] = 0
                phase = np.mod(fit_delays[t]*coord['freq'] + fit_offset[t] + np.pi, 2.*np.pi) - np.pi
                vals[coord2,:,t] = -1.*phase#/2.
                #weights[:,:,t] = 0.
                weights[coord1,:,t] = fit_weights[t]
                weights[coord2,:,t] = fit_weights[t]

        # reorder axes back to the original order, needed for setValues
        vals = reorderAxes( vals, ['pol','freq','time'], [ax for ax in soltab.getAxesNames() if ax in ['pol','freq','time']] )
        weights = reorderAxes( weights, ['pol','freq','time'], [ax for ax in soltab.getAxesNames() if ax in ['pol','freq','time']] )
        soltabout.setSelection(**coord)
        soltabout.setValues( vals )
        soltabout.setValues( weights, weight=True )

    return 0
Ejemplo n.º 11
0
def run(soltab,
        outsoltab,
        axisToRegrid,
        newdelta,
        delta='',
        maxFlaggedWidth=0,
        log=False):
    """
    This operation for LoSoTo implements regridding and linear interpolation of data for an axis.
    WEIGHT: compliant

    Parameters
    ----------
    outsoltab: str
        Name of output soltab

    axisToRegrid : str
        Name of the axis for which regridding/interpolation will be done

    newdelta : float or str
        Fundamental width between samples after regridding. E.g., "100kHz" or "10s"

    delta : float or str, optional
        Fundamental width between samples in axisToRegrid. E.g., "100kHz" or "10s". If "",
        it is calculated from the axisToRegrid values

    maxFlaggedWidth : int, optional
        Maximum allowable width in number of samples (after regridding) above which
        interpolated values are flagged (e.g., maxFlaggedWidth = 5 would allow gaps of
        5 samples or less to be interpolated across but gaps of 6 or more would be
        flagged)

    log : bool, optional
        Interpolation is done in log10 space, by default False
    """
    import scipy.ndimage as nd

    # Check inputs
    if axisToRegrid not in soltab.getAxesNames():
        logging.error('Axis \"' + axisToRegrid + '\" not found.')
        return 1
    if axisToRegrid not in ['freq', 'time']:
        logging.error('Axis \"' + axisToRegrid +
                      '\" must be either time or freq.')
        return 1
    newdelta = _convert_strval(newdelta)
    if delta == "":
        deltas = soltab.getAxisValues(axisToRegrid)[1:] - soltab.getAxisValues(
            axisToRegrid)[:-1]
        delta = np.min(deltas)
        logging.info('Using {} for delta'.format(delta))
    else:
        delta = _convert_strval(delta)
    if soltab.getType() == 'amplitude' and not log:
        logging.warning(
            'Amplitude solution tab detected and log=False. Amplitude solution tables should be treated in log space.'
        )

    # Regrid axis
    axisind = soltab.getAxesNames().index(axisToRegrid)
    orig_axisvals = soltab.getAxisValues(axisToRegrid)
    new_axisvals = _regrid_axis(orig_axisvals, delta, newdelta)
    orig_shape = soltab.val.shape
    new_shape = list(orig_shape)
    new_shape[axisind] = len(new_axisvals)
    new_vals = np.zeros(new_shape, dtype='float')
    new_weights = np.zeros(new_shape, dtype='float')

    for vals, weights, coord, selection in soltab.getValuesIter(
            returnAxes=[axisToRegrid], weight=True):
        flagged = np.logical_or(np.equal(weights, 0.0), np.isnan(vals))
        weights[flagged] = 0.0
        unflagged = np.not_equal(weights, 0.0)
        if np.sum(unflagged) > 2:
            # If there are at least two unflagged points, interpolate with mask
            if log:
                vals = np.log10(vals)
            new_vals[selection] = np.interp(new_axisvals,
                                            orig_axisvals[unflagged],
                                            vals[unflagged],
                                            left=np.nan,
                                            right=np.nan)

            # For the weights, interpolate without the mask
            new_weights[selection] = np.round(
                np.interp(new_axisvals,
                          orig_axisvals,
                          weights,
                          left=np.nan,
                          right=np.nan))

        # Check for flagged gaps
        if maxFlaggedWidth > 1:
            inv_weights = new_weights[selection].astype(bool).squeeze()
            rank = len(inv_weights.shape)
            connectivity = nd.generate_binary_structure(rank, rank)
            mask_labels, count = nd.label(~inv_weights, connectivity)
            for i in range(count):
                ind = np.where(mask_labels == i + 1)
                gapsize = len(ind[0])
                if gapsize <= maxFlaggedWidth:
                    # Unflag narrow gaps
                    selection[axisind] = ind[0]
                    new_weights[selection] = 1.0

    # Write new soltab
    solset = soltab.getSolset()
    axesVals = []
    for axisName in soltab.getAxesNames():
        if axisName == axisToRegrid:
            axesVals.append(new_axisvals)
        else:
            axesVals.append(soltab.getAxisValues(axisName))
    if log:
        new_vals = 10**new_vals
    s = solset.makeSoltab(soltab.getType(),
                          outsoltab,
                          axesNames=soltab.getAxesNames(),
                          axesVals=axesVals,
                          vals=new_vals,
                          weights=new_weights)
    s.addHistory('CREATE by INTERPOLATE operation from ' + soltab.name + '.')

    return 0
Ejemplo n.º 12
0
def run(soltab,
        outSoltab='tecscreen',
        height=200.0e3,
        order=12,
        beta=5.0 / 3.0,
        ncpu=0):
    """
    Fits a screen to TEC + scalaraphase values.

    The results of the fit are stored in the soltab parent solset in
    "outSoltab" and the residual phases (actual-screen) are stored in
    "outsoltabresid". These values are the screen phase values per station per
    pierce point per solution interval. The pierce point locations are stored in
    an auxiliary array in the output soltabs.

    Screens can be plotted with the PLOTSCREEN operation.

    Parameters
    ----------
    soltab: solution table
        Soltab containing phase solutions
    outSoltab: str, optional
        Name of output soltab
    height : float, optional
        Height in m of screen
    order : int, optional
        Order of screen (i.e., number of KL base vectors to keep).
    beta: float, optional
        Power-law index for phase structure function (5/3 => pure Kolmogorov
        turbulence)
    ncpu: int, optional
        Number of CPUs to use. If 0, all are used
    niter: int, optional
        Number of iterations to do when determining weights
    nsigma: float, optional
        Number of sigma above which directions are flagged

    """
    import numpy as np
    from numpy import newaxis
    import re
    import os

    # Get screen type
    screen_type = soltab.getType()
    if screen_type not in ['phase', 'tec']:
        logging.error(
            'Screens can only be fit to soltabs of type "phase" or "tec".')
        return 1
    logging.info('Using solution table {0} to calculate {1} screens'.format(
        soltab.name, screen_type))

    # Load phases
    axis_names = soltab.getAxesNames()
    r = np.array(soltab.val)
    weights = soltab.weight[:]
    if 'freq' in soltab.getAxesNames():
        freqs = soltab.freq[:]
        if len(freqs) > 1:
            logging.error('Screens can only be fit at a single frequency')
            return 1
        freq = freqs[0]

        # remove degenerate freq axis
        freq_ind = soltab.getAxesNames().index('freq')
        r = np.squeeze(r, axis=freq_ind)
        weights = np.squeeze(weights, axis=freq_ind)
        axis_names.pop(freq_ind)

    # fix for missing dir axis
    if not 'dir' in soltab.getAxesNames():
        r = np.array([r])
        weights = np.array([weights])
        dir_ind = len(axis_names)
        source_names = ['POINTING']
    else:
        dir_ind = axis_names.index('dir')
        source_names = soltab.dir[:]

    time_ind = axis_names.index('time')
    ant_ind = axis_names.index('ant')
    r = r.transpose([dir_ind, ant_ind, time_ind])
    weights = weights.transpose([dir_ind, ant_ind, time_ind])
    times = np.array(soltab.time)

    # Collect station and source names and positions and times, making sure
    # that they are ordered correctly.
    solset = soltab.getSolset()
    source_dict = solset.getSou()
    source_positions = []
    for source in source_names:
        source_positions.append(source_dict[source])
    station_names = soltab.ant[:]
    station_dict = solset.getAnt()
    station_positions = []
    for station in station_names:
        station_positions.append(station_dict[station])
    N_sources = len(source_names)
    N_times = len(times)
    N_stations = len(station_names)

    logging.info('Using height = {0} m and order = {1}'.format(height, order))
    if height < 100e3:
        logging.warning("Height is less than 100e3 m. This is likely too low.")

    # Initialize various arrays
    N_piercepoints = N_sources * N_stations
    if screen_type == 'phase':
        real_screen = np.zeros((N_sources, N_stations, N_times))
        real_residual = np.zeros((N_sources, N_stations, N_times))
        imag_screen = np.zeros((N_sources, N_stations, N_times))
        imag_residual = np.zeros((N_sources, N_stations, N_times))
    screen = np.zeros((N_sources, N_stations, N_times))
    residual = np.zeros((N_sources, N_stations, N_times))
    val_amp = 1.0
    r_0 = 100.0  # shouldn't matter what we choose
    rr = np.reshape(r, [N_piercepoints, N_times])

    # Find pierce points and airmass values for given screen height
    pp, airmass, midRA, midDec = _calculate_piercepoints(
        np.array(station_positions), np.array(source_positions),
        np.array(times), height)

    # Fit the screens
    station_weights = np.reshape(weights, [N_piercepoints, N_times])
    if screen_type == 'phase':
        mpm = multiprocManager(ncpu, _fit_phase_screen)
        for tindx, t in enumerate(times):
            w = np.diag(station_weights[:, tindx])[:, :, newaxis]
            mpm.put([
                station_names, source_names, pp[tindx, newaxis, :, :],
                airmass[tindx, newaxis, :], rr[:, tindx, newaxis], w, [t],
                height, order, r_0, beta
            ])
        mpm.wait()
        for (real_scr, real_res, imag_scr, imag_res, phase_scr, phase_res,
             t) in mpm.get():
            i = times.tolist().index(t[0])
            real_screen[:, :, i] = real_scr[0, :, :]
            real_residual[:, :, i] = real_res[0, :, :]
            imag_screen[:, :, i] = imag_scr[0, :, :]
            imag_residual[:, :, i] = imag_res[0, :, :]
            screen[:, :, i] = phase_scr[0, :, :]
            residual[:, :, i] = phase_res[0, :, :]
    elif screen_type == 'tec':
        mpm = multiprocManager(ncpu, _fit_tec_screen)
        for tindx, t in enumerate(times):
            w = np.diag(station_weights[:, tindx])[:, :, newaxis]
            mpm.put([
                station_names, source_names, pp[tindx, newaxis, :, :],
                airmass[tindx, newaxis, :], rr[:, tindx, newaxis], w, [t],
                height, order, r_0, beta
            ])
        mpm.wait()
        for (scr, res, t) in mpm.get():
            i = times.tolist().index(t[0])
            screen[:, :, i] = scr[0, :, :]
            residual[:, :, i] = res[0, :, :]
    weights = np.reshape(station_weights, (N_sources, N_stations, N_times))

    # Write the results to the output solset
    dirs_out = source_names
    times_out = times
    ants_out = station_names

    # Store screen values
    weights = weights.transpose([0, 2, 1])  # order is now [dir, time, ant]
    vals = screen.transpose([0, 2, 1])
    screen_st = solset.makeSoltab('{}screen'.format(screen_type),
                                  outSoltab,
                                  axesNames=['dir', 'time', 'ant'],
                                  axesVals=[dirs_out, times_out, ants_out],
                                  vals=vals,
                                  weights=weights)
    vals = residual.transpose([0, 2, 1])
    resscreen_st = solset.makeSoltab('{}screenresid'.format(screen_type),
                                     outSoltab + 'resid',
                                     axesNames=['dir', 'time', 'ant'],
                                     axesVals=[dirs_out, times_out, ants_out],
                                     vals=vals,
                                     weights=weights)

    # Store beta, r_0, height, and order as attributes of the screen soltabs
    screen_st.obj._v_attrs['beta'] = beta
    screen_st.obj._v_attrs['r_0'] = r_0
    screen_st.obj._v_attrs['height'] = height
    screen_st.obj._v_attrs['order'] = order
    if 'freq' in soltab.getAxesNames():
        screen_st.obj._v_attrs['freq'] = freq

    # Store piercepoint table. Note that it does not conform to the axis
    # shapes, so we cannot use makeSoltab()
    solset.obj._v_file.create_array('/' + solset.name + '/' +
                                    screen_st.obj._v_name,
                                    'piercepoint',
                                    obj=pp)

    screen_st.addHistory('CREATE (by DIRECTIONSCREEN operation)')
    resscreen_st.addHistory('CREATE (by DIRECTIONSCREEN operation)')

    return 0
Ejemplo n.º 13
0
def create_h5parm(instrumentdbFiles,
                  antennaFile,
                  fieldFile,
                  skydbFile,
                  h5parmFile,
                  complevel,
                  solsetName,
                  globaldbFile=None,
                  verbose=False):
    """
    Create the h5parm file.
    Input:
       instrumentdbFiles - list of the finenames of the solutions.
       antennaFile - file name of the antenna table.
       fieldFile - file name of the field table.
       skydbFile - file name of the sky table.
       h5parmFile - file name of the h5parm file that will be created.
       complevel - level of compression. It is usually 5.
       solsetName - Name of the solution set. Usually "sol###".
       globaldbFile (optional) - Name of the globaldbFile. Used only for 
         logging purposes.
    """

    # open/create the h5parm file and the solution-set
    h5parm = h5parm_mod(h5parmFile, readonly=False, complevel=complevel)

    solset = h5parm.makeSolset(solsetName)

    # Create tables using the first instrumentdb
    # TODO: all the instrument tables should be checked
    #pdb = lofar.parmdb.parmdb(instrumentdbFiles[0])
    pdb = pt.table(instrumentdbFiles[0])
    #    solTypes = list(set(x[0] for x in  (x.split(":") for x in pdb.getNames())))
    names = pt.table(instrumentdbFiles[0] + "/NAMES").getcol("NAME")

    solTypes = list(set(x[0] for x in (x.split(":") for x in names)))
    logging.info('Found solution types: ' + ', '.join(solTypes))

    # rewrite solTypes in order to put together
    # Gain <-> DirectionalGain
    # CommonRotationAngle <-> RotationAngle
    # CommonScalarPhase <-> ScalarPhase
    # CommonScalarAmplitude <-> ScalarAmplitude
    # it also separate Real/Imag/Ampl/Phase into different solTypes
    #    if "Gain" in solTypes:
    #        solTypes.remove('Gain')
    #        solTypes.append('*Gain:*:Real')
    #        solTypes.append('*Gain:*:Imag')
    #        solTypes.append('*Gain:*:Ampl')
    #        solTypes.append('*Gain:*:Phase')
    #    if "DirectionalGain" in solTypes:
    #        solTypes.remove('DirectionalGain')
    #        solTypes.append('*Gain:*:Real')
    #        solTypes.append('*Gain:*:Imag')
    #        solTypes.append('*Gain:*:Ampl')
    #        solTypes.append('*Gain:*:Phase')
    #    if "RotationAngle" in solTypes:
    #        solTypes.remove('RotationAngle')
    #        solTypes.append('*RotationAngle')
    #    if "CommonRotationAngle" in solTypes:
    #        solTypes.remove('CommonRotationAngle')
    #        solTypes.append('*RotationAngle')
    #    if "RotationMeasure" in solTypes:
    #        solTypes.remove('RotationMeasure')
    #        solTypes.append('*RotationMeasure')
    #    if "ScalarPhase" in solTypes:
    #        solTypes.remove('ScalarPhase')
    #        solTypes.append('*ScalarPhase')
    #    if "CommonScalarPhase" in solTypes:
    #        solTypes.remove('CommonScalarPhase')
    #        solTypes.append('*ScalarPhase')
    #    if "CommonScalarAmplitude" in solTypes:
    #        solTypes.remove('CommonScalarAmplitude')
    #        solTypes.append('*ScalarAmplitude')
    if "Gain" in solTypes:
        solTypes.remove('Gain')
        solTypes.append(['Gain', 'Real'])
        solTypes.append(['Gain', 'Imag'])
        solTypes.append(['Gain', 'Ampl'])
        solTypes.append(['Gain', 'Phase'])
    if "DirectionalGain" in solTypes:
        solTypes.remove('DirectionalGain')
        solTypes.append(['Gain', 'Real'])
        solTypes.append(['Gain', 'Imag'])
        solTypes.append(['Gain', 'Ampl'])
        solTypes.append(['Gain', 'Phase'])
    if "RotationAngle" in solTypes:
        solTypes.remove('RotationAngle')
        solTypes.append(['RotationAngle'])
    if "CommonRotationAngle" in solTypes:
        solTypes.remove('CommonRotationAngle')
        solTypes.append(['RotationAngle'])
    if "RotationMeasure" in solTypes:
        solTypes.remove('RotationMeasure')
        solTypes.append(['RotationMeasure'])
    if "ScalarPhase" in solTypes:
        solTypes.remove('ScalarPhase')
        solTypes.append(['ScalarPhase'])
    if "CommonScalarPhase" in solTypes:
        solTypes.remove('CommonScalarPhase')
        solTypes.append(['ScalarPhase'])
    if "CommonScalarAmplitude" in solTypes:
        solTypes.remove('CommonScalarAmplitude')
        solTypes.append(['ScalarAmplitude'])
    # solTypes = list(set(solTypes))
    print(solTypes)

    # every soltype creates a different solution-table
    for solType in solTypes:

        # skip missing solTypes (not all parmdbs have e.g. TEC)
        #if len(pdb.getNames(solType+':*')) == 0: continue

        found = False
        for name in names:
            if all([True if soli in name else False for soli in solType]):
                found = True
                break
        if not found:
            continue

        pols = set()
        dirs = set()
        ants = set()
        freqs = set()
        times = set()
        ptype = set()

        logging.info('Reading ' + ':'.join(solType) + '.')

        for instrumentdbFile in sorted(instrumentdbFiles):

            #pdb = lofar.parmdb.parmdb(instrumentdbFile)

            # create the axes grid, necessary if not all entries have the same axes lenght
            #data = pdb.getValuesGrid(solType+':*')
            data = getValuesGrid(instrumentdbFile, solType)
            # check good instrument table
            if len(data) == 0:
                logging.error('Instrument table %s is empty, ignoring.' %
                              instrumentdbFile)

            for solEntry in data:

                pol, dir, ant, parm = parmdbToAxes(solEntry)
                if pol is not None: pols |= set([pol])
                if dir is not None: dirs |= set([dir])
                if ant is not None: ants |= set([ant])
                freqs |= set(data[solEntry]['freqs'])
                times |= set(data[solEntry]['times'])

        pols = np.sort(list(pols))
        dirs = np.sort(list(dirs))
        ants = np.sort(list(ants))
        freqs = np.sort(list(freqs))
        times = np.sort(list(times))
        shape = [
            i
            for i in (len(pols), len(dirs), len(ants), len(freqs), len(times))
            if i != 0
        ]
        vals = np.empty(shape)
        vals[:] = np.nan
        weights = np.zeros(shape, dtype=np.float16)

        logging.info('Filling table.')

        for instrumentdbFile in instrumentdbFiles:

            #pdb = lofar.parmdb.parmdb(instrumentdbFile)

            # fill the values
            #data = pdb.getValuesGrid(solType+':*')
            data = getValuesGrid(instrumentdbFile, solType)
            #if 'Real' in solType: dataIm = pdb.getValuesGrid(solType.replace('Real','Imag')+':*')
            #if 'Imag' in solType: dataRe = pdb.getValuesGrid(solType.replace('Imag','Real')+':*')
            if 'Real' in solType:
                dataIm = getValuesGrid(instrumentdbFile, [solType[0], 'Imag'])
            if 'Imag' in solType:
                dataRe = getValuesGrid(instrumentdbFile, [solType[0], 'Real'])
            for solEntry in data:

                pol, dir, ant, parm = parmdbToAxes(solEntry)
                ptype |= set([solEntry.split(':')[0]
                              ])  # original parmdb solution type

                freq = data[solEntry]['freqs']
                time = data[solEntry]['times']

                val = data[solEntry]['values']

                # convert Real and Imag in Amp and Phase respectively
                if parm == 'Real':
                    solEntryIm = solEntry.replace('Real', 'Imag')
                    valI = dataIm[solEntryIm]['values']
                    val = np.sqrt((val**2) + (valI**2))
                if parm == 'Imag':
                    solEntryRe = solEntry.replace('Imag', 'Real')
                    valR = dataRe[solEntryRe]['values']
                    val = np.arctan2(val, valR)

                coords = []
                if pol is not None:
                    polCoord = np.searchsorted(pols, pol)
                    coords.append(polCoord)
                if dir is not None:
                    dirCoord = np.searchsorted(dirs, dir)
                    coords.append(dirCoord)
                if ant is not None:
                    antCoord = np.searchsorted(ants, ant)
                    coords.append(antCoord)
                freqCoord = np.searchsorted(freqs, freq)
                timeCoord = np.searchsorted(times, time)
                vals[tuple(coords)][np.ix_(freqCoord, timeCoord)] = val.T[:, :,
                                                                          0]
                weights[tuple(coords)][np.ix_(freqCoord, timeCoord)] = 1

        np.putmask(vals, ~np.isfinite(vals), 0)  # put inf and nans to 0
        #vals = np.nan_to_num(vals) # replace nans with 0 (flagged later)

        if solType == '*RotationAngle':
            np.putmask(weights, vals == 0., 0)  # flag where val=0
            solset.makeSoltab('rotation', axesNames=['dir','ant','freq','time'], \
                    axesVals=[dirs,ants,freqs,times], vals=vals, weights=weights, parmdbType=', '.join(list(ptype)))
        if solType == '*RotationMeasure':
            np.putmask(weights, vals == 0., 0)  # flag where val=0
            solset.makeSoltab('rotationmeasure', axesNames=['dir','ant','freq','time'], \
                    axesVals=[dirs,ants,freqs,times], vals=vals, weights=weights, parmdbType=', '.join(list(ptype)))
        elif solType == '*ScalarPhase':
            np.putmask(weights, vals == 0., 0)
            solset.makeSoltab('scalarphase', axesNames=['dir','ant','freq','time'], \
                    axesVals=[dirs,ants,freqs,times], vals=vals, weights=weights, parmdbType=', '.join(list(ptype)))
        elif solType == '*ScalarAmplitude':
            np.putmask(weights, vals == 0., 0)
            solset.makeSoltab('scalaramplitude', axesNames=['dir','ant','freq','time'], \
                    axesVals=[dirs,ants,freqs,times], vals=vals, weights=weights, parmdbType=', '.join(list(ptype)))
        elif solType == 'Clock':
            np.putmask(weights, vals == 0., 0)
            # clock may be diag or scalar
            if len(pols) == 0:
                solset.makeSoltab('clock', axesNames=['ant','freq','time'], \
                    axesVals=[ants,freqs,times], vals=vals, weights=weights, parmdbType=', '.join(list(ptype)))
            else:
                solset.makeSoltab('clock', axesNames=['pol','ant','freq','time'], \
                    axesVals=[pol,ants,freqs,times], vals=vals, weights=weights, parmdbType=', '.join(list(ptype)))
        elif solType == 'TEC':
            np.putmask(weights, vals == 0., 0)
            # tec may be diag or scalar
            if len(pols) == 0:
                solset.makeSoltab('tec', axesNames=['dir','ant','freq','time'], \
                    axesVals=[dirs,ants,freqs,times], vals=vals, weights=weights, parmdbType=', '.join(list(ptype)))
            else:
                solset.makeSoltab('tec', axesNames=['pol','dir','ant','freq','time'], \
                    axesVals=[pols,dirs,ants,freqs,times], vals=vals, weights=weights, parmdbType=', '.join(list(ptype)))
        elif solType == '*Gain:*:Real' or solType == '*Gain:*:Ampl' or solType == [
                'Gain', 'Real'
        ] or solType == ['Gain', 'Ampl']:
            np.putmask(vals, vals == 0,
                       1)  # nans were put to 0 before, set them to 1
            np.putmask(weights, vals == 1., 0)  # flag where val=1
            solset.makeSoltab('amplitude', axesNames=['pol','dir','ant','freq','time'], \
                    axesVals=[pols,dirs,ants,freqs,times], vals=vals, weights=weights, parmdbType=', '.join(list(ptype)))
        elif solType == '*Gain:*:Imag' or solType == '*Gain:*:Phase' or solType == [
                'Gain', 'Imag'
        ] or solType == ['Gain', 'Phase']:
            np.putmask(weights, vals == 0., 0)  # falg where val=0
            solset.makeSoltab('phase', axesNames=['pol','dir','ant','freq','time'], \
                    axesVals=[pols,dirs,ants,freqs,times], vals=vals, weights=weights, parmdbType=', '.join(list(ptype)))

        print(solType)
        logging.info('Flagged data: %.3f%%' %
                     (100. * (len(weights.flat) - np.count_nonzero(weights)) /
                      len(weights.flat)))

    logging.info('Collecting information from the ANTENNA table.')
    antennaTable = pt.table(antennaFile, ack=False)
    antennaNames = antennaTable.getcol('NAME')
    antennaPositions = antennaTable.getcol('POSITION')
    antennaTable.close()
    antennaTable = solset.obj._f_get_child('antenna')
    antennaTable.append(list(zip(*(antennaNames, antennaPositions))))

    logging.info('Collecting information from the FIELD table.')
    fieldTable = pt.table(fieldFile, ack=False)
    phaseDir = fieldTable.getcol('PHASE_DIR')
    pointing = phaseDir[0, 0, :]
    fieldTable.close()

    sourceTable = solset.obj._f_get_child('source')
    # add the field centre, that is also the direction for Gain and Common*
    sourceTable.append([('pointing', pointing)])

    dirs = []
    for tab in solset.obj._v_children:
        c = solset.obj._f_get_child(tab)
        if c._v_name != 'antenna' and c._v_name != 'source':
            if c.__contains__('dir'):
                dirs.extend(list(set(c.dir)))
    # remove duplicates
    dirs = list(set(dirs))
    # remove any pointing (already in the table)
    if 'pointing' in dirs:
        dirs.remove('pointing')

    if not os.path.isdir(skydbFile) and dirs != []:
        logging.critical('Missing skydb table.')
        sys.exit(1)

    if dirs != []:
        logging.info('Collecting information from the sky table.')
        sourceFile = skydbFile + '/SOURCES'
        src_table = pt.table(sourceFile, ack=False)
        sub_tables = src_table.getsubtables()
        vals = []
        ra = dec = np.nan
        has_patches_subtable = False
        for sub_table in sub_tables:
            if 'PATCHES' in sub_table:
                has_patches_subtable = True
        if has_patches_subtable:
            # Read values from PATCHES subtable
            src_table.close()
            sourceFile = skydbFile + '/SOURCES/PATCHES'
            src_table = pt.table(sourceFile, ack=False)
            patch_names = src_table.getcol('PATCHNAME')
            patch_ras = src_table.getcol('RA')
            patch_decs = src_table.getcol('DEC')
            for source in dirs:
                try:
                    patch_indx = patch_names.index(source)
                    ra = patch_ras[patch_indx]
                    dec = patch_decs[patch_indx]
                except ValueError:
                    ra = np.nan
                    dec = np.nan
                    logging.error('Cannot find the source ' + source +
                                  '. I leave NaNs.')
                vals.append([ra, dec])
            src_table.close()
        else:
            # Try to read default values from parmdb instead
            skydb = lofar.parmdb.parmdb(skydbFile)
            vals = []
            ra = dec = np.nan

            for source in dirs:
                try:
                    ra = skydb.getDefValues('Ra:' + source)['Ra:' +
                                                            source][0][0]
                    dec = skydb.getDefValues('Dec:' + source)['Dec:' +
                                                              source][0][0]
                except KeyError:
                    # Source not found in skymodel parmdb, try to find components
                    logging.warning('Cannot find the source ' + source +
                                    '. Trying components.')
                    ra = np.array(
                        list(
                            skydb.getDefValues('Ra:*' + source +
                                               '*').values()))
                    dec = np.array(
                        list(
                            skydb.getDefValues('Dec:*' + source +
                                               '*').values()))
                    if len(ra) == 0 or len(dec) == 0:
                        ra = np.nan
                        dec = np.nan
                        logging.error('Cannot find the source ' + source +
                                      '. I leave NaNs.')
                    else:
                        ra = ra.mean()
                        dec = dec.mean()
                        logging.info('Found average direction for ' + source +
                                     ' at ra:' + str(ra) + ' - dec:' +
                                     str(dec))
                vals.append([ra, dec])
        sourceTable.append(list(zip(*(dirs, vals))))

    logging.info("Total file size: " +
                 str(int(h5parm.H.get_filesize() / 1024. / 1024.)) + " M.")

    # Add CREATE entry to history and print summary of tables if verbose
    soltabs = solset.getSoltabs()
    for st in soltabs:
        if globaldbFile is None:
            st.addHistory(
                'CREATE (by H5parm_importer.py from %s:%s/%s)' %
                (socket.gethostname(), os.path.abspath(''), "manual list"))
        else:
            st.addHistory(
                'CREATE (by H5parm_importer.py from %s:%s/%s)' %
                (socket.gethostname(), os.path.abspath(''), globaldbFile))
    if verbose:
        logging.info(str(h5parm))

    del h5parm
    logging.info('Done.')
def run(soltab, interp_dirs, soltabOut=None, prefix='interp_', ncpu=0):
    """
    Add interpolated directions to h5parm
    Parameters
    ----------
    interp_dirs : 2d array of floats
        Shape n x 2, contains ra/dec in degree.
        For example: [[ra1,dec1],[ra2,dec2],...]

    soltabOut : string,  optional
        Default: Guess from soltype. If specifically set to input soltab, the input soltab will be overwritten.

    prefix : string, optional, default = "interp_".
        Name prefix of interpolated directions.
    """
    import multiprocessing as mp
    # need scipy commit f0a478c4a4172c4d2225910216de6b62721db161 for multidimensional interp. otherwise slow...
    if scipy.__version__ < '1.4.0':
        raise ImportError(
            'SciPy version >= 1.4.0 is required to support multidimensional rbf interpolation.'
        )

    logging.info("Working on soltab: " + soltab.name)

    soltype = soltab.getType()
    solset = soltab.getSolset()
    soltabOut = None if soltabOut == '' else soltabOut

    # check input
    if soltype == 'phase':
        interp_kind = 'wrap'
    elif soltype == 'amplitude':
        interp_kind = 'amp'
    elif soltype in ['tec', 'rotationmeasure', 'tec3rd']:
        interp_kind = 'lin'
    else:
        logging.error('Soltab type {} not supported.'.format(soltype))
        return 1
    if not 'dir' in soltab.getAxesNames():
        logging.error(
            'Data without dir axis cannot be interpolated in directions...')
        return 1
    if not 'ant' in soltab.getAxesNames():
        logging.error('Data without ant axis not supported...')
        return 1
    if len(soltab.dir) < 10:
        logging.error(
            'Not enough directions. Use at least ten for interpolation.')
        return 1

    ax_ord = soltab.getAxesNames()  # original order
    dir_ax = ax_ord.index('dir')

    # convert to rad
    if interp_dirs.shape == (2, ):  # if only one direction, add dir axis
        interp_dirs = interp_dirs[np.newaxis]
    interp_dirs = np.deg2rad(interp_dirs)
    # prepare array for interpolated values - concatenate new directions
    val_shape_inpt = list(soltab.val.shape)
    val_shape_interp = val_shape_inpt.copy()
    val_shape_interp[0] = len(
        interp_dirs)  # match direction axis with no of interp dirs
    interp_vals = np.zeros(val_shape_interp)
    interp_weights = np.ones(val_shape_interp)

    # ra/dec of calibrator dirs. Make sure order is the same as in soltab.
    cal_dirs = np.array([soltab.getSolset().getSou()[k] for k in soltab.dir])

    # Collect arguments for parallelization on antennas
    args, selections = [], []
    returnAxes = ax_ord.copy()
    returnAxes.remove('ant')
    for i, (vals, weights, coord, selection) in enumerate(
            soltab.getValuesIter(returnAxes=returnAxes, weight=True)):
        dir_ax_sel = returnAxes.index('dir')
        args.append(
            [vals, weights, cal_dirs, dir_ax_sel, interp_dirs, interp_kind])
        selections.append(selection)

    #################### DEBUG PLOT ######################
    if False:
        # What to plot?
        antidx = 32
        sel = [100, 201]
        import matplotlib.pyplot as plt
        minra, maxra = np.min(cal_dirs[:, 0]) - 0.005, np.max(
            cal_dirs[:, 0]) + 0.005
        ra_range = np.linspace(minra, maxra, 500)
        mindec, maxdec = np.min(cal_dirs[:, 1]) - 0.005, np.max(
            cal_dirs[:, 1]) + 0.005
        dec_range = np.linspace(mindec, maxdec, 500)
        plotdirs = np.meshgrid(ra_range, dec_range)
        plotdirs = np.array((np.array(plotdirs[0]).flatten(),
                             np.array(plotdirs[1]).flatten())).T
        pvals, pweights = args[antidx][0:2]
        pvals, pweights = np.array(pvals)[:, sel], np.array(pweights)[:, sel]
        plotvals = np.array(
            interpolate_directions3d(pvals,
                                     pweights,
                                     cal_dirs,
                                     dir_ax,
                                     plotdirs,
                                     interp_kind,
                                     smooth=1.e-3))[0, :, 0, 0, 0]
        plotvals = plotvals.reshape((500, 500))
        fig = plt.figure(dpi=200)
        plt.xlabel('RA', labelpad=13, fontsize=7.5)
        plt.ylabel('Dec', labelpad=13, fontsize=7.5)
        cmap = plt.set_cmap('jet')
        if soltype == 'phase':
            label = 'phase [rad]'
            vmin, vmax = -np.pi, np.pi
        elif soltype == 'tec':
            label = 'dTEC [TECU]'
            vmin, vmax = None, None
        elif soltype == 'amplitude':
            label = 'amplitude'
            # vmin, vmax = 0, 1
            vmin, vmax = np.min(plotvals), np.max(plotvals)
            # vmin, vmax = np.min(pvals[:,0,0,0]), np.max(pvals[:,0,0,0])
        else:
            label = 'unknown'
            vmin, vmax = None, None
        im = plt.imshow(plotvals,
                        cmap=cmap,
                        extent=[minra, maxra, maxdec, mindec],
                        vmin=vmin,
                        vmax=vmax)
        cb = fig.colorbar(im)
        plt.scatter(*cal_dirs.T,
                    c=pvals[:, 0, 0, 0],
                    cmap=cmap,
                    edgecolors='k',
                    vmin=vmin,
                    vmax=vmax)
        plt.scatter(*cal_dirs.T,
                    c=pvals[:, 0, 0, 0],
                    cmap=cmap,
                    edgecolors='k',
                    vmin=vmin,
                    vmax=vmax)
        plt.scatter(*interp_dirs.T,
                    facecolor="None",
                    marker='D',
                    edgecolors='k')
        cb.set_label(label, fontsize=7)
        cb.ax.tick_params(labelsize=6.5)
        plt.savefig(
            'debug_screen_interp.png',
            dpi=200,
            bbox_inches='tight',
        )
        import sys
        sys.exit()
    # run the interpolation
    ncpu = mp.cpu_count() if ncpu == 0 else ncpu  # default use all cores
    with mp.Pool(ncpu) as pool:
        logging.info('Start interpolation.')
        results = pool.starmap(interpolate_directions3d, args)

    # reorder results
    for selection, result in zip(selections, results):
        vals, weights = result
        # fill output arrays - the purpose of the reshape is to get the degenerate dim for the ant
        interp_vals[tuple(selection)] = vals.reshape(
            interp_vals[tuple(selection)].shape)
        interp_weights[tuple(selection)] = weights.reshape(
            interp_vals[tuple(selection)].shape)

    # concatenate existing values and interpolated direction values
    vals = np.concatenate([soltab.val, interp_vals], axis=dir_ax)
    weights = np.concatenate([soltab.weight, interp_weights], axis=dir_ax)

    # set names for the interpolated directions
    interp_dir_names = np.arange(len(interp_dirs)).astype(str)
    interp_dir_names = [prefix + n.zfill(3) for n in interp_dir_names]
    # prepare axes values, append directions
    axes_vals = [soltab.getAxisValues(axisName) for axisName in ax_ord]
    axes_vals[dir_ax] = np.concatenate([axes_vals[dir_ax], interp_dir_names])

    # prepare output - check if soltabOut exists
    if soltabOut in solset.getSoltabNames():
        logging.warning('Soltab {} exists. Overwriting...'.format(soltabOut))
        solset.getSoltab(soltabOut).delete()

    # make soltabOut
    soltabout = solset.makeSoltab(soltype=soltype, soltabName=soltabOut, axesNames=ax_ord, \
                                  axesVals=axes_vals, vals=vals, weights=weights)

    newSources = dict(zip(interp_dir_names, interp_dirs))
    # append interpolated dirs to solset source table
    sourceTable = solset.obj._f_get_child('source')
    for row in sourceTable.iterrows():
        _name, _dir = row['name'].decode(), row['dir']
        if _name in newSources.keys():
            if not np.all(np.isclose(_dir, newSources[_name])):
                logging.debug('Overwrite direction: {}'.format(_name))
                row['dir'] = newSources[_name]
            else:
                logging.debug('Source already in soltab: {}'.format(_name))
            del newSources[_name]

    if len(newSources) > 0:
        logging.debug('Adding to source table: {}'.format(newSources.keys()))
        sourceTable.append(list(zip(newSources.keys(), newSources.values())))
        # sourceTable.append(list(zip(newSources.keys(), newSources.values())))

    # Add CREATE entry to history
    soltabout.addHistory(
        'Created by INTERPOLATEDIRECTIONS operation from %s.' % soltab.name)
    return 0
Ejemplo n.º 15
0
def run(soltab, soltabOut='tec000', refAnt=''):
    """
    Bruteforce TEC extraction from phase solutions.

    Parameters
    ----------
    soltabOut : str, optional
        output table name (same solset), by deault "tec".

    refAnt : str, optional
        Reference antenna, by default the first.

    """
    import numpy as np
    import scipy.optimize

    def mod(d):
        return np.mod(d + np.pi, 2. * np.pi) - np.pi

    def cost_f(d, freq, y):
        nfreq, ntime = y.shape
        phase = mod(2 * np.pi * d * freq).repeat(ntime).reshape(nfreq, ntime)
        dist = np.abs(mod(phase - y))
        ngood = np.sum(~np.isnan(dist))
        return np.nansum(dist / ngood)

    logging.info("Find global DELAY for soltab: " + soltab.name)

    # input check
    solType = soltab.getType()
    if solType != 'phase':
        logging.warning("Soltab type of " + soltab._v_name + " is of type " +
                        solType + ", should be phase. Ignoring.")
        return 1

    ants = soltab.getAxisValues('ant')
    if refAnt != '' and refAnt != 'closest' and not refAnt in soltab.getAxisValues(
            'ant', ignoreSelection=True):
        logging.warning('Reference antenna ' + refAnt + ' not found. Using: ' +
                        ants[1])
        refAnt = ants[0]
    if refAnt == '': refAnt = ants[0]

    # times and ants needs to be complete or selection is much slower
    times = soltab.getAxisValues('time')

    # create new table
    solset = soltab.getSolset()
    soltabout = solset.makeSoltab(soltype = 'clock', soltabName = soltabOut, axesNames=['ant','time'], \
                      axesVals=[soltab.getAxisValues(axisName) for axisName in ['ant','time']], \
                      vals=np.zeros(shape=(soltab.getAxisLen('ant'),soltab.getAxisLen('time'))), \
                      weights=np.ones(shape=(soltab.getAxisLen('ant'),soltab.getAxisLen('time'))) )
    soltabout.addHistory('Created by GLOBALDELAY operation from %s.' %
                         soltab.name)

    for vals, weights, coord, selection in soltab.getValuesIter(
            returnAxes=['freq', 'time'], weight=True, refAnt=refAnt):

        if len(coord['freq']) < 5:
            logging.error(
                'Delay estimation needs at least 5 frequency channels, preferably distributed over a wide range.'
            )
            return 1

        # reorder axes
        vals = reorderAxes(vals, soltab.getAxesNames(), ['freq', 'time'])
        weights = reorderAxes(weights, soltab.getAxesNames(), ['freq', 'time'])

        ranges = (-1e-7, 1e-7)
        Ns = 1001

        delay_fitresult = np.zeros(len(times))
        weights_fitresult = np.ones(len(times))

        if not coord['ant'] == refAnt:

            if (weights == 0.).all() == True:
                logging.warning('Skipping flagged antenna: ' + coord['ant'])
                weights_fitresult[:] = 0
            else:

                freq = np.copy(coord['freq'])

                # brute force
                fit = scipy.optimize.brute(cost_f,
                                           ranges=(ranges, ),
                                           Ns=Ns,
                                           args=(freq, vals))
                delay_fitresult[:] = fit

                logging.info('%s: average delay: %f ns' %
                             (coord['ant'], fit * 1e9))

        # reorder axes back to the original order, needed for setValues
        soltabout.setSelection(ant=coord['ant'])
        soltabout.setValues(delay_fitresult)
        soltabout.setValues(weights_fitresult, weight=True)

    return 0
Ejemplo n.º 16
0
def run(soltab, soltabOutTEC='tec000', soltabOutBP='phase000', refAnt=''):
    """
    Isolate BP from TEC in a calibrator (uGMRT data)

    Parameters
    ----------
    soltabOutTEC : str, optional
        output TEC table name (same solset), by deault "tec000".

    soltabOutBP : str, optional
        output bandpass table name (same solset), by deault "phase000".

    refAnt : str, optional
        Reference antenna, by default get the closest for each antenna.

    """
    import numpy as np

    logging.info("Find BANDPASS+TEC for soltab: " + soltab.name)

    # input check
    solType = soltab.getType()
    if solType != 'phase':
        logging.warning("Soltab type of " + soltab._v_name + " is of type " +
                        solType + ", should be phase. Ignoring.")
        return 1

    ants = soltab.getAxisValues('ant')
    if refAnt != '' and refAnt != 'closest' and not refAnt in soltab.getAxisValues(
            'ant', ignoreSelection=True):
        logging.error('Reference antenna ' + refAnt + ' not found. Using: ' +
                      ants[1])
        refAnt = ants[0]
    if refAnt == '': refAnt = ants[0]

    # create new table
    solset = soltab.getSolset()
    soltaboutTEC = solset.makeSoltab(soltype = 'tec', soltabName = soltabOut, axesNames=['ant','time'], \
                      axesVals=[soltab.getAxisValues(axisName) for axisName in ['ant','time']], \
                      vals=np.zeros(shape=(soltab.getAxisLen('ant'),soltab.getAxisLen('time'))), \
                      weights=np.ones(shape=(soltab.getAxisLen('ant'),soltab.getAxisLen('time'))) )
    soltaboutTEC.addHistory('Created by BANDPASSTEC operation from %s.' %
                            soltab.name)

    soltaboutBP = solset.makeSoltab(soltype = 'phase', soltabName = soltabOut, axesNames=['ant','freq','pol'], \
                      axesVals=[soltab.getAxisValues(axisName) for axisName in ['ant','freq', 'pol']], \
                      vals=np.zeros(shape=(soltab.getAxisLen('ant'),soltab.getAxisLen('freq'),soltab.getAxisLen('pol'))), \
                      weights=np.ones(shape=(soltab.getAxisLen('ant'),soltab.getAxisLen('freq'),soltab.getAxisLen('pol'))) )
    soltaboutBP.addHistory('Created by BANDPASSTEC operation from %s.' %
                           soltab.name)

    # get values
    vals, axesVals = soltab.getValues(retAxesVals=True)

    # separate bandpass/tec

    plot = False
    if plot:
        pass

    # write solutions back
    soltaboutTEC.setValues(fitd)
    soltaboutBP.setValues(fitd)
    soltaboutTEC.setValues(fitweights, weight=True)
    soltaboutBP.setValues(fitweights, weight=True)

    return 0
Ejemplo n.º 17
0
#!/usr/bin/env python
# -*- coding: utf-8 -*-

from losoto.lib_operations import *
from losoto._logging import logger as logging
from losoto.operations._faraday_timestep import _run_timestep

logging.debug('Loading FRjump module.')
logging.warning('FRjump module is still experimental - we strongly recommend to check the results carefully')


def _run_parser(soltab, parser, step):
    soltabOut = parser.getstr( step, 'soltabOut', 'rotationmeasure002' )
    soltabPhase = parser.getstr( step, 'soltabPhase', 'phase000')
    clipping = np.array(parser.getarray(step, 'clipping', [0,1e9]),dtype=float)
    frequencies = np.array(parser.getarray(step, 'frequencies', []),dtype=float)

    parser.checkSpelling( step, soltab, ['soltabOut','soltabPhase','clipping','frequencies'])
    return run(soltab, soltabOut,clipping, soltabPhase,frequencies)

def costfunctionRM(RM, wav, phase):
    return np.sum(abs(np.cos(2.*RM[0]*wav*wav) - np.cos(phase)) + abs(np.sin(2.*RM[0]*wav*wav) - np.sin(phase)))

def getPhaseWrapBase(wavels):
    """
    freqs: frequency grid of the data
    return the step size from a local minima (2pi phase wrap) to the others [0]: TEC, [1]: clock
    """
    wavels = np.array(wavels)
    nF = wavels.shape[0]
    A = np.zeros((nF, 1), dtype=np.float)
Ejemplo n.º 18
0
def _flag_bandpass(freqs, amps, weights, telescope, nSigma, ampRange,
                   maxFlaggedFraction, maxStddev, plot, ants, s, outQueue):
    """
    Flags bad amplitude solutions relative to median bandpass (in log space) by setting
    the corresponding weights to 0.0

    Note: A median over the time axis is done before flagging, so the flags are not time-
    dependent

    Parameters
    ----------
    freqs : array
        Array of frequencies

    amps : array
        Array of amplitudes as [time, ant, freq, pol]

    weights : array
        Array of weights as [time, ant, freq, pol]

    telescope : str, optional
        Specifies the telescope for the bandpass model

    nSigma : float
        Number of sigma for flagging. Amplitudes outside of nSigma*stddev are flagged

    maxFlaggedFraction : float
        Maximum allowable fraction of flagged frequencies. Stations with higher fractions
        will be completely flagged

    maxStddev : float
        Maximum allowable standard deviation

    plot : bool
        If True, the bandpass with flags and best-fit line is plotted for each station

    ants : list
        List of station names

    s : int
        Station index

    ampRange : array
        2-element array of the median amplitude level to be acceptable, ampRange[0]: lower limit, ampRange[1]: upper limit

    Returns
    -------
    indx, weights : int, array
        Station index, modified weights array
    """
    def _B(x, k, i, t, extrap, invert):
        if k == 0:
            if extrap:
                if invert:
                    return -1.0
                else:
                    return 1.0
            else:
                return 1.0 if t[i] <= x < t[i + 1] else 0.0
        if t[i + k] == t[i]:
            c1 = 0.0
        else:
            c1 = (x - t[i]) / (t[i + k] - t[i]) * _B(x, k - 1, i, t, extrap,
                                                     invert)
        if t[i + k + 1] == t[i + 1]:
            c2 = 0.0
        else:
            c2 = (t[i + k + 1] - x) / (t[i + k + 1] - t[i + 1]) * _B(
                x, k - 1, i + 1, t, extrap, invert)
        return c1 + c2

    def _bspline(x, t, c, k):
        n = len(t) - k - 1
        assert (n >= k + 1) and (len(c) >= n)
        invert = False
        extrap = [False] * n
        if x >= t[n]:
            extrap[-1] = True
        elif x < t[k]:
            extrap[0] = True
            invert = False
        return sum(c[i] * _B(x, k, i, t, e, invert)
                   for i, e in zip(list(range(n)), extrap))

    def _bandpass_LBA(freq, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12,
                      c13):
        """
        Defines the functional form of the LBA bandpass in terms of splines of degree 3

        The spline fit was done using LSQUnivariateSpline() on the median bandpass between
        30 MHz and 78 MHz. The knots were set by hand to acheive a good fit with a
        minimum number of parameters.

        Parameters
        ----------
        freq : array
            Array of frequencies

        c1-c13 : float
            Spline coefficients

        Returns
        -------
        bandpass : list
            List of bandpass values as function of frequency
        """
        knots = np.array([
            30003357.0, 30003357.0, 30003357.0, 30003357.0, 40000000.0,
            50000000.0, 55000000.0, 56000000.0, 60000000.0, 62000000.0,
            63000000.0, 64000000.0, 70000000.0, 77610779.0, 77610779.0,
            77610779.0, 77610779.0
        ])
        coeffs = np.array(
            [c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13])
        return [_bspline(f, knots, coeffs, 3) for f in freq]

    def _bandpass_HBA_low(freq, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10):
        """
        Defines the functional form of the HBA-low bandpass in terms of splines of degree
        3

        The spline fit was done using LSQUnivariateSpline() on the median bandpass between
        120 MHz and 188 MHz. The knots were set by hand to acheive a good fit with a
        minimum number of parameters.

        Parameters
        ----------
        freq : array
            Array of frequencies

        c1-c10 : float
            Spline coefficients

        Returns
        -------
        bandpass : list
            List of bandpass values as function of frequency
        """
        knots = np.array([
            1.15e+08, 1.15e+08, 1.15e+08, 1.15e+08, 1.30e+08, 1.38e+08,
            1.48e+08, 1.60e+08, 1.68e+08, 1.78e+08, 1.90e+08, 1.90e+08,
            1.9e+08, 1.9e+08
        ])
        coeffs = np.array([c1, c2, c3, c4, c5, c6, c7, c8, c9, c10])
        return [_bspline(f, knots, coeffs, 3) for f in freq]

    def _fit_bandpass(freq, logamp, sigma, band, do_fit=True):
        """
        Fits amplitudes with one of the bandpass functions

        The initial coefficients were determined from a LSQUnivariateSpline() fit on the
        median bandpass of the appropriate band. The allowable fitting ranges were set by
        hand through testing on a number of observations (to allow the bandpass function
        to adjust for the differences between stations but not to fit to RFI, etc.).

        Parameters
        ----------
        freq : array
            Array of frequencies

        amps : array
            Array of log10(amplitudes)

        sigma : array
            Array of sigma (1/weights**2)

        band : str
            Band name ('hba_low', etc.)

        do_fit : bool, optional
            If True, the fitting is done. If False, the unmodified model bandpass is
            returned

        Returns
        -------
        fit_parms, bandpass : list, list
            List of best-fit parameters, List of bandpass values as function of frequency
        """
        from scipy.optimize import curve_fit

        if band.lower() == 'hba_low':
            bandpass_function = _bandpass_HBA_low
            init_coeffs = np.array([
                -0.01460369, 0.05062699, 0.02827004, 0.03738518, -0.05729109,
                0.02303295, -0.03550487, -0.0803113, -0.2394929, -0.358301
            ])
            bounds_deltas_lower = [
                0.06, 0.05, 0.04, 0.04, 0.04, 0.04, 0.1, 0.1, 0.2, 0.5
            ]
            bounds_deltas_upper = [
                0.06, 0.1, 0.1, 0.1, 0.04, 0.04, 0.04, 0.04, 0.05, 0.06
            ]
        elif band.lower() == 'lba':
            bandpass_function = _bandpass_LBA
            init_coeffs = np.array([
                -0.22654016, -0.1950495, -0.07763014, 0.10002095, 0.32797671,
                0.46900048, 0.47155583, 0.31945897, 0.29072278, 0.08064795,
                -0.15761538, -0.36020451, -0.51163338
            ])
            bounds_deltas_lower = [
                0.25, 0.2, 0.05, 0.05, 0.05, 0.1, 0.1, 0.16, 0.2, 0.15, 0.15,
                0.25, 0.3
            ]
            bounds_deltas_upper = [
                0.4, 0.3, 0.15, 0.05, 0.05, 0.05, 0.08, 0.05, 0.08, 0.15, 0.15,
                0.25, 0.35
            ]
        else:
            logging.error('The "{}" band is not supported'.format(band))
            return None, None

        if do_fit:
            lower = [c - b for c, b in zip(init_coeffs, bounds_deltas_lower)]
            upper = [c + b for c, b in zip(init_coeffs, bounds_deltas_upper)]
            param_bounds = (lower, upper)
            try:
                popt, pcov = curve_fit(bandpass_function,
                                       freq,
                                       logamp,
                                       sigma=sigma,
                                       bounds=param_bounds,
                                       method='dogbox',
                                       ftol=1e-3,
                                       xtol=1e-3,
                                       gtol=1e-3)
                return popt, bandpass_function(freq, *tuple(popt))
            except RuntimeError:
                logging.error('Fitting failed.')
                return None, bandpass_function(freq, *tuple(init_coeffs))
        else:
            return None, bandpass_function(freq, *tuple(init_coeffs))

    # Check that telescope and band is supported. Skip flagging if not
    if telescope.lower() == 'lofar':
        # Determine which band we're in
        if np.median(freqs) < 180e6 and np.median(freqs) > 110e6:
            band = 'hba_low'
        elif np.median(freqs) < 90e6:
            band = 'lba'
        else:
            logging.warning(
                'The median frequency of {} Hz is outside of the currently supported LOFAR bands '
                '(LBA and HBA-low). Flagging will be skipped'.format(
                    np.median(freqs)))
            outQueue.put([s, weights])
            return
    else:
        logging.warning(
            "Only telescope = 'lofar' is currently supported for bandpass mode. "
            "Flagging will be skipped")
        outQueue.put([s, weights])
        return

    # Skip fully flagged stations
    if np.all(weights == 0.0):
        outQueue.put([s, weights])
        return

    # Build arrays for fitting
    flagged = np.where(np.logical_or(weights == 0.0, np.isnan(amps)))
    amps_flagged = amps.copy()
    amps_flagged[flagged] = np.nan
    sigma = weights.copy()
    sigma[flagged] = 1.0
    sigma = np.sqrt(1.0 / sigma)
    sigma[flagged] = 1e8

    # Set range of allowed values for the median
    if ampRange is None or ampRange == [0.0, 0.0]:
        # Use sensible values depending on correlator
        if np.nanmedian(amps_flagged) > 1.0:
            # new correlator
            ampRange = [50.0, 325.0]
        else:
            # old correlator
            ampRange = [0.0004, 0.0018]
    median_min = ampRange[0]
    median_max = ampRange[-1]

    # Iterate over polarizations
    npols = amps.shape[2]
    for pol in range(npols):
        # Skip fully flagged polarizations
        if np.all(weights[:, :, pol] == 0.0):
            continue

        # Take median over time and divide out the median offset
        with np.warnings.catch_warnings():
            # Filter NaN warnings -- we deal with NaNs below
            np.warnings.filterwarnings('ignore',
                                       r'All-NaN (slice|axis) encountered')
            amps_div = np.nanmedian(amps_flagged[:, :, pol], axis=0)
            median_val = np.nanmedian(amps_div)
        amps_div /= median_val
        sigma_div = np.median(sigma[:, :, pol], axis=0)
        sigma_orig = sigma_div.copy()
        unflagged = np.where(~np.isnan(amps_div))
        nsols_unflagged = len(unflagged[0])
        median_flagged = np.where(np.isnan(amps_div))
        amps_div[median_flagged] = 1.0
        sigma_div[median_flagged] = 1e8
        median_flagged = np.where(amps_div <= 0.0)
        amps_div[median_flagged] = 1.0
        sigma_div[median_flagged] = 1e8

        # Before doing the fitting, renormalize and flag any solutions that deviate from
        # the model bandpass by a large factor to avoid biasing the first fit
        _, bp_sp = _fit_bandpass(freqs,
                                 np.log10(amps_div),
                                 sigma_div,
                                 band,
                                 do_fit=False)
        normval = np.median(np.log10(amps_div) -
                            bp_sp)  # value to normalize model to data
        amps_div /= 10**normval
        bad = np.where(np.abs(np.array(bp_sp) - np.log10(amps_div)) > 0.2)
        sigma_div[bad] = 1e8
        if np.all(sigma_div > 1e7):
            logging.info('Flagged {0} (pol {1}) due to poor match to '
                         'baseline bandpass model'.format(ants[s], pol))
            weights[:, :, pol] = 0.0
            outQueue.put([s, weights])
            return

        # Iteratively fit and flag
        maxiter = 5
        niter = 0
        nflag = 0
        nflag_prev = -1
        while nflag != nflag_prev and niter < maxiter:
            p, bp_sp = _fit_bandpass(freqs, np.log10(amps_div), sigma_div,
                                     band)
            stdev_all = np.sqrt(
                np.average((bp_sp - np.log10(amps_div))**2,
                           weights=(1 / sigma_div)**2))
            stdev = min(maxStddev, stdev_all)
            bad = np.where(np.abs(bp_sp - np.log10(amps_div)) > nSigma * stdev)
            nflag = len(bad[0])
            if nflag == 0 or nflag == nsols_unflagged:
                break
            if niter > 0:
                nflag_prev = nflag
            sigma_div = sigma_orig.copy()  # reset flags to original ones
            sigma_div[bad] = 1e8
            niter += 1

        if plot:
            import matplotlib.pyplot as plt
            plt.plot(freqs, bp_sp, 'g-', lw=3)
            plt.plot(freqs, np.log10(amps_div), 'o', c='g')
            plt.plot(freqs[bad], np.log10(amps_div)[bad], 'o', c='r')
            plt.show()

        # Check whether entire station is bad (high stdev or high flagged fraction). If
        # so, flag all frequencies and polarizations
        if stdev_all > nSigma * maxStddev:
            # Station has high stddev relative to median bandpass
            logging.info('Flagged {0} (pol {1}) due to high stddev '
                         '({2})'.format(ants[s], pol, stdev_all))
            weights[:, :, pol] = 0.0
        elif float(len(bad[0])) / float(nsols_unflagged) > maxFlaggedFraction:
            # Station has high fraction of initially unflagged solutions that are now flagged
            logging.info('Flagged {0} (pol {1}) due to high flagged fraction '
                         '({2:.2f})'.format(
                             ants[s], pol,
                             float(len(bad[0])) / float(nsols_unflagged)))
            weights[:, :, pol] = 0.0
        else:
            flagged = np.where(sigma_div > 1e3)
            nflagged_orig = len(np.where(weights[:, :, pol] == 0.0)[0])
            weights[:, flagged[0], pol] = 0.0
            nflagged_new = len(np.where(weights[:, :, pol] == 0.0)[0])
            median_val = np.nanmedian(amps[np.where(weights[:, :, pol] > 0.0)])
            if median_val < median_min or median_val > median_max:
                # Station has extreme median value
                logging.info(
                    'Flagged {0} (pol {1}) due to extreme median value '
                    '({2})'.format(ants[s], pol, median_val))
                weights[:, :, pol] = 0.0
            else:
                # Station is OK, flag bad points only
                prcnt = float(nflagged_new - nflagged_orig) / float(
                    np.product(weights.shape[:-1])) * 100.0
                logging.info(
                    'Flagged {0:.1f}% of solutions for {1} (pol {2})'.format(
                        prcnt, ants[s], pol))

    outQueue.put([s, weights])
Ejemplo n.º 19
0
def run(soltab,
        axesToSmooth,
        size=[],
        mode='runningmedian',
        degree=1,
        replace=False,
        log=False,
        refAnt=''):
    """
    A smoothing function: running-median on an arbitrary number of axes, running polyfit and Savitzky-Golay on one axis, or set all solutions to the mean/median value.
    WEIGHT: flag ready.

    Parameters
    ----------
    axesToSmooth : array of str
        Axes used to compute the smoothing function.

    size : array of int, optional
        Window size for the runningmedian, savitzky-golay, and runningpoly (array of same size of axesToSmooth), by default [].

    mode : {'runningmedian','runningpoly','savitzky-golay','mean','median'}, optional
        Runningmedian or runningpoly or Savitzky-Golay or mean or median (these last two values set all the solutions to the mean/median), by default "runningmedian".

    degree : int, optional
        Degrees of the polynomia for the runningpoly or savitzky-golay modes, by default 1.

    replace : bool, optional
        Flagged data are replaced with smoothed value and unflagged, by default False.

    log : bool, optional
        clip is done in log10 space, by default False

    refAnt : str, optional
        Reference antenna for phases. By default None.
    """

    import numpy as np
    from scipy.ndimage import generic_filter

    if refAnt == '': refAnt = None
    elif not refAnt in soltab.getAxisValues('ant', ignoreSelection=True):
        logging.error('Reference antenna ' + refAnt + ' not found. Using: ' +
                      soltab.getAxisValues('ant')[1])
        refAnt = soltab.getAxisValues('ant')[1]

    if mode == "runningmedian" and len(axesToSmooth) != len(size):
        logging.error("Axes and Size lengths must be equal for runningmedian.")
        return 1

    if (mode == "runningpoly"
            or mode == "savitzky-golay") and (len(axesToSmooth) != 1
                                              or len(size) != 1):
        logging.error(
            "Axes and size lengths must be 1 for runningpoly or savitzky-golay."
        )
        return 1

    if (mode == "runningpoly"
            or mode == "savitzky-golay") and soltab.getType() == 'phase':
        logging.error(
            "Runningpoly and savitzky-golay modes cannot work on phases.")
        return 1

    for i, s in enumerate(size):
        if s % 2 == 0:
            logging.warning('Size should be odd, adding 1.')
            size[i] += 1

    logging.info("Smoothing soltab: " + soltab.name)

    for i, axis in enumerate(axesToSmooth[:]):
        if axis not in soltab.getAxesNames():
            del axesToSmooth[i]
            del size[i]
            logging.warning('Axis \"' + axis + '\" not found. Ignoring.')

    if soltab.getType() == 'amplitude' and not log:
        logging.warning(
            'Amplitude solution tab detected and log=False. Amplitude solution tables should be treated in log space.'
        )

    if mode == 'median' or mode == 'mean':
        vals = soltab.getValues(retAxesVals=False, reference=refAnt)
        if log: vals = np.log10(vals)
        weights = soltab.getValues(retAxesVals=False, weight=True)
        np.putmask(vals, weights == 0, np.nan)
        idx_axes = [
            soltab.getAxesNames().index(axisToSmooth)
            for axisToSmooth in axesToSmooth
        ]

        # handle phases by using a complex array
        if soltab.getType() == 'phase':
            vals = np.exp(1j * vals)

        if mode == 'median':
            vals[:] = np.nanmedian(vals, axis=idx_axes, keepdims=True)
        if mode == 'mean':
            logging.warning(
                'Mean does not support NaN yet, use median if it is a problem.'
            )
            vals[:] = np.mean(
                vals, axis=idx_axes, keepdims=True
            )  # annoying np.nanmean does not accept axis=list!

        # go back to phases
        if soltab.getType() == 'phase':
            vals = np.angle(vals)

        # write back
        if log: vals = 10**vals
        soltab.setValues(vals)
        if replace:
            weights[(weights == 0)] = 1
            weights[np.isnan(
                vals
            )] = 0  # all the slice was flagged, cannot estrapolate value
            soltab.setValues(weights, weight=True)

    else:
        for vals, weights, coord, selection in soltab.getValuesIter(
                returnAxes=axesToSmooth, weight=True, reference=refAnt):

            # skip completely flagged selections
            if (weights == 0).all(): continue
            if log: vals = np.log10(vals)

            if mode == 'runningmedian':
                vals_bkp = vals[weights == 0]

                # handle phases by using a complex array
                if soltab.getType() == 'phase':
                    vals = np.exp(1j * vals)

                    valsreal = np.real(vals)
                    valsimag = np.imag(vals)
                    np.putmask(valsreal, weights == 0, np.nan)
                    np.putmask(valsimag, weights == 0, np.nan)

                    # run generic_filter twice, once for real once for imaginary
                    valsrealnew = generic_filter(valsreal,
                                                 np.nanmedian,
                                                 size=size,
                                                 mode='constant',
                                                 cval=np.nan)
                    valsimagnew = generic_filter(valsimag,
                                                 np.nanmedian,
                                                 size=size,
                                                 mode='constant',
                                                 cval=np.nan)
                    valsnew = valsrealnew + 1j * valsimagnew  # go back to complex
                    valsnew = np.angle(valsnew)  # go back to phases

                else:  # other than phases
                    np.putmask(vals, weights == 0, np.nan)
                    valsnew = generic_filter(vals,
                                             np.nanmedian,
                                             size=size,
                                             mode='constant',
                                             cval=np.nan)

                if replace:
                    weights[weights == 0] = 1
                    weights[np.isnan(
                        valsnew
                    )] = 0  # all the size was flagged cannoth estrapolate value
                else:
                    valsnew[weights == 0] = vals_bkp

            elif mode == 'runningpoly':

                def polyfit(data):
                    if (np.isnan(data)).all():
                        return np.nan  # all size is flagged
                    x = np.arange(len(data))[~np.isnan(data)]
                    y = data[~np.isnan(data)]
                    p = np.polynomial.polynomial.polyfit(x, y, deg=degree)
                    #import matplotlib as mpl
                    #mpl.use("Agg")
                    #import matplotlib.pyplot as plt
                    #plt.plot(x, y, 'ro')
                    #plt.plot(x, np.polyval( p[::-1], x ), 'k-')
                    #plt.savefig('test.png')
                    #sys.exit()
                    return np.polyval(
                        p[::-1], (size[0] - 1) / 2
                    )  # polyval has opposite convention for polynomial order

                # flags and at edges pass 0 and then remove them
                vals_bkp = vals[weights == 0]
                np.putmask(vals, weights == 0, np.nan)
                valsnew = generic_filter(vals,
                                         polyfit,
                                         size=size[0],
                                         mode='constant',
                                         cval=np.nan)
                if replace:
                    weights[weights == 0] = 1
                    weights[np.isnan(
                        valsnew
                    )] = 0  # all the size was flagged cannot extrapolate value
                else:
                    valsnew[weights == 0] = vals_bkp
                #print coord['ant'], vals, valsnew

            elif mode == 'savitzky-golay':
                vals_bkp = vals[weights == 0]
                np.putmask(vals, weights == 0, np.nan)
                valsnew = _savitzky_golay(vals, size[0], degree)
                if replace:
                    weights[weights == 0] = 1
                    weights[np.isnan(
                        valsnew
                    )] = 0  # all the size was flagged cannot extrapolate value
                else:
                    valsnew[weights == 0] = vals_bkp

            else:
                logging.error(
                    'Mode must be: runningmedian, runningpoly, savitzky-golay, median or mean'
                )
                return 1

            if log: valsnew = 10**valsnew
            soltab.setValues(valsnew, selection)
            if replace: soltab.setValues(weights, selection, weight=True)

    soltab.flush()
    soltab.addHistory('SMOOTH (over %s with mode = %s)' % (axesToSmooth, mode))
    return 0
Ejemplo n.º 20
0
def run(soltab, doUnwrap=False, refAnt='', plotName='', ndiv=1):
    """
    Find the structure function from phase solutions of core stations.

    Parameters
    ----------
    doUnwrap : bool, optional

    refAnt : str, optional
        Reference antenna, by default the first.

    plotName : str, optional
        Plot file name, by default no plot.

    ndiv : int, optional
        

    """
    import numpy as np
    from losoto.lib_unwrap import unwrap, unwrap_2d

    logging.info("Find structure function for soltab: " + soltab.name)

    # input check
    solType = soltab.getType()
    if solType != 'phase':
        logging.warning("Soltab type of " + soltab._v_name + " is of type " +
                        solType + ", should be phase.")
        return 1

    ants = soltab.getAxisValues('ant')
    if refAnt != '' and refAnt != 'closest' and not refAnt in soltab.getAxisValues(
            'ant', ignoreSelection=True):
        logging.error('Reference antenna ' + refAnt + ' not found. Using: ' +
                      ants[1])
        refAnt = ants[1]
    if refAnt == '' and doUnwrap:
        logging.error('Unwrap requires reference antenna. Using: ' + ants[1])
        refAnt = ants[1]
    if refAnt == '': refAnt = None

    soltab.setSelection(ant='CS*', update=True)

    posAll = soltab.getSolset().getAnt()

    for vals, weights, coord, selection in soltab.getValuesIter(
            returnAxes=['freq', 'pol', 'ant', 'time'],
            weight=True,
            reference=refAnt):

        # reorder axes
        vals = reorderAxes(vals, soltab.getAxesNames(),
                           ['pol', 'ant', 'freq', 'time'])
        weights = reorderAxes(weights, soltab.getAxesNames(),
                              ['pol', 'ant', 'freq', 'time'])

        # order positions
        pos = np.array([list(posAll[ant]) for ant in coord['ant']])

        # avg pols
        vals = np.cos(vals) + 1.j * np.sin(vals)
        vals = np.nansum(vals, axis=0)
        vals = np.angle(vals)
        flags = np.array((weights[0] == 0) | (weights[1] == 0), dtype=bool)

        # unwrap
        if doUnwrap:
            # remove mean to facilitate unwrapping
            for a, ant in enumerate(coord['ant']):
                if not (flags[a, :, :] == True).all() and ant != refAnt:
                    logging.debug('Unwrapping: ' + ant)
                    mean = np.angle(np.nanmean(np.exp(1j * vals[a].flatten())))
                    vals[a] -= mean
                    vals[a] = np.mod(vals[a] + np.pi, 2 * np.pi) - np.pi
                    vals[a, :, :] = unwrap_2d(vals[a, :, :], flags[a, :, :],
                                              coord['freq'], coord['time'])

        logging.debug('Computing differential values...')
        t1 = np.ma.array(vals, mask=flags)  # mask flagged data
        dph = t1[np.newaxis] - t1[:, np.newaxis]  # ant x ant x freq x time
        D = pos[np.newaxis] - pos[:, np.newaxis]  # ant x ant x 3
        D2 = np.triu(
            np.sqrt(np.sum(D**2, axis=-1))
        )  # calc distance and keep only uppoer triangle larger than 0
        myselect = (D2 > 0)

        if not doUnwrap:
            logging.debug('Re-normalising...')
            dph = np.mod(dph + np.pi, 2 * np.pi) - np.pi
            avgdph = np.ma.average(
                dph,
                axis=2)  # avg in freq (can do because is between -pi and pi)
            #one extra step to remove most(all) phase wraps, phase wraps disturbe the averaging...
            dph = np.remainder(
                dph -
                np.ma.average(avgdph, axis=-1)[:, :, np.newaxis, np.newaxis] +
                np.pi, 2 * np.pi) + np.ma.average(
                    avgdph,
                    axis=-1)[:, :, np.newaxis,
                             np.newaxis] - np.pi  #center around the avg value

        logging.debug('Computing sructure function...')
        avgdph = np.ma.average(dph, axis=2)  # avg in freq to reduce noise

        variances = []
        pars = []
        avgdph = avgdph[
            ..., avgdph.shape[-1] %
            ndiv:]  # remove a few timeslots to make the array divisible by np.split
        for i, avgdphSplit in enumerate(np.split(avgdph, ndiv, axis=-1)):
            variance = np.ma.var(avgdphSplit, axis=-1) * (
                np.average(coord['freq']) /
                150.e6)**2  # get time variance and rescale to 150 MHz

            # linear regression
            #A = np.ones((2,D2[myselect].shape[0]),dtype=float)
            #A[1,:] = np.log10(D2[myselect][~variance.mask])
            #par = np.dot(np.linalg.inv(np.dot(A,A.T)),np.dot(A,np.log10(variance[myselect])))
            mask = variance[myselect].mask
            A = np.vstack([
                np.log10(D2[myselect][~mask]),
                np.ones(len(D2[myselect][~mask]))
            ])
            par = np.linalg.lstsq(A.T, np.log10(variance[myselect][~mask]))[0]
            S0 = 10**(-1 * par[1] / par[0])
            logging.info(r't%i: beta=%.2f - R_diff=%.2f km' %
                         (i, par[0], S0 / 1.e3))
            variances.append(variance)
            pars.append(par)

        if plotName != '':
            if plotName.split('.')[-1] != 'png': plotName += '.png'  # add png

            if not 'matplotlib' in sys.modules:
                import matplotlib as mpl
                mpl.use("Agg")
            import matplotlib.pyplot as plt

            fig = plt.figure()
            fig.subplots_adjust(wspace=0)
            ax = fig.add_subplot(111)
            ax1 = ax.twinx()

            for i, variance in enumerate(variances):
                if len(variances) > 1:
                    color = plt.cm.jet(
                        i / float(len(variances) - 1))  # from 0 to 1
                else:
                    color = 'black'
                ax.plot(D2[myselect] / 1.e3,
                        variance[myselect],
                        marker='o',
                        linestyle='',
                        color=color,
                        markeredgecolor='none',
                        label='T')

                # regression
                par = pars[i]
                x = D2[myselect]
                S0 = 10**(-1 * par[1] / par[0])
                if color == 'black':
                    color = 'red'  # in case of single color, use red line that is more visible
                ax1.plot(x.flatten() / 1.e3,
                         par[0] * np.log10(x.flatten()) + par[1],
                         linestyle='-',
                         color=color,
                         label=r'$\beta=%.2f$ - $R_{\rm diff}=%.2f$ km' %
                         (par[0], S0 / 1.e3))

            ax.set_xlabel('Distance (km)')
            ax.set_ylabel(r'Phase variance @150 MHz (rad$^2$)')
            ax.set_xscale('log')
            ax.set_yscale('log')

            ymin = np.min(variance[myselect])
            ymax = np.max(variance[myselect])
            ax.set_xlim(xmin=0.1, xmax=3)
            ax.set_ylim(ymin, ymax)
            ax1.set_ylim(np.log10(ymin), np.log10(ymax))
            ax1.legend(loc='lower right', frameon=False)
            ax1.set_yticks([])

            logging.warning('Save pic: %s' % plotName)
            plt.savefig(plotName, bbox_inches='tight')

    return 0
Ejemplo n.º 21
0
def run(soltab, soltabOut='rotationmeasure000', refAnt='', maxResidual=1.):
    """
    Faraday rotation extraction from either a rotation table or a circular phase (of which the operation get the polarisation difference).

    Parameters
    ----------
    
    soltabOut : str, optional
        output table name (same solset), by deault "rotationmeasure000".
        
    refAnt : str, optional
        Reference antenna, by default the first.

    maxResidual : float, optional
        Max average residual in radians before flagging datapoint, by default 1. If 0: no check.

    """
    import numpy as np
    import scipy.optimize

    rmwavcomplex = lambda RM, wav, y: abs(
        np.cos(2. * RM[0] * wav * wav) - np.cos(y)) + abs(
            np.sin(2. * RM[0] * wav * wav) - np.sin(y))
    c = 2.99792458e8

    logging.info("Find FR for soltab: " + soltab.name)

    # input check
    solType = soltab.getType()
    if solType == 'phase':
        returnAxes = ['pol', 'freq', 'time']
        if 'RR' in soltab.getAxisValues(
                'pol') and 'LL' in soltab.getAxisValues('pol'):
            coord_rr = np.where(soltab.getAxisValues('pol') == 'RR')[0][0]
            coord_ll = np.where(soltab.getAxisValues('pol') == 'LL')[0][0]
        elif 'XX' in soltab.getAxisValues(
                'pol') and 'YY' in soltab.getAxisValues('pol'):
            logging.warning(
                'Linear polarization detected, LoSoTo assumes XX->RR and YY->LL.'
            )
            coord_rr = np.where(soltab.getAxisValues('pol') == 'XX')[0][0]
            coord_ll = np.where(soltab.getAxisValues('pol') == 'YY')[0][0]
        else:
            logging.error(
                "Cannot proceed with Faraday estimation with polarizations: " +
                str(coord['pol']))
            return 1
    elif solType == 'rotation':
        returnAxes = ['freq', 'time']
    else:
        logging.warning("Soltab type of " + soltab._v_name + " is of type " +
                        solType + ", should be phase or rotation. Ignoring.")
        return 1

    ants = soltab.getAxisValues('ant')
    if refAnt != '' and refAnt != 'closest' and not refAnt in soltab.getAxisValues(
            'ant', ignoreSelection=True):
        logging.error('Reference antenna ' + refAnt + ' not found. Using: ' +
                      ants[1])
        refAnt = ants[0]
    if refAnt == '': refAnt = ants[0]

    # times and ants needs to be complete or selection is much slower
    times = soltab.getAxisValues('time')

    # create new table
    solset = soltab.getSolset()
    soltabout = solset.makeSoltab('rotationmeasure',
                                  soltabName=soltabOut,
                                  axesNames=['ant', 'time'],
                                  axesVals=[ants, times],
                                  vals=np.zeros((len(ants), len(times))),
                                  weights=np.ones((len(ants), len(times))))
    soltabout.addHistory('Created by FARADAY operation from %s.' % soltab.name)

    for vals, weights, coord, selection in soltab.getValuesIter(
            returnAxes=returnAxes, weight=True, reference=refAnt):

        if len(coord['freq']) < 10:
            logging.error(
                'Faraday rotation estimation needs at least 10 frequency channels, preferably distributed over a wide range.'
            )
            return 1

        # reorder axes
        vals = reorderAxes(vals, soltab.getAxesNames(), returnAxes)
        weights = reorderAxes(weights, soltab.getAxesNames(), returnAxes)
        weights[np.isnan(vals)] = 0.

        fitrm = np.zeros(len(times))
        fitweights = np.ones(len(times))  # all unflagged to start
        fitrmguess = 0.001  # good guess

        if not coord['ant'] == refAnt:
            logging.debug('Working on ant: ' + coord['ant'] + '...')

            if (weights == 0.).all() == True:
                logging.warning('Skipping flagged antenna: ' + coord['ant'])
                fitweights[:] = 0
            else:

                for t, time in enumerate(times):

                    if solType == 'phase':
                        idx = ((weights[coord_rr, :, t] != 0.) &
                               (weights[coord_ll, :, t] != 0.))
                        freq = np.copy(coord['freq'])[idx]
                        phase_rr = vals[coord_rr, :, t][idx]
                        phase_ll = vals[coord_ll, :, t][idx]
                        # RR-LL to be consistent with BBS/NDPPP
                        phase_diff = (
                            phase_rr - phase_ll
                        )  # not divide by 2 otherwise jump problem, then later fix this
                    else:  # rotation table
                        idx = ((weights[:, t] != 0.) & (weights[:, t] != 0.))
                        freq = np.copy(coord['freq'])[idx]
                        phase_diff = 2. * vals[:, t][
                            idx]  # a rotation is between -pi and +pi

                    if len(freq) < 20:
                        fitweights[t] = 0
                        logging.warning(
                            'No valid data found for Faraday fitting for antenna: '
                            + coord['ant'] + ' at timestamp ' + str(t))
                        continue

                    # if more than 1/4 of chans are flagged
                    if (len(idx) - len(freq)) / float(len(idx)) > 1 / 4.:
                        logging.debug(
                            'High number of filtered out data points for the timeslot %i: %i/%i'
                            % (t, len(idx) - len(freq), len(idx)))

                    wav = c / freq

                    fitresultrm_wav, success = scipy.optimize.leastsq(
                        rmwavcomplex, [fitrmguess], args=(wav, phase_diff))
                    # fractional residual
                    residual = np.nanmean(
                        np.abs(
                            np.mod((2. * fitresultrm_wav * wav * wav) -
                                   phase_diff + np.pi, 2. * np.pi) - np.pi))

                    #                    print "t:", t, "result:", fitresultrm_wav, "residual:", residual

                    if maxResidual == 0 or residual < maxResidual:
                        fitrmguess = fitresultrm_wav[0]
                        weight = 1
                    else:
                        # high residual, flag
                        logging.warning('Bad solution for ant: ' +
                                        coord['ant'] + ' (time: ' + str(t) +
                                        ', resdiaul: ' + str(residual) + ').')
                        weight = 0

                    fitrm[t] = fitresultrm_wav[0]
                    fitweights[t] = weight

                    # Debug plot
                    doplot = False
                    if doplot and coord['ant'] == 'RS310LBA' and t % 10 == 0:
                        print("Plotting")
                        if not 'matplotlib' in sys.modules:
                            import matplotlib as mpl
                            mpl.rc('font', size=8)
                            mpl.rc('figure.subplot',
                                   left=0.05,
                                   bottom=0.05,
                                   right=0.95,
                                   top=0.95,
                                   wspace=0.22,
                                   hspace=0.22)
                            mpl.use("Agg")
                        import matplotlib.pyplot as plt

                        fig = plt.figure()
                        fig.subplots_adjust(wspace=0)
                        ax = fig.add_subplot(111)

                        # plot rm fit
                        plotrm = lambda RM, wav: np.mod(
                            (2. * RM * wav * wav) + np.pi, 2. * np.pi
                        ) - np.pi  # notice the factor of 2
                        ax.plot(freq,
                                plotrm(fitresultrm_wav, c / freq[:]),
                                "-",
                                color='purple')

                        if solType == 'phase':
                            ax.plot(
                                freq,
                                np.mod(phase_rr + np.pi, 2. * np.pi) - np.pi,
                                'ob')
                            ax.plot(
                                freq,
                                np.mod(phase_ll + np.pi, 2. * np.pi) - np.pi,
                                'og')
                        ax.plot(freq,
                                np.mod(phase_diff + np.pi, 2. * np.pi) - np.pi,
                                '.',
                                color='purple')

                        residual = np.mod(
                            plotrm(fitresultrm_wav, c / freq[:]) - phase_diff +
                            np.pi, 2. * np.pi) - np.pi
                        ax.plot(freq, residual, '.', color='yellow')

                        ax.set_xlabel('freq')
                        ax.set_ylabel('phase')
                        ax.set_ylim(ymin=-np.pi, ymax=np.pi)

                        logging.warning('Save pic: ' + str(t) + '_' +
                                        coord['ant'] + '.png')
                        plt.savefig(str(t) + '_' + coord['ant'] + '.png',
                                    bbox_inches='tight')
                        del fig

        soltabout.setSelection(ant=coord['ant'], time=coord['time'])
        soltabout.setValues(np.expand_dims(fitrm, axis=1))
        soltabout.setValues(np.expand_dims(fitweights, axis=1), weight=True)

    return 0
Ejemplo n.º 22
0
def run(soltab, axesInPlot, axisInTable='', axisInCol='', axisDiff='', NColFig=0, figSize=[0,0], markerSize=2, minmax=[0,0], log='', \
               plotFlag=False, doUnwrap=False, refAnt='', soltabsToAdd='', makeAntPlot=False, makeMovie=False, prefix='', ncpu=0):
    """
    This operation for LoSoTo implements basic plotting
    WEIGHT: flag-only compliant, no need for weight

    Parameters
    ----------
    axesInPlot : array of str
        1- or 2-element array which says the coordinates to plot (2 for 3D plots).

    axisInTable : str, optional
        the axis to plot on a page - e.g. ant to get all antenna's on one file. By default ''.

    axisInCol : str, optional
        The axis to plot in different colours - e.g. pol to get correlations with different colors. By default ''.

    axisDiff : str, optional
        This must be a len=2 axis and the plot will have the differential value - e.g. 'pol' to plot XX-YY. By default ''.

    NColFig : int, optional
        Number of columns in a multi-table image. By default is automatically chosen.

    figSize : array of int, optional
        Size of the image [x,y], if one of the values is 0, then it is automatically chosen. By default automatic set.

    markerSize : int, optional
        Size of the markers in the 2D plot. By default 2.

    minmax : array of float, optional
        Min max value for the independent variable (0 means automatic). By default 0.

    log : bool, optional
        Use Log='XYZ' to set which axes to put in Log. By default ''.

    plotFlag : bool, optional
        Whether to plot also flags as red points in 2D plots. By default False.

    doUnwrap : bool, optional
        Unwrap phases. By default False.

    refAnt : str, optional
        Reference antenna for phases. By default None.

    soltabsToAdd : str, optional
        Tables to "add" (e.g. 'sol000/tec000'), it works only for tec and clock to be added to phases. By default None.

    makeAntPlot : bool, optional
        Make a plot containing antenna coordinates in x,y and in color the value to plot, axesInPlot must be [ant]. By default False.

    makeMovie : bool, optional
        Make a movie summing up all the produced plots, by default False.

    prefix : str, optional
        Prefix to add before the self-generated filename, by default None.

    ncpu : int, optional
        Number of cpus, by default all available.
    """
    import os, random
    import numpy as np
    from losoto.lib_unwrap import unwrap, unwrap_2d

    logging.info("Plotting soltab: " + soltab.name)

    # input check

    # str2list
    if axisInTable == '': axisInTable = []
    else: axisInTable = [axisInTable]
    if axisInCol == '': axisInCol = []
    else: axisInCol = [axisInCol]
    if axisDiff == '': axisDiff = []
    else: axisDiff = [axisDiff]

    if len(set(axisInTable + axesInPlot + axisInCol +
               axisDiff)) != len(axisInTable + axesInPlot + axisInCol +
                                 axisDiff):
        logging.error('Axis defined multiple times.')
        return 1

    # just because we use lists, check that they are 1-d
    if len(axisInTable) > 1 or len(axisInCol) > 1 or len(axisDiff) > 1:
        logging.error(
            'Too many TableAxis/ColAxis/DiffAxis, they must be at most one each.'
        )
        return 1

    for axis in axesInPlot + axisInCol + axisDiff:
        if axis not in soltab.getAxesNames():
            logging.error('Axis \"' + axis + '\" not found.')
            return 1

    if makeMovie:
        prefix = prefix + '__tmp__'

    if os.path.dirname(prefix) != '' and not os.path.exists(
            os.path.dirname(prefix)):
        logging.debug('Creating ' + os.path.dirname(prefix) + '.')
        os.makedirs(os.path.dirname(prefix))

    if refAnt == '': refAnt = None
    elif refAnt != 'closest' and not refAnt in soltab.getAxisValues(
            'ant', ignoreSelection=True):
        logging.error('Reference antenna ' + refAnt + ' not found. Using: ' +
                      soltab.getAxisValues('ant')[1])
        refAnt = soltab.getAxisValues('ant')[1]

    minZ, maxZ = minmax

    solset = soltab.getSolset()
    soltabsToAdd = [
        solset.getSoltab(soltabName) for soltabName in soltabsToAdd
    ]

    cmesh = False
    if len(axesInPlot) == 2:
        cmesh = True
        # not color possible in 3D
        axisInCol = []
    elif len(axesInPlot) != 1:
        logging.error('Axes must be a len 1 or 2 array.')
        return 1
    # end input check

    # all axes that are not iterated by anything else
    axesInFile = soltab.getAxesNames()
    for axis in axisInTable + axesInPlot + axisInCol + axisDiff:
        axesInFile.remove(axis)

    # set subplots scheme
    if axisInTable != []:
        Nplots = soltab.getAxisLen(axisInTable[0])
    else:
        Nplots = 1

    # prepare antennas coord in makeAntPlot case
    if makeAntPlot:
        if axesInPlot != ['ant']:
            logging.error(
                'If makeAntPlot is selected the "Axes" values must be "ant"')
            return 1
        antCoords = [[], []]
        for ant in soltab.getAxisValues(
                'ant'):  # select only user-selected antenna in proper order
            antCoords[0].append(+1 * soltab.getSolset().getAnt()[ant][1])
            antCoords[1].append(-1 * soltab.getSolset().getAnt()[ant][0])

    else:
        antCoords = []

    datatype = soltab.getType()

    # start processes for multi-thread
    mpm = multiprocManager(ncpu, _plot)

    # compute dataCube size
    shape = []
    if axisInTable != []: shape.append(soltab.getAxisLen(axisInTable[0]))
    else: shape.append(1)
    if axisInCol != []: shape.append(soltab.getAxisLen(axisInCol[0]))
    else: shape.append(1)
    if cmesh:
        shape.append(soltab.getAxisLen(axesInPlot[1]))
        shape.append(soltab.getAxisLen(axesInPlot[0]))
    else:
        shape.append(soltab.getAxisLen(axesInPlot[0]))

    # will contain the data to pass to each thread to make 1 image
    dataCube = np.ma.zeros(shape=shape, fill_value=np.nan)

    # cycle on files
    if makeMovie: pngs = []  # store png filenames
    for vals, coord, selection in soltab.getValuesIter(
            returnAxes=axisDiff + axisInTable + axisInCol + axesInPlot):

        # set filename
        filename = ''
        for axis in axesInFile:
            filename += axis + str(coord[axis]) + '_'
        filename = filename[:-1]  # remove last _
        if prefix + filename == '': filename = 'plot'

        # axis vals (they are always the same, regulat arrays)
        xvals = coord[axesInPlot[0]]
        # if plotting antenna - convert to number
        if axesInPlot[0] == 'ant':
            xvals = np.arange(len(xvals))

        # if plotting time - convert in h/min/s
        xlabelunit = ''
        if axesInPlot[0] == 'time':
            if xvals[-1] - xvals[0] > 3600:
                xvals = (xvals - xvals[0]) / 3600.  # hrs
                xlabelunit = ' [hr]'
            elif xvals[-1] - xvals[0] > 60:
                xvals = (xvals - xvals[0]) / 60.  # mins
                xlabelunit = ' [min]'
            else:
                xvals = (xvals - xvals[0])  # sec
                xlabelunit = ' [s]'
        # if plotting freq convert in MHz
        elif axesInPlot[0] == 'freq':
            xvals = xvals / 1.e6  # MHz
            xlabelunit = ' [MHz]'

        if cmesh:
            # axis vals (they are always the same, regular arrays)
            yvals = coord[axesInPlot[1]]
            # same as above but for y-axis
            if axesInPlot[1] == 'ant':
                yvals = np.arange(len(yvals))

            if len(xvals) <= 1 or len(yvals) <= 1:
                logging.error(
                    '3D plot must have more then one value per axes.')
                mpm.wait()
                return 1

            ylabelunit = ''
            if axesInPlot[1] == 'time':
                if yvals[-1] - yvals[0] > 3600:
                    yvals = (yvals - yvals[0]) / 3600.  # hrs
                    ylabelunit = ' [hr]'
                elif yvals[-1] - yvals[0] > 60:
                    yvals = (yvals - yvals[0]) / 60.  # mins
                    ylabelunit = ' [min]'
                else:
                    yvals = (yvals - yvals[0])  # sec
                    ylabelunit = ' [s]'
            elif axesInPlot[1] == 'freq':  # Mhz
                yvals = yvals / 1.e6
                ylabelunit = ' [MHz]'
        else:
            yvals = None
            if datatype == 'clock':
                datatype = 'Clock'
                ylabelunit = ' (s)'
            elif datatype == 'tec':
                datatype = 'dTEC'
                ylabelunit = ' (TECU)'
            elif datatype == 'rotationmeasure':
                datatype = 'dRM'
                ylabelunit = r' (rad m$^{-2}$)'
            elif datatype == 'tec3rd':
                datatype = r'dTEC$_3$'
                ylabelunit = r' (rad m$^{-3}$)'
            else:
                ylabelunit = ''

        # cycle on tables
        soltab1Selection = soltab.selection  # save global selection and subselect only axex to iterate
        soltab.selection = selection
        titles = []

        for Ntab, (vals, coord, selection) in enumerate(
                soltab.getValuesIter(returnAxes=axisDiff + axisInCol +
                                     axesInPlot)):

            # set tile
            titles.append('')
            for axis in coord:
                if axis in axesInFile + axesInPlot + axisInCol: continue
                titles[Ntab] += axis + ':' + str(coord[axis]) + ' '
            titles[Ntab] = titles[Ntab][:-1]  # remove last ' '

            # cycle on colors
            soltab2Selection = soltab.selection
            soltab.selection = selection
            for Ncol, (vals, weight, coord, selection) in enumerate(
                    soltab.getValuesIter(returnAxes=axisDiff + axesInPlot,
                                         weight=True,
                                         reference=refAnt)):

                # differential plot
                if axisDiff != []:
                    # find ordered list of axis
                    names = [
                        axis for axis in soltab.getAxesNames()
                        if axis in axisDiff + axesInPlot
                    ]
                    if axisDiff[0] not in names:
                        logging.error("Axis to differentiate (%s) not found." %
                                      axisDiff[0])
                        mpm.wait()
                        return 1
                    if len(coord[axisDiff[0]]) != 2:
                        logging.error(
                            "Axis to differentiate (%s) has too many values, only 2 is allowed."
                            % axisDiff[0])
                        mpm.wait()
                        return 1

                    # find position of interesting axis
                    diff_idx = names.index(axisDiff[0])
                    # roll to first place
                    vals = np.rollaxis(vals, diff_idx, 0)
                    vals = vals[0] - vals[1]
                    weight = np.rollaxis(weight, diff_idx, 0)
                    weight[0][weight[1] == 0] = 0
                    weight = weight[0]
                    del coord[axisDiff[0]]

                # add tables if required (e.g. phase/tec)
                for soltabToAdd in soltabsToAdd:
                    logging.warning('soltabsToAdd not implemented. Ignoring.')
#                    newCoord = {}
#                    for axisName in coord.keys():
#                        # prepare selected on present axes
#                        if axisName in soltabToAdd.getAxesNames():
#                            if type(coord[axisName]) is np.ndarray:
#                                newCoord[axisName] = coord[axisName]
#                            else:
#                                newCoord[axisName] = [coord[axisName]] # avoid being interpreted as regexp, faster
#
#                    soltabToAdd.setSelection(**newCoord)
#                    valsAdd = np.squeeze(soltabToAdd.getValues(retAxesVals=False, weight=False, reference=refAnt))
#
#                    # add missing axes
#                    print ('shape:', vals.shape)
#                    for axisName in coord.keys():
#                        if not axisName in soltabToAdd.getAxesNames():
#                            # find axis positions
#                            axisPos = soltab.getAxesNames().index(axisName)
#                            # create a new axes for the table to add and duplicate the values
#                            valsAdd = np.expand_dims(valsAdd, axisPos)
#                            print ('shape to add:', valsAdd.shape)
#
#                    if soltabToAdd.getType() == 'clock':
#                        valsAdd = 2. * np.pi * valsAdd * coord['freq']
#                    elif soltabToAdd.getType() == 'tec':
#                        valsAdd = -8.44797245e9 * valsAdd / coord['freq']
#                    else:
#                        logging.warning('Only Clock or TEC can be added to solutions. Ignoring: '+soltabToAdd.getType()+'.')
#                        continue
#
#                    if valsAdd.shape != vals.shape:
#                        logging.error('Cannot combine the table '+soltabToAdd.getType()+' with '+soltab.getType()+'. Wrong shape.')
#                        mpm.wait()
#                        return 1
#
#                    vals += valsAdd

# normalize
                if (soltab.getType() == 'phase'
                        or soltab.getType() == 'scalarphase'):
                    vals = normalize_phase(vals)
                if (soltab.getType() == 'rotation'):
                    vals = np.mod(vals + np.pi / 2., np.pi) - np.pi / 2.

                # is user requested axis in an order that is different from h5parm, we need to transpose
                if cmesh:
                    if soltab.getAxesNames().index(
                            axesInPlot[0]) < soltab.getAxesNames().index(
                                axesInPlot[1]):
                        vals = vals.T
                        weight = weight.T

                # unwrap if required
                if (soltab.getType() == 'phase'
                        or soltab.getType() == 'scalarphase') and doUnwrap:
                    if len(axesInPlot) == 1:
                        vals = unwrap(vals)
                    else:
                        flags = np.array((weight == 0), dtype=bool)
                        if not (flags == True).all():
                            vals = unwrap_2d(vals, flags, coord[axesInPlot[0]],
                                             coord[axesInPlot[1]])

                dataCube[Ntab, Ncol] = vals
                sel1 = np.where(weight == 0.)
                sel2 = np.where(np.isnan(vals))
                if cmesh:
                    dataCube[Ntab, Ncol, sel1[0], sel1[1]] = np.ma.masked
                    dataCube[Ntab, Ncol, sel2[0], sel2[1]] = np.ma.masked
                else:
                    dataCube[Ntab, Ncol, sel1[0]] = np.ma.masked
                    dataCube[Ntab, Ncol, sel2[0]] = np.ma.masked

            soltab.selection = soltab2Selection
            ### end cycle on colors

        # if dataCube too large (> 500 MB) do not go parallel
        if np.array(dataCube).nbytes > 1024 * 1024 * 500:
            logging.debug('Big plot, parallel not possible.')
            _plot(Nplots, NColFig, figSize, markerSize, cmesh, axesInPlot,
                  axisInTable, xvals, yvals, xlabelunit, ylabelunit, datatype,
                  prefix + filename, titles, log, dataCube, minZ, maxZ,
                  plotFlag, makeMovie, antCoords, None)
        else:
            mpm.put([
                Nplots, NColFig, figSize, markerSize, cmesh, axesInPlot,
                axisInTable, xvals, yvals, xlabelunit, ylabelunit, datatype,
                prefix + filename, titles, log,
                np.ma.copy(dataCube), minZ, maxZ, plotFlag, makeMovie,
                antCoords
            ])
        if makeMovie: pngs.append(prefix + filename + '.png')

        soltab.selection = soltab1Selection
        ### end cycle on tables
    mpm.wait()

    if makeMovie:

        def long_substr(strings):
            """
            Find longest common substring
            """
            substr = ''
            if len(strings) > 1 and len(strings[0]) > 0:
                for i in range(len(strings[0])):
                    for j in range(len(strings[0]) - i + 1):
                        if j > len(substr) and all(strings[0][i:i + j] in x
                                                   for x in strings):
                            substr = strings[0][i:i + j]
            return substr

        movieName = long_substr(pngs)
        assert movieName != ''  # need a common prefix, use prefix keyword in case
        logging.info('Making movie: ' + movieName)
        # make every movie last 20 sec, min one second per slide
        fps = np.ceil(len(pngs) / 200.)
        ss="mencoder -ovc lavc -lavcopts vcodec=mpeg4:vpass=1:vbitrate=6160000:mbd=2:keyint=132:v4mv:vqmin=3:lumi_mask=0.07:dark_mask=0.2:"+\
                "mpeg_quant:scplx_mask=0.1:tcplx_mask=0.1:naq -mf type=png:fps="+str(fps)+" -nosound -o "+movieName.replace('__tmp__','')+".mpg mf://"+movieName+"*  > mencoder.log 2>&1"
        os.system(ss)
        #for png in pngs: os.system('rm '+png)

    return 0
Ejemplo n.º 23
0
def run(soltab, axesToClip=None, clipLevel=5., log=False, mode='median'):
    """
    Clip solutions around the median by a factor specified by the user.
    WEIGHT: flag compliant, putting weights into median is tricky

    Parameters
    ----------
    axesToClip : list of str
        axes along which to calculate the median (e.g. [time,freq]).

    clipLevel : float, optional
        factor above/below median at which to clip, by default 5.

    log : bool, optional
        clip is done in log10 space, by default False.

    mode : str, optional
        if "median" then flag at rms*clipLevel times away from the median, if it is "above", 
        then flag all values above clipLevel, if it is "below" then flag all values below clipLevel. 
        By default median.
    """

    import numpy as np

    def percentFlagged(w):
        return 100. * (weights.size - np.count_nonzero(weights)) / float(
            weights.size)

    logging.info("Clipping soltab: " + soltab.name)

    # input check
    if len(axesToClip) == 1 and mode == 'median':
        logging.error("Please specify axes to clip.")
        return 1
    elif len(axesToClip) == 0:
        axesToClip = soltab.getAxesNames()

    if mode != 'median' and mode != 'above' and mode != 'below':
        logging.error("Mode can be only: median, above or below.")
        return 1

    if mode == 'median' and soltab.getType() == 'amplitude' and not log:
        logging.warning(
            'Amplitude solution tab detected and log=False. Amplitude solution tables should be treated in log space.'
        )

    # some checks
    for i, axis in enumerate(axesToClip[:]):
        if axis not in soltab.getAxesNames():
            del axesToClip[i]
            logging.warning('Axis \"' + axis + '\" not found. Ignoring.')

    for vals, weights, coord, selection in soltab.getValuesIter(
            returnAxes=axesToClip, weight=True):

        initPercent = percentFlagged(weights)

        # skip all flagged
        if (weights == 0).all():
            continue

        if log: vals = np.log10(vals)

        if mode == 'median':
            valmedian = np.nanmedian(vals[(weights != 0)])
            rms = np.nanstd(vals[(weights != 0)])
            np.putmask(weights, np.abs(vals - valmedian) > rms * clipLevel, 0)

        elif mode == 'above':
            np.putmask(weights, vals > clipLevel, 0)

        elif mode == 'below':
            np.putmask(weights, vals < clipLevel, 0)

        # writing back the solutions
        if log: vals = 10**vals
        soltab.setValues(weights, selection, weight=True)

        #print('max', np.max(vals[(weights != 0)]))
        #print('median', np.nanmedian(vals[(weights != 0)]))
        logging.debug('Percentage of data flagged (%s): %.3f%% -> %.3f%%' \
            % (removeKeys(coord, axesToClip), initPercent, percentFlagged(weights)))

    soltab.addHistory('CLIP (over %s with %s sigma cut)' %
                      (axesToClip, clipLevel))

    soltab.flush()

    return 0
Ejemplo n.º 24
0
def run(soltab,
        soltabOut='tec000',
        refAnt='',
        maxResidualFlag=2.5,
        maxResidualProp=1.):
    """
    Bruteforce TEC extraction from phase solutions.

    Parameters
    ----------
    soltabOut : str, optional
        output table name (same solset), by deault "tec".

    refAnt : str, optional
        Reference antenna, by default the first.

    maxResidualFlag : float, optional
        Max average residual in radians before flagging datapoint, by default 2.5 If 0: no check.

    maxResidualProp : float, optional
        Max average residual in radians before stop propagating solutions, by default 1. If 0: no check.

    """
    import numpy as np
    import scipy.optimize
    from losoto.lib_unwrap import unwrap_2d

    def mod(d):
        return np.mod(d + np.pi, 2. * np.pi) - np.pi

    drealbrute = lambda d, freq, y: np.sum(
        np.abs(mod(-8.44797245e9 * d / freq) - y))  # bruteforce
    dreal = lambda d, freq, y: mod(-8.44797245e9 * d[0] / freq) - y
    #dreal2 = lambda d, freq, y: mod(-8.44797245e9*d[0]/freq + d[1]) - y
    #dcomplex = lambda d, freq, y:  np.sum( ( np.cos(-8.44797245e9*d/freq)  - np.cos(y) )**2 ) +  np.sum( ( np.sin(-8.44797245e9*d/freq)  - np.sin(y) )**2 )
    #def dcomplex( d, freq, y, y_pre, y_post):
    #    return np.sum( ( np.absolute( np.exp(-1j*8.44797245e9*d/freq)  - np.exp(1j*y) ) )**2 ) + \
    #           .5*np.sum( ( np.absolute( np.exp(-1j*8.44797245e9*d/freq)  - np.exp(1j*y_pre) ) )**2 ) + \
    #           .5*np.sum( ( np.absolute( np.exp(-1j*8.44797245e9*d/freq)  - np.exp(1j*y_post) ) )**2 )
    #dcomplex2 = lambda d, freq, y:  abs(np.cos(-8.44797245e9*d[0]/freq + d[1])  - np.cos(y)) + abs(np.sin(-8.44797245e9*d[0]/freq + d[1])  - np.sin(y))

    logging.info("Find TEC for soltab: " + soltab.name)

    # input check
    solType = soltab.getType()
    if solType != 'phase':
        logging.warning("Soltab type of " + soltab._v_name + " is of type " +
                        solType + ", should be phase. Ignoring.")
        return 1

    ants = soltab.getAxisValues('ant')
    if refAnt != '' and refAnt != 'closest' and not refAnt in soltab.getAxisValues(
            'ant', ignoreSelection=True):
        logging.error('Reference antenna ' + refAnt + ' not found. Using: ' +
                      ants[1])
        refAnt = ants[0]
    if refAnt == '': refAnt = ants[0]

    # times and ants needs to be complete or selection is much slower
    times = soltab.getAxisValues('time')

    # create new table
    solset = soltab.getSolset()
    soltabout = solset.makeSoltab(soltype = 'tec', soltabName = soltabOut, axesNames=['ant','time'], \
                      axesVals=[soltab.getAxisValues(axisName) for axisName in ['ant','time']], \
                      vals=np.zeros(shape=(soltab.getAxisLen('ant'),soltab.getAxisLen('time'))), \
                      weights=np.ones(shape=(soltab.getAxisLen('ant'),soltab.getAxisLen('time'))) )
    soltabout.addHistory('Created by TEC operation from %s.' % soltab.name)

    for vals, weights, coord, selection in soltab.getValuesIter(
            returnAxes=['freq', 'time'], weight=True, reference=refAnt):

        if len(coord['freq']) < 10:
            logging.error(
                'Delay estimation needs at least 10 frequency channels, preferably distributed over a wide range.'
            )
            return 1

        # reorder axes
        vals = reorderAxes(vals, soltab.getAxesNames(), ['freq', 'time'])
        weights = reorderAxes(weights, soltab.getAxesNames(), ['freq', 'time'])

        fitd = np.zeros(len(times))
        fitweights = np.ones(len(times))  # all unflagged to start
        #fitdguess = 0.01 # good guess
        ranges = (-0.5, 0.5)
        Ns = 1000

        if not coord['ant'] == refAnt:

            if (weights == 0.).all() == True:
                logging.warning('Skipping flagged antenna: ' + coord['ant'])
                fitweights[:] = 0
            else:

                # unwrap 2d timexfreq
                #flags = np.array((weights == 0), dtype=bool)
                #vals = unwrap_2d(vals, flags, coord['freq'], coord['time'])

                for t, time in enumerate(times):

                    # apply flags
                    idx = (weights[:, t] != 0.)
                    freq = np.copy(coord['freq'])[idx]

                    if t == 0: phaseComb_pre = vals[idx, 0]
                    else: phaseComb_pre = vals[idx, t - 1]
                    if t == len(times) - 1: phaseComb_post = vals[idx, -1]
                    else: phaseComb_post = vals[idx, t + 1]

                    phaseComb = vals[idx, t]

                    if len(freq) < 10:
                        fitweights[t] = 0
                        logging.warning(
                            'No valid data found for delay fitting for antenna: '
                            + coord['ant'] + ' at timestamp ' + str(t))
                        continue

                    # if more than 1/4 of chans are flagged
                    if (len(idx) - len(freq)) / float(len(idx)) > 1 / 3.:
                        logging.debug(
                            'High number of filtered out data points for the timeslot %i: %i/%i'
                            % (t, len(idx) - len(freq), len(idx)))

                    # least square 2
                    #fitresultd2, success = scipy.optimize.leastsq(dreal2, [fitdguess,0.], args=(freq, phaseComb))
                    #numjumps = np.around(fitresultd2[1]/(2*np.pi))
                    #print 'best jumps:', numjumps
                    #phaseComb -= numjumps * 2*np.pi
                    #fitresultd, success = scipy.optimize.leastsq(dreal, [fitresultd2[0]], args=(freq, phaseComb))

                    # least square 1
                    #fitresultd, success = scipy.optimize.leastsq(dreal, [fitdguess], args=(freq, phaseComb))

                    # hopper
                    #fitresultd = scipy.optimize.basinhopping(dreal, [fitdguess], T=1, minimizer_kwargs={'args':(freq, phaseComb)})
                    #fitresultd = [fitresultd.x]

                    #best_residual = np.nanmean(np.abs( mod(-8.44797245e9*fitresultd[0]/freq) - phaseComb ) )

                    #best_residual = np.inf
                    #for jump in [-2,-1,0,1,2]:
                    #    fitresultd, success = scipy.optimize.leastsq(dreal, [fitdguess], args=(freq, phaseComb - jump * 2*np.pi))
                    #    print fitresultd
                    #    # fractional residual
                    #    residual = np.nanmean(np.abs( (-8.44797245e9*fitresultd[0]/freq) - phaseComb - jump * 2*np.pi ) )
                    #    if residual < best_residual:
                    #        best_residual = residual
                    #        fitd[t] = fitresultd[0]
                    #        best_jump = jump

                    # brute force
                    fitresultd = scipy.optimize.brute(drealbrute,
                                                      ranges=(ranges, ),
                                                      Ns=Ns,
                                                      args=(freq, phaseComb))
                    fitresultd, success = scipy.optimize.leastsq(
                        dreal, fitresultd, args=(freq, phaseComb))
                    best_residual = np.nanmean(
                        np.abs(
                            mod(-8.44797245e9 * fitresultd[0] / freq) -
                            phaseComb))

                    fitd[t] = fitresultd[0]
                    if maxResidualFlag == 0 or best_residual < maxResidualFlag:
                        fitweights[t] = 1
                        if maxResidualProp == 0 or best_residual < maxResidualProp:
                            ranges = (fitresultd[0] - 0.05,
                                      fitresultd[0] + 0.05)
                            Ns = 100
                        else:
                            ranges = (-0.5, 0.5)
                            Ns = 1000
                    else:
                        # high residual, flag and reset initial guess
                        logging.warning('Bad solution for ant: ' +
                                        coord['ant'] + ' (time: ' + str(t) +
                                        ', resdiual: ' + str(best_residual) +
                                        ').')
                        fitweights[t] = 0
                        ranges = (-0.5, 0.5)
                        Ns = 1000

                    # Debug plot
                    doplot = False
                    if doplot and (coord['ant'] == 'RS509LBA' or coord['ant']
                                   == 'RS210LBA') and t % 50 == 0:
                        print("Plotting")
                        if not 'matplotlib' in sys.modules:
                            import matplotlib as mpl
                            mpl.rc('figure.subplot',
                                   left=0.05,
                                   bottom=0.05,
                                   right=0.95,
                                   top=0.95,
                                   wspace=0.22,
                                   hspace=0.22)
                            mpl.use("Agg")
                        import matplotlib.pyplot as plt

                        fig = plt.figure()
                        fig.subplots_adjust(wspace=0)
                        ax = fig.add_subplot(111)

                        # plot rm fit
                        plotd = lambda d, freq: -8.44797245e9 * d / freq
                        ax.plot(freq,
                                plotd(fitresultd[0], freq[:]),
                                "-",
                                color='purple')
                        ax.plot(freq,
                                mod(plotd(fitresultd[0], freq[:])),
                                ":",
                                color='purple')

                        #ax.plot(freq, vals[idx,t], '.b' )
                        #ax.plot(freq, phaseComb + numjumps * 2*np.pi, 'x', color='purple' )
                        ax.plot(freq, phaseComb, 'o', color='purple')

                        residual = mod(plotd(fitd[t], freq[:]) - phaseComb)
                        ax.plot(freq, residual, '.', color='orange')

                        ax.set_xlabel('freq')
                        ax.set_ylabel('phase')
                        #ax.set_ylim(ymin=-np.pi, ymax=np.pi)

                        logging.warning('Save pic: ' + str(t) + '_' +
                                        coord['ant'] + '.png')
                        plt.savefig(str(t) + '_' + coord['ant'] + '.png',
                                    bbox_inches='tight')
                        del fig

                logging.info('%s: average tec: %f TECU' %
                             (coord['ant'], np.mean(2 * fitd)))

        # reorder axes back to the original order, needed for setValues
        soltabout.setSelection(ant=coord['ant'])
        soltabout.setValues(fitd)
        soltabout.setValues(fitweights, weight=True)

    return 0
Ejemplo n.º 25
0
def run(soltab, soltabOut='tec000', refAnt=''):
    """
    Estimate TEC from ph solutions assuming no wrap for solution at t=0

    Parameters
    ----------
    soltabOut : str, optional
        output table name (same solset), by deault "tec".

    refAnt : str, optional
        Reference antenna, by default the first.

    """

    import numpy as np
    import scipy.optimize
    from losoto.lib_unwrap import unwrap

    logging.info("Find TEC for soltab: " + soltab.name)

    # input check
    solType = soltab.getType()
    if solType != 'phase':
        logging.warning("Soltab type of " + soltab._v_name + " is of type " +
                        solType + ", should be phase. Ignoring.")
        return 1

    ants = soltab.getAxisValues('ant')
    if refAnt != '' and refAnt != 'closest' and not refAnt in soltab.getAxisValues(
            'ant', ignoreSelection=True):
        logging.warning('Reference antenna ' + refAnt + ' not found. Using: ' +
                        ants[1])
        refAnt = ants[0]
    if refAnt == '': refAnt = ants[0]

    # times and ants needs to be complete or selection is much slower
    times = soltab.getAxisValues('time')

    # create new table
    solset = soltab.getSolset()
    soltabout = solset.makeSoltab(soltype = 'tec', soltabName = soltabOut, axesNames=['ant','time','dir'], \
                      axesVals=[soltab.getAxisValues(axisName) for axisName in ['ant','time','dir']], \
                      vals=np.zeros(shape=(soltab.getAxisLen('ant'),soltab.getAxisLen('time'),soltab.getAxisLen('dir'))), \
                      weights=np.ones(shape=(soltab.getAxisLen('ant'),soltab.getAxisLen('time'),soltab.getAxisLen('dir'))) )
    soltabout.addHistory('Created by TEC operation from %s.' % soltab.name)

    for vals, weights, coord, selection in soltab.getValuesIter(
            returnAxes=['freq', 'time'], weight=True, refAnt=refAnt):

        assert len(coord['freq']) == 1  # it works with phase at only 1 freq

        if not coord['ant'] == refAnt:

            if (weights == 0.).all() == True:
                logging.warning('Skipping flagged antenna: ' + coord['ant'])
                fitweights[:] = 0
            else:

                # unwrap
                vals = np.reshape(unwrap(np.squeeze(vals)), vals.shape)
                vals *= coord['freq'] / (-8.44797245e9)
                logging.info('%s: average tec: %f TECU' %
                             (coord['ant'], np.mean(vals)))

        # reorder axes back to the original order, needed for setValues
        soltabout.setSelection(ant=coord['ant'], dir=coord['dir'])
        soltabout.setValues(vals)
        soltabout.setValues(weights, weight=True)

    return 0