Ejemplo n.º 1
0
def run(soltab):
    """
    Take absolute value. Needed before smooth if amplitudes are negative!
    WEIGHT: no need to be weight compliant

    Parameters
    ----------
    soltab : soltab obj
        Solution table.
    """

    import numpy as np

    logging.info("Taking ABSolute value of soltab: " + soltab.name)

    vals = soltab.getValues(retAxesVals=False)
    count = np.count_nonzero(vals < 0)

    logging.info('Abs: %i points initially negative (%f %%)' %
                 (count, 100 * float(count) / np.count_nonzero(vals)))

    # writing back the solutions
    soltab.setValues(np.abs(vals))

    soltab.addHistory('ABSolute value taken')

    return 0
Ejemplo n.º 2
0
def run( soltab, refAnt='', refDir=''):
    """
    Reference to an antenna

    Parameters
    ----------
    refAnt : str, optional
        Reference antenna for phases or "best" to use the least flagged antenna. Empty string does not change phases. Default: ''.
    refDir : str, optional
        Reference direction for phases. Empty string does not change phases. Default: ''.
    """

    if not soltab.getType() in ['phase', 'scalarphase', 'rotation', 'tec', 'clock', 'tec3rd', 'rotationmeasure']:
        logging.error('Reference possible only for phase, scalarphase, clock, tec, tec3rd, rotation and rotationmeasure solution tables. Ignore referencing.')
        return 1

    if refAnt == 'best': 
        weights = soltab.getValues(retAxesVals=False, weight=True)
        weights = np.sum(weights, axis=tuple([i for i, axis_name in enumerate(soltab.getAxesNames()) if axis_name != 'ant']), dtype=np.float)
        refAnt = soltab.getAxisValues('ant')[np.where(weights == np.max(weights))[0][0]]
        logging.info('Using %s for reference antenna.' % refAnt)

    elif not refAnt in soltab.getAxisValues('ant', ignoreSelection = True) and refAnt != '':
        logging.error('Reference antenna '+refAnt+' not found.')
        return 1

    if not refDir in soltab.getAxisValues('dir', ignoreSelection = True) and refDir != '':
        logging.error('Reference direction '+refDir+' not found.')
        return 1

    # get reference direction if needed
    if refDir != '' and refAnt != '':
        logging.info('Referencing on both dir and ant.')
        soltab.setSelection(dir=refDir, ant=refAnt)
        valsRef = soltab.getValues(retAxesVals=False)
        soltab.clearSelection()
        vals = soltab.getValues(retAxesVals=False)
        dirAxis = soltab.getAxesNames().index('dir')
        antAxis = soltab.getAxesNames().index('ant')
        print ('len vals ref', valsRef.shape)
        valsRef = np.repeat(valsRef, axis=dirAxis, repeats=len(soltab.getAxisValues('dir')))
        valsRef = np.repeat(valsRef, axis=antAxis, repeats=len(soltab.getAxisValues('ant')))
        vals = vals - valsRef

    elif refDir != '':
        soltab.setSelection(dir=refDir)
        valsRefDir = soltab.getValues(retAxesVals=False)
        soltab.clearSelection()
        vals = soltab.getValues(retAxesVals=False)
        dirAxis = soltab.getAxesNames().index('dir')
        vals = vals - np.repeat(valsRefDir, axis=dirAxis, repeats=len(soltab.getAxisValues('dir')))

    # use automatic antenna referencing
    elif refAnt != '':
        vals = soltab.getValues(retAxesVals=False, refAnt=refAnt)

    soltab.setValues(vals)

    soltab.addHistory('REFERENCED (to antenna: %s)' % (refAnt))
    return 0
Ejemplo n.º 3
0
def run(soltab, dataVal=-999.):
    """
    This operation reset all the selected solution values.
    WEIGHT: flag compliant, no need for weight

    Parameters
    ----------
    dataVal : float, optional
        If given set values to this number, otherwise uses 1 for amplitude and 0 for all other soltab types.
    """

    logging.info("Resetting soltab: " + soltab.name)

    solType = soltab.getType()

    if dataVal == -999.:
        if solType == 'amplitude':
            dataVal = 1.
        else:
            dataVal = 0.

    soltab.setValues(dataVal)

    soltab.addHistory('RESET')
    return 0
Ejemplo n.º 4
0
def run(soltab, soltabOutG=None, soltabOutD=None):
    """
    Duplicate a table

    Parameters
    ----------
    soltabOutG : str, optional
        Output table name (diagonal component). By default choose next available from table type.

    soltabOutD : str, optional
        Output table name (leakage component). By default choose next available from table type.
    """

    logging.info('Split leakage tables %s -> %s + %s' %
                 (soltab.name, soltabOutG, soltabOutD))

    if soltab.getType() != 'amplitude' and soltab.getType() != 'phase':
        logging.error(
            'SPLITLEAK can work only on amplitude/phase soltabs. Found: %s.' %
            soltab.getType())
        return 1
    if not np.all(soltab.getAxisValues('pol') == ['XX', 'XY', 'YX', 'YY']):
        logging.error('Pol in unusual order or not linear: not implemented.')
        return 1

    solset = soltab.getSolset()

    ### G component
    soltabOutG = solset.makeSoltab(soltype = soltab.getType(), soltabName = soltabOutG, axesNames=soltab.getAxesNames(), \
        axesVals=[soltab.getAxisValues(axisName) for axisName in soltab.getAxesNames()], \
        vals=soltab.getValues(retAxesVals = False), weights=soltab.getValues(weight = True, retAxesVals = False))

    # set offdiag to 0
    soltabOutG.setSelection(pol=['XY', 'YX'])
    soltabOutG.setValues(0.)

    ### D component
    soltabOutD = solset.makeSoltab(soltype = soltab.getType(), soltabName = soltabOutD, axesNames=soltab.getAxesNames(), \
        axesVals=[soltab.getAxisValues(axisName) for axisName in soltab.getAxesNames()], \
        vals=soltab.getValues(retAxesVals = False), weights=soltab.getValues(weight = True, retAxesVals = False))

    # divide offdiag by diag, then set diag to 1 (see Hamaker+ 96, appendix D)
    soltabOutD.setSelection(pol=['XX', 'YY'])
    valsDiag = np.copy(soltabOutD.getValues(retAxesVals=False))
    if soltab.getType() == 'amplitude':
        soltabOutD.setValues(1.)
    if soltab.getType() == 'phase':
        soltabOutD.setValues(0.)

    soltabOutD.setSelection(pol=['XY', 'YX'])
    valsOffdiag = soltabOutD.getValues(retAxesVals=False)
    if soltab.getType() == 'amplitude':
        soltabOutD.setValues(valsOffdiag / valsDiag)
    if soltab.getType() == 'phase':
        soltabOutD.setValues(valsOffdiag - valsDiag)

    soltabOutG.addHistory('SPLITLEAK: G component of %s' % (soltab.name))
    soltabOutD.addHistory('SPLITLEAK: D component of %s' % (soltab.name))
    return 0
Ejemplo n.º 5
0
def run(soltab, axesToExt, size, percent=50., maxCycles=3, ncpu=0):
    """
    This operation for LoSoTo implement a extend flag procedure
    It can work in multi dimensional space and for each datum check if the surrounding data are flagged to a certain %, then flag also that datum
    The size of the surrounding footprint can be tuned
    WEIGHT: compliant

    Parameters
    ----------
    axesToExt : list of str
        Axes used to find close flags.

    size : list of int
        Size of the window (diameter), per axis. If 0 is given then the entire length of the axis is assumed.
        Must be a vector of same length of Axes.

    percent : float, optional
        Percent of flagged data around the point to flag it, by default 50.

    maxCycles : int, optional
        Number of independent cycles of flag expansion, by default 3.

    ncpu : int, optional
        Number of CPU used, by default all available.
    """

    import numpy as np

    logging.info("Extending flag on soltab: " + soltab.name)

    # input check
    if axesToExt == []:
        logging.error("Please specify at least one axis to extend flag.")
        return 1

    # start processes for multi-thread
    mpm = multiprocManager(ncpu, _flag)

    for axisToExt in axesToExt:
        if axisToExt not in soltab.getAxesNames():
            logging.error('Axis \"' + axisToExt + '\" not found.')
            mpm.wait()
            return 1

    # fill the queue (note that sf and sw cannot be put into a queue since they have file references)
    for vals, weights, coord, selection in soltab.getValuesIter(
            returnAxes=axesToExt, weight=True):
        mpm.put(
            [weights, coord, axesToExt, selection, percent, size, maxCycles])

    mpm.wait()

    logging.info('Writing solutions')
    for w, sel in mpm.get():
        soltab.setValues(w, sel, weight=True)

    soltab.addHistory('FLAG EXTENDED (over %s)' % (str(axesToExt)))
    return 0
Ejemplo n.º 6
0
def run( soltab, axisReplicate, fromCell, updateWeights=True):
    """
    Replace the values along a certain axis taking them from one specific axic cell

    Parameters
    ----------
    axisReplicate : str
        Axis along which replicate the values.

    fromCell : str
        A cell value in axisReplicate from which to copy the data values.

    updateWeights : bool
        If False then weights are untoched, if True they are replicated like data. Default: True.
    """
    import numpy as np

    if not axisReplicate in soltab.getAxesNames():
        logging.error('Cannot find axis %s.' % axisReplicate)
        return 1

    axisType = type(soltab.getAxisValues(axisReplicate)[0])
    try:
        fromCell = np.array([fromCell]).astype(axisType)[0]
    except:
        logging.error('Cannot convert to type %s the value in fromCell: %s.' % (str(axisType),fromCell))
        return 1

    if not fromCell in soltab.getAxisValues(axisReplicate):
        logging.error('Cannot find %s in %s.' % (fromCell, axisReplicate))
        return 1

    logging.info("Replicate axis on soltab: "+soltab.name)

    # get the cell to replicate
    axisReplicateLen = soltab.getAxisLen(axisReplicate, ignoreSelection=False) # keep selection into account
    old_selection = soltab.selection

    # get slice with 1 value to replicate
    soltab.setSelection(**{axisReplicate:fromCell})
    vals = soltab.getValues(retAxesVals=False)
    if updateWeights:
        weights = soltab.getValues(retAxesVals=False, weight=True)

    cellPos = list(soltab.getAxisValues(axisReplicate)).index(fromCell)
    axisReplicatePos = soltab.getAxesNames().index(axisReplicate)

    # expand on the right axis
    vals = np.repeat(vals, repeats=axisReplicateLen, axis=axisReplicatePos)

    # write back
    soltab.selection = old_selection
    soltab.setValues(vals)
    if updateWeights:
        soltab.setValues(weights, weight=True)

    soltab.addHistory('REPLICATEONAXIS (over axis %s)' % (axisReplicate))
    return 0
Ejemplo n.º 7
0
def run( soltab, axisDelete, fromCell):
    """
    Delete an axis only keeping the values of a certain slice.

    Parameters
    ----------
    axisDelete : str
        Axis to delete.

    fromCell : str
        A cell value in axisDelete from which to keep the data values. If it is the string
        "first"/"last" then uses the first/last element of the axis.
    """
    import numpy as np

    if not axisDelete in soltab.getAxesNames():
        logging.error('Cannot find axis %s.' % axisDelete)
        return 1

    if fromCell == 'first':
        fromCell = soltab.getAxisValues(axisDelete)[0]
    elif fromCell == 'last':
        fromCell = soltab.getAxisValues(axisDelete)[-1]

    axisType = type(soltab.getAxisValues(axisDelete)[0])
    try:
        fromCell = np.array([fromCell]).astype(axisType)[0]
    except:
        logging.error('Cannot convert to type %s the value in fromCell: %s.' % (str(axisType),fromCell))
        return 1

    if not fromCell in soltab.getAxisValues(axisDelete):
        logging.error('Cannot find %s in %s.' % (fromCell, axisDelete))
        return 1

    logging.info("Delete axis on soltab: "+soltab.name)

    # get slice with 1 value to replicate
    soltab.setSelection(**{axisDelete:fromCell})
    axisDeleteIdx = soltab.getAxesNames().index(axisDelete)
    fromCellIdx = list(soltab.getAxisValues(axisDelete)).index(fromCell)
    vals = soltab.getValues(retAxesVals=False)
    vals = np.take(soltab.getValues(retAxesVals=False), fromCellIdx, axisDeleteIdx)
    weight = np.take(soltab.getValues(retAxesVals=False, weight=True), fromCellIdx, axisDeleteIdx)
    solset = soltab.getSolset()
    sttype = soltab.getType()
    stname = soltab.name

    axes = soltab.getAxesNames()
    axes.remove(axisDelete)
    axesVals = [soltab.getAxisValues(ax) for ax in axes]
    soltab.delete()

    st = solset.makeSoltab(sttype, stname, axesNames=axes, axesVals=axesVals, vals=vals, weights=weight)

    st.addHistory('DELETEAXIS (over axis %s from cell %s)' % (axisDelete, str(fromCell)))
    return 0
Ejemplo n.º 8
0
def run(soltab, axesToNorm, normVal=1., log=False):
    """
    Normalize the solutions to a given value
    WEIGHT: Weights compliant

    Parameters
    ----------
    axesToNorm : array of str
        Axes along which compute the normalization.

    normVal : float, optional
        Number to normalize to vals = vals * (normVal/valsMean), by default 1.

    log : bool, optional
        clip is done in log10 space, by default False.
    """
    import numpy as np

    logging.info("Normalizing soltab: " + soltab.name)

    # input check
    axesNames = soltab.getAxesNames()
    for normAxis in axesToNorm:
        if normAxis not in axesNames:
            logging.error('Normalization axis ' + normAxis + ' not found.')
            return 1

    if soltab.getType() == 'amplitude' and not log:
        logging.warning(
            'Amplitude solution tab detected and log=False. Amplitude solution tables should be treated in log space.'
        )

    for vals, weights, coord, selection in soltab.getValuesIter(
            returnAxes=axesToNorm, weight=True):

        if log: vals = np.log10(vals)

        # rescale solutions
        if np.all(weights == 0): continue  # skip flagged selections
        valsMean = np.nanmean(vals[weights != 0])

        if log: vals[weights != 0] += np.log10(normVal) - valsMean
        else: vals[weights != 0] *= normVal / valsMean

        logging.debug("Rescaling by: " + str(normVal / valsMean))

        # writing back the solutions
        if log: vals = 10**vals
        soltab.setValues(vals, selection)

    soltab.flush()
    soltab.addHistory('NORM (on axis %s)' % (axesToNorm))

    return 0
Ejemplo n.º 9
0
def _calculate_piercepoints(station_positions, source_positions):
    """
    Returns array of piercepoint locations

    Parameters
    ----------
    station_positions : array
        Array of station positions
    source_positions : array
        Array of source positions

    Returns
    -------
    pp : array
        Array of pierce points
    midRA : float
        Reference RA for WCS system (deg)
    midDec : float
        Reference Dec for WCS system (deg)

    """
    import numpy as np

    logging.info('Calculating screen pierce-point locations...')
    N_sources = source_positions.shape[0]
    N_stations = station_positions.shape[0]
    N_piercepoints = N_stations * N_sources
    pp = np.zeros((N_piercepoints, 3))

    xyz = np.zeros((N_sources, 3))
    ra_deg = source_positions.T[0] * 180.0 / np.pi
    dec_deg = source_positions.T[1] * 180.0 / np.pi
    xy, midRA, midDec = _getxy(ra_deg, dec_deg)
    xyz[:, 0] = xy[0]
    xyz[:, 1] = xy[1]
    pp_idx = 0
    for i in range(N_sources):
        for station_position in station_positions:
            pp[pp_idx, :] = xyz[i]
            pp_idx += 1

    return pp, midRA, midDec
Ejemplo n.º 10
0
def run(soltab, axesToNorm, normVal=1.):
    """
    Normalize the solutions to a given value
    WEIGHT: Weights compliant

    Parameters
    ----------
    axesToNorm : array of str
        Axes along which compute the normalization.

    normVal : float, optional
        Number to normalize to vals = vals * (normVal/valsMean), by default 1.
    """
    import numpy as np

    logging.info("Normalizing soltab: " + soltab.name)

    # input check
    axesNames = soltab.getAxesNames()
    for normAxis in axesToNorm:
        if normAxis not in axesNames:
            logging.error('Normalization axis ' + normAxis + ' not found.')
            return 1

    for vals, weights, coord, selection in soltab.getValuesIter(
            returnAxes=axesToNorm, weight=True):

        # rescale solutions
        if np.all(weights == 0): continue  # skip flagged selections
        valsMean = np.nanmean(vals[weights != 0])
        vals[weights != 0] *= normVal / valsMean
        logging.debug(str(coord))
        logging.debug("Rescaling by: " + str(normVal / valsMean))

        # writing back the solutions
        soltab.setValues(vals, selection)

    soltab.flush()
    soltab.addHistory('NORM (on axis %s)' % (axesToNorm))

    return 0
Ejemplo n.º 11
0
def run(soltab, soltabOut=''):
    """
    Duplicate a table

    Parameters
    ----------
    soltabOut : str, optional
        Output table name. By default choose next available from table type.
    """

    if soltabOut == '':
        soltabOut = None

    solset = soltab.getSolset()
    soltabout = solset.makeSoltab(soltype = soltab.getType(), soltabName = soltabOut, axesNames=soltab.getAxesNames(), \
        axesVals=[soltab.getAxisValues(axisName) for axisName in soltab.getAxesNames()], \
        vals=soltab.getValues(retAxesVals = False), weights=soltab.getValues(weight = True, retAxesVals = False))
    # parmdbType=soltab.obj._v_attrs['parmdb_type'] # deprecated

    logging.info('Duplicate %s -> %s' % (soltab.name, soltabout.name))

    soltabout.addHistory('DUPLICATE from table %s' % (soltab.name))
    return 0
Ejemplo n.º 12
0
def run(soltab, soltabOut='', overwrite=False):
    """
    Duplicate a table

    Parameters
    ----------
    soltabOut : str, optional
        Output table name. By default choose next available from table type.
    overwrite : bool, optional
        Overwrite soltabOut if it already exists?
    """

    if soltabOut == '':
        soltabOut = None

    solset = soltab.getSolset()
    if soltabOut in solset.getSoltabNames() and overwrite:
        logging.info('Overwriting soltabOut {}'.format(soltabOut))
        solset.getSoltab(soltabOut).delete()

    soltabout = solset.makeSoltab(soltype=soltab.getType(),
                                  soltabName=soltabOut,
                                  axesNames=soltab.getAxesNames(),
                                  axesVals=[
                                      soltab.getAxisValues(axisName)
                                      for axisName in soltab.getAxesNames()
                                  ],
                                  vals=soltab.getValues(retAxesVals=False),
                                  weights=soltab.getValues(weight=True,
                                                           retAxesVals=False))
    # parmdbType=soltab.obj._v_attrs['parmdb_type'] # deprecated

    logging.info('Duplicate %s -> %s' % (soltab.name, soltabout.name))

    soltabout.addHistory('DUPLICATE from table %s' % (soltab.name))
    return 0
Ejemplo n.º 13
0
def run(soltab, doUnwrap=False, refAnt='', plotName='', ndiv=1):
    """
    Find the structure function from phase solutions of core stations.

    Parameters
    ----------
    doUnwrap : bool, optional

    refAnt : str, optional
        Reference antenna, by default the first.

    plotName : str, optional
        Plot file name, by default no plot.

    ndiv : int, optional
        

    """
    import numpy as np
    from losoto.lib_unwrap import unwrap, unwrap_2d

    logging.info("Find structure function for soltab: " + soltab.name)

    # input check
    solType = soltab.getType()
    if solType != 'phase':
        logging.warning("Soltab type of " + soltab._v_name + " is of type " +
                        solType + ", should be phase.")
        return 1

    ants = soltab.getAxisValues('ant')
    if refAnt != '' and refAnt != 'closest' and not refAnt in soltab.getAxisValues(
            'ant', ignoreSelection=True):
        logging.error('Reference antenna ' + refAnt + ' not found. Using: ' +
                      ants[1])
        refAnt = ants[1]
    if refAnt == '' and doUnwrap:
        logging.error('Unwrap requires reference antenna. Using: ' + ants[1])
        refAnt = ants[1]
    if refAnt == '': refAnt = None

    soltab.setSelection(ant='CS*', update=True)

    posAll = soltab.getSolset().getAnt()

    for vals, weights, coord, selection in soltab.getValuesIter(
            returnAxes=['freq', 'pol', 'ant', 'time'],
            weight=True,
            reference=refAnt):

        # reorder axes
        vals = reorderAxes(vals, soltab.getAxesNames(),
                           ['pol', 'ant', 'freq', 'time'])
        weights = reorderAxes(weights, soltab.getAxesNames(),
                              ['pol', 'ant', 'freq', 'time'])

        # order positions
        pos = np.array([list(posAll[ant]) for ant in coord['ant']])

        # avg pols
        vals = np.cos(vals) + 1.j * np.sin(vals)
        vals = np.nansum(vals, axis=0)
        vals = np.angle(vals)
        flags = np.array((weights[0] == 0) | (weights[1] == 0), dtype=bool)

        # unwrap
        if doUnwrap:
            # remove mean to facilitate unwrapping
            for a, ant in enumerate(coord['ant']):
                if not (flags[a, :, :] == True).all() and ant != refAnt:
                    logging.debug('Unwrapping: ' + ant)
                    mean = np.angle(np.nanmean(np.exp(1j * vals[a].flatten())))
                    vals[a] -= mean
                    vals[a] = np.mod(vals[a] + np.pi, 2 * np.pi) - np.pi
                    vals[a, :, :] = unwrap_2d(vals[a, :, :], flags[a, :, :],
                                              coord['freq'], coord['time'])

        logging.debug('Computing differential values...')
        t1 = np.ma.array(vals, mask=flags)  # mask flagged data
        dph = t1[np.newaxis] - t1[:, np.newaxis]  # ant x ant x freq x time
        D = pos[np.newaxis] - pos[:, np.newaxis]  # ant x ant x 3
        D2 = np.triu(
            np.sqrt(np.sum(D**2, axis=-1))
        )  # calc distance and keep only uppoer triangle larger than 0
        myselect = (D2 > 0)

        if not doUnwrap:
            logging.debug('Re-normalising...')
            dph = np.mod(dph + np.pi, 2 * np.pi) - np.pi
            avgdph = np.ma.average(
                dph,
                axis=2)  # avg in freq (can do because is between -pi and pi)
            #one extra step to remove most(all) phase wraps, phase wraps disturbe the averaging...
            dph = np.remainder(
                dph -
                np.ma.average(avgdph, axis=-1)[:, :, np.newaxis, np.newaxis] +
                np.pi, 2 * np.pi) + np.ma.average(
                    avgdph,
                    axis=-1)[:, :, np.newaxis,
                             np.newaxis] - np.pi  #center around the avg value

        logging.debug('Computing sructure function...')
        avgdph = np.ma.average(dph, axis=2)  # avg in freq to reduce noise

        variances = []
        pars = []
        avgdph = avgdph[
            ..., avgdph.shape[-1] %
            ndiv:]  # remove a few timeslots to make the array divisible by np.split
        for i, avgdphSplit in enumerate(np.split(avgdph, ndiv, axis=-1)):
            variance = np.ma.var(avgdphSplit, axis=-1) * (
                np.average(coord['freq']) /
                150.e6)**2  # get time variance and rescale to 150 MHz

            # linear regression
            #A = np.ones((2,D2[myselect].shape[0]),dtype=float)
            #A[1,:] = np.log10(D2[myselect][~variance.mask])
            #par = np.dot(np.linalg.inv(np.dot(A,A.T)),np.dot(A,np.log10(variance[myselect])))
            mask = variance[myselect].mask
            A = np.vstack([
                np.log10(D2[myselect][~mask]),
                np.ones(len(D2[myselect][~mask]))
            ])
            par = np.linalg.lstsq(A.T, np.log10(variance[myselect][~mask]))[0]
            S0 = 10**(-1 * par[1] / par[0])
            logging.info(r't%i: beta=%.2f - R_diff=%.2f km' %
                         (i, par[0], S0 / 1.e3))
            variances.append(variance)
            pars.append(par)

        if plotName != '':
            if plotName.split('.')[-1] != 'png': plotName += '.png'  # add png

            if not 'matplotlib' in sys.modules:
                import matplotlib as mpl
                mpl.use("Agg")
            import matplotlib.pyplot as plt

            fig = plt.figure()
            fig.subplots_adjust(wspace=0)
            ax = fig.add_subplot(111)
            ax1 = ax.twinx()

            for i, variance in enumerate(variances):
                if len(variances) > 1:
                    color = plt.cm.jet(
                        i / float(len(variances) - 1))  # from 0 to 1
                else:
                    color = 'black'
                ax.plot(D2[myselect] / 1.e3,
                        variance[myselect],
                        marker='o',
                        linestyle='',
                        color=color,
                        markeredgecolor='none',
                        label='T')

                # regression
                par = pars[i]
                x = D2[myselect]
                S0 = 10**(-1 * par[1] / par[0])
                if color == 'black':
                    color = 'red'  # in case of single color, use red line that is more visible
                ax1.plot(x.flatten() / 1.e3,
                         par[0] * np.log10(x.flatten()) + par[1],
                         linestyle='-',
                         color=color,
                         label=r'$\beta=%.2f$ - $R_{\rm diff}=%.2f$ km' %
                         (par[0], S0 / 1.e3))

            ax.set_xlabel('Distance (km)')
            ax.set_ylabel(r'Phase variance @150 MHz (rad$^2$)')
            ax.set_xscale('log')
            ax.set_yscale('log')

            ymin = np.min(variance[myselect])
            ymax = np.max(variance[myselect])
            ax.set_xlim(xmin=0.1, xmax=3)
            ax.set_ylim(ymin, ymax)
            ax1.set_ylim(np.log10(ymin), np.log10(ymax))
            ax1.legend(loc='lower right', frameon=False)
            ax1.set_yticks([])

            logging.warning('Save pic: %s' % plotName)
            plt.savefig(plotName, bbox_inches='tight')

    return 0
Ejemplo n.º 14
0
def run(soltab,
        mode,
        maxFlaggedFraction=0.5,
        nSigma=5.0,
        maxStddev=None,
        ampRange=None,
        telescope='lofar',
        skipInternational=False,
        refAnt='',
        soltabExport='',
        ncpu=0):
    """
    This operation for LoSoTo implements a station-flagging procedure. Flags are time-independent.
    WEIGHT: compliant

    Parameters
    ----------
    mode: str
        Fitting algorithm: bandpass or resid. Bandpass mode clips amplitudes relative to a model bandpass (only LOFAR is currently supported). Resid mode clips residual phases or log(amplitudes).

    maxFlaggedFraction : float, optional
        This sets the maximum allowable fraction of flagged solutions above which the entire station is flagged.

    nSigma : float, optional
        This sets the number of standard deviations considered when outlier clipping is done.

    maxStddev : float, optional
        Maximum allowable standard deviation when outlier clipping is done. For phases, this should value
        should be in radians, for amplitudes in log(amp). If None (or negative), a value of 0.1 rad is
        used for phases and 0.01 for amplitudes.

    ampRange : array, optional
        2-element array of the median amplitude level to be acceptable, ampRange[0]: lower limit, ampRange[1]: upper limit.
        If None or [0, 0], a reasonable range for typical observations is used.

    telescope : str, optional
        Specifies the telescope if mode = 'bandpass'.

    skipInternational : str, optional
        If True, skip flagging of international LOFAR stations (only used if telescope = 'lofar')

    refAnt : str, optional
        If mode = resid, this sets the reference antenna for phase solutions, by default None.

    soltabExport : str, optional
        Soltab to export station flags to. Note: exported flags are not time- or frequency-dependent.

    ncpu : int, optional
        Number of cpu to use, by default all available.
    """

    logging.info("Flagging on soltab: " + soltab.name)

    # input check
    if refAnt == '':
        refAnt = None
    if soltabExport == '':
        soltabExport = None
    if mode is None or mode.lower() not in ['bandpass', 'resid']:
        logging.error('Mode must be one of bandpass or resid')
        return 1
    solType = soltab.getType()
    if maxStddev is None or maxStddev <= 0.0:
        if solType == 'phase':
            maxStddev = 0.1  # in radians
        else:
            maxStddev = 0.01  # in log10(amp)

    # Axis order must be [time, ant, freq, pol], so reorder if necessary
    axis_names = soltab.getAxesNames()
    if ('freq' not in axis_names or 'pol' not in axis_names
            or 'time' not in axis_names or 'ant' not in axis_names):
        logging.error("Currently, flagstation requires the following axes: "
                      "freq, pol, time, and ant.")
        return 1
    freq_ind = axis_names.index('freq')
    pol_ind = axis_names.index('pol')
    time_ind = axis_names.index('time')
    ant_ind = axis_names.index('ant')
    if 'dir' in axis_names:
        dir_ind = axis_names.index('dir')
        vals_arraytmp = soltab.val[:].transpose(
            [time_ind, ant_ind, freq_ind, pol_ind, dir_ind])
        weights_arraytmp = soltab.weight[:].transpose(
            [time_ind, ant_ind, freq_ind, pol_ind, dir_ind])
    else:
        vals_arraytmp = soltab.val[:].transpose(
            [time_ind, ant_ind, freq_ind, pol_ind])
        weights_arraytmp = soltab.weight[:].transpose(
            [time_ind, ant_ind, freq_ind, pol_ind])

    # Check for NaN solutions and flag
    flagged = np.where(np.isnan(vals_arraytmp))
    weights_arraytmp[flagged] = 0.0

    if mode == 'bandpass':
        if solType != 'amplitude':
            logging.error(
                "Soltab must be of type amplitude for bandpass mode.")
            return 1

        # Fill the queue
        if 'dir' in axis_names:
            for d, dirname in enumerate(soltab.dir):
                mpm = multiprocManager(ncpu, _flag_bandpass)
                for s in range(len(soltab.ant)):
                    if ('CS' not in soltab.ant[s] and 'RS' not in soltab.ant[s]
                            and skipInternational
                            and telescope.lower() == 'lofar'):
                        continue
                    mpm.put([
                        soltab.freq[:], vals_arraytmp[:, s, :, :, d],
                        weights_arraytmp[:, s, :, :,
                                         d], telescope, nSigma, ampRange,
                        maxFlaggedFraction, maxStddev, False, soltab.ant[:], s
                    ])
                mpm.wait()
                for (s, w) in mpm.get():
                    weights_arraytmp[:, s, :, :, d] = w
        else:
            mpm = multiprocManager(ncpu, _flag_bandpass)
            for s in range(len(soltab.ant)):
                if ('CS' not in soltab.ant[s] and 'RS' not in soltab.ant[s]
                        and skipInternational
                        and telescope.lower() == 'lofar'):
                    continue
                mpm.put([
                    soltab.freq[:], vals_arraytmp[:, s, :, :],
                    weights_arraytmp[:, s, :, :], telescope, nSigma, ampRange,
                    maxFlaggedFraction, maxStddev, False, soltab.ant[:], s
                ])
            mpm.wait()
            for (s, w) in mpm.get():
                weights_arraytmp[:, s, :, :] = w

        # Make sure that fully flagged stations have all pols flagged
        for s in range(len(soltab.ant)):
            for p in range(len(soltab.pol)):
                if np.all(weights_arraytmp[:, s, :, p] == 0.0):
                    weights_arraytmp[:, s, :, :] = 0.0
                    break

        # Write new weights
        if 'dir' in axis_names:
            weights_array = weights_arraytmp.transpose(
                [time_ind, ant_ind, freq_ind, pol_ind, dir_ind])
        else:
            weights_array = weights_arraytmp.transpose(
                [time_ind, ant_ind, freq_ind, pol_ind])
        soltab.setValues(weights_array, weight=True)
        soltab.addHistory(
            'FLAGSTATION (mode=bandpass, telescope={0}, maxFlaggedFraction={1}, '
            'nSigma={2})'.format(telescope, maxFlaggedFraction, nSigma))
    else:
        if solType not in ['phase', 'amplitude']:
            logging.error(
                "Soltab must be of type phase or amplitude for resid mode.")
            return 1

        # Subtract reference phases
        if refAnt is not None:
            if solType != 'phase':
                logging.error(
                    'Reference possible only for phase solution tables. Ignoring referencing.'
                )
            else:
                if refAnt == 'nearest':
                    for i, antToRef in enumerate(soltab.getAxisValues('ant')):
                        # get the closest antenna
                        antDists = soltab.getSolset().getAntDist(
                            antToRef)  # this is a dict
                        for badAnt in soltab._getFullyFlaggedAnts():
                            del antDists[badAnt]  # remove bad ants
                        reference = list(antDists.keys(
                        ))[list(antDists.values()).index(
                            sorted(antDists.values())[1]
                        )]  # get the second closest antenna (the first is itself)
                        refInd = soltab.getAxisValues(
                            'ant',
                            ignoreSelection=True).tolist().index(reference)
                        if 'dir' in axis_names:
                            vals_arrayref = vals_arraytmp[:,
                                                          refInd, :, :, :].copy(
                                                          )
                        else:
                            vals_arrayref = vals_arraytmp[:,
                                                          refInd, :, :].copy()
                        if 'dir' in axis_names:
                            vals_arraytmp[:, i, :, :, :] -= vals_arrayref
                        else:
                            vals_arraytmp[:, i, :, :] -= vals_arrayref
                else:
                    ants = soltab.getAxisValues('ant')
                    if refAnt not in ants:
                        logging.warning('Reference antenna ' + refAnt +
                                        ' not found. Using: ' + ants[0])
                        refAnt = ants[0]
                    refInd = ants.tolist().index(refAnt)
                    if 'dir' in axis_names:
                        vals_arrayref = vals_arraytmp[:,
                                                      refInd, :, :, :].copy()
                    else:
                        vals_arrayref = vals_arraytmp[:, refInd, :, :].copy()
                    for i in range(len(soltab.ant)):
                        if 'dir' in axis_names:
                            vals_arraytmp[:, i, :, :, :] -= vals_arrayref
                        else:
                            vals_arraytmp[:, i, :, :] -= vals_arrayref

        # Fill the queue
        if 'dir' in axis_names:
            for d, dirname in enumerate(soltab.dir):
                mpm = multiprocManager(ncpu, _flag_resid)
                for s in range(len(soltab.ant)):
                    mpm.put([
                        vals_arraytmp[:, s, :, :, d],
                        weights_arraytmp[:, s, :, :, d], solType, nSigma,
                        maxFlaggedFraction, maxStddev, soltab.ant[:], s
                    ])
                mpm.wait()
                for (s, w) in mpm.get():
                    weights_arraytmp[:, s, :, :, d] = w
        else:
            mpm = multiprocManager(ncpu, _flag_resid)
            for s in range(len(soltab.ant)):
                mpm.put([
                    vals_arraytmp[:, s, :, :], weights_arraytmp[:, s, :, :],
                    solType, nSigma, maxFlaggedFraction, maxStddev,
                    soltab.ant[:], s
                ])
            mpm.wait()
            for (s, w) in mpm.get():
                weights_arraytmp[:, s, :, :] = w

        # Make sure that fully flagged stations have all pols flagged
        for s in range(len(soltab.ant)):
            for p in range(len(soltab.pol)):
                if np.all(weights_arraytmp[:, s, :, p] == 0.0):
                    weights_arraytmp[:, s, :, :] = 0.0
                    break

        # Write new weights
        if 'dir' in axis_names:
            weights_array = weights_arraytmp.transpose(
                [time_ind, ant_ind, freq_ind, pol_ind, dir_ind])
        else:
            weights_array = weights_arraytmp.transpose(
                [time_ind, ant_ind, freq_ind, pol_ind])
        soltab.setValues(weights_array, weight=True)
        soltab.addHistory('FLAGSTATION (mode=resid, maxFlaggedFraction={0}, '
                          'nSigma={1})'.format(maxFlaggedFraction, nSigma))

    if soltabExport is not None:
        # Transfer station flags to soltabExport
        solset = soltab.getSolset()
        soltabexp = solset.getSoltab(soltabExport)
        axis_namesexp = soltabexp.getAxesNames()

        for stat in soltabexp.ant:
            if stat in soltab.ant:
                s = soltab.ant[:].tolist().index(stat)
                if 'pol' in axis_namesexp:
                    for pol in soltabexp.pol:
                        if pol in soltab.pol:
                            soltabexp.setSelection(ant=stat, pol=pol)
                            p = soltab.pol[:].tolist().index(pol)
                            if np.all(weights_arraytmp[:, s, :, p] == 0):
                                soltabexp.setValues(np.zeros(
                                    soltabexp.weight.shape),
                                                    weight=True)
                else:
                    soltabexp.setSelection(ant=stat)
                    if np.all(weights_arraytmp[:, s, :, :] == 0):
                        soltabexp.setValues(np.zeros(soltabexp.weight.shape),
                                            weight=True)
        soltabexp.addHistory('WEIGHT imported by FLAGSTATION from ' +
                             soltab.name + '.')

    return 0
Ejemplo n.º 15
0
def _flag_resid(vals, weights, soltype, nSigma, maxFlaggedFraction, maxStddev,
                ants, s, outQueue):
    """
    Flags bad residuals relative to mean by setting the corresponding weights to 0.0

    Parameters
    ----------
    vals : array
        Array of values as [time, ant, freq, pol]

    weights : array
        Array of weights as [time, ant, freq, pol]

    soltype : str
        Type of solutions: phase or amplitude

    nSigma : float
        Number of sigma for flagging. vals outside of nSigma*stddev are flagged

    maxFlaggedFraction : float
        Maximum allowable fraction of flagged frequencies. Stations with higher fractions
        will be completely flagged

    maxStddev : float
        Maximum allowable standard deviation

    ants : list
        List of station names

    s : int
        Station index

    Returns
    -------
    indx, weights : int, array
        Station index, modified weights array
    """
    # Skip fully flagged stations
    if np.all(weights == 0.0):
        outQueue.put([s, weights])
        return

    # Iterate over polarizations
    npols = vals.shape[2]  # number of polarizations
    for pol in range(npols):
        # Check flags
        weights_orig = weights[:, :, pol]
        if soltype == 'phase':
            bad_sols = np.where(np.isnan(vals[:, :, pol]))
        else:
            bad_sols = np.where(
                np.logical_or(np.isnan(vals[:, :, pol]),
                              vals[:, :, pol] <= 0.0))
        weights_orig[bad_sols] = 0.0
        if np.all(weights_orig == 0.0):
            # Skip fully flagged polarizations
            continue
        flagged = np.where(weights_orig == 0.0)
        unflagged = np.where(weights_orig != 0.0)

        if soltype == 'amplitude':
            # Take the log
            vals[:, :, pol] = np.log10(vals[:, :, pol])

        # Remove mean (to avoid wraps near +/- pi) and set flagged points to 0
        if soltype == 'phase':
            mean = np.angle(
                np.nansum(weights_orig.flatten() *
                          np.exp(1j * vals[:, :, pol].flatten())) /
                (vals[:, :, pol].flatten().size * sum(weights_orig.flatten())))
        else:
            mean = np.nansum(
                weights_orig.flatten() *
                vals[:, :, pol].flatten()) / (vals[:, :, pol].flatten().size *
                                              sum(weights_orig.flatten()))
        vals_flagged = vals[:, :, pol]
        if soltype == 'phase':
            # Remove the mean to avoid wrapping issues near +/- pi
            vals_flagged = normalize_phase(vals_flagged - mean)
        vals_flagged[flagged] = 0.0

        # Iteratively fit and flag
        nsols_unflagged = len(vals_flagged[unflagged])
        maxiter = 5
        niter = 0
        nflag = 0
        nflag_prev = -1
        weights_copy = weights_orig.copy()
        while nflag != nflag_prev and niter < maxiter:
            stdev_all = np.sqrt(
                np.average(vals_flagged**2, weights=weights_copy))
            stdev = min(maxStddev, stdev_all)
            bad = np.where(np.abs(vals_flagged) > nSigma * stdev)
            nflag = len(bad[0])
            if nflag == 0 or nflag == nsols_unflagged:
                break
            if niter > 0:
                nflag_prev = nflag
            weights_copy = weights_orig.copy()  # reset flags to original ones
            weights_copy[bad] = 0
            niter += 1

        # Check whether station is bad (high flagged fraction). If
        # so, flag all frequencies and polarizations
        if float(len(bad[0])) / float(nsols_unflagged) > maxFlaggedFraction:
            # Station has high fraction of initially unflagged solutions that are now flagged
            logging.info('Flagged {0} (pol {1}) due to high flagged fraction '
                         '({2:.2f})'.format(
                             ants[s], pol,
                             float(len(bad[0])) / float(nsols_unflagged)))
            weights[:, :, pol] = 0.0
        else:
            # Station is OK, flag bad points only
            nflagged_orig = len(np.where(weights_orig == 0.0)[0])
            nflagged_new = len(np.where(weights_copy == 0.0)[0])
            weights[:, :, pol] = weights_copy
            prcnt = float(nflagged_new - nflagged_orig) / float(
                np.product(weights_orig.shape)) * 100.0
            logging.info(
                'Flagged {0:.1f}% of solutions for {1} (pol {2})'.format(
                    prcnt, ants[s], pol))

    outQueue.put([s, weights])
Ejemplo n.º 16
0
def _flag_bandpass(freqs, amps, weights, telescope, nSigma, ampRange,
                   maxFlaggedFraction, maxStddev, plot, ants, s, outQueue):
    """
    Flags bad amplitude solutions relative to median bandpass (in log space) by setting
    the corresponding weights to 0.0

    Note: A median over the time axis is done before flagging, so the flags are not time-
    dependent

    Parameters
    ----------
    freqs : array
        Array of frequencies

    amps : array
        Array of amplitudes as [time, ant, freq, pol]

    weights : array
        Array of weights as [time, ant, freq, pol]

    telescope : str, optional
        Specifies the telescope for the bandpass model

    nSigma : float
        Number of sigma for flagging. Amplitudes outside of nSigma*stddev are flagged

    maxFlaggedFraction : float
        Maximum allowable fraction of flagged frequencies. Stations with higher fractions
        will be completely flagged

    maxStddev : float
        Maximum allowable standard deviation

    plot : bool
        If True, the bandpass with flags and best-fit line is plotted for each station

    ants : list
        List of station names

    s : int
        Station index

    ampRange : array
        2-element array of the median amplitude level to be acceptable, ampRange[0]: lower limit, ampRange[1]: upper limit

    Returns
    -------
    indx, weights : int, array
        Station index, modified weights array
    """
    def _B(x, k, i, t, extrap, invert):
        if k == 0:
            if extrap:
                if invert:
                    return -1.0
                else:
                    return 1.0
            else:
                return 1.0 if t[i] <= x < t[i + 1] else 0.0
        if t[i + k] == t[i]:
            c1 = 0.0
        else:
            c1 = (x - t[i]) / (t[i + k] - t[i]) * _B(x, k - 1, i, t, extrap,
                                                     invert)
        if t[i + k + 1] == t[i + 1]:
            c2 = 0.0
        else:
            c2 = (t[i + k + 1] - x) / (t[i + k + 1] - t[i + 1]) * _B(
                x, k - 1, i + 1, t, extrap, invert)
        return c1 + c2

    def _bspline(x, t, c, k):
        n = len(t) - k - 1
        assert (n >= k + 1) and (len(c) >= n)
        invert = False
        extrap = [False] * n
        if x >= t[n]:
            extrap[-1] = True
        elif x < t[k]:
            extrap[0] = True
            invert = False
        return sum(c[i] * _B(x, k, i, t, e, invert)
                   for i, e in zip(list(range(n)), extrap))

    def _bandpass_LBA(freq, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12,
                      c13):
        """
        Defines the functional form of the LBA bandpass in terms of splines of degree 3

        The spline fit was done using LSQUnivariateSpline() on the median bandpass between
        30 MHz and 78 MHz. The knots were set by hand to acheive a good fit with a
        minimum number of parameters.

        Parameters
        ----------
        freq : array
            Array of frequencies

        c1-c13 : float
            Spline coefficients

        Returns
        -------
        bandpass : list
            List of bandpass values as function of frequency
        """
        knots = np.array([
            30003357.0, 30003357.0, 30003357.0, 30003357.0, 40000000.0,
            50000000.0, 55000000.0, 56000000.0, 60000000.0, 62000000.0,
            63000000.0, 64000000.0, 70000000.0, 77610779.0, 77610779.0,
            77610779.0, 77610779.0
        ])
        coeffs = np.array(
            [c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13])
        return [_bspline(f, knots, coeffs, 3) for f in freq]

    def _bandpass_HBA_low(freq, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10):
        """
        Defines the functional form of the HBA-low bandpass in terms of splines of degree
        3

        The spline fit was done using LSQUnivariateSpline() on the median bandpass between
        120 MHz and 188 MHz. The knots were set by hand to acheive a good fit with a
        minimum number of parameters.

        Parameters
        ----------
        freq : array
            Array of frequencies

        c1-c10 : float
            Spline coefficients

        Returns
        -------
        bandpass : list
            List of bandpass values as function of frequency
        """
        knots = np.array([
            1.15e+08, 1.15e+08, 1.15e+08, 1.15e+08, 1.30e+08, 1.38e+08,
            1.48e+08, 1.60e+08, 1.68e+08, 1.78e+08, 1.90e+08, 1.90e+08,
            1.9e+08, 1.9e+08
        ])
        coeffs = np.array([c1, c2, c3, c4, c5, c6, c7, c8, c9, c10])
        return [_bspline(f, knots, coeffs, 3) for f in freq]

    def _fit_bandpass(freq, logamp, sigma, band, do_fit=True):
        """
        Fits amplitudes with one of the bandpass functions

        The initial coefficients were determined from a LSQUnivariateSpline() fit on the
        median bandpass of the appropriate band. The allowable fitting ranges were set by
        hand through testing on a number of observations (to allow the bandpass function
        to adjust for the differences between stations but not to fit to RFI, etc.).

        Parameters
        ----------
        freq : array
            Array of frequencies

        amps : array
            Array of log10(amplitudes)

        sigma : array
            Array of sigma (1/weights**2)

        band : str
            Band name ('hba_low', etc.)

        do_fit : bool, optional
            If True, the fitting is done. If False, the unmodified model bandpass is
            returned

        Returns
        -------
        fit_parms, bandpass : list, list
            List of best-fit parameters, List of bandpass values as function of frequency
        """
        from scipy.optimize import curve_fit

        if band.lower() == 'hba_low':
            bandpass_function = _bandpass_HBA_low
            init_coeffs = np.array([
                -0.01460369, 0.05062699, 0.02827004, 0.03738518, -0.05729109,
                0.02303295, -0.03550487, -0.0803113, -0.2394929, -0.358301
            ])
            bounds_deltas_lower = [
                0.06, 0.05, 0.04, 0.04, 0.04, 0.04, 0.1, 0.1, 0.2, 0.5
            ]
            bounds_deltas_upper = [
                0.06, 0.1, 0.1, 0.1, 0.04, 0.04, 0.04, 0.04, 0.05, 0.06
            ]
        elif band.lower() == 'lba':
            bandpass_function = _bandpass_LBA
            init_coeffs = np.array([
                -0.22654016, -0.1950495, -0.07763014, 0.10002095, 0.32797671,
                0.46900048, 0.47155583, 0.31945897, 0.29072278, 0.08064795,
                -0.15761538, -0.36020451, -0.51163338
            ])
            bounds_deltas_lower = [
                0.25, 0.2, 0.05, 0.05, 0.05, 0.1, 0.1, 0.16, 0.2, 0.15, 0.15,
                0.25, 0.3
            ]
            bounds_deltas_upper = [
                0.4, 0.3, 0.15, 0.05, 0.05, 0.05, 0.08, 0.05, 0.08, 0.15, 0.15,
                0.25, 0.35
            ]
        else:
            logging.error('The "{}" band is not supported'.format(band))
            return None, None

        if do_fit:
            lower = [c - b for c, b in zip(init_coeffs, bounds_deltas_lower)]
            upper = [c + b for c, b in zip(init_coeffs, bounds_deltas_upper)]
            param_bounds = (lower, upper)
            try:
                popt, pcov = curve_fit(bandpass_function,
                                       freq,
                                       logamp,
                                       sigma=sigma,
                                       bounds=param_bounds,
                                       method='dogbox',
                                       ftol=1e-3,
                                       xtol=1e-3,
                                       gtol=1e-3)
                return popt, bandpass_function(freq, *tuple(popt))
            except RuntimeError:
                logging.error('Fitting failed.')
                return None, bandpass_function(freq, *tuple(init_coeffs))
        else:
            return None, bandpass_function(freq, *tuple(init_coeffs))

    # Check that telescope and band is supported. Skip flagging if not
    if telescope.lower() == 'lofar':
        # Determine which band we're in
        if np.median(freqs) < 180e6 and np.median(freqs) > 110e6:
            band = 'hba_low'
        elif np.median(freqs) < 90e6:
            band = 'lba'
        else:
            logging.warning(
                'The median frequency of {} Hz is outside of the currently supported LOFAR bands '
                '(LBA and HBA-low). Flagging will be skipped'.format(
                    np.median(freqs)))
            outQueue.put([s, weights])
            return
    else:
        logging.warning(
            "Only telescope = 'lofar' is currently supported for bandpass mode. "
            "Flagging will be skipped")
        outQueue.put([s, weights])
        return

    # Skip fully flagged stations
    if np.all(weights == 0.0):
        outQueue.put([s, weights])
        return

    # Build arrays for fitting
    flagged = np.where(np.logical_or(weights == 0.0, np.isnan(amps)))
    amps_flagged = amps.copy()
    amps_flagged[flagged] = np.nan
    sigma = weights.copy()
    sigma[flagged] = 1.0
    sigma = np.sqrt(1.0 / sigma)
    sigma[flagged] = 1e8

    # Set range of allowed values for the median
    if ampRange is None or ampRange == [0.0, 0.0]:
        # Use sensible values depending on correlator
        if np.nanmedian(amps_flagged) > 1.0:
            # new correlator
            ampRange = [50.0, 325.0]
        else:
            # old correlator
            ampRange = [0.0004, 0.0018]
    median_min = ampRange[0]
    median_max = ampRange[-1]

    # Iterate over polarizations
    npols = amps.shape[2]
    for pol in range(npols):
        # Skip fully flagged polarizations
        if np.all(weights[:, :, pol] == 0.0):
            continue

        # Take median over time and divide out the median offset
        with np.warnings.catch_warnings():
            # Filter NaN warnings -- we deal with NaNs below
            np.warnings.filterwarnings('ignore',
                                       r'All-NaN (slice|axis) encountered')
            amps_div = np.nanmedian(amps_flagged[:, :, pol], axis=0)
            median_val = np.nanmedian(amps_div)
        amps_div /= median_val
        sigma_div = np.median(sigma[:, :, pol], axis=0)
        sigma_orig = sigma_div.copy()
        unflagged = np.where(~np.isnan(amps_div))
        nsols_unflagged = len(unflagged[0])
        median_flagged = np.where(np.isnan(amps_div))
        amps_div[median_flagged] = 1.0
        sigma_div[median_flagged] = 1e8
        median_flagged = np.where(amps_div <= 0.0)
        amps_div[median_flagged] = 1.0
        sigma_div[median_flagged] = 1e8

        # Before doing the fitting, renormalize and flag any solutions that deviate from
        # the model bandpass by a large factor to avoid biasing the first fit
        _, bp_sp = _fit_bandpass(freqs,
                                 np.log10(amps_div),
                                 sigma_div,
                                 band,
                                 do_fit=False)
        normval = np.median(np.log10(amps_div) -
                            bp_sp)  # value to normalize model to data
        amps_div /= 10**normval
        bad = np.where(np.abs(np.array(bp_sp) - np.log10(amps_div)) > 0.2)
        sigma_div[bad] = 1e8
        if np.all(sigma_div > 1e7):
            logging.info('Flagged {0} (pol {1}) due to poor match to '
                         'baseline bandpass model'.format(ants[s], pol))
            weights[:, :, pol] = 0.0
            outQueue.put([s, weights])
            return

        # Iteratively fit and flag
        maxiter = 5
        niter = 0
        nflag = 0
        nflag_prev = -1
        while nflag != nflag_prev and niter < maxiter:
            p, bp_sp = _fit_bandpass(freqs, np.log10(amps_div), sigma_div,
                                     band)
            stdev_all = np.sqrt(
                np.average((bp_sp - np.log10(amps_div))**2,
                           weights=(1 / sigma_div)**2))
            stdev = min(maxStddev, stdev_all)
            bad = np.where(np.abs(bp_sp - np.log10(amps_div)) > nSigma * stdev)
            nflag = len(bad[0])
            if nflag == 0 or nflag == nsols_unflagged:
                break
            if niter > 0:
                nflag_prev = nflag
            sigma_div = sigma_orig.copy()  # reset flags to original ones
            sigma_div[bad] = 1e8
            niter += 1

        if plot:
            import matplotlib.pyplot as plt
            plt.plot(freqs, bp_sp, 'g-', lw=3)
            plt.plot(freqs, np.log10(amps_div), 'o', c='g')
            plt.plot(freqs[bad], np.log10(amps_div)[bad], 'o', c='r')
            plt.show()

        # Check whether entire station is bad (high stdev or high flagged fraction). If
        # so, flag all frequencies and polarizations
        if stdev_all > nSigma * maxStddev:
            # Station has high stddev relative to median bandpass
            logging.info('Flagged {0} (pol {1}) due to high stddev '
                         '({2})'.format(ants[s], pol, stdev_all))
            weights[:, :, pol] = 0.0
        elif float(len(bad[0])) / float(nsols_unflagged) > maxFlaggedFraction:
            # Station has high fraction of initially unflagged solutions that are now flagged
            logging.info('Flagged {0} (pol {1}) due to high flagged fraction '
                         '({2:.2f})'.format(
                             ants[s], pol,
                             float(len(bad[0])) / float(nsols_unflagged)))
            weights[:, :, pol] = 0.0
        else:
            flagged = np.where(sigma_div > 1e3)
            nflagged_orig = len(np.where(weights[:, :, pol] == 0.0)[0])
            weights[:, flagged[0], pol] = 0.0
            nflagged_new = len(np.where(weights[:, :, pol] == 0.0)[0])
            median_val = np.nanmedian(amps[np.where(weights[:, :, pol] > 0.0)])
            if median_val < median_min or median_val > median_max:
                # Station has extreme median value
                logging.info(
                    'Flagged {0} (pol {1}) due to extreme median value '
                    '({2})'.format(ants[s], pol, median_val))
                weights[:, :, pol] = 0.0
            else:
                # Station is OK, flag bad points only
                prcnt = float(nflagged_new - nflagged_orig) / float(
                    np.product(weights.shape[:-1])) * 100.0
                logging.info(
                    'Flagged {0:.1f}% of solutions for {1} (pol {2})'.format(
                        prcnt, ants[s], pol))

    outQueue.put([s, weights])
Ejemplo n.º 17
0
def run( soltab, soltabOut='phasediff', maxResidual=1., fitOffset=False, average=False, replace=False, minFreq=0, refAnt='' ):
    """
    Estimate polarization misalignment as delay.

    Parameters
    ----------
    soltabOut : str, optional
        output table name (same solset), by deault "phasediff".

    maxResidual : float, optional
        Maximum acceptable rms of the residuals in radians before flagging, by default 1. If 0: No check.

    fitOffset : bool, optional
        Assume that together with a delay each station has also a differential phase offset (important for old LBA observations). By default False.    

    average : bool, optional
        Mean-average in time the resulting delays/offset. By default False.    
    
    replace : bool, optional
        replace using smoothed value instead of flag bad data? Smooth must be active. By default, False.

    minFreq : float, optional
        minimum frequency [Hz] to use in estimating the PA. By default, 0 (all freqs).

    refAnt : str, optional
        Reference antenna, by default the first.
    """
    import numpy as np
    import scipy.optimize
    from scipy import stats
    from scipy.ndimage import generic_filter

    logging.info("Finding polarization align for soltab: "+soltab.name)

    delaycomplex = lambda d, freq, y: abs(np.cos(d[0]*freq)  - np.cos(y)) + abs(np.sin(d[0]*freq)  - np.sin(y))
    #delaycomplex = lambda d, freq, y: abs(d[0]*freq  - y)

    solType = soltab.getType()
    if solType != 'phase':
        logging.warning("Soltab type of "+soltab.name+" is of type "+solType+", should be phase. Ignoring.")
        return 1
    
    if refAnt != '' and refAnt != 'closest' and not refAnt in soltab.getAxisValues('ant', ignoreSelection = True):
        logging.warning('Reference antenna '+refAnt+' not found. Using: '+soltab.getAxisValues('ant')[1])
        refAnt = soltab.getAxisValues('ant')[1]
    if refAnt == '': refAnt = soltab.getAxisValues('ant')[1]

    # times and ants needs to be complete or selection is much slower
    times = soltab.getAxisValues('time')

    # create new table
    solset = soltab.getSolset()
    soltabout = solset.makeSoltab(soltype = soltab.getType(), soltabName = soltabOut, axesNames=soltab.getAxesNames(),
                      axesVals=[soltab.getAxisValues(axisName) for axisName in soltab.getAxesNames()],
                      vals=soltab.getValues(retAxesVals = False), weights=soltab.getValues(weight = True, retAxesVals = False))
    soltabout.addHistory('Created by POLALIGN operation from %s.' % soltab.name)

    if 'XX' in soltab.getAxisValues('pol'): pol = 'XX'
    elif 'RR' in soltab.getAxisValues('pol'): pol = 'RR'
    else:
        logging.error('Cannot reference to known polarisation.')
        return 1

    for vals, weights, coord, selection in soltab.getValuesIter(returnAxes=['freq','pol','time'], weight=True, refAnt=refAnt):

        # reorder axes
        vals = reorderAxes( vals, soltab.getAxesNames(), ['pol','freq','time'] )
        weights = reorderAxes( weights, soltab.getAxesNames(), ['pol','freq','time'] )

        if 'RR' in coord['pol'] and 'LL' in coord['pol']:
            coord1 = np.where(coord['pol'] == 'RR')[0][0]
            coord2 = np.where(coord['pol'] == 'LL')[0][0]
        elif 'XX' in coord['pol'] and 'YY' in coord['pol']:
            coord1 = np.where(coord['pol'] == 'XX')[0][0]
            coord2 = np.where(coord['pol'] == 'YY')[0][0]

        if (weights == 0.).all() == True:
            logging.warning('Skipping flagged antenna: '+coord['ant'])
            weights[:] = 0
        else:

            fit_delays=[]; fit_offset=[]; fit_weights=[]
            for t, time in enumerate(times):

                # apply flags
                idx       = ( (weights[coord1,:,t] != 0.) & (weights[coord2,:,t] != 0.) & (coord['freq'] > minFreq) )
                freq      = np.copy(coord['freq'])[idx]
                phase1    = vals[coord1,:,t][idx]
                phase2    = vals[coord2,:,t][idx]

                if len(freq) < 30:
                    fit_weights.append(0.)
                    fit_delays.append(0.)
                    fit_offset.append(0.)
                    logging.debug('Not enough unflagged point for the timeslot '+str(t))
                    continue
    
                # if more than 1/2 of chans are flagged
                if (len(idx) - len(freq))/float(len(idx)) > 1/2.:
                    logging.debug('High number of filtered out data points for the timeslot %i: %i/%i' % (t, len(idx) - len(freq), len(idx)) )
    
                phase_diff = phase1 - phase2
                phase_diff = np.mod(phase_diff + np.pi, 2.*np.pi) - np.pi
                phase_diff = np.unwrap(phase_diff)

                A = np.vstack([freq, np.ones(len(freq))]).T
                fitresultdelay = np.linalg.lstsq(A, phase_diff.T)[0]
                # get the closest n*(2pi) to the intercept and refit with only 1 parameter
                if not fitOffset:
                    numjumps = np.around(fitresultdelay[1]/(2*np.pi))
                    A = np.reshape(freq, (-1,1)) # no b
                    phase_diff -= numjumps * 2 * np.pi
                    fitresultdelay = np.linalg.lstsq(A, phase_diff.T)[0]
                    fitresultdelay = [fitresultdelay[0],0.] # set offset to 0 to keep the rest of the script equal

                # fractional residual
                residual = np.mean(np.abs( fitresultdelay[0]*freq + fitresultdelay[1] - phase_diff ))

                fit_delays.append(fitresultdelay[0])
                fit_offset.append(fitresultdelay[1])
                if maxResidual == 0 or residual < maxResidual:
                    fit_weights.append(1.)
                else:       
                    # high residual, flag
                    logging.debug('Bad solution for ant: '+coord['ant']+' (time: '+str(t)+', residual: '+str(residual)+') -> ignoring.')
                    fit_weights.append(0.)

                # Debug plot
                doplot = False
                #if doplot and t%100==0 and (coord['ant'] == 'RS310LBA' or coord['ant'] == 'CS301LBA'):
                if doplot and t%10==0 and (coord['ant'] == 'W04'):
                    if not 'matplotlib' in sys.modules:
                        import matplotlib as mpl
                        mpl.rc('figure.subplot',left=0.05, bottom=0.05, right=0.95, top=0.95,wspace=0.22, hspace=0.22 )
                        mpl.use("Agg")
                    import matplotlib.pyplot as plt

                    fig = plt.figure()
                    fig.subplots_adjust(wspace=0)
                    ax = fig.add_subplot(111)

                    # plot rm fit
                    plotdelay = lambda delay, offset, freq: np.mod( delay*freq + offset + np.pi, 2.*np.pi) - np.pi
                    ax.plot(freq, fitresultdelay[0]*freq + fitresultdelay[1], "-", color='purple', label=r'delay:%f$\nu$ (ns) + %f ' % (fitresultdelay[0]*1e9,fitresultdelay[1]) )

                    ax.plot(freq, np.mod(phase1 + np.pi, 2.*np.pi) - np.pi, 'ob' )
                    ax.plot(freq, np.mod(phase2 + np.pi, 2.*np.pi) - np.pi, 'og' )
                    #ax.plot(freq, np.mod(phase_diff + np.pi, 2.*np.pi) - np.pi, '.', color='purple' )                           
                    ax.plot(freq, phase_diff, '.', color='purple' )                           
 
                    residual = np.mod(plotdelay(fitresultdelay[0], fitresultdelay[1], freq)-phase_diff + np.pi,2.*np.pi)-np.pi
                    ax.plot(freq, residual, '.', color='yellow')
    
                    ax.set_xlabel('freq')
                    ax.set_ylabel('phase')
                    #ax.set_ylim(ymin=-np.pi, ymax=np.pi)

                    logging.warning('Save pic: '+str(t)+'_'+coord['ant']+'.png')
                    fig.legend(loc='upper left')
                    plt.savefig(coord['ant']+'_'+str(t)+'.png', bbox_inches='tight')
                    del fig
            # end cycle in time

            fit_weights = np.array(fit_weights)
            fit_delays = np.array(fit_delays)
            fit_offset = np.array(fit_offset)

            # avg in time
            if average:
                fit_delays_bkp = fit_delays[ fit_weights == 0 ]
                fit_offset_bkp = fit_offset[ fit_weights == 0 ]
                np.putmask(fit_delays, fit_weights == 0, np.nan)
                np.putmask(fit_offset, fit_weights == 0, np.nan)
                fit_delays[:] = np.nanmean(fit_delays)
                # angle mean
                fit_offset[:] = np.angle( np.nansum( np.exp(1j*fit_offset) ) / np.count_nonzero(~np.isnan(fit_offset)) )

                if replace:
                    fit_weights[ fit_weights == 0 ] = 1.
                    fit_weights[ np.isnan(fit_delays) ] = 0. # all the size was flagged cannot estrapolate value
                else:
                    fit_delays[ fit_weights == 0 ] = fit_delays_bkp
                    fit_offset[ fit_weights == 0 ] = fit_offset_bkp

            logging.info('%s: average delay: %f ns (offset: %f)' % ( coord['ant'], np.mean(fit_delays)*1e9, np.mean(fit_offset)))
            for t, time in enumerate(times):
                #vals[:,:,t] = 0.
                #vals[coord1,:,t] = fit_delays[t]*np.array(coord['freq'])/2.
                vals[coord1,:,t] = 0
                phase = np.mod(fit_delays[t]*coord['freq'] + fit_offset[t] + np.pi, 2.*np.pi) - np.pi
                vals[coord2,:,t] = -1.*phase#/2.
                #weights[:,:,t] = 0.
                weights[coord1,:,t] = fit_weights[t]
                weights[coord2,:,t] = fit_weights[t]

        # reorder axes back to the original order, needed for setValues
        vals = reorderAxes( vals, ['pol','freq','time'], [ax for ax in soltab.getAxesNames() if ax in ['pol','freq','time']] )
        weights = reorderAxes( weights, ['pol','freq','time'], [ax for ax in soltab.getAxesNames() if ax in ['pol','freq','time']] )
        soltabout.setSelection(**coord)
        soltabout.setValues( vals )
        soltabout.setValues( weights, weight=True )

    return 0
Ejemplo n.º 18
0
def _run_antenna(vals, vals_e, vals_init, weights, selection, tec_jump,
                 antname, outQueue):
    import itertools, random

    extend = 1  # number of point to eaxtend each block

    class Block(object):
        """
        Implement blocks of contiguous series of good datapoints
        """
        def __init__(self,
                     jump_idx_init,
                     jump_idx_end,
                     vals,
                     vals_e,
                     tec_jump,
                     type='poly'):
            self.id = int(np.mean([jump_idx_init, jump_idx_end]))  # not used
            self.tec_jump = tec_jump  # empirically calculated expected jump size
            self.jump_idx_init = jump_idx_init  # indef of first element of the block
            self.jump_idx_end = jump_idx_end  # index of last element of the block
            self.idx = range(jump_idx_init, jump_idx_end)  # list of indexes
            self.idx_exp = range(
                jump_idx_init - extend, jump_idx_end + extend
            )  # add 2 at the edges. These are used to calculate matches with other blocks
            self.vals = vals
            self.vals_e = vals_e
            self.len = len(vals)  # lenght of the block

            # fill expected values outside the block boundaries
            self.expand(type)

        def get_vals(self, idx):
            """
            Return values and errors.
            Idx is already the local index of the block.
            """
            return self.vals_exp[idx], self.vals_e_exp[idx]

        def expand(self, type='nearest', order=1, size=4):
            """
            predict values+err outside the edges
            TODO: for now it's just nearest
            """
            if self.len == 1:
                type = 'nearest'  # if 1 point, then nearest-neighbor
            if (self.len - 1) < order:
                order = self.len - 1  # 1st order needts 2 ponts, 2nd order 3 and so on...
            if self.len < size:
                size = self.len  # avoid using more points than available

            if type == 'nearest':
                self.vals_exp = np.concatenate( \
                        ( np.full((extend), self.vals[0]) ,
                          self.vals,\
                          np.full((extend), self.vals[-1])
                        ) )
                # errors are between 0.1 and 1
                #self.vals_e_exp = np.concatenate( \
                #        ( np.linspace(2,0.5,1) ,
                #          self.vals_e,\
                #          np.linspace(0.5,2,1)
                #        ) )
                self.vals_e_exp = np.concatenate( \
                        ( [1.]*extend ,
                          self.vals_e,\
                          [1.]*extend
                        ) )

            elif type == 'poly':
                vals_init = self.vals[:size]
                vals_e_init = self.vals_e[:size]
                idx_init = self.idx[:size]
                vals_end = self.vals[-1 * size:]
                vals_e_end = self.vals_e[-1 * size:]
                idx_end = self.idx[-1 * size:]
                p_init = np.poly1d(
                    np.polyfit(idx_init, vals_init, order, w=1. / vals_e_init))
                p_end = np.poly1d(
                    np.polyfit(idx_end, vals_end, order, w=1. / vals_e_end))
                self.vals_exp = np.concatenate( \
                        ( p_init(self.idx_exp[:extend]) ,
                          self.vals,\
                          p_end(self.idx_exp[-1*extend:])
                        ) )
                # errors are between 100% and 200%
                self.vals_e_exp = np.concatenate( \
                        ( np.linspace(2,1,extend) ,
                          self.vals_e,\
                          np.linspace(1,2,extend)
                        ) )

        def jump(self, jump=0):
            """
            Return values of this block after applying a jump
            """
            self.vals_exp += self.tec_jump * jump


#    def distance(block1, block2):
#        """
#        Estimate a distance between two blocks taking advantage of blocks predicted edges
#        """
#        # find indexes of vals_exp that are common to both blocks
#        common_idx_block1 = [i for i, x in enumerate(block1.idx_exp) if x in block2.idx_exp]
#        common_idx_block2 = [i for i, x in enumerate(block2.idx_exp) if x in block1.idx_exp]
#        if len(common_idx_block1) == 0: return 0 # no overlapping
#
#        v1, v1_e = block1.get_vals(idx=common_idx_block1)
#        v2, v2_e = block2.get_vals(idx=common_idx_block2)
#
#        num = np.sum( v1_e * v2_e * (1+abs(v1 - v2))**(1/3.))
#        den = np.sum( v1_e * v2_e )
#
#        return num/den
#
#    def global_potential_old(blocks):
#        """
#        Calculate a "potential" of the time serie summing up the distances between blocks
#        TODO: rewrite this to go through each point, not each block
#        """
#        potential = 0
#        for block1, block2 in zip(blocks[:-1], blocks[1:]):
#            potential += distance(block1, block2)
#
#        return potential

    def global_potential(blocks, idxs):
        """
        Calculate a "potential" of the time serie going thorugh each point
        """
        potentials = 0
        for idx in idxs:
            vals = []
            vals_w = []
            for block in blocks:
                if idx in block.idx_exp:
                    idx_block = block.idx_exp.index(idx)
                    vals.append(block.vals_exp[idx_block])
                    vals_w.append(1. / block.vals_e_exp[idx_block])

            if len(vals) == 1: continue  # skip inside block
            potential_n = 0
            potential_d = 0
            for idx1, idx2 in itertools.combinations(range(len(vals)), 2):
                potential_n += (vals_w[idx1] * vals_w[idx2] * 1 /
                                (vals[idx1] - vals[idx2])**2.)
                potential_d += vals_w[idx1] * vals_w[idx2]

            potentials += potential_n / potential_d

        return potentials

    ###############################################################

    for i in range(1000):

        # find blocks
        vals_diff = np.diff(vals)
        vals_diff = np.concatenate(
            ([100], list(vals_diff), [100]))  # add edges
        jumps_idx = np.where(np.abs(vals_diff) > tec_jump *
                             (2 / 3.))[0]  # find jumps

        # no more jumps
        if len(jumps_idx) == 2: break

        # make block objects
        blocks = []
        for jump_idx_init, jump_idx_end in zip(jumps_idx[:-1], jumps_idx[1:]):
            blocks.append( Block(jump_idx_init, jump_idx_end, \
                    vals[jump_idx_init:jump_idx_end],  vals_e[jump_idx_init:jump_idx_end], tec_jump=tec_jump) )

        # move the small blocks first
        lens = [block.len for block in blocks]
        if len(set(lens)) > 1:
            max_len = sorted(set(lens))[1]
            best_block = blocks[random.choice(
                np.where(np.array(lens) <= max_len)
                [0])]  # random block from the two smallest length
        else:
            best_block = random.choice(blocks)  # random block

        #best_block = random.choice(blocks) # random block
        potentials = []
        jump_range = [-5, -4, -3, -2, -1, 1, 2, 3, 4, 5]
        for jump_size in jump_range:
            best_block.jump(jump_size)
            idx_touse = range(best_block.idx_exp[0], best_block.idx_exp[extend+1]) + \
                        range(best_block.idx_exp[-1*extend-1], best_block.idx_exp[-1]+1) # check only $extend idx around this block
            #print('block:',best_block.idx_exp,idx_touse)
            #print ('Best block (',best_block.id,') idx:',best_block.idx_exp)
            potentials.append(global_potential(blocks, set(idx_touse)))
            best_block.jump(-1 * jump_size)  # return to normality

        #print "potentials:", potentials

        # find best jump
        idx = potentials.index(max(potentials))
        best_jump = jump_range[idx]
        logging.debug(
            '(Ant %s - Cycle %i - #jumps: %i) - Best jump (%i) on block: %i (len %i)'
            % (antname, i, len(jumps_idx), best_jump, best_block.id,
               best_block.len))

        # recreate vals with the updated block value
        vals[best_block.idx] = best_block.vals + tec_jump * best_jump

        plot = False
        if plot:
            print("Preapare plot")
            best_block.vals_exp += tec_jump * best_jump
            import matplotlib as mpl
            mpl.use("Agg")
            import matplotlib.pyplot as plt
            fig = plt.figure(figsize=(16, 8))
            ax = fig.add_subplot(111)
            ax.plot(vals_init, 'k.')
            for block in blocks:
                if block is best_block: continue
                #ax.plot(block.idx_exp, block.vals_exp, 'b,')
                ax.errorbar(block.idx_exp[-1:],
                            block.vals_exp[-1:],
                            block.vals_e_exp[-1:] / 30.,
                            color='blue',
                            ecolor='blue',
                            marker=',',
                            linestyle='')
                ax.errorbar(block.idx_exp[:1],
                            block.vals_exp[:1],
                            block.vals_e_exp[:1] / 30.,
                            color='blue',
                            ecolor='blue',
                            marker=',',
                            linestyle='')
            #ax.plot(coord['time'], vals, 'r.')
            ax.errorbar(range(len(vals)),
                        vals,
                        vals_e / 30.,
                        color='green',
                        ecolor='red',
                        marker='.',
                        linestyle='')
            ax.set_xlim(0, len(vals))
            ax.set_ylim(-0.5, 0.5)
            fig.savefig('jump_%s_%03i.png' % (antname, i), bbox_inches='tight')
            fig.clf()

    # check that this is the closest value to the global minimum
    # (this assumes that the majority of the points are correct)
    zeros = []
    jumps = []
    for jump in range(-100, 101):
        #print('TEST jump %i' % jump)
        #print((vals_init - (vals + tec_jump*jump)) )[:10]
        #print( np.where( np.abs(vals_init - (vals + tec_jump*jump)) < 1e-5 )[0] )
        zeros.append(
            len(
                np.where(np.abs(vals_init -
                                (vals + tec_jump * jump)) < 1e-5)[0]))
        jumps.append(jump)
    idx = zeros.index(max(zeros))
    logging.info('%s: Rescaling all values by %i jumps.' %
                 (antname, jumps[idx]))
    vals += tec_jump * jumps[idx]

    outQueue.put([vals, weights, selection])
Ejemplo n.º 19
0
def run(soltab, soltabOutTEC='tec000', soltabOutBP='phase000', refAnt=''):
    """
    Isolate BP from TEC in a calibrator (uGMRT data)

    Parameters
    ----------
    soltabOutTEC : str, optional
        output TEC table name (same solset), by deault "tec000".

    soltabOutBP : str, optional
        output bandpass table name (same solset), by deault "phase000".

    refAnt : str, optional
        Reference antenna, by default get the closest for each antenna.

    """
    import numpy as np

    logging.info("Find BANDPASS+TEC for soltab: " + soltab.name)

    # input check
    solType = soltab.getType()
    if solType != 'phase':
        logging.warning("Soltab type of " + soltab._v_name + " is of type " +
                        solType + ", should be phase. Ignoring.")
        return 1

    ants = soltab.getAxisValues('ant')
    if refAnt != '' and refAnt != 'closest' and not refAnt in soltab.getAxisValues(
            'ant', ignoreSelection=True):
        logging.error('Reference antenna ' + refAnt + ' not found. Using: ' +
                      ants[1])
        refAnt = ants[0]
    if refAnt == '': refAnt = ants[0]

    # create new table
    solset = soltab.getSolset()
    soltaboutTEC = solset.makeSoltab(soltype = 'tec', soltabName = soltabOut, axesNames=['ant','time'], \
                      axesVals=[soltab.getAxisValues(axisName) for axisName in ['ant','time']], \
                      vals=np.zeros(shape=(soltab.getAxisLen('ant'),soltab.getAxisLen('time'))), \
                      weights=np.ones(shape=(soltab.getAxisLen('ant'),soltab.getAxisLen('time'))) )
    soltaboutTEC.addHistory('Created by BANDPASSTEC operation from %s.' %
                            soltab.name)

    soltaboutBP = solset.makeSoltab(soltype = 'phase', soltabName = soltabOut, axesNames=['ant','freq','pol'], \
                      axesVals=[soltab.getAxisValues(axisName) for axisName in ['ant','freq', 'pol']], \
                      vals=np.zeros(shape=(soltab.getAxisLen('ant'),soltab.getAxisLen('freq'),soltab.getAxisLen('pol'))), \
                      weights=np.ones(shape=(soltab.getAxisLen('ant'),soltab.getAxisLen('freq'),soltab.getAxisLen('pol'))) )
    soltaboutBP.addHistory('Created by BANDPASSTEC operation from %s.' %
                           soltab.name)

    # get values
    vals, axesVals = soltab.getValues(retAxesVals=True)

    # separate bandpass/tec

    plot = False
    if plot:
        pass

    # write solutions back
    soltaboutTEC.setValues(fitd)
    soltaboutBP.setValues(fitd)
    soltaboutTEC.setValues(fitweights, weight=True)
    soltaboutBP.setValues(fitweights, weight=True)

    return 0
Ejemplo n.º 20
0
def run(soltab, soltabOut='rotationmeasure000', refAnt='', maxResidual=1.):
    """
    Faraday rotation extraction from either a rotation table or a circular phase (of which the operation get the polarisation difference).

    Parameters
    ----------
    
    soltabOut : str, optional
        output table name (same solset), by deault "rotationmeasure000".
        
    refAnt : str, optional
        Reference antenna, by default the first.

    maxResidual : float, optional
        Max average residual in radians before flagging datapoint, by default 1. If 0: no check.

    """
    import numpy as np
    import scipy.optimize

    rmwavcomplex = lambda RM, wav, y: abs(
        np.cos(2. * RM[0] * wav * wav) - np.cos(y)) + abs(
            np.sin(2. * RM[0] * wav * wav) - np.sin(y))
    c = 2.99792458e8

    logging.info("Find FR for soltab: " + soltab.name)

    # input check
    solType = soltab.getType()
    if solType == 'phase':
        returnAxes = ['pol', 'freq', 'time']
        if 'RR' in soltab.getAxisValues(
                'pol') and 'LL' in soltab.getAxisValues('pol'):
            coord_rr = np.where(soltab.getAxisValues('pol') == 'RR')[0][0]
            coord_ll = np.where(soltab.getAxisValues('pol') == 'LL')[0][0]
        elif 'XX' in soltab.getAxisValues(
                'pol') and 'YY' in soltab.getAxisValues('pol'):
            logging.warning(
                'Linear polarization detected, LoSoTo assumes XX->RR and YY->LL.'
            )
            coord_rr = np.where(soltab.getAxisValues('pol') == 'XX')[0][0]
            coord_ll = np.where(soltab.getAxisValues('pol') == 'YY')[0][0]
        else:
            logging.error(
                "Cannot proceed with Faraday estimation with polarizations: " +
                str(coord['pol']))
            return 1
    elif solType == 'rotation':
        returnAxes = ['freq', 'time']
    else:
        logging.warning("Soltab type of " + soltab._v_name + " is of type " +
                        solType + ", should be phase or rotation. Ignoring.")
        return 1

    ants = soltab.getAxisValues('ant')
    if refAnt != '' and refAnt != 'closest' and not refAnt in soltab.getAxisValues(
            'ant', ignoreSelection=True):
        logging.error('Reference antenna ' + refAnt + ' not found. Using: ' +
                      ants[1])
        refAnt = ants[0]
    if refAnt == '': refAnt = ants[0]

    # times and ants needs to be complete or selection is much slower
    times = soltab.getAxisValues('time')

    # create new table
    solset = soltab.getSolset()
    soltabout = solset.makeSoltab('rotationmeasure',
                                  soltabName=soltabOut,
                                  axesNames=['ant', 'time'],
                                  axesVals=[ants, times],
                                  vals=np.zeros((len(ants), len(times))),
                                  weights=np.ones((len(ants), len(times))))
    soltabout.addHistory('Created by FARADAY operation from %s.' % soltab.name)

    for vals, weights, coord, selection in soltab.getValuesIter(
            returnAxes=returnAxes, weight=True, reference=refAnt):

        if len(coord['freq']) < 10:
            logging.error(
                'Faraday rotation estimation needs at least 10 frequency channels, preferably distributed over a wide range.'
            )
            return 1

        # reorder axes
        vals = reorderAxes(vals, soltab.getAxesNames(), returnAxes)
        weights = reorderAxes(weights, soltab.getAxesNames(), returnAxes)
        weights[np.isnan(vals)] = 0.

        fitrm = np.zeros(len(times))
        fitweights = np.ones(len(times))  # all unflagged to start
        fitrmguess = 0.001  # good guess

        if not coord['ant'] == refAnt:
            logging.debug('Working on ant: ' + coord['ant'] + '...')

            if (weights == 0.).all() == True:
                logging.warning('Skipping flagged antenna: ' + coord['ant'])
                fitweights[:] = 0
            else:

                for t, time in enumerate(times):

                    if solType == 'phase':
                        idx = ((weights[coord_rr, :, t] != 0.) &
                               (weights[coord_ll, :, t] != 0.))
                        freq = np.copy(coord['freq'])[idx]
                        phase_rr = vals[coord_rr, :, t][idx]
                        phase_ll = vals[coord_ll, :, t][idx]
                        # RR-LL to be consistent with BBS/NDPPP
                        phase_diff = (
                            phase_rr - phase_ll
                        )  # not divide by 2 otherwise jump problem, then later fix this
                    else:  # rotation table
                        idx = ((weights[:, t] != 0.) & (weights[:, t] != 0.))
                        freq = np.copy(coord['freq'])[idx]
                        phase_diff = 2. * vals[:, t][
                            idx]  # a rotation is between -pi and +pi

                    if len(freq) < 20:
                        fitweights[t] = 0
                        logging.warning(
                            'No valid data found for Faraday fitting for antenna: '
                            + coord['ant'] + ' at timestamp ' + str(t))
                        continue

                    # if more than 1/4 of chans are flagged
                    if (len(idx) - len(freq)) / float(len(idx)) > 1 / 4.:
                        logging.debug(
                            'High number of filtered out data points for the timeslot %i: %i/%i'
                            % (t, len(idx) - len(freq), len(idx)))

                    wav = c / freq

                    fitresultrm_wav, success = scipy.optimize.leastsq(
                        rmwavcomplex, [fitrmguess], args=(wav, phase_diff))
                    # fractional residual
                    residual = np.nanmean(
                        np.abs(
                            np.mod((2. * fitresultrm_wav * wav * wav) -
                                   phase_diff + np.pi, 2. * np.pi) - np.pi))

                    #                    print "t:", t, "result:", fitresultrm_wav, "residual:", residual

                    if maxResidual == 0 or residual < maxResidual:
                        fitrmguess = fitresultrm_wav[0]
                        weight = 1
                    else:
                        # high residual, flag
                        logging.warning('Bad solution for ant: ' +
                                        coord['ant'] + ' (time: ' + str(t) +
                                        ', resdiaul: ' + str(residual) + ').')
                        weight = 0

                    fitrm[t] = fitresultrm_wav[0]
                    fitweights[t] = weight

                    # Debug plot
                    doplot = False
                    if doplot and coord['ant'] == 'RS310LBA' and t % 10 == 0:
                        print("Plotting")
                        if not 'matplotlib' in sys.modules:
                            import matplotlib as mpl
                            mpl.rc('font', size=8)
                            mpl.rc('figure.subplot',
                                   left=0.05,
                                   bottom=0.05,
                                   right=0.95,
                                   top=0.95,
                                   wspace=0.22,
                                   hspace=0.22)
                            mpl.use("Agg")
                        import matplotlib.pyplot as plt

                        fig = plt.figure()
                        fig.subplots_adjust(wspace=0)
                        ax = fig.add_subplot(111)

                        # plot rm fit
                        plotrm = lambda RM, wav: np.mod(
                            (2. * RM * wav * wav) + np.pi, 2. * np.pi
                        ) - np.pi  # notice the factor of 2
                        ax.plot(freq,
                                plotrm(fitresultrm_wav, c / freq[:]),
                                "-",
                                color='purple')

                        if solType == 'phase':
                            ax.plot(
                                freq,
                                np.mod(phase_rr + np.pi, 2. * np.pi) - np.pi,
                                'ob')
                            ax.plot(
                                freq,
                                np.mod(phase_ll + np.pi, 2. * np.pi) - np.pi,
                                'og')
                        ax.plot(freq,
                                np.mod(phase_diff + np.pi, 2. * np.pi) - np.pi,
                                '.',
                                color='purple')

                        residual = np.mod(
                            plotrm(fitresultrm_wav, c / freq[:]) - phase_diff +
                            np.pi, 2. * np.pi) - np.pi
                        ax.plot(freq, residual, '.', color='yellow')

                        ax.set_xlabel('freq')
                        ax.set_ylabel('phase')
                        ax.set_ylim(ymin=-np.pi, ymax=np.pi)

                        logging.warning('Save pic: ' + str(t) + '_' +
                                        coord['ant'] + '.png')
                        plt.savefig(str(t) + '_' + coord['ant'] + '.png',
                                    bbox_inches='tight')
                        del fig

        soltabout.setSelection(ant=coord['ant'], time=coord['time'])
        soltabout.setValues(np.expand_dims(fitrm, axis=1))
        soltabout.setValues(np.expand_dims(fitweights, axis=1), weight=True)

    return 0
Ejemplo n.º 21
0
def run(soltab, axesToClip=None, clipLevel=5., log=False, mode='median'):
    """
    Clip solutions around the median by a factor specified by the user.
    WEIGHT: flag compliant, putting weights into median is tricky

    Parameters
    ----------
    axesToClip : list of str
        axes along which to calculate the median (e.g. [time,freq]).

    clipLevel : float, optional
        factor above/below median at which to clip, by default 5.

    log : bool, optional
        clip is done in log10 space, by default False.

    mode : str, optional
        if "median" then flag at rms*clipLevel times away from the median, if it is "above", 
        then flag all values above clipLevel, if it is "below" then flag all values below clipLevel. 
        By default median.
    """

    import numpy as np

    def percentFlagged(w):
        return 100. * (weights.size - np.count_nonzero(weights)) / float(
            weights.size)

    logging.info("Clipping soltab: " + soltab.name)

    # input check
    if len(axesToClip) == 1 and mode == 'median':
        logging.error("Please specify axes to clip.")
        return 1
    elif len(axesToClip) == 0:
        axesToClip = soltab.getAxesNames()

    if mode != 'median' and mode != 'above' and mode != 'below':
        logging.error("Mode can be only: median, above or below.")
        return 1

    if mode == 'median' and soltab.getType() == 'amplitude' and not log:
        logging.warning(
            'Amplitude solution tab detected and log=False. Amplitude solution tables should be treated in log space.'
        )

    # some checks
    for i, axis in enumerate(axesToClip[:]):
        if axis not in soltab.getAxesNames():
            del axesToClip[i]
            logging.warning('Axis \"' + axis + '\" not found. Ignoring.')

    for vals, weights, coord, selection in soltab.getValuesIter(
            returnAxes=axesToClip, weight=True):

        initPercent = percentFlagged(weights)

        # skip all flagged
        if (weights == 0).all():
            continue

        if log: vals = np.log10(vals)

        if mode == 'median':
            valmedian = np.nanmedian(vals[(weights != 0)])
            rms = np.nanstd(vals[(weights != 0)])
            np.putmask(weights, np.abs(vals - valmedian) > rms * clipLevel, 0)

        elif mode == 'above':
            np.putmask(weights, vals > clipLevel, 0)

        elif mode == 'below':
            np.putmask(weights, vals < clipLevel, 0)

        # writing back the solutions
        if log: vals = 10**vals
        soltab.setValues(weights, selection, weight=True)

        #print('max', np.max(vals[(weights != 0)]))
        #print('median', np.nanmedian(vals[(weights != 0)]))
        logging.debug('Percentage of data flagged (%s): %.3f%% -> %.3f%%' \
            % (removeKeys(coord, axesToClip), initPercent, percentFlagged(weights)))

    soltab.addHistory('CLIP (over %s with %s sigma cut)' %
                      (axesToClip, clipLevel))

    soltab.flush()

    return 0
Ejemplo n.º 22
0
def run(soltab, axesInPlot, axisInTable='', axisInCol='', axisDiff='', NColFig=0, figSize=[0,0], markerSize=2, minmax=[0,0], log='', \
               plotFlag=False, doUnwrap=False, refAnt='', soltabsToAdd='', makeAntPlot=False, makeMovie=False, prefix='', ncpu=0):
    """
    This operation for LoSoTo implements basic plotting
    WEIGHT: flag-only compliant, no need for weight

    Parameters
    ----------
    axesInPlot : array of str
        1- or 2-element array which says the coordinates to plot (2 for 3D plots).

    axisInTable : str, optional
        the axis to plot on a page - e.g. ant to get all antenna's on one file. By default ''.

    axisInCol : str, optional
        The axis to plot in different colours - e.g. pol to get correlations with different colors. By default ''.

    axisDiff : str, optional
        This must be a len=2 axis and the plot will have the differential value - e.g. 'pol' to plot XX-YY. By default ''.

    NColFig : int, optional
        Number of columns in a multi-table image. By default is automatically chosen.

    figSize : array of int, optional
        Size of the image [x,y], if one of the values is 0, then it is automatically chosen. By default automatic set.

    markerSize : int, optional
        Size of the markers in the 2D plot. By default 2.

    minmax : array of float, optional
        Min max value for the independent variable (0 means automatic). By default 0.

    log : bool, optional
        Use Log='XYZ' to set which axes to put in Log. By default ''.

    plotFlag : bool, optional
        Whether to plot also flags as red points in 2D plots. By default False.

    doUnwrap : bool, optional
        Unwrap phases. By default False.

    refAnt : str, optional
        Reference antenna for phases. By default None.

    soltabsToAdd : str, optional
        Tables to "add" (e.g. 'sol000/tec000'), it works only for tec and clock to be added to phases. By default None.

    makeAntPlot : bool, optional
        Make a plot containing antenna coordinates in x,y and in color the value to plot, axesInPlot must be [ant]. By default False.

    makeMovie : bool, optional
        Make a movie summing up all the produced plots, by default False.

    prefix : str, optional
        Prefix to add before the self-generated filename, by default None.

    ncpu : int, optional
        Number of cpus, by default all available.
    """
    import os, random
    import numpy as np
    from losoto.lib_unwrap import unwrap, unwrap_2d

    logging.info("Plotting soltab: " + soltab.name)

    # input check

    # str2list
    if axisInTable == '': axisInTable = []
    else: axisInTable = [axisInTable]
    if axisInCol == '': axisInCol = []
    else: axisInCol = [axisInCol]
    if axisDiff == '': axisDiff = []
    else: axisDiff = [axisDiff]

    if len(set(axisInTable + axesInPlot + axisInCol +
               axisDiff)) != len(axisInTable + axesInPlot + axisInCol +
                                 axisDiff):
        logging.error('Axis defined multiple times.')
        return 1

    # just because we use lists, check that they are 1-d
    if len(axisInTable) > 1 or len(axisInCol) > 1 or len(axisDiff) > 1:
        logging.error(
            'Too many TableAxis/ColAxis/DiffAxis, they must be at most one each.'
        )
        return 1

    for axis in axesInPlot + axisInCol + axisDiff:
        if axis not in soltab.getAxesNames():
            logging.error('Axis \"' + axis + '\" not found.')
            return 1

    if makeMovie:
        prefix = prefix + '__tmp__'

    if os.path.dirname(prefix) != '' and not os.path.exists(
            os.path.dirname(prefix)):
        logging.debug('Creating ' + os.path.dirname(prefix) + '.')
        os.makedirs(os.path.dirname(prefix))

    if refAnt == '': refAnt = None
    elif refAnt != 'closest' and not refAnt in soltab.getAxisValues(
            'ant', ignoreSelection=True):
        logging.error('Reference antenna ' + refAnt + ' not found. Using: ' +
                      soltab.getAxisValues('ant')[1])
        refAnt = soltab.getAxisValues('ant')[1]

    minZ, maxZ = minmax

    solset = soltab.getSolset()
    soltabsToAdd = [
        solset.getSoltab(soltabName) for soltabName in soltabsToAdd
    ]

    cmesh = False
    if len(axesInPlot) == 2:
        cmesh = True
        # not color possible in 3D
        axisInCol = []
    elif len(axesInPlot) != 1:
        logging.error('Axes must be a len 1 or 2 array.')
        return 1
    # end input check

    # all axes that are not iterated by anything else
    axesInFile = soltab.getAxesNames()
    for axis in axisInTable + axesInPlot + axisInCol + axisDiff:
        axesInFile.remove(axis)

    # set subplots scheme
    if axisInTable != []:
        Nplots = soltab.getAxisLen(axisInTable[0])
    else:
        Nplots = 1

    # prepare antennas coord in makeAntPlot case
    if makeAntPlot:
        if axesInPlot != ['ant']:
            logging.error(
                'If makeAntPlot is selected the "Axes" values must be "ant"')
            return 1
        antCoords = [[], []]
        for ant in soltab.getAxisValues(
                'ant'):  # select only user-selected antenna in proper order
            antCoords[0].append(+1 * soltab.getSolset().getAnt()[ant][1])
            antCoords[1].append(-1 * soltab.getSolset().getAnt()[ant][0])

    else:
        antCoords = []

    datatype = soltab.getType()

    # start processes for multi-thread
    mpm = multiprocManager(ncpu, _plot)

    # compute dataCube size
    shape = []
    if axisInTable != []: shape.append(soltab.getAxisLen(axisInTable[0]))
    else: shape.append(1)
    if axisInCol != []: shape.append(soltab.getAxisLen(axisInCol[0]))
    else: shape.append(1)
    if cmesh:
        shape.append(soltab.getAxisLen(axesInPlot[1]))
        shape.append(soltab.getAxisLen(axesInPlot[0]))
    else:
        shape.append(soltab.getAxisLen(axesInPlot[0]))

    # will contain the data to pass to each thread to make 1 image
    dataCube = np.ma.zeros(shape=shape, fill_value=np.nan)

    # cycle on files
    if makeMovie: pngs = []  # store png filenames
    for vals, coord, selection in soltab.getValuesIter(
            returnAxes=axisDiff + axisInTable + axisInCol + axesInPlot):

        # set filename
        filename = ''
        for axis in axesInFile:
            filename += axis + str(coord[axis]) + '_'
        filename = filename[:-1]  # remove last _
        if prefix + filename == '': filename = 'plot'

        # axis vals (they are always the same, regulat arrays)
        xvals = coord[axesInPlot[0]]
        # if plotting antenna - convert to number
        if axesInPlot[0] == 'ant':
            xvals = np.arange(len(xvals))

        # if plotting time - convert in h/min/s
        xlabelunit = ''
        if axesInPlot[0] == 'time':
            if xvals[-1] - xvals[0] > 3600:
                xvals = (xvals - xvals[0]) / 3600.  # hrs
                xlabelunit = ' [hr]'
            elif xvals[-1] - xvals[0] > 60:
                xvals = (xvals - xvals[0]) / 60.  # mins
                xlabelunit = ' [min]'
            else:
                xvals = (xvals - xvals[0])  # sec
                xlabelunit = ' [s]'
        # if plotting freq convert in MHz
        elif axesInPlot[0] == 'freq':
            xvals = xvals / 1.e6  # MHz
            xlabelunit = ' [MHz]'

        if cmesh:
            # axis vals (they are always the same, regular arrays)
            yvals = coord[axesInPlot[1]]
            # same as above but for y-axis
            if axesInPlot[1] == 'ant':
                yvals = np.arange(len(yvals))

            if len(xvals) <= 1 or len(yvals) <= 1:
                logging.error(
                    '3D plot must have more then one value per axes.')
                mpm.wait()
                return 1

            ylabelunit = ''
            if axesInPlot[1] == 'time':
                if yvals[-1] - yvals[0] > 3600:
                    yvals = (yvals - yvals[0]) / 3600.  # hrs
                    ylabelunit = ' [hr]'
                elif yvals[-1] - yvals[0] > 60:
                    yvals = (yvals - yvals[0]) / 60.  # mins
                    ylabelunit = ' [min]'
                else:
                    yvals = (yvals - yvals[0])  # sec
                    ylabelunit = ' [s]'
            elif axesInPlot[1] == 'freq':  # Mhz
                yvals = yvals / 1.e6
                ylabelunit = ' [MHz]'
        else:
            yvals = None
            if datatype == 'clock':
                datatype = 'Clock'
                ylabelunit = ' (s)'
            elif datatype == 'tec':
                datatype = 'dTEC'
                ylabelunit = ' (TECU)'
            elif datatype == 'rotationmeasure':
                datatype = 'dRM'
                ylabelunit = r' (rad m$^{-2}$)'
            elif datatype == 'tec3rd':
                datatype = r'dTEC$_3$'
                ylabelunit = r' (rad m$^{-3}$)'
            else:
                ylabelunit = ''

        # cycle on tables
        soltab1Selection = soltab.selection  # save global selection and subselect only axex to iterate
        soltab.selection = selection
        titles = []

        for Ntab, (vals, coord, selection) in enumerate(
                soltab.getValuesIter(returnAxes=axisDiff + axisInCol +
                                     axesInPlot)):

            # set tile
            titles.append('')
            for axis in coord:
                if axis in axesInFile + axesInPlot + axisInCol: continue
                titles[Ntab] += axis + ':' + str(coord[axis]) + ' '
            titles[Ntab] = titles[Ntab][:-1]  # remove last ' '

            # cycle on colors
            soltab2Selection = soltab.selection
            soltab.selection = selection
            for Ncol, (vals, weight, coord, selection) in enumerate(
                    soltab.getValuesIter(returnAxes=axisDiff + axesInPlot,
                                         weight=True,
                                         reference=refAnt)):

                # differential plot
                if axisDiff != []:
                    # find ordered list of axis
                    names = [
                        axis for axis in soltab.getAxesNames()
                        if axis in axisDiff + axesInPlot
                    ]
                    if axisDiff[0] not in names:
                        logging.error("Axis to differentiate (%s) not found." %
                                      axisDiff[0])
                        mpm.wait()
                        return 1
                    if len(coord[axisDiff[0]]) != 2:
                        logging.error(
                            "Axis to differentiate (%s) has too many values, only 2 is allowed."
                            % axisDiff[0])
                        mpm.wait()
                        return 1

                    # find position of interesting axis
                    diff_idx = names.index(axisDiff[0])
                    # roll to first place
                    vals = np.rollaxis(vals, diff_idx, 0)
                    vals = vals[0] - vals[1]
                    weight = np.rollaxis(weight, diff_idx, 0)
                    weight[0][weight[1] == 0] = 0
                    weight = weight[0]
                    del coord[axisDiff[0]]

                # add tables if required (e.g. phase/tec)
                for soltabToAdd in soltabsToAdd:
                    logging.warning('soltabsToAdd not implemented. Ignoring.')
#                    newCoord = {}
#                    for axisName in coord.keys():
#                        # prepare selected on present axes
#                        if axisName in soltabToAdd.getAxesNames():
#                            if type(coord[axisName]) is np.ndarray:
#                                newCoord[axisName] = coord[axisName]
#                            else:
#                                newCoord[axisName] = [coord[axisName]] # avoid being interpreted as regexp, faster
#
#                    soltabToAdd.setSelection(**newCoord)
#                    valsAdd = np.squeeze(soltabToAdd.getValues(retAxesVals=False, weight=False, reference=refAnt))
#
#                    # add missing axes
#                    print ('shape:', vals.shape)
#                    for axisName in coord.keys():
#                        if not axisName in soltabToAdd.getAxesNames():
#                            # find axis positions
#                            axisPos = soltab.getAxesNames().index(axisName)
#                            # create a new axes for the table to add and duplicate the values
#                            valsAdd = np.expand_dims(valsAdd, axisPos)
#                            print ('shape to add:', valsAdd.shape)
#
#                    if soltabToAdd.getType() == 'clock':
#                        valsAdd = 2. * np.pi * valsAdd * coord['freq']
#                    elif soltabToAdd.getType() == 'tec':
#                        valsAdd = -8.44797245e9 * valsAdd / coord['freq']
#                    else:
#                        logging.warning('Only Clock or TEC can be added to solutions. Ignoring: '+soltabToAdd.getType()+'.')
#                        continue
#
#                    if valsAdd.shape != vals.shape:
#                        logging.error('Cannot combine the table '+soltabToAdd.getType()+' with '+soltab.getType()+'. Wrong shape.')
#                        mpm.wait()
#                        return 1
#
#                    vals += valsAdd

# normalize
                if (soltab.getType() == 'phase'
                        or soltab.getType() == 'scalarphase'):
                    vals = normalize_phase(vals)
                if (soltab.getType() == 'rotation'):
                    vals = np.mod(vals + np.pi / 2., np.pi) - np.pi / 2.

                # is user requested axis in an order that is different from h5parm, we need to transpose
                if cmesh:
                    if soltab.getAxesNames().index(
                            axesInPlot[0]) < soltab.getAxesNames().index(
                                axesInPlot[1]):
                        vals = vals.T
                        weight = weight.T

                # unwrap if required
                if (soltab.getType() == 'phase'
                        or soltab.getType() == 'scalarphase') and doUnwrap:
                    if len(axesInPlot) == 1:
                        vals = unwrap(vals)
                    else:
                        flags = np.array((weight == 0), dtype=bool)
                        if not (flags == True).all():
                            vals = unwrap_2d(vals, flags, coord[axesInPlot[0]],
                                             coord[axesInPlot[1]])

                dataCube[Ntab, Ncol] = vals
                sel1 = np.where(weight == 0.)
                sel2 = np.where(np.isnan(vals))
                if cmesh:
                    dataCube[Ntab, Ncol, sel1[0], sel1[1]] = np.ma.masked
                    dataCube[Ntab, Ncol, sel2[0], sel2[1]] = np.ma.masked
                else:
                    dataCube[Ntab, Ncol, sel1[0]] = np.ma.masked
                    dataCube[Ntab, Ncol, sel2[0]] = np.ma.masked

            soltab.selection = soltab2Selection
            ### end cycle on colors

        # if dataCube too large (> 500 MB) do not go parallel
        if np.array(dataCube).nbytes > 1024 * 1024 * 500:
            logging.debug('Big plot, parallel not possible.')
            _plot(Nplots, NColFig, figSize, markerSize, cmesh, axesInPlot,
                  axisInTable, xvals, yvals, xlabelunit, ylabelunit, datatype,
                  prefix + filename, titles, log, dataCube, minZ, maxZ,
                  plotFlag, makeMovie, antCoords, None)
        else:
            mpm.put([
                Nplots, NColFig, figSize, markerSize, cmesh, axesInPlot,
                axisInTable, xvals, yvals, xlabelunit, ylabelunit, datatype,
                prefix + filename, titles, log,
                np.ma.copy(dataCube), minZ, maxZ, plotFlag, makeMovie,
                antCoords
            ])
        if makeMovie: pngs.append(prefix + filename + '.png')

        soltab.selection = soltab1Selection
        ### end cycle on tables
    mpm.wait()

    if makeMovie:

        def long_substr(strings):
            """
            Find longest common substring
            """
            substr = ''
            if len(strings) > 1 and len(strings[0]) > 0:
                for i in range(len(strings[0])):
                    for j in range(len(strings[0]) - i + 1):
                        if j > len(substr) and all(strings[0][i:i + j] in x
                                                   for x in strings):
                            substr = strings[0][i:i + j]
            return substr

        movieName = long_substr(pngs)
        assert movieName != ''  # need a common prefix, use prefix keyword in case
        logging.info('Making movie: ' + movieName)
        # make every movie last 20 sec, min one second per slide
        fps = np.ceil(len(pngs) / 200.)
        ss="mencoder -ovc lavc -lavcopts vcodec=mpeg4:vpass=1:vbitrate=6160000:mbd=2:keyint=132:v4mv:vqmin=3:lumi_mask=0.07:dark_mask=0.2:"+\
                "mpeg_quant:scplx_mask=0.1:tcplx_mask=0.1:naq -mf type=png:fps="+str(fps)+" -nosound -o "+movieName.replace('__tmp__','')+".mpg mf://"+movieName+"*  > mencoder.log 2>&1"
        os.system(ss)
        #for png in pngs: os.system('rm '+png)

    return 0
Ejemplo n.º 23
0
def run(soltab, soltabsToSub, ratio=False):
    """
    Subtract/divide two tables or a clock/tec/tec3rd/rm from a phase.

    Parameters
    ----------
    soltabsToSub : list of str
        List of soltabs to subtract

    ratio : bool, optional
        Return the ratio instead of subtracting, by default False.
    """
    import numpy as np

    logging.info("Subtract soltab: " + soltab.name)

    solset = soltab.getSolset()
    for soltabToSub in soltabsToSub:
        soltabsub = solset.getSoltab(soltabToSub)

        if soltab.getType() != 'phase' and (
                soltabsub.getType() == 'tec' or soltabsub.getType() == 'clock'
                or soltabsub.getType() == 'rotationmeasure'
                or soltabsub.getType() == 'tec3rd'):
            logging.warning(
                soltabToSub +
                ' is of type clock/tec/rm and should be subtracted from a phase. Skipping it.'
            )
            return 1
        logging.info('Subtracting table: ' + soltabToSub)

        # a major speed up if tables are assumed with same axes, check that (should be the case in almost any case)
        for i, axisName in enumerate(soltabsub.getAxesNames()):
            # also armonise selection by copying only the axes present in the outtable and in the right order
            soltabsub.selection[i] = soltab.selection[
                soltab.getAxesNames().index(axisName)]
            assert (soltabsub.getAxisValues(axisName) == soltab.getAxisValues(
                axisName)).all()  # table not conform

        if soltab.getValues(retAxesVals=False,
                            weight=False).shape != soltabsub.getValues(
                                retAxesVals=False, weight=False).shape:
            hasMissingAxes = True
        else:
            hasMissingAxes = False

        if soltabsub.getType() == 'clock' or soltabsub.getType(
        ) == 'tec' or soltabsub.getType() == 'tec3rd' or soltabsub.getType(
        ) == 'rotationmeasure' or hasMissingAxes:

            freq = soltab.getAxisValues('freq')
            vals = soltab.getValues(retAxesVals=False, weight=False)
            weights = soltab.getValues(retAxesVals=False, weight=True)
            #print 'vals', vals.shape

            # valsSub doesn't have freq
            valsSub = soltabsub.getValues(retAxesVals=False, weight=False)
            weightsSub = soltabsub.getValues(retAxesVals=False, weight=True)
            #print 'valsSub', valsSub.shape

            # add missing axes and move it to the last position
            expand = [
                soltab.getAxisLen(ax) for ax in soltab.getAxesNames()
                if ax not in soltabsub.getAxesNames()
            ]
            #print "expand:", expand
            valsSub = np.resize(valsSub, expand + list(valsSub.shape))
            weightsSub = np.resize(weightsSub, expand + list(weightsSub.shape))
            #print 'valsSub missing axes', valsSub.shape

            # reorder axes to match soltab
            names = [
                ax for ax in soltab.getAxesNames()
                if ax not in soltabsub.getAxesNames()
            ] + soltabsub.getAxesNames()
            #print names, soltab.getAxesNames()
            valsSub = reorderAxes(valsSub, names, soltab.getAxesNames())
            weightsSub = reorderAxes(weightsSub, names, soltab.getAxesNames())
            weights[weightsSub == 0] = 0  # propagate flags
            #print 'valsSub reorder', valsSub.shape

            # put freq axis at the end
            idxFreq = soltab.getAxesNames().index('freq')
            vals = np.swapaxes(vals, idxFreq, len(vals.shape) - 1)
            valsSub = np.swapaxes(valsSub, idxFreq, len(valsSub.shape) - 1)
            #print 'vals reshaped', valsSub.shape

            # a multiplication will go along the last axis of the array
            if soltabsub.getType() == 'clock':
                vals -= 2. * np.pi * valsSub * freq

            elif soltabsub.getType() == 'tec':
                vals -= -8.44797245e9 * valsSub / freq

            elif soltabsub.getType() == 'tec3rd':
                vals -= -1.e21 * valsSub / np.power(freq, 3)

            elif soltabsub.getType() == 'rotationmeasure':
                # put pol axis at the beginning
                idxPol = soltab.getAxesNames().index('pol')
                if idxPol == len(vals.shape) - 1:
                    idxPol = idxFreq  # maybe freq swapped with pol
                vals = np.swapaxes(vals, idxPol, 0)
                valsSub = np.swapaxes(valsSub, idxPol, 0)
                #print 'vals reshaped 2', valsSub.shape

                wav = 2.99792458e8 / freq
                ph = wav * wav * valsSub
                #if coord['pol'] == 'XX' or coord['pol'] == 'RR':
                pols = soltab.getAxisValues('pol')
                assert len(pols) == 2  # full jons not supported
                if (pols[0] == 'XX' and pols[1] == 'YY') or \
                   (pols[0] == 'RR' and pols[1] == 'LL'):
                    vals[0] -= ph[0]
                    vals[1] += ph[1]
                else:
                    vals[0] += ph[0]
                    vals[1] -= ph[1]

                vals = np.swapaxes(vals, 0, idxPol)

            else:
                if ratio:
                    vals = (vals - valsSub) / valsSub
                else:
                    vals -= valsSub

            # move freq axis back
            vals = np.swapaxes(vals, len(vals.shape) - 1, idxFreq)

            soltab.setValues(vals)
            soltab.setValues(weights, weight=True)
        else:
            if ratio:
                soltab.setValues((soltab.getValues(retAxesVals=False) -
                                  soltabsub.getValues(retAxesVals=False)) /
                                 soltabsub.getValues(retAxesVals=False))
            else:
                soltab.setValues(
                    soltab.getValues(retAxesVals=False) -
                    soltabsub.getValues(retAxesVals=False))
            weight = soltab.getValues(retAxesVals=False, weight=True)
            weight[soltabsub.getValues(retAxesVals=False, weight=True) ==
                   0] = 0
            soltab.setValues(weight, weight=True)

    soltab.addHistory('RESIDUALS by subtracting tables ' +
                      ' '.join(soltabsToSub))

    return 0
Ejemplo n.º 24
0
def run(soltab, refAnt='', soltabError='', ncpu=0):
    """
    Remove jumps from TEC solutions.
    WEIGHT: uses the errors.

    Parameters
    ----------
    soltabError : str, optional
        The table name with solution errors. By default it has the same name of soltab with "error" in place of "tec".

    refAnt : str, optional
        Reference antenna for phases. By default None.

    """

    import scipy.ndimage.filters
    import numpy as np
    from scipy.optimize import minimize
    from scipy.interpolate import griddata
    import scipy.cluster.vq as vq

    def getPhaseWrapBase(freqs):
        """
        freqs: frequency grid of the data
        return the step size from a local minima (2pi phase wrap) to the others [0]: TEC, [1]: clock
        """
        freqs = np.array(freqs)
        nF = freqs.shape[0]
        A = np.zeros((nF, 2), dtype=np.float)
        A[:, 1] = freqs * 2 * np.pi * 1e-9
        A[:, 0] = -8.44797245e9 / freqs
        steps = np.dot(np.dot(np.linalg.inv(np.dot(A.T, A)), A.T),
                       2 * np.pi * np.ones((nF, ), dtype=np.float))
        return steps

    if soltab.getType() != 'tec':
        logging.error('TECJUMP works only on tec solutions.')
        return 1

    if soltabError == '': soltabError = soltab.name.replace('tec', 'error')
    solset = soltab.getSolset()
    soltab_e = solset.getSoltab(soltabError)
    try:
        seltab_e = solset.getSoltab(soltabError)
    except:
        logging.error('Cannot fine error solution table %s.' % soltabError)
        return 1
    vals_e_all = soltab_e.getValues(retAxesVals=False, weight=False)

    logging.info("Removing TEC jumps from soltab: " + soltab.name)

    ants = soltab.getAxisValues('ant')
    if refAnt != '' and refAnt != 'closest' and not refAnt in soltab.getAxisValues(
            'ant', ignoreSelection=True):
        logging.warning('Reference antenna ' + refAnt + ' not found. Using: ' +
                        ants[1])
        refAnt = ants[0]
    if refAnt == '': refAnt = None

    # Get the theoretical tec jump
    tec_jump_theory = abs(getPhaseWrapBase([42.308e6, 42.308e6 + 23828e6])[0])

    # Find the average jump on all antennas by averaging all jumps found to be withing 0.5 and 2 times the tec_jump_theory
    vals = soltab.getValues(retAxesVals=False, refAnt=refAnt)
    timeAxis = soltab.getAxesNames().index('time')
    vals = np.swapaxes(vals, 0, timeAxis)
    vals = vals[1, ...] - vals[:-1, ...]
    vals = vals[(vals > tec_jump_theory * 1) & (vals < tec_jump_theory * 1.5)]
    if len(vals) == 0:
        logging.info('TEC jump - theoretical: %.5f TECU - NO JUMP FOUND' %
                     (tec_jump_theory))
        return 0
    tec_jump = np.nanmedian(vals)

    logging.info('TEC jump - theoretical: %.5f TECU - estimated: %.5f TECU' %
                 (tec_jump_theory, tec_jump))

    mpm = multiprocManager(ncpu, _run_antenna)

    for vals, weights, coord, selection in soltab.getValuesIter(
            returnAxes='time', weight=True, refAnt=refAnt):

        # skip all flagged
        if (weights == 0).all(): continue
        # skip reference
        if (vals[(weights != 0)] == 0).all(): continue

        logging.info('Working on ant: %s' % (coord['ant']))

        # 1d linear interp to fill flagged data
        vals[np.where(weights == 0)] = np.interp(
            np.where(weights == 0)[0],
            np.where(weights != 0)[0], vals[np.where(weights != 0)])

        vals_init = np.copy(vals)  # backup for final check
        vals_e = np.squeeze(vals_e_all[selection])
        vals_e[np.where(weights == 0)] = 1.

        mpm.put([
            vals, vals_e, vals_init, weights, selection, tec_jump, coord['ant']
        ])

    mpm.wait()
    for (vals, weights, selection) in mpm.get():
        # set back to 0 the values for flagged data
        vals[weights == 0] = 0
        soltab.setValues(vals, selection)
    soltab.addHistory('TECJUMP')

    return 0
Ejemplo n.º 25
0
def run(soltab,
        tecsoltabOut='tec000',
        clocksoltabOut='clock000',
        offsetsoltabOut='phase_offset000',
        tec3rdsoltabOut='tec3rd000',
        flagBadChannels=True,
        flagCut=5.,
        chi2cut=3000.,
        combinePol=False,
        removePhaseWraps=True,
        fit3rdorder=False,
        circular=False,
        reverse=False,
        invertOffset=False,
        nproc=10):
    """
    Separate phase solutions into Clock and TEC.
    The Clock and TEC values are stored in the specified output soltab with type 'clock', 'tec', 'tec3rd'.

    Parameters
    ----------
    flagBadChannels : bool, optional
        Detect and remove bad channel before fitting, by default True.

    flagCut : float, optional


    chi2cut : float, optional


    combinePol : bool, optional
        Find a combined polarization solution, by default False.

    removePhaseWraps : bool, optional
        Detect and remove phase wraps, by default True.

    fit3rdorder : bool, optional
        Fit a 3rd order ionospheric ocmponent (usefult <40 MHz). By default False.

    circular : bool, optional
        Assume circular polarization with FR not removed. By default False.

    reverse : bool, optional
        Reverse the time axis. By default False.

    invertOffset : bool, optional
        Invert (reverse the sign of) the phase offsets. By default False. Set to True
        if you want to use them with the residuals operation.
    """
    import numpy as np
    from ._fitClockTEC import doFit

    logging.info("Clock/TEC separation on soltab: " + soltab.name)

    # some checks
    solType = soltab.getType()
    if solType != 'phase':
        logging.warning("Soltab type of " + soltab.name + " is: " + solType +
                        " should be phase. Ignoring.")
        return 1

    # Collect station properties
    solset = soltab.getSolset()
    station_dict = solset.getAnt()
    stations = soltab.getAxisValues('ant')
    station_positions = np.zeros((len(stations), 3), dtype=np.float)
    for i, station_name in enumerate(stations):
        station_positions[i, 0] = station_dict[station_name][0]
        station_positions[i, 1] = station_dict[station_name][1]
        station_positions[i, 2] = station_dict[station_name][2]

    returnAxes = ['ant', 'freq', 'pol', 'time']
    for vals, flags, coord, selection in soltab.getValuesIter(
            returnAxes=returnAxes, weight=True):

        if len(coord['ant']) < 2:
            logging.error(
                'Clock/TEC separation needs at least 2 antennas selected.')
            return 1
        if len(coord['freq']) < 10:
            logging.error(
                'Clock/TEC separation needs at least 10 frequency channels, preferably distributed over a wide range'
            )
            return 1

        freqs = coord['freq']
        stations = coord['ant']
        times = coord['time']

        # get axes index
        axes = [i for i in soltab.getAxesNames() if i in returnAxes]

        # reverse time axes
        if reverse:
            vals = np.swapaxes(
                np.swapaxes(vals, 0, axes.index('time'))[::-1], 0,
                axes.index('time'))
            flags = np.swapaxes(
                np.swapaxes(flags, 0, axes.index('time'))[::-1], 0,
                axes.index('time'))

        result=doFit(vals,flags==0,freqs,stations,station_positions,axes,\
                         flagBadChannels=flagBadChannels,flagcut=flagCut,chi2cut=chi2cut,combine_pol=combinePol,removePhaseWraps=removePhaseWraps,fit3rdorder=fit3rdorder,circular=circular,n_proc=nproc)
        if fit3rdorder:
            clock, tec, offset, tec3rd = result
            if reverse:
                clock = clock[::-1, :]
                tec = tec[::-1, :]
                tec3rd = tec3rd[::-1, :]
        else:
            clock, tec, offset = result
            if reverse:
                clock = clock[::-1, :]
                tec = tec[::-1, :]
        if invertOffset:
            offset *= -1.0

        weights = tec > -5
        tec[np.logical_not(weights)] = 0
        clock[np.logical_not(weights)] = 0
        weights = np.float16(weights)

        if combinePol or not 'pol' in soltab.getAxesNames():
            tf_st = solset.makeSoltab('tec',
                                      soltabName=tecsoltabOut,
                                      axesNames=['time', 'ant'],
                                      axesVals=[times, stations],
                                      vals=tec[:, :, 0],
                                      weights=weights[:, :, 0])
            tf_st.addHistory('CREATE (by CLOCKTECFIT operation)')
            tf_st = solset.makeSoltab('clock',
                                      soltabName=clocksoltabOut,
                                      axesNames=['time', 'ant'],
                                      axesVals=[times, stations],
                                      vals=clock[:, :, 0] * 1e-9,
                                      weights=weights[:, :, 0])
            tf_st.addHistory('CREATE (by CLOCKTECFIT operation)')
            tf_st = solset.makeSoltab('phase',
                                      soltabName=offsetsoltabOut,
                                      axesNames=['ant'],
                                      axesVals=[stations],
                                      vals=offset[:, 0],
                                      weights=np.ones_like(offset[:, 0],
                                                           dtype=np.float16))
            tf_st.addHistory('CREATE (by CLOCKTECFIT operation)')
            if fit3rdorder:
                tf_st = solset.makeSoltab('tec3rd',
                                          soltabName=tec3rdsoltabOut,
                                          axesNames=['time', 'ant'],
                                          axesVals=[times, stations],
                                          vals=tec3rd[:, :, 0],
                                          weights=weights[:, :, 0])
        else:
            tf_st = solset.makeSoltab('tec',
                                      soltabName=tecsoltabOut,
                                      axesNames=['time', 'ant', 'pol'],
                                      axesVals=[times, stations, ['XX', 'YY']],
                                      vals=tec,
                                      weights=weights)
            tf_st.addHistory('CREATE (by CLOCKTECFIT operation)')
            tf_st = solset.makeSoltab('clock',
                                      soltabName=clocksoltabOut,
                                      axesNames=['time', 'ant', 'pol'],
                                      axesVals=[times, stations, ['XX', 'YY']],
                                      vals=clock * 1e-9,
                                      weights=weights)
            tf_st.addHistory('CREATE (by CLOCKTECFIT operation)')
            tf_st = solset.makeSoltab('phase',
                                      soltabName=offsetsoltabOut,
                                      axesNames=['ant', 'pol'],
                                      axesVals=[stations, ['XX', 'YY']],
                                      vals=offset,
                                      weights=np.ones_like(offset,
                                                           dtype=np.float16))
            tf_st.addHistory('CREATE (by CLOCKTECFIT operation)')
            if fit3rdorder:
                tf_st = solset.makeSoltab(
                    'tec3rd',
                    soltabName=tec3rdsoltabOut,
                    axesNames=['time', 'ant', 'pol'],
                    axesVals=[times, stations, ['XX', 'YY']],
                    vals=tec3rd,
                    weights=weights)
    return 0
Ejemplo n.º 26
0
def run(soltab, opt1, opt2=[1., 2., 3.], opt3=0):
    """
    Generic unspecified step for easy expansion.

    Parameters
    ----------
    opt1 : float
        Is a mandatory parameter.

    opt2 : list of float, optional
        Is optional, by default [1.,2.,3.]

    opt2 : int, optional
        Is optional, by default 0.
    """

    # load specific libs
    import numpy as np

    # initial logging
    logging.info("Working on soltab: " + soltab.name)

    # check input
    # ...

    axisNames = soltab.getAxesNames()
    logging.info("Axis names are: " + str(axisNames))

    solType = soltab.getType()
    logging.info("Soltab type is: " + solType)

    soltab.setSelection(ant=soltab.getAxisValues('ant')[0])
    logging.info("Selection is: " + str(soltab.selection))

    # find axis values
    logging.info("Antennas (no selection) are: " +
                 str(soltab.getAxisValues('ant', ignoreSelection=True)))
    logging.info("Antennas (with selection) are: " +
                 str(soltab.getAxisValues('ant')))
    # but one can also use (selection is active here!)
    logging.info("Antennas (other method) are: " + str(soltab.ant))
    logging.info("Frequencies are: " + str(soltab.freq))
    logging.info("Directions are: " + str(soltab.dir))
    logging.info("Polarizations are: " + str(soltab.pol))
    # try to access a non-existent axis
    soltab.getAxisValues('nonexistantaxis')

    # now get all values given this selection
    logging.info("Get data using soltab.val")
    val = soltab.val
    logging.debug('shape of val: ' + str(soltab.val.shape))
    logging.info("$ val is " + str(val[0, 0, 0, 0, 100]))
    weight = soltab.weight
    time = soltab.time
    thisTime = soltab.time[100]

    # another way to get the data is using the getValues()
    logging.info("Get data using getValues()")
    grid, axes = soltab.getValues()
    # axis names
    logging.info("Axes: " + str(soltab.getAxesNames()))
    # axis shape
    print(axes)
    print([soltab.getAxisLen(axis)
           for axis in axes])  # not ordered, is a dict!
    # data array shape (same of axis shape)
    logging.info("Shape of values: " + str(grid.shape))
    #logging.info("$ val is "+str(grid[0,0,0,0,100]))

    # reset selection
    soltab.setSelection()
    logging.info('Reset selection to \'\'')
    logging.info("Antennas are: " + str(soltab.ant))
    logging.info("Frequencies are: " + str(soltab.freq))
    logging.info("Directions are: " + str(soltab.dir))
    logging.info("Polarizations are: " + str(soltab.pol))

    # finally the getValuesIter allaws to iterate across all possible combinations of a set of axes
    logging.info('Iteration on time/freq')
    for vals, coord, selection in soltab.getValuesIter(
            returnAxes=['time', 'freq']):
        # writing back the solutions
        soltab.setValues(vals, selection)
    logging.info('Iteration on time')
    for vals, coord, selection in soltab.getValuesIter(returnAxes=['time']):
        # writing back the solutions
        soltab.setValues(vals, selection)
    logging.info('Iteration on dir after selection to 1 dir')
    soltab.setSelection(dir='pointing')
    for vals, coord, selection in t.getValuesIter(returnAxes=['dir']):
        # writing back the solutions
        soltab.setValues(vals, selection)

    return 0  # if everything went fine, otherwise 1
Ejemplo n.º 27
0
def _plot(Nplots, NColFig, figSize, markerSize, cmesh, axesInPlot, axisInTable,
          xvals, yvals, xlabelunit, ylabelunit, datatype, filename, titles,
          log, dataCube, minZ, maxZ, plotFlag, makeMovie, antCoords, outQueue):
    import os
    from itertools import cycle, chain
    import numpy as np

    # find common min and max if not set
    flat = dataCube.filled(np.nan).flatten()
    if np.isnan(flat).all() or np.all(flat == 0):
        minZ = -0.1
        maxZ = 0.1
    elif minZ == 0 and maxZ == 0:
        if datatype == 'phase':
            minZ = np.nanmin(flat)
            maxZ = np.nanmax(flat)
        elif datatype == 'amplitude' and len(axesInPlot) > 1:
            flat[np.isnan(flat)] = np.nanmedian(
                flat)  # get rid of nans (problem in "<" below)
            maxZ = np.nanmedian(flat) + 3 * np.nanstd(flat[
                (flat / np.nanmedian(flat)) < 100])
            maxZ = np.nanmin([np.nanmax(flat), maxZ])
            minZ = np.nanmin(flat)
        else:
            minZ = np.nanmin(flat)
            maxZ = np.nanmax(flat)

        # prevent same min/max (still a problem at 0)
        if minZ == maxZ:
            minZ *= 0.99
            maxZ *= 1.01

        # add some space for clock plots
        if datatype == 'Clock':
            minZ -= 1e-8
            maxZ += 1e-8

        logging.info("Autoset min: %f, max:%f" % (minZ, maxZ))

    # if user-defined number of col use that
    if NColFig != 0: Nc = NColFig
    else: Nc = int(np.ceil(np.sqrt(Nplots)))
    Nr = int(np.ceil(np.float(Nplots) / Nc))

    if figSize[0] == 0:
        if makeMovie: figSize[0] = 5 + 2 * Nc
        else: figSize[0] = 10 + 3 * Nc
    if figSize[1] == 0:
        if makeMovie: figSize[1] = 4 + 1 * Nr
        else: figSize[1] = 8 + 2 * Nr

    figgrid, axa = plt.subplots(Nr,
                                Nc,
                                sharex=True,
                                sharey=True,
                                figsize=figSize)

    if Nplots == 1: axa = np.array([axa])
    figgrid.subplots_adjust(hspace=0, wspace=0)
    axaiter = chain.from_iterable(axa)

    # axes label
    if len(axa.shape) == 1:  # only one row
        [
            ax.set_xlabel(axesInPlot[0] + xlabelunit, fontsize=20)
            for ax in axa[:]
        ]
        if cmesh:
            axa[0].set_ylabel(axesInPlot[1] + ylabelunit, fontsize=20)
        else:
            axa[0].set_ylabel(datatype + ylabelunit, fontsize=20)
    else:
        [
            ax.set_xlabel(axesInPlot[0] + xlabelunit, fontsize=20)
            for ax in axa[-1, :]
        ]
        if cmesh:
            [
                ax.set_ylabel(axesInPlot[1] + ylabelunit, fontsize=20)
                for ax in axa[:, 0]
            ]
        else:
            [
                ax.set_ylabel(datatype + ylabelunit, fontsize=20)
                for ax in axa[:, 0]
            ]

    # if gaps in time, collapse and add a black vertical line on separation points
    if axesInPlot[0] == 'time' and cmesh == False:
        delta = np.abs(xvals[:-1] - xvals[1:])
        jumps = np.where(delta > 100 * np.median(delta))[
            0]  # jump if larger than 100 times the minimum step
        # remove jumps
        for j in jumps:
            xvals[j + 1:] -= delta[j]
        gap = xvals[-1] / 100  # 1%
        for j in jumps:
            xvals[j + 1:] += gap

    im = None
    for Ntab, title in enumerate(titles):

        ax = axa.flatten()[Ntab]
        ax.text(.5,
                .9,
                title,
                horizontalalignment='center',
                fontsize=14,
                transform=ax.transAxes)

        # add vertical lines and numbers at jumps (numbers are the jump sizes)
        if axesInPlot[0] == 'time' and cmesh == False and not np.all(
                np.isnan(dataCube[Ntab].filled(np.nan))):
            flat = dataCube[Ntab].filled(np.nan).flatten()
            [ax.axvline(xvals[j] + gap / 2., color='k') for j in jumps]
            if minZ != 0: texty = minZ + np.abs(np.nanmin(flat)) * 0.01
            else:
                texty = np.nanmin(
                    dataCube[Ntab]) + np.abs(np.nanmin(flat)) * 0.01
            [
                ax.text(xvals[j] + gap / 2.,
                        texty,
                        '%.0f' % delta[j],
                        fontsize=10) for j in jumps
            ]

        # set log scales if activated
        if 'X' in log: ax.set_xscale('log')
        if 'Y' in log: ax.set_yscale('log')

        colors = cycle(
            ['#377eb8', '#b88637', '#4daf4a', '#984ea3', '#ffff33', '#f781bf'])
        for Ncol, data in enumerate(dataCube[Ntab]):

            # set color, use defined colors if a few lines, otherwise a continuum colormap
            if len(dataCube[Ntab]) <= 6:
                color = next(colors)
                colorFlag = '#e41a1c'
            else:
                color = plt.cm.jet(
                    Ncol / float(len(dataCube[Ntab]) - 1))  # from 0 to 1
                colorFlag = 'k'

            vals = dataCube[Ntab][Ncol]
            if np.ma.getmask(dataCube[Ntab][Ncol]).all():
                continue

            # 3D cmesh plot
            if cmesh:
                # stratch the imshow output to fill the plot size
                bbox = ax.get_window_extent().transformed(
                    figgrid.dpi_scale_trans.inverted())
                aspect = ((xvals[-1] - xvals[0]) * bbox.height) / (
                    (yvals[-1] - yvals[0]) * bbox.width)
                if 'Z' in log:
                    if minZ == 0: minZ = np.log10(1e-6)
                    else: minZ = np.log10(minZ)
                    maxZ = np.log10(maxZ)
                    vals = np.log10(vals)

                if datatype == 'phase' or datatype == 'rotation':
                    #cmap = phase_colormap
                    cmap = plt.cm.jet
                else:
                    try:
                        cmap = plt.cm.viridis
                    except AttributeError:
                        cmap = plt.cm.rainbow

                # ugly fix to enforce min/max as imshow has some problems with very large numbers
                if not np.isnan(vals).all():
                    vals.data[
                        vals.filled(np.nanmedian(vals.data)) > maxZ] = maxZ
                    vals.data[
                        vals.filled(np.nanmedian(vals.data)) < minZ] = minZ

                im = ax.imshow(vals.filled(np.nan), origin='lower', interpolation="none", cmap=cmap, norm=None, \
                        extent=[xvals[0],xvals[-1],yvals[0],yvals[-1]], aspect=str(aspect), vmin=minZ, vmax=maxZ)

            # make an antenna plot
            elif antCoords != []:
                ax.set_xlabel('')
                ax.set_ylabel('')
                ax.axes.get_xaxis().set_ticks([])
                ax.axes.get_yaxis().set_ticks([])
                #vals = (vals-0.9)/(1.1-0.9)
                areas = (
                    5 + vals * 10
                )**2  # normalize marker diameter in pts**2 to 15-30 pt - assumes vals are between 0 and 1!
                ax.scatter(antCoords[0],
                           antCoords[1],
                           c=vals,
                           s=areas,
                           cmap=plt.cm.jet,
                           vmin=-0.5,
                           vmax=0.5)
                size = np.max([
                    np.max(antCoords[0]) - np.min(antCoords[0]),
                    np.max(antCoords[1]) - np.min(antCoords[1])
                ]) * 1.1  # make img squared
                ax.set_xlim(xmin=np.median(antCoords[0]) - size / 2.,
                            xmax=np.median(antCoords[0]) + size / 2.)
                ax.set_ylim(ymin=np.median(antCoords[1]) - size / 2.,
                            ymax=np.median(antCoords[1]) + size / 2.)

            # 2D scatter plot
            else:
                ax.plot(xvals[~vals.mask],
                        vals[~vals.mask],
                        'o',
                        color=color,
                        markersize=markerSize,
                        markeredgecolor='none'
                        )  # flagged data are automatically masked
                if plotFlag:
                    ax.plot(xvals[vals.mask],
                            vals.data[vals.mask],
                            'o',
                            color=colorFlag,
                            markersize=markerSize,
                            markeredgecolor='none')  # plot flagged points
                ax.set_xlim(xmin=min(xvals), xmax=max(xvals))
                ax.set_ylim(ymin=minZ, ymax=maxZ)

    if not im is None:
        # add a color bar to show scale
        figgrid.colorbar(im,
                         ax=axa.ravel().tolist(),
                         use_gridspec=True,
                         fraction=0.02,
                         pad=0.005,
                         aspect=35)

    logging.info("Saving " + filename + '.png')
    try:
        figgrid.savefig(filename + '.png', bbox_inches='tight')
    except:
        figgrid.tight_layout()
        figgrid.savefig(filename + '.png')
    plt.close()
Ejemplo n.º 28
0
def run(soltab,
        outsoltab,
        order=12,
        beta=5.0 / 3.0,
        niter=2,
        nsigma=5.0,
        refAnt=-1,
        scale_order=True,
        scale_dist=None,
        min_order=5,
        adjust_order=True):
    """
    Fits station screens to input soltab (type 'phase' or 'amplitude' only).

    The results of the fit are stored in the soltab parent solset in "outsoltab"
    and the residual values (actual - screen) are stored in "outsoltabresid".
    These values are the screen amplitude values per station per pierce point
    per solution interval. The pierce point locations are stored in an auxiliary
    array in the output soltabs.

    Screens can be plotted with the PLOTSCREEN operation.

    Parameters
    ----------
    soltab: solution table
        Soltab containing amplitude solutions
    outsoltab: str
        Name of output soltab
    order : int, optional
        Order of screen (i.e., number of KL base vectors to keep). If the order
        is scaled by dist (scale_order = True), the order is calculated as
        order * sqrt(dist/scale_dist)
    beta: float, optional
        Power-law index for amp structure function (5/3 => pure Kolmogorov
        turbulence)
    niter: int, optional
        Number of iterations to do when determining weights
    nsigma: float, optional
        Number of sigma above which directions are flagged
    refAnt: str or int, optional
        Index (if int) or name (if str) of reference station (-1 => no ref)
    scale_order : bool, optional
        If True, scale the screen order with sqrt of distance/scale_dist to the
        reference station
    scale_dist : float, optional
        Distance used to normalize the distances used to scale the screen order.
        If None, the max distance is used
    adjust_order : bool, optional
        If True, adjust the screen order to obtain a reduced chi^2 of approx.
        unity
    min_order : int, optional
        The minimum allowed order if adjust_order = True.

    """
    import numpy as np
    from numpy import newaxis

    # Get screen type
    screen_type = soltab.getType()
    if screen_type not in ['phase', 'amplitude', 'tec']:
        logging.error(
            'Screens can only be fit to soltabs of type "phase", "tec", or "amplitude".'
        )
        return 1
    logging.info('Using solution table {0} to calculate {1} screens'.format(
        soltab.name, screen_type))

    # Load values, etc.
    r_full = np.array(soltab.val)
    weights_full = soltab.weight[:]
    times = np.array(soltab.time)
    freqs = soltab.freq[:]
    axis_names = soltab.getAxesNames()
    freq_ind = axis_names.index('freq')
    dir_ind = axis_names.index('dir')
    time_ind = axis_names.index('time')
    ant_ind = axis_names.index('ant')
    if 'pol' in axis_names:
        is_scalar = False
        pol_ind = axis_names.index('pol')
        N_pols = len(soltab.pol[:])
        r_full = r_full.transpose(
            [dir_ind, time_ind, freq_ind, ant_ind, pol_ind])
        weights_full = weights_full.transpose(
            [dir_ind, time_ind, freq_ind, ant_ind, pol_ind])
    else:
        is_scalar = True
        N_pols = 1
        r_full = r_full.transpose([dir_ind, time_ind, freq_ind, ant_ind])
        r_full = r_full[:, :, :, :, newaxis]
        weights_full = weights_full.transpose(
            [dir_ind, time_ind, freq_ind, ant_ind])
        weights_full = weights_full[:, :, :, :, newaxis]

    # Collect station and source names and positions and times, making sure
    # that they are ordered correctly.
    solset = soltab.getSolset()
    source_names = soltab.dir[:]
    source_dict = solset.getSou()
    source_positions = []
    for source in source_names:
        source_positions.append(source_dict[source])
    station_names = soltab.ant[:]
    if type(station_names) is not list:
        station_names = station_names.tolist()
    station_dict = solset.getAnt()
    station_positions = []
    for station in station_names:
        station_positions.append(station_dict[station])
    N_sources = len(source_names)
    N_times = len(times)
    N_stations = len(station_names)
    N_freqs = len(freqs)
    N_piercepoints = N_sources

    # Set ref station
    if type(refAnt) is str:
        if N_stations == 1:
            refAnt = -1
        elif refAnt in station_names:
            refAnt = station_names.index(refAnt)
        else:
            refAnt = -1

    if scale_order:
        dist = []
        if refAnt == -1:
            station_order = [order] * N_stations
        else:
            for s in range(len(station_names)):
                dist.append(
                    _get_ant_dist(station_positions[s],
                                  station_positions[refAnt]))
            if scale_dist is None:
                scale_dist = max(dist)
            logging.info('Using variable order (with max order = {0} '
                         'and scaling dist = {1} m)'.format(order, scale_dist))
            station_order = []
            for s in range(len(station_names)):
                station_order.append(
                    max(min_order,
                        min(order,
                            int(order * np.sqrt(dist[s] / scale_dist)))))
    else:
        station_order = [order] * len(station_names)
        logging.info('Using order = {0}'.format(order))

    # Initialize various arrays and parameters
    screen = np.zeros((N_sources, N_stations, N_times, N_freqs, N_pols))
    residual = np.zeros((N_sources, N_stations, N_times, N_freqs, N_pols))
    screen_order = np.zeros((N_stations, N_times, N_freqs, N_pols))
    r_0 = 100
    target_redchi2 = 1.0

    # Calculate full piercepoint arrays
    pp_list = []
    full_matrices = []
    for s in range(N_stations):
        pp_s, midRA, midDec = _calculate_piercepoints(
            np.array([station_positions[s]]), np.array(source_positions))
        pp_list.append(pp_s)
        full_matrices.append(_calculate_svd(pp_s, r_0, beta, N_piercepoints))

    # Fit station screens
    for freq_ind in range(N_freqs):
        for pol_ind in range(N_pols):
            r = r_full[:, :, freq_ind, :,
                       pol_ind]  # order is now [dir, time, ant]
            r = r.transpose([0, 2, 1])  # order is now [dir, ant, time]
            weights = weights_full[:, :, freq_ind, :, pol_ind]
            weights = weights.transpose([0, 2, 1])

            # Fit screens
            for s, stat in enumerate(station_names):
                if s == refAnt and (screen_type == 'phase'
                                    or screen_type == 'tec'):
                    # skip reference station (phase- or tec-type only)
                    continue
                if np.all(np.isnan(r[:, s, :])) or np.all(weights[:,
                                                                  s, :] == 0):
                    # skip fully flagged stations
                    continue
                screen_order[s, :, freq_ind, pol_ind] = station_order[s]
                rr = np.reshape(r[:, s, :], [N_piercepoints, N_times])
                pp = pp_list[s]

                # Iterate:
                # 1. fit screens
                # 2. flag nsigma outliers
                # 3. refit with new weights
                # 4. repeat for niter
                station_weights = weights[:, s, :]
                init_station_weights = weights[:, s, :].copy(
                )  # preserve initial weights
                for iterindx in range(niter):
                    if iterindx > 0:
                        # Flag outliers
                        if screen_type == 'phase' or screen_type == 'tec':
                            # Use residuals
                            screen_diff = residual[:, s, :, freq_ind, pol_ind]
                        elif screen_type == 'amplitude':
                            # Use log residuals
                            screen_diff = np.log10(rr) - np.log10(
                                np.abs(rr -
                                       residual[:, s, :, freq_ind, pol_ind]))
                        station_weights = _flag_outliers(
                            init_station_weights, screen_diff, nsigma,
                            screen_type)

                    # Fit the screens
                    norderiter = 1
                    if adjust_order:
                        if iterindx > 0:
                            norderiter = 4
                    for tindx, t in enumerate(times):
                        N_unflagged = np.where(
                            station_weights[:, tindx] > 0.0)[0].size
                        if N_unflagged == 0:
                            continue
                        if screen_order[s, tindx, freq_ind,
                                        pol_ind] > N_unflagged - 1:
                            screen_order[s, tindx, freq_ind,
                                         pol_ind] = N_unflagged - 1
                        hit_upper = False
                        hit_lower = False
                        hit_upper2 = False
                        hit_lower2 = False
                        sign = 1.0
                        for oindx in range(norderiter):
                            skip_fit = False
                            if iterindx > 0:
                                if np.all(station_weights[:, tindx] ==
                                          prev_station_weights[:, tindx]):
                                    if not adjust_order:
                                        # stop fitting if weights did not change
                                        break
                                    elif oindx == 0:
                                        # Skip the fit for first iteration, as it is the same as the prev one
                                        skip_fit = True
                            if not np.all(station_weights[:, tindx] ==
                                          0.0) and not skip_fit:
                                scr, res = _fit_screen(
                                    [stat], source_names, full_matrices[s],
                                    pp[:, :], rr[:, tindx],
                                    station_weights[:, tindx],
                                    int(screen_order[s, tindx, freq_ind,
                                                     pol_ind]), r_0, beta,
                                    screen_type)
                                screen[:, s, tindx, freq_ind, pol_ind] = scr[:,
                                                                             0]
                                residual[:, s, tindx, freq_ind,
                                         pol_ind] = res[:, 0]

                            if hit_lower2 or hit_upper2:
                                break

                            if adjust_order and iterindx > 0:
                                if screen_type == 'phase':
                                    redchi2 = _circ_chi2(
                                        residual[:, s, tindx, freq_ind,
                                                 pol_ind],
                                        station_weights[:, tindx]) / (
                                            N_unflagged -
                                            screen_order[s, tindx, freq_ind,
                                                         pol_ind])
                                elif screen_type == 'amplitude':
                                    # Use log residuals
                                    screen_diff = np.log10(
                                        rr[:, tindx]) - np.log10(
                                            np.abs(rr[:, tindx] -
                                                   residual[:, s, tindx,
                                                            freq_ind,
                                                            pol_ind]))
                                    redchi2 = np.sum(
                                        np.square(screen_diff) *
                                        station_weights[:, tindx]) / (
                                            N_unflagged -
                                            screen_order[s, tindx, freq_ind,
                                                         pol_ind])
                                else:
                                    redchi2 = np.sum(
                                        np.square(residual[:, s, tindx,
                                                           freq_ind, pol_ind])
                                        * station_weights[:, tindx]) / (
                                            N_unflagged -
                                            screen_order[s, tindx, freq_ind,
                                                         pol_ind])
                                if oindx > 0:
                                    if redchi2 > 1.0 and prev_redchi2 < redchi2:
                                        sign *= -1
                                    if redchi2 < 1.0 and prev_redchi2 > redchi2:
                                        sign *= -1
                                prev_redchi2 = redchi2
                                order_factor = (
                                    N_unflagged -
                                    screen_order[s, tindx, freq_ind,
                                                 pol_ind])**0.2
                                target_order = float(
                                    screen_order[s, tindx, freq_ind, pol_ind]
                                ) - sign * order_factor * (target_redchi2 -
                                                           redchi2)
                                target_order = max(station_order[s],
                                                   target_order)
                                target_order = min(int(round(target_order)),
                                                   N_unflagged - 1)
                                if target_order <= 0:
                                    target_order = min(station_order[s],
                                                       N_unflagged - 1)
                                if target_order == screen_order[
                                        s, tindx, freq_ind,
                                        pol_ind]:  # don't fit again if order is the same as last one
                                    break
                                if target_order == N_unflagged - 1:  # check whether we've been here before. If so, break
                                    if hit_upper:
                                        hit_upper2 = True
                                    hit_upper = True
                                if target_order == station_order[
                                        s]:  # check whether we've been here before. If so, break
                                    if hit_lower:
                                        hit_lower2 = True
                                    hit_lower = True
                                screen_order[s, tindx, freq_ind,
                                             pol_ind] = target_order
                    prev_station_weights = station_weights.copy()
                weights[:, s, :] = station_weights
            weights_full[:, :, freq_ind, :, pol_ind] = weights.transpose(
                [0, 2, 1])  # order is now [dir, time, ant]

    # Write the results to the output solset
    dirs_out = source_names
    times_out = times
    ants_out = station_names
    freqs_out = freqs

    # Store screen values
    vals = screen.transpose(
        [2, 3, 1, 0, 4])  # order is now ['time', 'freq', 'ant', 'dir', 'pol']
    weights = weights_full.transpose(
        [1, 2, 3, 0, 4])  # order is now ['time', 'freq', 'ant', 'dir', 'pol']
    if is_scalar:
        screen_st = solset.makeSoltab(
            '{}screen'.format(screen_type),
            outsoltab,
            axesNames=['time', 'freq', 'ant', 'dir'],
            axesVals=[times_out, freqs_out, ants_out, dirs_out],
            vals=vals[:, :, :, :, 0],
            weights=weights[:, :, :, :, 0])
        vals = residual.transpose([2, 3, 1, 0, 4])
        weights = np.zeros(vals.shape)
        for d in range(N_sources):
            # Store the screen order as the weights of the residual soltab
            weights[:, :, :, d, :] = screen_order.transpose(
                [1, 2, 0, 3])  # order is now [time, ant, freq, pol]
        resscreen_st = solset.makeSoltab(
            '{}screenresid'.format(screen_type),
            outsoltab + 'resid',
            axesNames=['time', 'freq', 'ant', 'dir'],
            axesVals=[times_out, freqs_out, ants_out, dirs_out],
            vals=vals[:, :, :, :, 0],
            weights=weights[:, :, :, :, 0])
    else:
        pols_out = soltab.pol[:]
        screen_st = solset.makeSoltab(
            '{}screen'.format(screen_type),
            outsoltab,
            axesNames=['time', 'freq', 'ant', 'dir', 'pol'],
            axesVals=[times_out, freqs_out, ants_out, dirs_out, pols_out],
            vals=vals,
            weights=weights)
        vals = residual.transpose([2, 3, 1, 0, 4])
        weights = np.zeros(vals.shape)
        for d in range(N_sources):
            # Store the screen order as the weights of the residual soltab
            weights[:, :, :, d, :] = screen_order.transpose(
                [1, 2, 0, 3])  # order is now [time, ant, freq, pol]
        resscreen_st = solset.makeSoltab(
            '{}screenresid'.format(screen_type),
            outsoltab + 'resid',
            axesNames=['time', 'freq', 'ant', 'dir', 'pol'],
            axesVals=[times_out, freqs_out, ants_out, dirs_out, pols_out],
            vals=vals,
            weights=weights)

    # Store beta, r_0, height, and order as attributes of the screen soltabs
    screen_st.obj._v_attrs['beta'] = beta
    screen_st.obj._v_attrs['r_0'] = r_0
    screen_st.obj._v_attrs['height'] = 0.0
    screen_st.obj._v_attrs['midra'] = midRA
    screen_st.obj._v_attrs['middec'] = midDec

    # Store piercepoint table. Note that it does not conform to the axis
    # shapes, so we cannot use makeSoltab()
    solset.obj._v_file.create_array('/' + solset.name + '/' +
                                    screen_st.obj._v_name,
                                    'piercepoint',
                                    obj=pp)

    screen_st.addHistory('CREATE (by STATIONSCREEN operation)')
    resscreen_st.addHistory('CREATE (by STATIONSCREEN operation)')

    return 0
Ejemplo n.º 29
0
def run(soltab,
        soltabOut='tec000',
        refAnt='',
        maxResidualFlag=2.5,
        maxResidualProp=1.):
    """
    Bruteforce TEC extraction from phase solutions.

    Parameters
    ----------
    soltabOut : str, optional
        output table name (same solset), by deault "tec".

    refAnt : str, optional
        Reference antenna, by default the first.

    maxResidualFlag : float, optional
        Max average residual in radians before flagging datapoint, by default 2.5 If 0: no check.

    maxResidualProp : float, optional
        Max average residual in radians before stop propagating solutions, by default 1. If 0: no check.

    """
    import numpy as np
    import scipy.optimize
    from losoto.lib_unwrap import unwrap_2d

    def mod(d):
        return np.mod(d + np.pi, 2. * np.pi) - np.pi

    drealbrute = lambda d, freq, y: np.sum(
        np.abs(mod(-8.44797245e9 * d / freq) - y))  # bruteforce
    dreal = lambda d, freq, y: mod(-8.44797245e9 * d[0] / freq) - y
    #dreal2 = lambda d, freq, y: mod(-8.44797245e9*d[0]/freq + d[1]) - y
    #dcomplex = lambda d, freq, y:  np.sum( ( np.cos(-8.44797245e9*d/freq)  - np.cos(y) )**2 ) +  np.sum( ( np.sin(-8.44797245e9*d/freq)  - np.sin(y) )**2 )
    #def dcomplex( d, freq, y, y_pre, y_post):
    #    return np.sum( ( np.absolute( np.exp(-1j*8.44797245e9*d/freq)  - np.exp(1j*y) ) )**2 ) + \
    #           .5*np.sum( ( np.absolute( np.exp(-1j*8.44797245e9*d/freq)  - np.exp(1j*y_pre) ) )**2 ) + \
    #           .5*np.sum( ( np.absolute( np.exp(-1j*8.44797245e9*d/freq)  - np.exp(1j*y_post) ) )**2 )
    #dcomplex2 = lambda d, freq, y:  abs(np.cos(-8.44797245e9*d[0]/freq + d[1])  - np.cos(y)) + abs(np.sin(-8.44797245e9*d[0]/freq + d[1])  - np.sin(y))

    logging.info("Find TEC for soltab: " + soltab.name)

    # input check
    solType = soltab.getType()
    if solType != 'phase':
        logging.warning("Soltab type of " + soltab._v_name + " is of type " +
                        solType + ", should be phase. Ignoring.")
        return 1

    ants = soltab.getAxisValues('ant')
    if refAnt != '' and refAnt != 'closest' and not refAnt in soltab.getAxisValues(
            'ant', ignoreSelection=True):
        logging.error('Reference antenna ' + refAnt + ' not found. Using: ' +
                      ants[1])
        refAnt = ants[0]
    if refAnt == '': refAnt = ants[0]

    # times and ants needs to be complete or selection is much slower
    times = soltab.getAxisValues('time')

    # create new table
    solset = soltab.getSolset()
    soltabout = solset.makeSoltab(soltype = 'tec', soltabName = soltabOut, axesNames=['ant','time'], \
                      axesVals=[soltab.getAxisValues(axisName) for axisName in ['ant','time']], \
                      vals=np.zeros(shape=(soltab.getAxisLen('ant'),soltab.getAxisLen('time'))), \
                      weights=np.ones(shape=(soltab.getAxisLen('ant'),soltab.getAxisLen('time'))) )
    soltabout.addHistory('Created by TEC operation from %s.' % soltab.name)

    for vals, weights, coord, selection in soltab.getValuesIter(
            returnAxes=['freq', 'time'], weight=True, reference=refAnt):

        if len(coord['freq']) < 10:
            logging.error(
                'Delay estimation needs at least 10 frequency channels, preferably distributed over a wide range.'
            )
            return 1

        # reorder axes
        vals = reorderAxes(vals, soltab.getAxesNames(), ['freq', 'time'])
        weights = reorderAxes(weights, soltab.getAxesNames(), ['freq', 'time'])

        fitd = np.zeros(len(times))
        fitweights = np.ones(len(times))  # all unflagged to start
        #fitdguess = 0.01 # good guess
        ranges = (-0.5, 0.5)
        Ns = 1000

        if not coord['ant'] == refAnt:

            if (weights == 0.).all() == True:
                logging.warning('Skipping flagged antenna: ' + coord['ant'])
                fitweights[:] = 0
            else:

                # unwrap 2d timexfreq
                #flags = np.array((weights == 0), dtype=bool)
                #vals = unwrap_2d(vals, flags, coord['freq'], coord['time'])

                for t, time in enumerate(times):

                    # apply flags
                    idx = (weights[:, t] != 0.)
                    freq = np.copy(coord['freq'])[idx]

                    if t == 0: phaseComb_pre = vals[idx, 0]
                    else: phaseComb_pre = vals[idx, t - 1]
                    if t == len(times) - 1: phaseComb_post = vals[idx, -1]
                    else: phaseComb_post = vals[idx, t + 1]

                    phaseComb = vals[idx, t]

                    if len(freq) < 10:
                        fitweights[t] = 0
                        logging.warning(
                            'No valid data found for delay fitting for antenna: '
                            + coord['ant'] + ' at timestamp ' + str(t))
                        continue

                    # if more than 1/4 of chans are flagged
                    if (len(idx) - len(freq)) / float(len(idx)) > 1 / 3.:
                        logging.debug(
                            'High number of filtered out data points for the timeslot %i: %i/%i'
                            % (t, len(idx) - len(freq), len(idx)))

                    # least square 2
                    #fitresultd2, success = scipy.optimize.leastsq(dreal2, [fitdguess,0.], args=(freq, phaseComb))
                    #numjumps = np.around(fitresultd2[1]/(2*np.pi))
                    #print 'best jumps:', numjumps
                    #phaseComb -= numjumps * 2*np.pi
                    #fitresultd, success = scipy.optimize.leastsq(dreal, [fitresultd2[0]], args=(freq, phaseComb))

                    # least square 1
                    #fitresultd, success = scipy.optimize.leastsq(dreal, [fitdguess], args=(freq, phaseComb))

                    # hopper
                    #fitresultd = scipy.optimize.basinhopping(dreal, [fitdguess], T=1, minimizer_kwargs={'args':(freq, phaseComb)})
                    #fitresultd = [fitresultd.x]

                    #best_residual = np.nanmean(np.abs( mod(-8.44797245e9*fitresultd[0]/freq) - phaseComb ) )

                    #best_residual = np.inf
                    #for jump in [-2,-1,0,1,2]:
                    #    fitresultd, success = scipy.optimize.leastsq(dreal, [fitdguess], args=(freq, phaseComb - jump * 2*np.pi))
                    #    print fitresultd
                    #    # fractional residual
                    #    residual = np.nanmean(np.abs( (-8.44797245e9*fitresultd[0]/freq) - phaseComb - jump * 2*np.pi ) )
                    #    if residual < best_residual:
                    #        best_residual = residual
                    #        fitd[t] = fitresultd[0]
                    #        best_jump = jump

                    # brute force
                    fitresultd = scipy.optimize.brute(drealbrute,
                                                      ranges=(ranges, ),
                                                      Ns=Ns,
                                                      args=(freq, phaseComb))
                    fitresultd, success = scipy.optimize.leastsq(
                        dreal, fitresultd, args=(freq, phaseComb))
                    best_residual = np.nanmean(
                        np.abs(
                            mod(-8.44797245e9 * fitresultd[0] / freq) -
                            phaseComb))

                    fitd[t] = fitresultd[0]
                    if maxResidualFlag == 0 or best_residual < maxResidualFlag:
                        fitweights[t] = 1
                        if maxResidualProp == 0 or best_residual < maxResidualProp:
                            ranges = (fitresultd[0] - 0.05,
                                      fitresultd[0] + 0.05)
                            Ns = 100
                        else:
                            ranges = (-0.5, 0.5)
                            Ns = 1000
                    else:
                        # high residual, flag and reset initial guess
                        logging.warning('Bad solution for ant: ' +
                                        coord['ant'] + ' (time: ' + str(t) +
                                        ', resdiual: ' + str(best_residual) +
                                        ').')
                        fitweights[t] = 0
                        ranges = (-0.5, 0.5)
                        Ns = 1000

                    # Debug plot
                    doplot = False
                    if doplot and (coord['ant'] == 'RS509LBA' or coord['ant']
                                   == 'RS210LBA') and t % 50 == 0:
                        print("Plotting")
                        if not 'matplotlib' in sys.modules:
                            import matplotlib as mpl
                            mpl.rc('figure.subplot',
                                   left=0.05,
                                   bottom=0.05,
                                   right=0.95,
                                   top=0.95,
                                   wspace=0.22,
                                   hspace=0.22)
                            mpl.use("Agg")
                        import matplotlib.pyplot as plt

                        fig = plt.figure()
                        fig.subplots_adjust(wspace=0)
                        ax = fig.add_subplot(111)

                        # plot rm fit
                        plotd = lambda d, freq: -8.44797245e9 * d / freq
                        ax.plot(freq,
                                plotd(fitresultd[0], freq[:]),
                                "-",
                                color='purple')
                        ax.plot(freq,
                                mod(plotd(fitresultd[0], freq[:])),
                                ":",
                                color='purple')

                        #ax.plot(freq, vals[idx,t], '.b' )
                        #ax.plot(freq, phaseComb + numjumps * 2*np.pi, 'x', color='purple' )
                        ax.plot(freq, phaseComb, 'o', color='purple')

                        residual = mod(plotd(fitd[t], freq[:]) - phaseComb)
                        ax.plot(freq, residual, '.', color='orange')

                        ax.set_xlabel('freq')
                        ax.set_ylabel('phase')
                        #ax.set_ylim(ymin=-np.pi, ymax=np.pi)

                        logging.warning('Save pic: ' + str(t) + '_' +
                                        coord['ant'] + '.png')
                        plt.savefig(str(t) + '_' + coord['ant'] + '.png',
                                    bbox_inches='tight')
                        del fig

                logging.info('%s: average tec: %f TECU' %
                             (coord['ant'], np.mean(2 * fitd)))

        # reorder axes back to the original order, needed for setValues
        soltabout.setSelection(ant=coord['ant'])
        soltabout.setValues(fitd)
        soltabout.setValues(fitweights, weight=True)

    return 0
Ejemplo n.º 30
0
def run(soltab, soltabOut='tec000', refAnt=''):
    """
    Bruteforce TEC extraction from phase solutions.

    Parameters
    ----------
    soltabOut : str, optional
        output table name (same solset), by deault "tec".

    refAnt : str, optional
        Reference antenna, by default the first.

    """
    import numpy as np
    import scipy.optimize

    def mod(d):
        return np.mod(d + np.pi, 2. * np.pi) - np.pi

    def cost_f(d, freq, y):
        nfreq, ntime = y.shape
        phase = mod(2 * np.pi * d * freq).repeat(ntime).reshape(nfreq, ntime)
        dist = np.abs(mod(phase - y))
        ngood = np.sum(~np.isnan(dist))
        return np.nansum(dist / ngood)

    logging.info("Find global DELAY for soltab: " + soltab.name)

    # input check
    solType = soltab.getType()
    if solType != 'phase':
        logging.warning("Soltab type of " + soltab._v_name + " is of type " +
                        solType + ", should be phase. Ignoring.")
        return 1

    ants = soltab.getAxisValues('ant')
    if refAnt != '' and refAnt != 'closest' and not refAnt in soltab.getAxisValues(
            'ant', ignoreSelection=True):
        logging.warning('Reference antenna ' + refAnt + ' not found. Using: ' +
                        ants[1])
        refAnt = ants[0]
    if refAnt == '': refAnt = ants[0]

    # times and ants needs to be complete or selection is much slower
    times = soltab.getAxisValues('time')

    # create new table
    solset = soltab.getSolset()
    soltabout = solset.makeSoltab(soltype = 'clock', soltabName = soltabOut, axesNames=['ant','time'], \
                      axesVals=[soltab.getAxisValues(axisName) for axisName in ['ant','time']], \
                      vals=np.zeros(shape=(soltab.getAxisLen('ant'),soltab.getAxisLen('time'))), \
                      weights=np.ones(shape=(soltab.getAxisLen('ant'),soltab.getAxisLen('time'))) )
    soltabout.addHistory('Created by GLOBALDELAY operation from %s.' %
                         soltab.name)

    for vals, weights, coord, selection in soltab.getValuesIter(
            returnAxes=['freq', 'time'], weight=True, refAnt=refAnt):

        if len(coord['freq']) < 5:
            logging.error(
                'Delay estimation needs at least 5 frequency channels, preferably distributed over a wide range.'
            )
            return 1

        # reorder axes
        vals = reorderAxes(vals, soltab.getAxesNames(), ['freq', 'time'])
        weights = reorderAxes(weights, soltab.getAxesNames(), ['freq', 'time'])

        ranges = (-1e-7, 1e-7)
        Ns = 1001

        delay_fitresult = np.zeros(len(times))
        weights_fitresult = np.ones(len(times))

        if not coord['ant'] == refAnt:

            if (weights == 0.).all() == True:
                logging.warning('Skipping flagged antenna: ' + coord['ant'])
                weights_fitresult[:] = 0
            else:

                freq = np.copy(coord['freq'])

                # brute force
                fit = scipy.optimize.brute(cost_f,
                                           ranges=(ranges, ),
                                           Ns=Ns,
                                           args=(freq, vals))
                delay_fitresult[:] = fit

                logging.info('%s: average delay: %f ns' %
                             (coord['ant'], fit * 1e9))

        # reorder axes back to the original order, needed for setValues
        soltabout.setSelection(ant=coord['ant'])
        soltabout.setValues(delay_fitresult)
        soltabout.setValues(weights_fitresult, weight=True)

    return 0