Exemple #1
0
def compute_Pkp(sky_map, kp):
    N = kp.shape[0]
    Nf, Np = sky_map.shape
    Pkp = np.zeros(N)
    for pi in mpiutil.mpirange(Np):
        gkp = (cd_span / N) * ndft(cd, sky_map[:, pi], kp)  # K (h Mpc^-1)^-1
        # according to the definition <\delta(k) \delta(k')^*> = (2\pi)^3 \delta(k - k') P(k), but for 1D, <\delta(k) \delta(k')^*> = (2\pi) \delta(k - k') P(k)
        Pkp += np.abs(gkp)**2 / (2.0 * np.pi)  # K^2 (Mpc / h)
    Pkp /= Np  # K^2 (Mpc / h)
    Pkp *= 1.0e6  # mK^2 (Mpc / h)

    tmp = mpiutil.gather_array(Pkp.reshape(1, -1), axis=0, root=0)
    if mpiutil.rank0:
        print 'tmp shape: ', tmp.shape
        Pkp = np.sum(tmp, axis=0)
        print 'Pkp shape: ', Pkp.shape

    return Pkp  # only rank0 has the correct Pkp
result = mpiutil.parallel_map(lambda x: x * x, glist, root=0)
if rank == 0:
    print 'result = %s' % result

# split_all
separator(sec, 'split_all')
print 'rank %d has: %s' % (rank, mpiutil.split_all(6))

# split_local
separator(sec, 'split_local')
print 'rank %d has: %s' % (rank, mpiutil.split_local(6))

# gather_array
separator(sec, 'gather_array')
if rank == 0:
    local_ary = np.array([[0, 1], [6, 7]])
elif rank == 1:
    local_ary = np.array([[2], [8]])
elif rank == 2:
    local_ary = np.array([[3], [9]])
if rank == 3:
    local_ary = np.array([[4, 5], [10, 11]])
global_ary = mpiutil.gather_array(local_ary, axis=1, root=0)
if rank == 0:
    print 'global_ary = %s' % global_ary

# scatter_array
separator(sec, 'scatter_array')
local_ary = mpiutil.scatter_array(global_ary, axis=1, root=0)
print 'rank %d has local_ary = %s' % (rank, local_ary)
Exemple #3
0
    def process(self, ts):

        tsys = self.params['tsys']
        accuracy_boost = self.params['accuracy_boost']
        l_boost = self.params['l_boost']
        bl_range = self.params['bl_range']
        auto_correlations = self.params['auto_correlations']
        beam_dir = output_path(self.params['beam_dir'])
        noise_weight = self.params['noise_weight']

        assert isinstance(
            ts, Timestream
        ), '%s only works for Timestream object' % self.__class__.__name__

        ts.redistribute('baseline')

        lat = ts.attrs['sitelat']
        # lon = ts.attrs['sitelon']
        lon = 0.0
        # lon = np.degrees(ts['ra_dec'][0, 0]) # the first ra
        local_origin = False
        freqs = ts.freq[:]  # MHz
        nfreq = freqs.shape[0]
        band_width = ts.attrs['freqstep']  # MHz
        try:
            ndays = ts.attrs['ndays']
        except KeyError:
            ndays = 1
        feeds = ts['feedno'][:]
        bl_order = mpiutil.gather_array(ts.local_bl,
                                        axis=0,
                                        root=None,
                                        comm=ts.comm)
        bls = [tuple(bl) for bl in bl_order]
        az, alt = ts['az_alt'][0]
        az = np.degrees(az)
        alt = np.degrees(alt)
        pointing = [az, alt, 0.0]
        feedpos = ts['feedpos'][:]

        if ts.is_dish:
            from tlpipe.map.drift.telescope import tl_dish

            dish_width = ts.attrs['dishdiam']
            tel = tl_dish.TlUnpolarisedDishArray(lat, lon, freqs, band_width,
                                                 tsys, ndays, accuracy_boost,
                                                 l_boost, bl_range,
                                                 auto_correlations,
                                                 local_origin, dish_width,
                                                 feedpos, pointing)
        elif ts.is_cylinder:
            from tlpipe.map.drift.telescope import tl_cylinder

            # factor = 1.2 # suppose an illumination efficiency, keep same with that in timestream_common
            factor = 0.79  # for xx
            # factor = 0.88 # for yy
            cyl_width = factor * ts.attrs['cywid']
            tel = tl_cylinder.TlUnpolarisedCylinder(
                lat, lon, freqs, band_width, tsys, ndays, accuracy_boost,
                l_boost, bl_range, auto_correlations, local_origin, cyl_width,
                feedpos)
        else:
            raise RuntimeError('Unknown array type %s' % ts.attrs['telescope'])

        # beamtransfer
        bt = beamtransfer.BeamTransfer(beam_dir, tel, noise_weight, True)
        bt.generate()

        return super(GenBeam, self).process(ts)
Exemple #4
0
def hough_transform(I,
                    time,
                    freq,
                    dl,
                    dh,
                    nd=None,
                    t0l=None,
                    t0h=None,
                    nt0=None,
                    threshold=3.0,
                    comm=None):
    """Hough transform algorithm.

    Parameters
    ----------
    I : 2D np.ndarray
        Input data image, row time, column freq, both in ascending. order
    time : 1D np.ndarray
        The corresponding time of `I`, in ms.
    freq : 1D np.ndarray
        The corresponding frequency of `I`, in GHz.
    dl : float
        Lowest d = 4.15 DM.
    dh : float
        High d = 4.15 DM.
    nd : integer
        Number of d.
    t0l : float
        Lowest time offset t0, in ms.
    t0h : float
        Highest time offset t0, in ms.
    nt0 : integer
        Number of time offset t0.
    threshold : float
        How many sigmas to truncate the data.
    comm : MPI communicator
        MPI communicator. Required if executed parallelly by using MPI.
    """

    try:
        from caput import mpiutil
    except ImportError:
        # no mpiutil, can not use MPI to speed up
        comm = None

    fl, fh = freq[0], freq[-1]
    Nf = freq.shape[0]
    tl, th = time[0], time[-1]
    Nt = time.shape[0]
    df = (fh - fl) / Nf
    dt = (th - tl) / Nt

    # mask data based on a given threshold
    med = np.median(I)
    if threshold > 0.0:
        abs_diff = np.abs(I - med)
        mad = np.median(abs_diff) / 0.6745
        Im = np.where(abs_diff > threshold * mad, I - med,
                      np.nan)  # subtract median
    else:
        Im = I - med

    inds = np.where(np.isfinite(Im))  # inds of non-masked vals

    if nd is None:
        dd = fl**2 * dt
        nd = np.int(np.ceil((dh - dl) / dd))
    d = np.linspace(dl, dh, nd)
    # compute the range of t0
    if t0l is None:
        t0l = -fl**-2 * dh + tl
    if t0h is None:
        t0h = -fh**-2 * dl + th
    assert t0h > t0l, "Must have t0h (= %g) > t0l (= %g)" % (t0h, t0l)
    if nt0 is None:
        nt0 = nd
    dt0 = (t0h - t0l) / nt0

    # initialize the accumulator
    A = np.zeros((nt0, nd))  # the accumulator

    # accumulat the accumulator
    if comm is None or comm.size == 1:
        linds = zip(inds[0], inds[1])
    else:
        linds = mpiutil.mpilist(zip(inds[0], inds[1]), comm=comm)
    for (ti, fi) in linds:
        fv = fl + fi * df
        tv = tl + ti * dt
        t0v = -fv**-2 * d + tv
        for di in xrange(nd):
            t0i = np.int(np.around((t0v[di] - t0l) / dt0))
            t0i = max(0, t0i)
            t0i = min(nt0 - 1, t0i)
            A[t0i, di] += Im[ti, fi]

    # gather and accumulat A
    if not (comm is None or comm.size == 1):
        A = A.reshape((1, nt0, nd))
        Ac = mpiutil.gather_array(A, comm=comm)
        if mpiutil.rank0:
            A = Ac.sum(axis=0)
        else:
            A = None

    return A, dl, dh, t0l, t0h
Exemple #5
0
    def process(self, ts):

        assert isinstance(ts, Timestream), '%s only works for Timestream object' % self.__class__.__name__

        num_mean = self.params['num_mean']
        save_ns_vis = self.params['save_ns_vis']
        ns_vis_file = self.params['ns_vis_file']
        apply_gain = self.params['apply_gain']
        save_gain = self.params['save_gain']
        gain_file = self.params['gain_file']
        show_progress = self.params['show_progress']
        progress_step = self.params['progress_step']
        tag_output_iter = self.params['tag_output_iter']

        if save_ns_vis or apply_gain or save_gain:
            pol_type = ts['pol'].attrs['pol_type']
            if pol_type != 'linear':
                raise RuntimeError('Can not do ns_eigcal for pol_type: %s' % pol_type)

            ts.redistribute('baseline')

            nt = ts.local_vis.shape[0]
            freq = ts.freq[:]
            pol = [ ts.pol_dict[p] for p in ts['pol'][:] ] # as string
            bls = mpiutil.gather_array(ts.local_bl[:], root=None, comm=ts.comm)
            feedno = ts['feedno'][:].tolist()

            nf = ts.local_freq.shape[0]
            npol = ts.local_pol.shape[0]
            nlb = ts.local_bl.shape[0]
            nfeed = len(feedno)

            if num_mean <= 0:
                raise RuntimeError('Invalid num_mean = %s' % num_mean)
            ns_on = ts['ns_on'][:]
            ns_on = np.where(ns_on, 1, 0)
            diff_ns = np.diff(ns_on)
            on_si = np.where(diff_ns==1)[0] + 1 # start inds of ON
            on_ei = np.where(diff_ns==-1)[0] + 1 # (end inds + 1) of ON
            if on_ei[0] < on_si[0]:
                on_ei = on_ei[1:]
            if on_si[-1] > on_ei[-1]:
                on_si = on_si[:-1]

            if on_si[0] < num_mean+1: # not enough off data in the beginning to use
                on_si = on_si[1:]
                on_ei = on_ei[1:]

            if len(on_si) != len(on_ei):
                raise RuntimeError('len(on_si) != len(on_ei)')
            num_on = len(on_si)
            cal_inds = (on_si + on_ei) / 2 # cal inds are the center inds on ON


            # find indices mapping between Vmat and vis
            # bis = range(nbl)
            bis_conj = [] # indices that shold be conj
            mis = [] # indices in the nfeed x nfeed matrix by flatten it to a vector
            mis_conj = [] # indices (of conj vis) in the nfeed x nfeed matrix by flatten it to a vector
            for bi, (fdi, fdj) in enumerate(bls):
                ai, aj = feedno.index(fdi), feedno.index(fdj)
                mis.append(ai * nfeed + aj)
                if ai != aj:
                    bis_conj.append(bi)
                    mis_conj.append(aj * nfeed + ai)


            tfp_inds = list(itertools.product(range(num_on), range(nf), range(npol)))
            ns, ss, es = mpiutil.split_all(len(tfp_inds), comm=ts.comm)
            # gather data to make each process to have its own data which has all bls
            for ri, (ni, si, ei) in enumerate(zip(ns, ss, es)):
                lon_off = np.zeros((ni, nlb), dtype=ts.vis.dtype)
                for ii, (ti, fi, pi) in enumerate(tfp_inds[si:ei]):
                    si_on, ei_on = on_si[ti], on_ei[ti]
                    # mean of ON - mean of OFF
                    if ei_on - si_on > 3:
                        # does not use the two ends if there are more than three ONs
                        lon_off[ii] = np.mean(ts.local_vis[si_on+1:ei_on-1, fi, pi], axis=0) - np.ma.mean(np.ma.array(ts.local_vis[si_on-num_mean-1:si_on-1, fi, pi], mask=ts.local_vis_mask[si_on-num_mean-1:si_on-1, fi, pi]), axis=0)
                    else:
                        lon_off[ii] = np.mean(ts.local_vis[si_on:ei_on, fi, pi], axis=0) - np.ma.mean(np.ma.array(ts.local_vis[si_on-num_mean-1:si_on-1, fi, pi], mask=ts.local_vis_mask[si_on-num_mean-1:si_on-1, fi, pi]), axis=0)

                # gather on_off from all process for separate bls
                on_off = mpiutil.gather_array(lon_off, axis=1, root=ri, comm=ts.comm)
                if ri == mpiutil.rank:
                    tfp_linds = tfp_inds[si:ei] # inds for this process
                    this_on_off = on_off
            del tfp_inds
            del lon_off
            tfp_len = len(tfp_linds)


            cnan = complex(np.nan, np.nan) # complex nan
            if save_ns_vis:
                # save the extracted noise source vis
                lsrc_vis = np.full((tfp_len, nfeed, nfeed), cnan, dtype=ts.vis.dtype)
                # save sky vis
                lsky_vis = np.full((tfp_len, nfeed, nfeed), cnan, dtype=ts.vis.dtype)
                # save outlier vis
                lotl_vis = np.full((tfp_len, nfeed, nfeed), cnan, dtype=ts.vis.dtype)

            if apply_gain or save_gain:
                lgain = np.zeros((tfp_len, nfeed), dtype=ts.vis.dtype)
                lgain_mask = np.zeros((tfp_len, nfeed), dtype=bool)

            # construct visibility matrix for a single time, freq, pol
            Vmat = np.full((nfeed, nfeed), cnan, dtype=ts.vis.dtype)

            if show_progress and mpiutil.rank0:
                pg = progress.Progress(len(tfp_linds), step=progress_step)

            for ii, (ti, fi, pi) in enumerate(tfp_linds):
                if show_progress and mpiutil.rank0:
                    pg.show(ii)

                Vmat.flat[mis] = this_on_off[ii]
                Vmat.flat[mis_conj] = this_on_off[ii, bis_conj].conj()

                if save_ns_vis:
                    lsky_vis[ii] = Vmat

                # initialize the outliers
                med = np.median(Vmat.real) + 1.0J * np.median(Vmat.imag)
                diff = Vmat - med
                S0 = np.where(np.abs(diff)>3.0*rpca_decomp.MAD(Vmat), diff, 0)
                # stable PCA decomposition
                V0, S = rpca_decomp.decompose(Vmat, rank=1, S=S0, max_iter=100, threshold='hard', tol=1.0e-6, debug=False)
                if save_ns_vis:
                    lsrc_vis[ii] = V0
                    lotl_vis[ii] = S

                if apply_gain or save_gain:
                    e, U = la.eigh(V0, eigvals=(nfeed-1, nfeed-1))
                    g = U[:, -1] * e[-1]**0.5
                    # g = U[:, -1] * nfeed**0.5 # to make g_i g_j^* ~ 1
                    if g[0].real < 0:
                        g *= -1.0 # make all g[0] phase 0, instead of pi
                    lgain[ii] = g
                    ### maybe does not flag abnormal values here to simplify the programming, the flag can be down in ps_cal
                    gabs = np.abs(g)
                    gmed = np.median(gabs)
                    gabs_diff = np.abs(gabs - gmed)
                    gmad = np.median(gabs_diff) / 0.6745
                    lgain_mask[ii, np.where(gabs_diff>3.0*gmad)[0]] = True # mask invalid feeds

            if save_ns_vis:
                if tag_output_iter:
                    ns_vis_file = output_path(ns_vis_file, iteration=self.iteration)
                else:
                    ns_vis_file = output_path(ns_vis_file)
                # create file and allocate space first by rank0
                if mpiutil.rank0:
                    with h5py.File(ns_vis_file, 'w') as f:
                        # allocate space
                        shp = (num_on, nf, npol, nfeed, nfeed)
                        f.create_dataset('sky_vis', shp, dtype=lsky_vis.dtype)
                        f.create_dataset('src_vis', shp, dtype=lsrc_vis.dtype)
                        f.create_dataset('outlier_vis', shp, dtype=lotl_vis.dtype)
                        f.attrs['dim'] = 'time, freq, pol, feed, feed'
                        try:
                            f.attrs['time_inds'] = (on_si + on_ei) / 2
                        except RuntimeError:
                            f.create_dataset('time_inds', data=(on_si + on_ei)/2)
                            f.attrs['time_inds'] = '/time_inds'
                        f.attrs['freq'] = ts.freq
                        f.attrs['pol'] = ts.pol
                        f.attrs['feed'] = np.array(feedno)

                mpiutil.barrier()

                # write data to file
                for i in range(10):
                    try:
                        # NOTE: if write simultaneously, will loss data with processes distributed in several nodes
                        for ri in xrange(mpiutil.size):
                            if ri == mpiutil.rank:
                                with h5py.File(ns_vis_file, 'r+') as f:
                                    for ii, (ti, fi, pi) in enumerate(tfp_linds):
                                        f['sky_vis'][ti, fi, pi] = lsky_vis[ii]
                                        f['src_vis'][ti, fi, pi] = lsrc_vis[ii]
                                        f['outlier_vis'][ti, fi, pi] = lotl_vis[ii]
                            mpiutil.barrier()
                        break
                    except IOError:
                        time.sleep(0.5)
                        continue
                else:
                    raise RuntimeError('Could not open file: %s...' % src_vis_file)

                del lsrc_vis
                del lsky_vis
                del lotl_vis

                mpiutil.barrier()

            if apply_gain or save_gain:
                gain = mpiutil.gather_array(lgain, axis=0, root=None, comm=ts.comm)
                gain_mask = mpiutil.gather_array(lgain_mask, axis=0, root=None, comm=ts.comm)
                del lgain
                del lgain_mask
                gain = gain.reshape(num_on, nf, npol, nfeed)
                gain_mask = gain_mask.reshape(num_on, nf, npol, nfeed)

                # normalize gain to make its amp ~ 1
                gain_med = np.ma.median(np.ma.array(np.abs(gain), mask=gain_mask))
                gain /= gain_med

                # phi = np.angle(gain)

                # delta_phi = np.zeros((num_on, nf, npol))

                # # get phase change
                # for ti in range(1, num_on):
                #     delta_phi[ti] = np.ma.mean(np.ma.array(phi[ti], mask=gain_mask[ti]) - np.ma.array(phi[ti-1], mask=gain_mask[ti-1]), axis=2)

                # # save original gain
                # gain_original = gain.copy()
                # # compensate phase changes
                # gain *= np.exp(1.0J * delta_phi[:, :, :, np.newaxis])

                gain_alltimes = np.full((nt, nf, npol, nfeed), cnan, dtype=gain.dtype)
                gain_alltimes_mask = np.zeros((nt, nf, npol, nfeed), dtype=bool)

                # interpolate to all time points
                for fi in range(nf):
                    for pi in range(npol):
                        for di in range(nfeed):
                            valid_inds = np.where(np.logical_not(gain_mask[:, fi, pi, di]))[0]
                            if len(valid_inds) < 0.75 * num_on:
                                # no enough points to do good interpolation
                                gain_alltimes_mask[:, fi, pi, di] = True
                            else:
                                # gain_alltimes[:, fi, pi, di] = InterpolatedUnivariateSpline(cal_inds[valid_inds], gain[valid_inds, fi, pi, di].real)(np.arange(nt)) + 1.0J * InterpolatedUnivariateSpline(cal_inds[valid_inds], gain[valid_inds, fi, pi, di].imag)(np.arange(nt))
                                # interpolate amp and phase to avoid abrupt changes
                                amp = InterpolatedUnivariateSpline(cal_inds[valid_inds], np.abs(gain[valid_inds, fi, pi, di]))(np.arange(nt))
                                phs = InterpolatedUnivariateSpline(cal_inds[valid_inds], np.unwrap(np.angle(gain[valid_inds, fi, pi, di])))(np.arange(nt))
                                gain_alltimes[:, fi, pi, di] = amp * np.exp(1.0J * phs)

                # apply gain to vis
                for fi in range(nf):
                    for pi in range(npol):
                        for bi, (fd1, fd2) in enumerate(ts['blorder'].local_data):
                            g1 = gain_alltimes[:, fi, pi, feedno.index(fd1)]
                            g1_mask = gain_alltimes_mask[:, fi, pi, feedno.index(fd1)]
                            g2 = gain_alltimes[:, fi, pi, feedno.index(fd2)]
                            g2_mask = gain_alltimes_mask[:, fi, pi, feedno.index(fd2)]
                            g12 = g1 * np.conj(g2)
                            g12_mask = np.logical_or(g1_mask, g2_mask)

                            if fd1 == fd2:
                                # auto-correlation should be real
                                ts.local_vis[:, fi, pi, bi] /= g12.real
                            else:
                                ts.local_vis[:, fi, pi, bi] /= g12
                            ts.local_vis_mask[:, fi, pi, bi] = np.logical_or(ts.local_vis_mask[:, fi, pi, bi], g12_mask)


                if save_gain:
                    if tag_output_iter:
                        gain_file = output_path(gain_file, iteration=self.iteration)
                    else:
                        gain_file = output_path(gain_file)
                    if mpiutil.rank0:
                        with h5py.File(gain_file, 'w') as f:
                            # allocate space for Gain
                            # dset = f.create_dataset('gain', data=gain_original) # gain without phase compensation
                            dset = f.create_dataset('gain', data=gain) # gain without phase compensation
                            # f.create_dataset('delta_phi', data=delta_phi)
                            f.create_dataset('gain_mask', data=gain_mask)
                            dset.attrs['dim'] = 'time, freq, pol, feed'
                            try:
                                dset.attrs['time_inds'] = cal_inds
                            except RuntimeError:
                                f.create_dataset('time_inds', data=cal_inds)
                                dset.attrs['time_inds'] = '/time_inds'
                            dset.attrs['freq'] = ts.freq
                            dset.attrs['pol'] = ts.pol
                            dset.attrs['feed'] = np.array(feedno)
                            dset.attrs['gain_med'] = gain_med # record the normalization factor
                            # save gain_alltimes
                            dset = f.create_dataset('gain_alltimes', data=gain_alltimes)
                            f.create_dataset('gain_alltimes_mask', data=gain_alltimes_mask)
                            dset.attrs['dim'] = 'time, freq, pol, feed'
                            try:
                                dset.attrs['time_inds'] = np.arange(nt)
                            except RuntimeError:
                                f.create_dataset('all_time_inds', data=np.arange(nt))
                                dset.attrs['time_inds'] = '/all_time_inds'
                            dset.attrs['freq'] = ts.freq
                            dset.attrs['pol'] = ts.pol
                            dset.attrs['feed'] = np.array(feedno)

                            f.create_dataset('time', data=ts.local_time)

                    mpiutil.barrier()


        return super(NsCal, self).process(ts)
Exemple #6
0
    def load_tod_excl_main_data(self):
        """Load time ordered attributes and datasets (exclude the main data) from all files."""

        super(TimestreamCommon, self).load_tod_excl_main_data()

        if 'sec1970' not in self.iterkeys():
            # generate sec1970
            int_time = self.infiles[0].attrs['inttime']
            sec1970s = []
            nts = []
            for fh in mpiutil.mpilist(self.infiles, method='con', comm=self.comm):
                sec1970s.append(fh.attrs['sec1970'])
                nts.append(fh[self.main_data_name].shape[0])
            sec1970 = np.zeros(sum(nts), dtype=np.float64) # precision float32 is not enough
            cum_nts = np.cumsum([0] + nts)
            for idx, (nt, sec) in enumerate(zip(nts, sec1970s)):
                sec1970[cum_nts[idx]:cum_nts[idx+1]] = np.array([ sec + i*int_time for i in xrange(nt)], dtype=np.float64) # precision float32 is not enough
            # gather local sec1970
            sec1970 = mpiutil.gather_array(sec1970, root=None, comm=self.comm)
            # select the corresponding section
            sec1970 = sec1970[self.main_data_start:self.main_data_stop][self.main_data_select[0]]

            # if time is just the distributed axis, load sec1970 distributed
            if 'time' == self.main_data_axes[self.main_data_dist_axis]:
                sec1970 = mpiarray.MPIArray.from_numpy_array(sec1970)
            self.create_main_time_ordered_dataset('sec1970', data=sec1970)
            # create attrs of this dset
            self['sec1970'].attrs["unit"] = 'second'
            # determine if it is continuous in time
            sec_diff = np.diff(sec1970)
            break_inds = np.where(sec_diff>1.5*int_time)[0]
            if len(break_inds) > 0:
                self['sec1970'].attrs["continuous"] = False
                self['sec1970'].attrs["break_inds"] = break_inds + 1
            else:
                self['sec1970'].attrs["continuous"] = True

            # generate julian date
            jul_date = np.array([ date_util.get_juldate(datetime.fromtimestamp(s), tzone=self.infiles[0].attrs['timezone']) for s in sec1970 ], dtype=np.float64) # precision float32 is not enough
            if 'time' == self.main_data_axes[self.main_data_dist_axis]:
                jul_date = mpiarray.MPIArray.wrap(jul_date, axis=0)
            # if time is just the distributed axis, load jul_date distributed
            self.create_main_time_ordered_dataset('jul_date', data=jul_date)
            # create attrs of this dset
            self['jul_date'].attrs["unit"] = 'day'

            # generate local time in hour from 0 to 24.0
            def _hour(t):
                return t.hour + t.minute/60.0 + t.second/3600.0 + t.microsecond/3.6e8
            local_hour = np.array([ _hour(datetime.fromtimestamp(s).time()) for s in sec1970 ], dtype=np.float64)
            if 'time' == self.main_data_axes[self.main_data_dist_axis]:
                local_hour = mpiarray.MPIArray.wrap(local_hour, axis=0)
            # if time is just the distributed axis, load local_hour distributed
            self.create_main_time_ordered_dataset('local_hour', data=local_hour)
            # create attrs of this dset
            self['local_hour'].attrs["unit"] = 'hour'

            # generate az, alt
            az_alt = np.zeros((self['sec1970'].local_data.shape[0], 2), dtype=np.float32) # radians
            if self.is_dish:
                # antpointing = rt['antpointing'][-1, :, :] # degree
                # pointingtime = rt['pointingtime'][-1, :, :] # degree
                az_alt[:, 0] = 0.0 # az
                az_alt[:, 1] = np.pi/2 # alt
            elif self.is_cylinder:
                az_alt[:, 0] = np.pi/2 # az
                az_alt[:, 1] = np.pi/2 # alt
            else:
                raise RuntimeError('Unknown antenna type %s' % self.attrs['telescope'])

            # generate ra, dec of the antenna pointing
            aa = self.array
            ra_dec = np.zeros_like(az_alt) # radians
            for ti in xrange(az_alt.shape[0]):
                az, alt = az_alt[ti]
                az, alt = ephem.degrees(az), ephem.degrees(alt)
                aa.set_jultime(self['jul_date'].local_data[ti])
                ra_dec[ti] = aa.radec_of(az, alt) # in radians, a point in the sky above the observer

            if self.main_data_dist_axis == 0:
                az_alt = mpiarray.MPIArray.wrap(az_alt, axis=0)
                ra_dec = mpiarray.MPIArray.wrap(ra_dec, axis=0)
            # if time is just the distributed axis, create distributed datasets
            self.create_main_time_ordered_dataset('az_alt', data=az_alt)
            self['az_alt'].attrs['unit'] = 'radian'
            self.create_main_time_ordered_dataset('ra_dec', data=ra_dec)
            self['ra_dec'].attrs['unit'] = 'radian'

            # determin if it is the same pointing
            if self.main_data_dist_axis == 0:
                az_alt = az_alt.local_array
                ra_dec = ra_dec.local_array
            # gather local az_alt
            az_alt = mpiutil.gather_array(az_alt, root=None, comm=self.comm)
            if np.allclose(az_alt[:, 0], az_alt[0, 0]) and np.allclose(az_alt[:, 1], az_alt[0, 1]):
                self['az_alt'].attrs['same_pointing'] = True
            else:
                self['az_alt'].attrs['same_pointing'] = False
            # determin if it is the same dec
            # gather local ra_dec
            ra_dec = mpiutil.gather_array(ra_dec, root=None, comm=self.comm)
            if np.allclose(ra_dec[:, 1], ra_dec[0, 1]):
                self['ra_dec'].attrs['same_dec'] = True
            else:
                self['ra_dec'].attrs['same_dec'] = False
Exemple #7
0
    def read_input(self):
        """Method for (maybe iteratively) reading data from input data files."""

        days = self.params['days']
        extra_inttime = self.params['extra_inttime']
        drop_days = self.params['drop_days']
        mode = self.params['mode']
        dist_axis = self.params['dist_axis']

        ngrp = len(self.input_grps)

        if self.next_grp:
            if mpiutil.rank0 and ngrp > 1:
                print 'Start file group %d of %d...' % (self.grp_cnt, ngrp)
            self.restart_iteration()  # re-start iteration for each group
            self.next_grp = False
            self.abs_start = None
            self.abs_stop = None

        input_files = self.input_grps[self.grp_cnt]
        start = self.start[self.grp_cnt]
        stop = self.stop[self.grp_cnt]

        if self.int_time is None:
            # NOTE: here assume all files have the same int_time
            with h5py.File(self.input_files[0], 'r') as f:
                self.int_time = f.attrs['inttime']

        if self.abs_start is None or self.abs_stop is None:
            tmp_tod = self._Tod_class(input_files, mode, start, stop,
                                      dist_axis)
            self.abs_start = tmp_tod.main_data_start
            self.abs_stop = tmp_tod.main_data_stop
            del tmp_tod

        iteration = self.iteration if self.iterable else 0
        this_start = self.abs_start + np.int(
            np.around(iteration * days * const.sday / self.int_time))
        this_stop = min(
            self.abs_stop, self.abs_start + np.int(
                np.around(
                    (iteration + 1) * days * const.sday / self.int_time)) +
            2 * extra_inttime)
        if this_stop >= self.abs_stop:
            self.next_grp = True
            self.grp_cnt += 1
        if this_start >= this_stop:
            self.next_grp = True
            self.grp_cnt += 1
            return None

        this_span = self.int_time * (this_stop - this_start)  # in unit second
        if this_span < drop_days * const.sday:
            if mpiutil.rank0:
                print 'Not enough span time, drop it...'
            return None
        elif (this_stop - this_start) <= extra_inttime:  # use int comparision
            if mpiutil.rank0:
                print 'Not enough span time (less than `extra_inttime`), drop it...'
            return None

        tod = self._Tod_class(input_files, mode, this_start, this_stop,
                              dist_axis)

        tod = self.data_select(tod)

        tod.load_all()  # load in all data

        if self.start_ra is None:  # the first iteration
            ra_dec = mpiutil.gather_array(tod['ra_dec'].local_data, root=None)
            self.start_ra = ra_dec[extra_inttime, 0]
        tod.vis.attrs['start_ra'] = self.start_ra  # used for re_order

        return tod
Exemple #8
0
    def process(self, ts):

        mask_daytime = self.params['mask_daytime']
        mask_time_range = self.params['mask_time_range']
        tsys = self.params['tsys']
        accuracy_boost = self.params['accuracy_boost']
        l_boost = self.params['l_boost']
        bl_range = self.params['bl_range']
        auto_correlations = self.params['auto_correlations']
        time_avg = self.params['time_avg']
        pol = self.params['pol']
        interp = self.params['interp']
        beam_dir = output_path(self.params['beam_dir'])
        use_existed_beam = self.params['use_existed_beam']
        gen_inv = self.params['gen_invbeam']
        noise_weight = self.params['noise_weight']
        ts_dir = output_path(self.params['ts_dir'])
        ts_name = self.params['ts_name']
        no_m_zero = self.params['no_m_zero']
        simulate = self.params['simulate']
        input_maps = self.params['input_maps']
        prior_map = self.params['prior_map']
        add_noise = self.params['add_noise']
        dirty_map = self.params['dirty_map']
        nbin = self.params['nbin']
        method = self.params['method']
        normalize = self.params['normalize']
        threshold = self.params['threshold']
        eps = self.params['epsilon']
        correct_order = self.params['correct_order']

        if use_existed_beam:
            # load the saved telescope from disk
            tel = None
        else:
            assert isinstance(ts, Timestream), '%s only works for Timestream object' % self.__class__.__name__

            ts.redistribute('baseline')

            lat = ts.attrs['sitelat']
            # lon = ts.attrs['sitelon']
            lon = 0.0
            # lon = np.degrees(ts['ra_dec'][0, 0]) # the first ra
            local_origin = False
            freqs = ts.freq[:] # MHz
            nfreq = freqs.shape[0]
            band_width = ts.attrs['freqstep'] # MHz
            try:
                ndays = ts.attrs['ndays']
            except KeyError:
                ndays = 1
            feeds = ts['feedno'][:]
            bl_order = mpiutil.gather_array(ts.local_bl, axis=0, root=None, comm=ts.comm)
            bls = [ tuple(bl) for bl in bl_order ]
            az, alt = ts['az_alt'][0]
            az = np.degrees(az)
            alt = np.degrees(alt)
            pointing = [az, alt, 0.0]
            feedpos = ts['feedpos'][:]

            if ts.is_dish:
                from tlpipe.map.drift.telescope import tl_dish

                dish_width = ts.attrs['dishdiam']
                tel = tl_dish.TlUnpolarisedDishArray(lat, lon, freqs, band_width, tsys, ndays, accuracy_boost, l_boost, bl_range, auto_correlations, local_origin, dish_width, feedpos, pointing)
            elif ts.is_cylinder:
                from tlpipe.map.drift.telescope import tl_cylinder

                # factor = 1.2 # suppose an illumination efficiency, keep same with that in timestream_common
                factor = 0.79 # for xx
                # factor = 0.88 # for yy
                cyl_width = factor * ts.attrs['cywid']
                tel = tl_cylinder.TlUnpolarisedCylinder(lat, lon, freqs, band_width, tsys, ndays, accuracy_boost, l_boost, bl_range, auto_correlations, local_origin, cyl_width, feedpos)
            else:
                raise RuntimeError('Unknown array type %s' % ts.attrs['telescope'])

            if not simulate:
                # select the corresponding vis and vis_mask
                if pol == 'xx':
                    local_vis = ts.local_vis[:, :, 0, :]
                    local_vis_mask = ts.local_vis_mask[:, :, 0, :]
                elif pol == 'yy':
                    local_vis = ts.local_vis[:, :, 1, :]
                    local_vis_mask = ts.local_vis_mask[:, :, 1, :]
                elif pol == 'I':
                    xx_vis = ts.local_vis[:, :, 0, :]
                    xx_vis_mask = ts.local_vis_mask[:, :, 0, :]
                    yy_vis = ts.local_vis[:, :, 1, :]
                    yy_vis_mask = ts.local_vis_mask[:, :, 1, :]

                    local_vis = np.zeros_like(xx_vis)
                    for ti in xrange(local_vis.shape[0]):
                        for fi in xrange(local_vis.shape[1]):
                            for bi in xrange(local_vis.shape[2]):
                                if xx_vis_mask[ti, fi, bi] != yy_vis_mask[ti, fi, bi]:
                                    if xx_vis_mask[ti, fi, bi]:
                                        local_vis[ti, fi, bi] = yy_vis[ti, fi, bi]
                                    else:
                                        local_vis[ti, fi, bi] = xx_vis[ti, fi, bi]
                                else:
                                    local_vis[ti, fi, bi] = 0.5 * (xx_vis[ti, fi, bi] + yy_vis[ti, fi, bi])
                    local_vis_mask = xx_vis_mask | yy_vis_mask
                else:
                    raise ValueError('Invalid pol: %s' % pol)

                if interp != 'none':
                    for fi in xrange(local_vis.shape[1]):
                        for bi in xrange(local_vis.shape[2]):
                            # interpolate for local_vis
                            true_inds = np.where(local_vis_mask[:, fi, bi])[0] # masked inds
                            if len(true_inds) > 0:
                                false_inds = np.where(~local_vis_mask[:, fi, bi])[0] # un-masked inds
                                if len(false_inds) > 0.1 * local_vis.shape[0]:
                # nearest interpolate for local_vis
                                    if interp in ('linear', 'nearest'):
                                        itp_real = interp1d(false_inds, local_vis[false_inds, fi, bi].real, kind=interp, fill_value='extrapolate', assume_sorted=True)
                                        itp_imag = interp1d(false_inds, local_vis[false_inds, fi, bi].imag, kind=interp, fill_value='extrapolate', assume_sorted=True)
                                    elif interp == 'rbf':
                                        itp_real = Rbf(false_inds, local_vis[false_inds, fi, bi].real, smooth=10)
                                        itp_imag = Rbf(false_inds, local_vis[false_inds, fi, bi].imag, smooth=10)
                                    else:
                                        raise ValueError('Unknown interpolation method: %s' % interp)
                                    local_vis[true_inds, fi, bi] = itp_real(true_inds) + 1.0J * itp_imag(true_inds) # the interpolated vis
                                else:
                                    local_vis[:, fi, bi] = 0 # TODO: may need to take special care

                # average data
                nt = ts['sec1970'].shape[0]
                phi_size = 2*tel.mmax + 1

                # phi = np.zeros((phi_size,), dtype=ts['ra_dec'].dtype)
                phi = np.linspace(0, 2*np.pi, phi_size, endpoint=False)
                vis = np.zeros((phi_size,)+local_vis.shape[1:], dtype=local_vis.dtype)

                if time_avg == 'avg':
                    nt_m = float(nt) / phi_size
                    # roll data to have phi=0 near the first
                    roll_len = np.int(np.around(0.5*nt_m))
                    local_vis[:] = np.roll(local_vis[:], roll_len, axis=0)
                    if interp == 'none':
                        local_vis_mask[:] = np.roll(local_vis_mask[:], roll_len, axis=0)
                    # ts['ra_dec'][:] = np.roll(ts['ra_dec'][:], roll_len, axis=0)

                    repeat_inds = np.repeat(np.arange(nt), phi_size)
                    num, start, end = mpiutil.split_m(nt*phi_size, phi_size)

                    # average over time
                    for idx in xrange(phi_size):
                        inds, weight = unique(repeat_inds[start[idx]:end[idx]], return_counts=True)
                        if interp == 'none':
                            vis[idx] = average(np.ma.array(local_vis[inds], mask=local_vis_mask[inds]), axis=0, weights=weight) # time mean
                        else:
                            vis[idx] = average(local_vis[inds], axis=0, weights=weight) # time mean
                        # phi[idx] = np.average(ts['ra_dec'][:, 0][inds], axis=0, weights=weight)
                elif time_avg == 'fft':
                    if interp == 'none':
                        raise ValueError('Can not do fft average without first interpolation')
                    Vm = np.fft.fftshift(np.fft.fft(local_vis, axis=0), axes=0)
                    vis[:] = np.fft.ifft(np.fft.ifftshift(Vm[nt/2-tel.mmax:nt/2+tel.mmax+1], axes=0), axis=0) / (1.0 * nt / phi_size)

                    # for fi in xrange(vis.shape[1]):
                    #     for bi in xrange(vis.shape[2]):
                    #         # plot local_vis and vis
                    #         import matplotlib
                    #         matplotlib.use('Agg')
                    #         import matplotlib.pyplot as plt

                    #         phi0 = np.linspace(0, 2*np.pi, nt, endpoint=False)
                    #         phi1 = np.linspace(0, 2*np.pi, phi_size, endpoint=False)
                    #         plt.figure()
                    #         plt.subplot(211)
                    #         plt.plot(phi0, local_vis[:, fi, bi].real, label='v0.real')
                    #         plt.plot(phi1, vis[:, fi, bi].real, label='v1.real')
                    #         plt.legend()
                    #         plt.subplot(212)
                    #         plt.plot(phi0, local_vis[:, fi, bi].imag, label='v0.imag')
                    #         plt.plot(phi1, vis[:, fi, bi].imag, label='v1.imag')
                    #         plt.legend()
                    #         plt.savefig('vis_fft/vis_%d_%d.png' % (fi, bi))
                    #         plt.close()

                else:
                    raise ValueError('Unknown time_avg: %s' % time_avg)

                del local_vis
                del local_vis_mask

                # mask daytime data
                if mask_daytime:
                    day_or_night = np.where(ts['local_hour'][:]>=mask_time_range[0] & ts['local_hour'][:]<=mask_time_range[1], True, False)
                    day_inds = np.where(np.repeat(day_or_night, phi_size).reshape(nt, phi_size).astype(np.int).sum(axis=1).astype(bool))[0]
                    vis[day_inds] = 0

                del ts # no longer need ts

                # redistribute vis to time axis
                vis = mpiarray.MPIArray.wrap(vis, axis=2).redistribute(0).local_array

                allpairs = tel.allpairs
                redundancy = tel.redundancy
                nrd = len(redundancy)

                # reorder bls according to allpairs
                vis_tmp = np.zeros_like(vis)
                for ind, (a1, a2) in enumerate(allpairs):
                    try:
                        b_ind = bls.index((feeds[a1], feeds[a2]))
                        vis_tmp[:, :, ind] = vis[:, :, b_ind]
                    except ValueError:
                        b_ind = bls.index((feeds[a2], feeds[a1]))
                        vis_tmp[:, :, ind] = vis[:, :, b_ind].conj()

                del vis

                # average over redundancy
                vis_stream = np.zeros(vis_tmp.shape[:-1]+(nrd,), dtype=vis_tmp.dtype)
                red_bin = np.cumsum(np.insert(redundancy, 0, 0)) # redundancy bin
                # average over redundancy
                for ind in xrange(nrd):
                    vis_stream[:, :, ind] = np.sum(vis_tmp[:, :, red_bin[ind]:red_bin[ind+1]], axis=2) / redundancy[ind]

                del vis_tmp

        # beamtransfer
        bt = beamtransfer.BeamTransfer(beam_dir, tel, noise_weight, True)
        if not use_existed_beam:
            bt.generate()
        if tel is None:
            tel = bt.telescope

        if simulate:
            ndays = 733
            tstream = timestream.simulate(bt, ts_dir, ts_name, input_maps, ndays, add_noise=add_noise)
        else:
            # timestream and map-making
            tstream = timestream.Timestream(ts_dir, ts_name, bt, no_m_zero)
            parent_path = os.path.dirname(tstream._fdir(0))

            if os.path.exists(parent_path + '/COMPLETED'):
                if mpiutil.rank0:
                    print 'Use existed timestream_f files in %s' % parent_path
            else:
                for fi in mpiutil.mpirange(nfreq):
                    # Make directory if required
                    if not os.path.exists(tstream._fdir(fi)):
                        os.makedirs(tstream._fdir(fi))

                # create memh5 object and write data to temporary file
                vis_h5 = memh5.MemGroup(distributed=True)
                vis_h5.create_dataset('/timestream', data=mpiarray.MPIArray.wrap(vis_stream, axis=0))
                tmp_file = parent_path +'/vis_stream_temp.hdf5'
                vis_h5.to_hdf5(tmp_file, hints=False)
                del vis_h5

                # re-organize data as need for tstream
                # make load even among nodes
                for fi in mpiutil.mpirange(nfreq, method='rand'):
                    # read the needed data from the temporary file
                    with h5py.File(tmp_file, 'r') as f:
                        vis_fi = f['/timestream'][:, fi, :]
                    # Write file contents
                    with h5py.File(tstream._ffile(fi), 'w') as f:
                        # Timestream data
                        # allocate space for vis_stream
                        shp = (nrd, phi_size)
                        f.create_dataset('/timestream', data=vis_fi.T)
                        f.create_dataset('/phi', data=phi)

                        # Telescope layout data
                        f.create_dataset('/feedmap', data=tel.feedmap)
                        f.create_dataset('/feedconj', data=tel.feedconj)
                        f.create_dataset('/feedmask', data=tel.feedmask)
                        f.create_dataset('/uniquepairs', data=tel.uniquepairs)
                        f.create_dataset('/baselines', data=tel.baselines)

                        # Telescope frequencies
                        f.create_dataset('/frequencies', data=freqs)

                        # Write metadata
                        f.attrs['beamtransfer_path'] = os.path.abspath(bt.directory)
                        f.attrs['ntime'] = phi_size

                mpiutil.barrier()

                # remove temp file
                if mpiutil.rank0:
                    os.remove(tmp_file)
                    # mark all frequencies tstream files are saved correctly
                    open(parent_path + '/COMPLETED', 'a').close()

        tstream.generate_mmodes()
        nside = hputil.nside_for_lmax(tel.lmax, accuracy_boost=tel.accuracy_boost)
        if dirty_map:
            tstream.mapmake_full(nside, 'map_full_dirty.hdf5', nbin, dirty=True, method=method, normalize=normalize, threshold=threshold)
        else:
            tstream.mapmake_full(nside, 'map_full.hdf5', nbin, dirty=False, method=method, normalize=normalize, threshold=threshold, eps=eps, correct_order=correct_order, prior_map_file=prior_map)

        # ts.add_history(self.history)

        return tstream
Exemple #9
0
    def process(self, ts):

        calibrator = self.params['calibrator']
        catalog = self.params['catalog']
        span = self.params['span']
        save_gain = self.params['save_gain']
        gain_file = self.params['gain_file']

        ts.redistribute('frequency')

        lfreq = ts.local_freq[:] # local freq

        feedno = ts['feedno'][:].tolist()
        pol = ts['pol'][:].tolist()
        bl = ts.bl[:]
        bls = [ tuple(b) for b in bl ]
        # # antpointing = np.radians(ts['antpointing'][-1, :, :]) # radians
        # transitsource = ts['transitsource'][:]
        # transit_time = transitsource[-1, 0] # second, sec1970
        # int_time = ts.attrs['inttime'] # second

        # calibrator
        srclist, cutoff, catalogs = a.scripting.parse_srcs(calibrator, catalog)
        cat = a.src.get_catalog(srclist, cutoff, catalogs)
        assert(len(cat) == 1), 'Allow only one calibrator'
        s = cat.values()[0]
        if mpiutil.rank0:
            print 'Calibrating for source %s with' % calibrator,
            print 'strength', s._jys, 'Jy',
            print 'measured at', s.mfreq, 'GHz',
            print 'with index', s.index

        # get transit time of calibrator
        # array
        aa = ts.array
        aa.set_jultime(ts['jul_date'][0]) # the first obs time point
        next_transit = aa.next_transit(s)
        transit_time = a.phs.ephem2juldate(next_transit) # Julian date
        if transit_time > ts['jul_date'][-1]:
            local_next_transit = ephem.Date(next_transit + 8.0 * ephem.hour)
            raise RuntimeError('Data does not contain local transit time %s of source %s' % (local_next_transit, calibrator))

        # the first transit index
        transit_inds = [ np.searchsorted(ts['jul_date'][:], transit_time) ]
        # find all other transit indices
        aa.set_jultime(ts['jul_date'][0] + 1.0)
        transit_time = a.phs.ephem2juldate(aa.next_transit(s)) # Julian date
        cnt = 2
        while(transit_time <= ts['jul_date'][-1]):
            transit_inds.append(np.searchsorted(ts['jul_date'][:], transit_time))
            aa.set_jultime(ts['jul_date'][0] + 1.0*cnt)
            transit_time = a.phs.ephem2juldate(aa.next_transit(s)) # Julian date
            cnt += 1

        print transit_inds

        ### now only use the first transit point to do the cal
        ### may need to improve in the future
        transit_ind = transit_inds[0]
        int_time = ts.attrs['inttime'] # second
        start_ind = transit_ind - np.int(span / int_time)
        end_ind = transit_ind + np.int(span / int_time)

        nt = end_ind - start_ind
        nfeed = len(feedno)
        eigval = np.empty((nt, nfeed, 2, len(lfreq)), dtype=np.float64)
        eigval[:] = np.nan
        gain = np.empty((nt, nfeed, 2, len(lfreq)), dtype=np.complex128)
        gain[:] = complex(np.nan, np.nan)

        # construct visibility matrix for a single time, pol, freq
        Vmat = np.zeros((nfeed, nfeed), dtype=ts.main_data.dtype)
        for ind, ti in enumerate(range(start_ind, end_ind)):
            # when noise on, just pass
            if ts['ns_on'][ti]:
                continue
            aa.set_jultime(ts['jul_date'][ti])
            s.compute(aa)
            # get fluxes vs. freq of the calibrator
            Sc = s.get_jys()
            # get the topocentric coordinate of the calibrator at the current time
            s_top = s.get_crds('top', ncrd=3)
            aa.sim_cache(cat.get_crds('eq', ncrd=3)) # for compute bm_response and sim
            for pi in [pol.index('xx'), pol.index('yy')]: # xx, yy
                aa.set_active_pol(pol[pi])
                for fi, freq in enumerate(lfreq): # mpi among freq
                    for i, ai in enumerate(feedno):
                        for j, aj in enumerate(feedno):
                            # uij = aa.gen_uvw(i, j, src='z').squeeze() # (rj - ri)/lambda
                            uij = aa.gen_uvw(i, j, src='z')[:, 0, :] # (rj - ri)/lambda
                            # bmij = aa.bm_response(i, j).squeeze() # will get error for only one local freq
                            # import pdb
                            # pdb.set_trace()
                            bmij = aa.bm_response(i, j).reshape(-1)
                            try:
                                bi = bls.index((ai, aj))
                                # Vmat[i, j] = ts.local_vis[ti, fi, pi, bi] / (Sc[fi] * bmij[fi] * np.exp(-2.0J * np.pi * np.dot(s_top, uij[:, fi]))) # xx, yy
                                Vmat[i, j] = ts.local_vis[ti, fi, pi, bi] / (Sc[fi] * bmij[fi] * np.exp(2.0J * np.pi * np.dot(s_top, uij[:, fi]))) # xx, yy
                            except ValueError:
                                bi = bls.index((aj, ai))
                                # Vmat[i, j] = np.conj(ts.local_vis[ti, fi, pi, bi] / (Sc[fi] * bmij[fi] * np.exp(-2.0J * np.pi * np.dot(s_top, uij[:, fi])))) # xx, yy
                                Vmat[i, j] = np.conj(ts.local_vis[ti, fi, pi, bi] / (Sc[fi] * bmij[fi] * np.exp(2.0J * np.pi * np.dot(s_top, uij[:, fi])))) # xx, yy

                    # Eigen decomposition

                    Vmat = np.where(np.isfinite(Vmat), Vmat, 0)

                    e, U = eigh(Vmat)
                    eigval[ind, :, pi, fi] = e[::-1] # descending order
                    # max eigen-val
                    lbd = e[-1] # lambda
                    # the gain vector for this freq
                    gvec = np.sqrt(lbd) * U[:, -1] # only eigen-vector corresponding to the maximum eigen-val
                    gain[ind, :, pi, fi] = gvec

        # apply gain to vis
        # get the time mean gain
        tgain = np.ma.mean(np.ma.masked_invalid(gain), axis=0) # time mean
        tgain = mpiutil.gather_array(tgain, axis=-1, root=None)

        ts.redistribute('baseline')
        ts.pol_and_bl_data_operate(cal, tgain=tgain)

        # save gain if required:
        if save_gain:
            gain_file = output_path(gain_file)
            eigval = mpiarray.MPIArray.wrap(eigval, axis=3)
            gain = mpiarray.MPIArray.wrap(gain, axis=3)
            mem_gain = memh5.MemGroup(distributed=True)
            mem_gain.create_dataset('eigval', data=eigval)
            mem_gain.create_dataset('gain', data=gain)
            # add attris
            mem_gain.attrs['jul_data'] = ts['jul_date'][start_ind:end_ind]
            mem_gain.attrs['feed'] = np.array(feedno)
            mem_gain.attrs['pol'] = np.array(['xx', 'yy'])
            mem_gain.attrs['freq'] = ts.freq[:] # freq should be common

            # save to file
            mem_gain.to_hdf5(gain_file, hints=False)

        ts.add_history(self.history)

        return ts
Exemple #10
0
    def process(self, rt):

        assert isinstance(rt, RawTimestream), '%s only works for RawTimestream object currently' % self.__class__.__name__

        if not 'ns_on' in rt.iterkeys():
            raise RuntimeError('No noise source info, can not do noise source calibration')

        local_bl_size, local_bl_start, local_bl_end = mpiutil.split_local(len(rt['blorder']))

        rt.redistribute('baseline')

        num_mean = self.params['num_mean']
        phs_only = self.params['phs_only']
        save_gain = self.params['save_gain']
        tag_output_iter = self.params['tag_output_iter']
        gain_file = self.params['gain_file']
        bl_incl = self.params['bl_incl']
        bl_excl = self.params['bl_excl']
        freq_incl = self.params['freq_incl']
        freq_excl = self.params['freq_excl']
        ns_stable = self.params['ns_stable']
        last_transit2stable = self.params['last_transit2stable']

#===================================
        normal_gain_file = self.params['normal_gain_file']
        rt.gain_file = self.params['gain_file']
        absolute_gain_filename = self.params['absolute_gain_filename']
        use_center_data = self.params['use_center_data']
        if use_center_data and rt['ns_on'].attrs['on_time'] <= 2:
            warnings.warn('The period %d <= 2, cannot get rid of the beginning and ending points. Use the whole average automatically!')
            use_center_data = False
#===================================
        # apply abs gain to data if ps_first
        if rt.ps_first:
            if absolute_gain_filename is None:
                raise NoPsGainFile('Absent parameter absolute_gain_file. In ps_first mode, absolute_gain_filename is required to process!')
            if not os.path.exists(output_path(absolute_gain_filename)):
                raise NoPsGainFile('No absolute gain file %s, do the ps calibration first!'%output_path(absolute_gain_filename))
            with h5py.File(output_path(absolute_gain_filename, 'r')) as agfilein:
                if not ns_stable:
#                    try: # if there is transit in this batch of data, build up transit amp and phs
#                        rt.normal_index = np.where(np.abs(rt['jul_date'] - agfilein.attrs['transit_time']) < 2.e-6)[0][0] # to avoid numeric error. 1.e-6 julian date is about 0.1s
                    if rt['sec1970'][0] <= agfilein.attrs['transit_time'] and rt['sec1970'][-1] >= agfilein.attrs['transit_time']:
                        rt.normal_index = np.int64(np.around((agfilein.attrs['transit_time'] - rt['sec1970'][0])/rt.attrs['inttime']))
                        if mpiutil.rank0:
                            print('Detected transit, time index %d, build up transit normalization gain file!'%rt.normal_index)
                        build_normal_file = True
                        rt.normal_phs = np.zeros((rt['vis'].shape[1], local_bl_size))
                        rt.normal_phs.fill(np.nan)
                        if not phs_only:
                            rt.normal_amp = np.zeros((rt['vis'].shape[1], local_bl_size))
                            rt.normal_amp.fill(np.nan)
#                    except IndexError: # no transit in this batch of data, load transit amp and phs from file
                    else: 
                        rt.normal_index = None # no transit flag
                        build_normal_file = False
                        if not os.path.exists(output_path(normal_gain_file)):
                            time_info = (aipy.phs.juldate2ephem(agfilein.attrs['transit_jul'] + 8./24.), aipy.phs.juldate2ephem(rt['jul_date'][0] + 8./24.), aipy.phs.juldate2ephem(rt['jul_date'][-1] + 8./24.))
                            raise TransitGainNotRecorded('The transit %s is not in time range %s to %s and no transit normalization gain was recorded!'%time_info)
                        else:
                            if mpiutil.rank0:
                                print('No transit, use existing transit normalization gain file!')
                            with h5py.File(output_path(normal_gain_file), 'r') as filein:
                                rt.normal_phs = filein['phs'][:, local_bl_start:local_bl_end]
                                if not phs_only:
                                    rt.normal_amp = filein['amp'][:, local_bl_start:local_bl_end]
                else:
                    rt.normal_index = None # no transit flag
                    rt.normal_phs = np.zeros((rt['vis'].shape[1], local_bl_size))
                    rt.normal_phs.fill(np.nan)
                    if not phs_only:
                        rt.normal_amp = np.zeros((rt['vis'].shape[1], local_bl_size))
                        rt.normal_amp.fill(np.nan)
                    try:
                        stable_time = rt['pointingtime'][-1,-1]
                    except KeyError:
                        stable_time = rt['transitsource'][-1, 0]
                    if stable_time > 0: # the time when the antenna will be pointing at the target region was recorded, use it
                        stable_time += 300 # plus 5 min to ensure stable
                    else:
                        stable_time = rt['transitsource'][-2, 0] + last_transit2stable
                    stable_time_jul = datetime.utcfromtimestamp(stable_time)
                    stable_time_jul = ephem.julian_date(stable_time_jul) # convert to julian date, convenient to display
                    if not os.path.exists(output_path(normal_gain_file)):
                        if mpiutil.rank0:
                            print('Normalization gain file has not been built yet. Try to build it!')
                            print('Recorded transit Time: %s'%aipy.phs.juldate2ephem(agfilein.attrs['transit_jul'] + 8./24.))
                            print('Last transit Time: %s'%aipy.phs.juldate2ephem(ephem.julian_date(datetime.utcfromtimestamp(rt['transitsource'][-2,0])) + 8./24.))
                            print('First time point of this data: %s'%aipy.phs.juldate2ephem(rt['jul_date'][0] + 8./24.))
                        build_normal_file = True
#                        if rt['jul_date'][0] < stable_time:
                        if rt.attrs['sec1970'][0] < stable_time:
                            raise BeforeStableTime('The beginning time point is %s, but only after %s, will the noise source be stable. Abort the noise calibration!'%(aipy.phs.juldate2ephem(rt['jul_date'][0] + 8./24.), aipy.phs.juldate2ephem(stable_time_jul + 8./24.)))
                    else:
                        if mpiutil.rank0:
                            print('Use existing normalization gain file!')
                        build_normal_file = False
                        with h5py.File(output_path(normal_gain_file), 'r') as filein:
                            rt.normal_phs = filein['phs'][:, local_bl_start:local_bl_end]
                            if not phs_only:
                                rt.normal_amp = filein['amp'][:, local_bl_start:local_bl_end]
                # apply absolute gain
                absgain = agfilein['gain'][:]
                polfeed, _ = bl2pol_feed_inds(rt.local_bl, agfilein['gain'].attrs['feed'][:], agfilein['gain'].attrs['pol'][:])
                for ii, (ipol, ifeed, jpol, jfeed) in enumerate(polfeed):
#                    rt.local_vis[:,:,ii] /= (absgain[None,:,ipol,ifeed-1] * absgain[None,:,jpol,jfeed-1].conj())
                    rt.local_vis[:,:,ii] /= (absgain[None,:,ipol,ifeed] * absgain[None,:,jpol,jfeed].conj())
#===================================
        nt = rt.local_vis.shape[0]
        if num_mean <= 0:
            raise RuntimeError('Invalid num_mean = %s' % num_mean)
        ns_on = rt['ns_on'][:]
        ns_on = np.where(ns_on, 1, 0)
        diff_ns = np.diff(ns_on)
        inds = np.where(diff_ns==1)[0] # NOTE: these are inds just 1 before the first ON
        if not rt.FRB_cal: # for FRB there might be just one noise point, avoid waste
            if inds[0]-1 < 0: # no off data in the beginning to use
                inds = inds[1:]
            if inds[-1]+2 > nt-1: # no on data in the end to use
                inds = inds[:-1]

        if save_gain:
            num_inds = len(inds)
            shp = (num_inds,)+rt.local_vis.shape[1:]
            dtype = rt.local_vis.real.dtype
            # create dataset to record ns_cal_time_inds
            rt.create_time_ordered_dataset('ns_cal_time_inds', inds)
            # create dataset to record ns_cal_phase
            ns_cal_phase = np.empty(shp, dtype=dtype)
            ns_cal_phase[:] = np.nan
            ns_cal_phase = mpiarray.MPIArray.wrap(ns_cal_phase, axis=2, comm=rt.comm)
            rt.create_freq_and_bl_ordered_dataset('ns_cal_phase', ns_cal_phase, axis_order=(None, 1, 2))
            rt['ns_cal_phase'].attrs['unit'] = 'radians'
            if not phs_only:
                # create dataset to record ns_cal_amp
                ns_cal_amp = np.empty(shp, dtype=dtype)
                ns_cal_amp[:] = np.nan
                ns_cal_amp = mpiarray.MPIArray.wrap(ns_cal_amp, axis=2, comm=rt.comm)
                rt.create_freq_and_bl_ordered_dataset('ns_cal_amp', ns_cal_amp, axis_order=(None, 1, 2))

        if bl_incl == 'all':
            bls_plt = [ tuple(bl) for bl in rt.bl ]
        else:
            bls_plt = [ bl for bl in bl_incl if not bl in bl_excl ]

        if freq_incl == 'all':
            freq_plt = range(rt.freq.shape[0])
        else:
            freq_plt = [ fi for fi in freq_incl if not fi in freq_excl ]

        show_progress = self.params['show_progress']
        progress_step = self.params['progress_step']

        if rt.ps_first:
            rt.freq_and_bl_data_operate(self.cal, full_data=True, show_progress=show_progress, progress_step=progress_step, keep_dist_axis=False, num_mean=num_mean, inds=inds, bls_plt=bls_plt, freq_plt=freq_plt, build_normal_file = build_normal_file)
        else:
            rt.freq_and_bl_data_operate(self.cal, full_data=True, show_progress=show_progress, progress_step=progress_step, keep_dist_axis=False, num_mean=num_mean, inds=inds, bls_plt=bls_plt, freq_plt=freq_plt)

        if save_gain:
            interp_mask_ratio = mpiutil.allreduce(np.sum(rt.interp_mask_count))/1./mpiutil.allreduce(np.size(rt.interp_mask_count)) * 100.
            if interp_mask_ratio > 50. and rt.normal_index is not None:
                warnings.warn('%.1f%% of the data was masked due to shortage of noise points for interpolation(need at least 4 to perform cubic spline)! The pointsource calibration may not be done due to too many masked points!'%interp_mask_ratio, NotEnoughPointToInterpolateWarning)
            if interp_mask_ratio > 80.:
                rt.interp_all_masked = True
            # gather bl_order to rank0
            bl_order = mpiutil.gather_array(rt['blorder'].local_data, axis=0, root=0, comm=rt.comm)
            # gather ns_cal_phase / ns_cal_amp to rank 0
            ns_cal_phase = mpiutil.gather_array(rt['ns_cal_phase'].local_data, axis=2, root=0, comm=rt.comm)
            phs_unit = rt['ns_cal_phase'].attrs['unit']
            rt.delete_a_dataset('ns_cal_phase', reserve_hint=False)
            if not phs_only:
                ns_cal_amp = mpiutil.gather_array(rt['ns_cal_amp'].local_data, axis=2, root=0, comm=rt.comm)
                rt.delete_a_dataset('ns_cal_amp', reserve_hint=False)

            if tag_output_iter:
                gain_file = output_path(gain_file, iteration=self.iteration)
            else:
                gain_file = output_path(gain_file)

            if rt.ps_first:
                phase_finite_count = mpiutil.allreduce(np.isfinite(rt.normal_phs).sum())
                if not phs_only:
                    amp_finite_count = mpiutil.allreduce(np.isfinite(rt.normal_amp).sum())
            if mpiutil.rank0:
                if rt.ps_first and build_normal_file:
                    if phase_finite_count == 0:
                        raise AllMasked('All values are masked when calculating phase!')
                    if not phs_only:
                        if amp_finite_count == 0:
                            raise AllMasked('All values are masked when calculating amplitude!')
                    with h5py.File(output_path(normal_gain_file), 'w') as tapfilein:
                        tapfilein.create_dataset('amp',(rt['vis'].shape[1], rt['vis'].shape[2]))
                        tapfilein.create_dataset('phs',(rt['vis'].shape[1], rt['vis'].shape[2]))
                with h5py.File(gain_file, 'w') as f:
                    # save time
                    f.create_dataset('time', data=rt['jul_date'][:])
                    f['time'].attrs['unit'] = 'Julian date'
                    # save freq
                    f.create_dataset('freq', data=rt['freq'][:])
                    f['freq'].attrs['unit'] = rt['freq'].attrs['unit']
                    # save bl
                    f.create_dataset('bl_order', data=bl_order)
                    # save ns_cal_time_inds
                    f.create_dataset('ns_cal_time_inds', data=rt['ns_cal_time_inds'][:])
                    # save ns_cal_phase
                    f.create_dataset('ns_cal_phase', data=ns_cal_phase)
                    f['ns_cal_phase'].attrs['unit'] = phs_unit
                    f['ns_cal_phase'].attrs['dim'] = '(time, freq, bl)'
                    if not phs_only:
                        # save ns_cal_amp
                        f.create_dataset('ns_cal_amp', data=ns_cal_amp)

                    # save channo
                    f.create_dataset('channo', data=rt['channo'][:])
                    f['channo'].attrs['dim'] = rt['channo'].attrs['dimname']
                    if rt.exclude_bad:
                        f['channo'].attrs['badchn'] = rt['channo'].attrs['badchn']

                    if not (absolute_gain_filename is None):
                        if not os.path.exists(output_path(absolute_gain_filename)):
                            raise NoPsGainFile('No absolute gain file %s, do the ps calibration first!'%output_path(absolute_gain_filename))
                        with h5py.File(output_path(absolute_gain_filename,'r')) as abs_gain:
                            new_gain = uni_gain(abs_gain, f, phs_only = phs_only)
                            f.create_dataset('uni_gain', data = new_gain)
                            f['uni_gain'].attrs['dim'] = '(time, freq, bl)'

            if rt.ps_first and build_normal_file:
                if mpiutil.rank0:
                    print('Start write normalization gain into %s'%output_path(normal_gain_file))
                for i in range(10):
                    try:
                        for ii in range(mpiutil.size):
                            if ii == mpiutil.rank:
                                with h5py.File(output_path(normal_gain_file), 'r+') as fileout:
                                    fileout['phs'][:,local_bl_start:local_bl_end] = rt.normal_phs[:,:]
                                    if not phs_only:
                                        fileout['amp'][:,local_bl_start:local_bl_end] = rt.normal_amp[:,:]
                            mpiutil.barrier()
                        break
                    except IOError:
                        time.sleep(0.5)
                        continue
            rt.delete_a_dataset('ns_cal_time_inds', reserve_hint=False)

        return super(NsCal, self).process(rt)
Exemple #11
0
    def process(self, rt):

        assert isinstance(
            rt, RawTimestream
        ), '%s only works for RawTimestream object currently' % self.__class__.__name__

        channel = self.params['channel']
        mask_near = max(0, int(self.params['mask_near']))

        rt.redistribute(0)  # make time the dist axis

        auto_inds = np.where(
            rt.bl[:, 0] == rt.bl[:,
                                 1])[0].tolist()  # inds for auto-correlations
        channels = [rt.bl[ai, 0] for ai in auto_inds]  # all chosen channels
        if channel is not None:
            if channel in channels:
                bl_ind = auto_inds[channels.index(channel)]
            else:
                bl_ind = auto_inds[0]
                if mpiutil.rank0:
                    print 'Warning: Required channel %d doen not in the data, use channel %d instead' % (
                        channel, rt.bl[bl_ind, 0])
        else:
            bl_ind = auto_inds[0]
        # move the chosen channel to the first
        auto_inds.remove(bl_ind)
        auto_inds = [bl_ind] + auto_inds

        for bl_ind in auto_inds:
            this_chan = rt.bl[bl_ind, 0]  # channel of this bl_ind
            vis = np.ma.array(rt.local_vis[:, :, bl_ind].real,
                              mask=rt.local_vis_mask[:, :, bl_ind])
            cnt = vis.count()  # number of not masked vals
            total_cnt = mpiutil.allreduce(cnt)
            vis_shp = rt.vis.shape
            ratio = float(total_cnt) / np.prod(
                (vis_shp[0], vis_shp[1]))  # ratio of un-masked vals
            if ratio < 0.5:  # too many masked vals
                if mpiutil.rank0:
                    warnings.warn(
                        'Too many masked values for auto-correlation of Channel: %d, does not use it'
                        % this_chan)
                continue

            tt_mean = mpiutil.gather_array(np.ma.mean(vis, axis=-1).filled(0),
                                           root=None)
            tt_mean_sort = np.sort(tt_mean)
            tt_mean_sort = tt_mean_sort[tt_mean_sort > 0]
            ttms_div = tt_mean_sort[1:] / tt_mean_sort[:-1]
            ind = np.argmax(ttms_div)
            if ind > 0 and ttms_div[ind] > 2.0 * max(ttms_div[ind - 1],
                                                     ttms_div[ind + 1]):
                sep = np.sqrt(tt_mean_sort[ind] * tt_mean_sort[ind + 1])
                break

            # ttms_diff = np.diff(tt_mean_sort)
            # ind = np.argmax(ttms_diff)
            # if ind > 0 and ttms_diff[ind] > 2.0 * max(ttms_diff[ind-1], ttms_diff[ind+1]):
            #     sep = 0.5 * (tt_mean_sort[ind] + tt_mean_sort[ind+1])
            #     break
        else:
            raise RuntimeError(
                'Failed to get the threshold to separate ns signal out')

        ns_on = np.where(tt_mean > sep, True, False)
        nTs = []
        nFs = []
        nT = 1 if ns_on[0] else 0
        nF = 1 if not ns_on[0] else 0
        for i in range(1, len(ns_on)):
            if ns_on[i]:
                if ns_on[i] == ns_on[i - 1]:
                    nT += 1
                else:
                    nT = 1
                    nFs.append(nF)
            else:
                if ns_on[i] == ns_on[i - 1]:
                    nF += 1
                else:
                    nF = 1
                    nTs.append(nT)
        on_time = Counter(nTs).most_common(1)[0][0]
        off_time = Counter(nFs).most_common(1)[0][0]
        period = on_time + off_time

        if 'noisesource' in rt.iterkeys():
            if rt['noisesource'].shape[0] == 1:  # only 1 noise source
                start, stop, cycle = rt['noisesource'][0, :]
                int_time = rt.attrs['inttime']
                true_on_time = np.round((stop - start) / int_time)
                true_period = np.round(cycle / int_time)
                if on_time != true_on_time and period != true_period:  # inconsistant with the record in the data
                    if mpiutil.rank0:
                        warnings.warn(
                            'Detected noise source info is inconsistant with the record in the data for auto-correlation of Channel: %d: on_time %d != record_on_time %d, period != record_period %d, does not use it'
                            % (this_chan, on_time, true_on_time, period,
                               true_period))
            elif rt['noisesource'].shape[0] >= 2:  # more than 1 noise source
                if mpiutil.rank0:
                    warnings.warn(
                        'More than 1 noise source, do not know how to deal with this currently'
                    )

        if mpiutil.rank0:
            print 'Detected noise source: period = %d, on_time = %d, off_time = %d' % (
                period, on_time, off_time)

        ns_on1 = mpiarray.MPIArray.from_numpy_array(ns_on)

        rt.create_main_time_ordered_dataset('ns_on', ns_on1)
        rt['ns_on'].attrs['period'] = period
        rt['ns_on'].attrs['on_time'] = on_time
        rt['ns_on'].attrs['off_time'] = off_time

        # set vis_mask corresponding to ns_on
        on_inds = np.where(rt['ns_on'].local_data[:])[0]
        rt.local_vis_mask[on_inds] = True

        if mask_near > 0:
            on_inds = np.where(ns_on)[0]
            new_on_inds = on_inds.tolist()
            for i in xrange(1, mask_near + 1):
                new_on_inds = new_on_inds + (on_inds - i).tolist() + (
                    on_inds + i).tolist()
            new_on_inds = np.unique(new_on_inds)

            if rt['vis_mask'].distributed:
                start = rt.vis_mask.local_offset[0]
                end = start + rt.vis_mask.local_shape[0]
            else:
                start = 0
                end = rt.vis_mask.shape[0]
            global_inds = np.arange(start, end).tolist()
            new_on_inds = np.intersect1d(new_on_inds, global_inds)
            local_on_inds = [global_inds.index(i) for i in new_on_inds]
            rt.local_vis_mask[
                local_on_inds] = True  # set mask using global slicing

        return super(Detect, self).process(rt)
Exemple #12
0
    def process(self, ts):

        assert isinstance(
            ts, Timestream
        ), '%s only works for Timestream object' % self.__class__.__name__

        #        if mpiutil.rank0:
        #            saveVmat = []

        calibrator = self.params['calibrator']
        catalog = self.params['catalog']
        vis_conj = self.params['vis_conj']
        zero_diag = self.params['zero_diag']
        span = self.params['span']
        reserve_high_gain = self.params['reserve_high_gain']
        plot_figs = self.params['plot_figs']
        fig_prefix = self.params['fig_name']
        tag_output_iter = self.params['tag_output_iter']
        save_src_vis = self.params['save_src_vis']
        src_vis_file = self.params['src_vis_file']
        subtract_src = self.params['subtract_src']
        apply_gain = self.params['apply_gain']
        save_gain = self.params['save_gain']
        save_phs_change = self.params['save_phs_change']
        gain_file = self.params['gain_file']
        temperature_convert = self.params['temperature_convert']
        show_progress = self.params['show_progress']
        progress_step = self.params['progress_step']
        srcdict = self.params['srcdict']
        max_iter = self.params['max_iter']

        #        if mpiutil.rank0:
        #            print(ts.keys())
        #            print(ts.attrs.keys())
        #
        #        import sys
        #        sys.exit()

        if save_src_vis or subtract_src or apply_gain or save_gain:
            pol_type = ts['pol'].attrs['pol_type']
            if pol_type != 'linear':
                raise RuntimeError('Can not do ps_cal for pol_type: %s' %
                                   pol_type)

            ts.redistribute('baseline')

            feedno = ts['feedno'][:].tolist()
            pol = [ts.pol_dict[p] for p in ts['pol'][:]]  # as string
            gain_pd = {
                'xx': 0,
                'yy': 1,
                0: 'xx',
                1: 'yy'
            }  # for gain related op
            bls = mpiutil.gather_array(ts.local_bl[:], root=None, comm=ts.comm)
            # # antpointing = np.radians(ts['antpointing'][-1, :, :]) # radians
            # transitsource = ts['transitsource'][:]
            # transit_time = transitsource[-1, 0] # second, sec1970
            # int_time = ts.attrs['inttime'] # second

            # calibrator
            #see whether the source should be observed
            #only for single calibrator
            if ts.is_dish:
                obsd_sources = ts['transitsource'].attrs['srcname']
                obsd_sources = obsd_sources.split(',')
                #                srcdict = {'cas':'CassiopeiaA'}
                src_index = 0.1
                for src_short, src_long in srcdict.items():
                    #                    print(src_short,src_long,obsd_sources,calibrator)
                    if (src_short == calibrator) and (src_long
                                                      in obsd_sources):
                        src_index = obsd_sources.index(src_long)
                        break
                else:
                    raise Exception(
                        'the transit source is not included in the observation plan!'
                    )
                obs_transit_time = ts['transitsource'][src_index][0]

#            srclist, cutoff, catalogs = a.scripting.parse_srcs(calibrator, catalog)
#            cat = a.src.get_catalog(srclist, cutoff, catalogs)
#            assert(len(cat) == 1), 'Allow only one calibrator'
#            s = cat.values()[0]

# get the calibrator
            try:
                s = calibrators.get_src(calibrator)
            except KeyError:
                if mpiutil.rank0:
                    print 'Calibrator %s is unavailable, available calibrators are:'
                    for key, d in calibrators.src_data.items():
                        print '%8s  ->  %12s' % (key, d[0])
                raise RuntimeError('Calibrator %s is unavailable')
            if mpiutil.rank0:
                print 'Try to calibrate with %s...' % s.src_name
#            if mpiutil.rank0:
#                print 'Calibrating for source %s with' % calibrator,
#                print 'strength', s._jys, 'Jy',
#                print 'measured at', s.mfreq, 'GHz',
#                print 'with index', s.index

# get transit time of calibrator
# array
            aa = ts.array
            aa.set_jultime(ts['jul_date'][0])  # the first obs time point
            next_transit = aa.next_transit(s)
            next_transit = aa.next_transit(
                s) if not ts.is_dish else ephem.date(
                    datetime.utcfromtimestamp(obs_transit_time))
            transit_time = a.phs.ephem2juldate(next_transit)  # Julian date
            # get time zone
            pattern = '[-+]?\d+'
            tz = re.search(pattern, ts.attrs['timezone']).group()
            tz = int(tz)

            local_next_transit = ephem.Date(
                next_transit + tz * ephem.hour)  # plus 8h to get Beijing time
            # if transit_time > ts['jul_date'][-1]:
            #========================================================
            if ts.is_dish:
                if (transit_time >
                        max(ts['jul_date'][-1],
                            ts['jul_date'][:].max())) or (transit_time < min(
                                ts['jul_date'][0], ts['jul_date'][:].min())):
                    #                    raise RuntimeError('dish data does not contain local transit time %s of source %s' % (local_next_transit, calibrator))
                    raise NoTransit(
                        'Dish data does not contain local transit time %s of source %s'
                        % (local_next_transit, calibrator))
                transit_inds = [
                    np.searchsorted(ts['jul_date'][:], transit_time)
                ]
                peak_span = int(np.around(10. / ts.attrs['inttime']))
                peak_start = ts['jul_date'][transit_inds[0] - peak_span]
                peak_end = ts['jul_date'][transit_inds[0] + peak_span]
                tspan = (peak_end - peak_start) * 86400.
                if mpiutil.rank0:
                    print('transit peak: %s' % local_next_transit)
                    print('%d point (should be about 10s) previous: %s' %
                          (peak_span,
                           ephem.date(
                               a.phs.juldate2ephem(peak_start) +
                               tz * ephem.hour)))
                    print(
                        '%d point (should be about 10s) later: %s' %
                        (peak_span,
                         ephem.date(
                             a.phs.juldate2ephem(peak_end) + tz * ephem.hour)))
                if tspan > 30.:
                    warnings.warn(
                        'Peak of transit is not continuous in the data! May lead to poor performance!'
                    )
            else:
                if transit_time > max(ts['jul_date'][-1],
                                      ts['jul_date'][:].max()):
                    #                    raise RuntimeError('Cylinder data does not contain local transit time %s of source %s' % (local_next_transit, calibrator))
                    raise NoTransit(
                        'Cylinder data does not contain local transit time %s of source %s'
                        % (local_next_transit, calibrator))

                # the first transit index
                transit_inds = [
                    np.searchsorted(ts['jul_date'][:], transit_time)
                ]
                # find all other transit indices
                aa.set_jultime(ts['jul_date'][0] + 1.0)
                transit_time = a.phs.ephem2juldate(
                    aa.next_transit(s))  # Julian date
                cnt = 2
                while (transit_time <= ts['jul_date'][-1]):
                    transit_inds.append(
                        np.searchsorted(ts['jul_date'][:], transit_time))
                    aa.set_jultime(ts['jul_date'][0] + 1.0 * cnt)
                    # Julian date
                    #                transit_time = a.phs.ephem2juldate(aa.next_transit(s) if not ts.is_dish else ephem.date(datetime.utcfromtimestamp(obs_transit_time)))
                    transit_time = a.phs.ephem2juldate(aa.next_transit(s))
                    cnt += 1
#========================================================

            if mpiutil.rank0:
                print 'transit ind of %s: %s, time: %s' % (
                    s.src_name, transit_inds, local_next_transit)

            if (not ts.ps_first) and ts.interp_all_masked:
                raise NotEnoughPointToInterpolateError(
                    'More than 80% of the data was masked due to shortage of noise points for interpolation(need at least 4 to perform cubic spline)! The pointsource calibration may not be done due to too many masked points!'
                )
            ### now only use the first transit point to do the cal
            ### may need to improve in the future
            transit_ind = transit_inds[0]
            int_time = ts.attrs['inttime']  # second
            start_ind = transit_ind - np.int(span / int_time)
            end_ind = transit_ind + np.int(
                span /
                int_time) + 1  # plus 1 to make transit_ind is at the center

            start_ind = max(0, start_ind)
            end_ind = min(end_ind, ts.vis.shape[0])

            if vis_conj:
                ts.local_vis[:] = ts.local_vis.conj()

            nt = end_ind - start_ind
            t_inds = range(start_ind, end_ind)
            freq = ts.freq[:]
            nf = len(freq)
            nlb = len(ts.local_bl[:])
            nfeed = len(feedno)
            tfp_inds = list(
                itertools.product(
                    t_inds, range(nf),
                    [pol.index('xx'), pol.index('yy')]))  # only for xx and yy
            ns, ss, es = mpiutil.split_all(len(tfp_inds), comm=ts.comm)
            # gather data to make each process to have its own data which has all bls
            for ri, (ni, si, ei) in enumerate(zip(ns, ss, es)):
                lvis = np.zeros((ni, nlb), dtype=ts.vis.dtype)
                lvis_mask = np.zeros((ni, nlb), dtype=ts.vis_mask.dtype)
                for ii, (ti, fi, pi) in enumerate(tfp_inds[si:ei]):
                    lvis[ii] = ts.local_vis[ti, fi, pi]
                    lvis_mask[ii] = ts.local_vis_mask[ti, fi, pi]
                # gather vis from all process for separate bls
                gvis = mpiutil.gather_array(lvis,
                                            axis=1,
                                            root=ri,
                                            comm=ts.comm)
                gvis_mask = mpiutil.gather_array(lvis_mask,
                                                 axis=1,
                                                 root=ri,
                                                 comm=ts.comm)
                if ri == mpiutil.rank:
                    tfp_linds = tfp_inds[si:ei]  # inds for this process
                    this_vis = gvis
                    this_vis_mask = gvis_mask
            del tfp_inds
            del lvis
            del lvis_mask
            tfp_len = len(tfp_linds)

            cnan = complex(np.nan, np.nan)  # complex nan
            if save_src_vis or subtract_src:
                # save calibrator src vis
                lsrc_vis = np.full((tfp_len, nfeed, nfeed),
                                   cnan,
                                   dtype=ts.vis.dtype)
                if save_src_vis:
                    # save sky vis
                    lsky_vis = np.full((tfp_len, nfeed, nfeed),
                                       cnan,
                                       dtype=ts.vis.dtype)
                    # save outlier vis
                    lotl_vis = np.full((tfp_len, nfeed, nfeed),
                                       cnan,
                                       dtype=ts.vis.dtype)

            if apply_gain or save_gain:
                lGain = np.full((tfp_len, nfeed), cnan, dtype=ts.vis.dtype)

            # find indices mapping between Vmat and vis
            # bis = range(nbl)
            bis_conj = []  # indices that shold be conj
            mis = [
            ]  # indices in the nfeed x nfeed matrix by flatten it to a vector
            mis_conj = [
            ]  # indices (of conj vis) in the nfeed x nfeed matrix by flatten it to a vector
            for bi, (fdi, fdj) in enumerate(bls):
                ai, aj = feedno.index(fdi), feedno.index(fdj)
                mis.append(ai * nfeed + aj)
                if ai != aj:
                    bis_conj.append(bi)
                    mis_conj.append(aj * nfeed + ai)

            # construct visibility matrix for a single time, freq, pol
            Vmat = np.full((nfeed, nfeed), cnan, dtype=ts.vis.dtype)
            # get flus of the calibrator in the observing frequencies
            Sc = s.get_jys(freq)
            if show_progress and mpiutil.rank0:
                pg = progress.Progress(tfp_len, step=progress_step)
            for ii, (ti, fi, pi) in enumerate(tfp_linds):
                if show_progress and mpiutil.rank0:
                    pg.show(ii)
                # when noise on, just pass
                if 'ns_on' in ts.iterkeys() and ts['ns_on'][ti]:
                    continue
                # aa.set_jultime(ts['jul_date'][ti])
                # s.compute(aa)
                # get the topocentric coordinate of the calibrator at the current time
                # s_top = s.get_crds('top', ncrd=3)
                # aa.sim_cache(cat.get_crds('eq', ncrd=3)) # for compute bm_response and sim
                Vmat.flat[mis] = np.ma.array(
                    this_vis[ii], mask=this_vis_mask[ii]).filled(cnan)
                Vmat.flat[mis_conj] = np.ma.array(
                    this_vis[ii, bis_conj],
                    mask=this_vis_mask[ii, bis_conj]).conj().filled(cnan)

                #                if mpiutil.rank0:
                #                    saveVmat += [Vmat.copy()]

                if save_src_vis:
                    lsky_vis[ii] = Vmat

                # set invalid val to 0
                invalid = ~np.isfinite(Vmat)  # a bool array
                # if too many masks

                if np.where(invalid)[0].shape[0] > 0.3 * nfeed**2:
                    continue
                Vmat[invalid] = 0
                #                # if all are zeros
                if np.allclose(Vmat, 0.0):
                    continue

                # fill diagonal of Vmat to 0
                if zero_diag:
                    np.fill_diagonal(Vmat, 0)

                # initialize the outliers
                med = np.median(Vmat.real) + 1.0J * np.median(Vmat.imag)
                diff = Vmat - med
                S0 = np.where(
                    np.abs(diff) > 3.0 * rpca_decomp.MAD(Vmat), diff, 0)
                # stable PCA decomposition
                #                V0, S = rpca_decomp.decompose(Vmat, rank=1, S=S0, max_iter=100, threshold='hard', tol=1.0e-6, debug=False)
                V0, S = rpca_decomp.decompose(Vmat,
                                              rank=1,
                                              S=S0,
                                              max_iter=max_iter,
                                              threshold='hard',
                                              tol=1.0e-6,
                                              debug=False)
                #                V0, S = rpca_decomp.decompose(Vmat, rank=1, S=S0, max_iter=1, threshold='hard', tol=1., debug=False)
                if save_src_vis or subtract_src:
                    lsrc_vis[ii] = V0
                    if save_src_vis:
                        lotl_vis[ii] = S

                # plot
                if plot_figs:
                    ind = ti - start_ind
                    # plot Vmat
                    plt.figure(figsize=(13, 5))
                    plt.subplot(121)
                    plt.imshow(Vmat.real,
                               aspect='equal',
                               origin='lower',
                               interpolation='nearest')
                    plt.colorbar(shrink=1.0)
                    plt.subplot(122)
                    plt.imshow(Vmat.imag,
                               aspect='equal',
                               origin='lower',
                               interpolation='nearest')
                    plt.colorbar(shrink=1.0)
                    fig_name = '%s_V_%d_%d_%s.png' % (fig_prefix, ind, fi,
                                                      pol[pi])
                    if tag_output_iter:
                        fig_name = output_path(fig_name,
                                               iteration=self.iteration)
                    else:
                        fig_name = output_path(fig_name)
                    plt.savefig(fig_name)
                    plt.close()
                    # plot V0
                    plt.figure(figsize=(13, 5))
                    plt.subplot(121)
                    plt.imshow(V0.real,
                               aspect='equal',
                               origin='lower',
                               interpolation='nearest')
                    plt.colorbar(shrink=1.0)
                    plt.subplot(122)
                    plt.imshow(V0.imag,
                               aspect='equal',
                               origin='lower',
                               interpolation='nearest')
                    plt.colorbar(shrink=1.0)
                    fig_name = '%s_V0_%d_%d_%s.png' % (fig_prefix, ind, fi,
                                                       pol[pi])
                    if tag_output_iter:
                        fig_name = output_path(fig_name,
                                               iteration=self.iteration)
                    else:
                        fig_name = output_path(fig_name)
                    plt.savefig(fig_name)
                    plt.close()
                    # plot S
                    plt.figure(figsize=(13, 5))
                    plt.subplot(121)
                    plt.imshow(S.real,
                               aspect='equal',
                               origin='lower',
                               interpolation='nearest')
                    plt.colorbar(shrink=1.0)
                    plt.subplot(122)
                    plt.imshow(S.imag,
                               aspect='equal',
                               origin='lower',
                               interpolation='nearest')
                    plt.colorbar(shrink=1.0)
                    fig_name = '%s_S_%d_%d_%s.png' % (fig_prefix, ind, fi,
                                                      pol[pi])
                    if tag_output_iter:
                        fig_name = output_path(fig_name,
                                               iteration=self.iteration)
                    else:
                        fig_name = output_path(fig_name)
                    plt.savefig(fig_name)
                    plt.close()
                    # plot N
                    N = Vmat - V0 - S
                    plt.figure(figsize=(13, 5))
                    plt.subplot(121)
                    plt.imshow(N.real,
                               aspect='equal',
                               origin='lower',
                               interpolation='nearest')
                    plt.colorbar(shrink=1.0)
                    plt.subplot(122)
                    plt.imshow(N.imag,
                               aspect='equal',
                               origin='lower',
                               interpolation='nearest')
                    plt.colorbar(shrink=1.0)
                    fig_name = '%s_N_%d_%d_%s.png' % (fig_prefix, ind, fi,
                                                      pol[pi])
                    if tag_output_iter:
                        fig_name = output_path(fig_name,
                                               iteration=self.iteration)
                    else:
                        fig_name = output_path(fig_name)
                    plt.savefig(fig_name)
                    plt.close()

                if apply_gain or save_gain:
                    #                    if mpiutil.rank0:
                    #                        np.save('Vmat',saveVmat)
                    e, U = la.eigh(V0 / Sc[fi], eigvals=(nfeed - 1, nfeed - 1))
                    g = U[:, -1] * e[-1]**0.5
                    if g[0].real < 0:
                        g *= -1.0  # make all g[0] phase 0, instead of pi
                    lGain[ii] = g

                    # plot Gain
                    liplot = max(start_ind, transit_ind - 1)
                    hiplot = min(end_ind, transit_ind + 1 + 1)
                    if plot_figs and ti >= liplot and ti <= hiplot:
                        #                    if ti >= liplot and ti <= hiplot:
                        ind = ti - start_ind
                        plt.figure()
                        plt.plot(feedno, g.real, 'b-', label='real')
                        plt.plot(feedno, g.real, 'bo')
                        plt.plot(feedno, g.imag, 'g-', label='imag')
                        plt.plot(feedno, g.imag, 'go')
                        plt.plot(feedno, np.abs(g), 'r-', label='abs')
                        plt.plot(feedno, np.abs(g), 'ro')
                        plt.xlim(feedno[0] - 1, feedno[-1] + 1)
                        yl, yh = plt.ylim()
                        plt.ylim(yl, yh + (yh - yl) / 5)
                        plt.xlabel('Feed number')
                        plt.legend()
                        fig_name = '%s_ants_%d_%d_%s.png' % (fig_prefix, ind,
                                                             fi, pol[pi])
                        print('plot %s' % fig_name)
                        if tag_output_iter:
                            fig_name = output_path(fig_name,
                                                   iteration=self.iteration)
                        else:
                            fig_name = output_path(fig_name)
                        plt.savefig(fig_name)
                        plt.close()

            # subtract the vis of calibrator from self.vis
            if subtract_src:
                nbl = len(bls)
                lv = np.zeros((lsrc_vis.shape[0], nbl), dtype=lsrc_vis.dtype)
                for bi, (fd1, fd2) in enumerate(bls):
                    b1, b2 = feedno.index(fd1), feedno.index(fd2)
                    lv[:, bi] = lsrc_vis[:, b1, b2]
                lv = mpiarray.MPIArray.wrap(lv, axis=0, comm=ts.comm)
                lv = lv.redistribute(axis=1).local_array.reshape(nt, nf, 2, -1)
                if 'ns_on' in ts.iterkeys():
                    lv[ts['ns_on']
                       [start_ind:
                        end_ind]] = 0  # avoid ns_on signal to become nan
                ts.local_vis[start_ind:end_ind, :,
                             pol.index('xx')] -= lv[:, :, 0]
                ts.local_vis[start_ind:end_ind, :,
                             pol.index('yy')] -= lv[:, :, 1]

                del lv

            if not save_src_vis:
                if subtract_src:
                    del lsrc_vis
            else:
                if tag_output_iter:
                    src_vis_file = output_path(src_vis_file,
                                               iteration=self.iteration)
                else:
                    src_vis_file = output_path(src_vis_file)
                # create file and allocate space first by rank0
                if mpiutil.rank0:
                    with h5py.File(src_vis_file, 'w') as f:
                        # allocate space
                        shp = (nt, nf, 2, nfeed, nfeed)
                        f.create_dataset('sky_vis', shp, dtype=lsky_vis.dtype)
                        f.create_dataset('src_vis', shp, dtype=lsrc_vis.dtype)
                        f.create_dataset('outlier_vis',
                                         shp,
                                         dtype=lotl_vis.dtype)
                        f.attrs['calibrator'] = calibrator
                        f.attrs['dim'] = 'time, freq, pol, feed, feed'
                        try:
                            f.attrs['time'] = ts.time[start_ind:end_ind]
                        except RuntimeError:
                            f.create_dataset('time',
                                             data=ts.time[start_ind:end_ind])
                            f.attrs['time'] = '/time'
                        f.attrs['freq'] = freq
                        f.attrs['pol'] = np.array(['xx', 'yy'])
                        f.attrs['feed'] = np.array(feedno)

                mpiutil.barrier()

                # write data to file
                for i in range(10):
                    try:
                        # NOTE: if write simultaneously, will loss data with processes distributed in several nodes
                        for ri in xrange(mpiutil.size):
                            if ri == mpiutil.rank:
                                with h5py.File(src_vis_file, 'r+') as f:
                                    for ii, (ti, fi,
                                             pi) in enumerate(tfp_linds):
                                        ti_ = ti - start_ind
                                        pi_ = gain_pd[pol[pi]]
                                        f['sky_vis'][ti_, fi,
                                                     pi_] = lsky_vis[ii]
                                        f['src_vis'][ti_, fi,
                                                     pi_] = lsrc_vis[ii]
                                        f['outlier_vis'][ti_, fi,
                                                         pi_] = lotl_vis[ii]
                            mpiutil.barrier()
                        break
                    except IOError:
                        time.sleep(0.5)
                        continue
                else:
                    raise RuntimeError('Could not open file: %s...' %
                                       src_vis_file)

                del lsrc_vis
                del lsky_vis
                del lotl_vis

                mpiutil.barrier()

            if apply_gain or save_gain:
                # flag outliers in lGain along each feed
                lG_abs = np.full_like(lGain, np.nan, dtype=lGain.real.dtype)
                for i in range(lGain.shape[0]):
                    valid_inds = np.where(np.isfinite(lGain[i]))[0]
                    if len(valid_inds) > 3:
                        vabs = np.abs(lGain[i, valid_inds])
                        vmed = np.median(vabs)
                        vabs_diff = np.abs(vabs - vmed)
                        vmad = np.median(vabs_diff) / 0.6745
                        if reserve_high_gain:
                            # reserve significantly higher ones, flag only significantly lower ones
                            lG_abs[i, valid_inds] = np.where(
                                vmed - vabs > 3.0 * vmad, np.nan, vabs)
                        else:
                            # flag both significantly higher and lower ones
                            lG_abs[i, valid_inds] = np.where(
                                vabs_diff > 3.0 * vmad, np.nan, vabs)

                # choose data slice near the transit time
                li = max(start_ind, transit_ind - 10) - start_ind
                hi = min(end_ind, transit_ind + 10 + 1) - start_ind
                # compute s_top for this time range
                n0 = np.zeros(((hi - li), 3))
                for ti, jt in enumerate(ts.time[start_ind:end_ind][li:hi]):
                    aa.set_jultime(jt)
                    s.compute(aa)
                    n0[ti] = s.get_crds('top', ncrd=3)
                if save_phs_change:
                    n0t = np.zeros((nt, 3))
                    for ti, jt in enumerate(ts.time[start_ind:end_ind]):
                        aa.set_jultime(jt)
                        s.compute(aa)
                        n0t[ti] = s.get_crds('top', ncrd=3)

                # get the positions of feeds
                feedpos = ts['feedpos'][:]

                # wrap and redistribute Gain and flagged G_abs
                Gain = mpiarray.MPIArray.wrap(lGain, axis=0, comm=ts.comm)
                Gain = Gain.redistribute(axis=1).reshape(
                    nt, nf, 2, None).redistribute(axis=0).reshape(
                        None, nf * 2 * nfeed).redistribute(axis=1)
                G_abs = mpiarray.MPIArray.wrap(lG_abs, axis=0, comm=ts.comm)
                G_abs = G_abs.redistribute(axis=1).reshape(
                    nt, nf, 2, None).redistribute(axis=0).reshape(
                        None, nf * 2 * nfeed).redistribute(axis=1)

                fpd_inds = list(
                    itertools.product(range(nf), range(2),
                                      range(nfeed)))  # only for xx and yy
                fpd_linds = mpiutil.mpilist(fpd_inds,
                                            method='con',
                                            comm=ts.comm)
                del fpd_inds
                # create data to save the solved gain for each feed
                lgain = np.full((len(fpd_linds), ), cnan,
                                dtype=Gain.dtype)  # gain for each feed
                if save_phs_change:
                    lphs = np.full((nt, len(fpd_linds)),
                                   np.nan,
                                   dtype=Gain.real.dtype
                                   )  # phase change with time for each feed

                # check for conj
                num_conj = 0
                for ii, (fi, pi, di) in enumerate(fpd_linds):
                    y = G_abs.local_array[li:hi, ii]
                    inds = np.where(np.isfinite(y))[0]
                    if len(inds) >= max(4, 0.5 * len(y)):
                        # get the approximate magnitude by averaging the central G_abs
                        # solve phase by least square fit
                        ui = (feedpos[di] - feedpos[0]) * (
                            1.0e6 * freq[fi]
                        ) / const.c  # position of this feed (relative to the first feed) in unit of wavelength
                        exp_factor = np.exp(2.0J * np.pi * np.dot(n0, ui))
                        ef = exp_factor
                        Gi = Gain.local_array[li:hi, ii]
                        e_phs = np.dot(ef[inds].conj(),
                                       Gi[inds] / y[inds]) / len(inds)
                        ea = np.abs(e_phs)
                        e_phs_conj = np.dot(ef[inds],
                                            Gi[inds] / y[inds]) / len(inds)
                        eac = np.abs(e_phs_conj)
                        if eac > ea:
                            num_conj += 1
                # reduce num_conj from all processes
                num_conj = mpiutil.allreduce(num_conj, comm=ts.comm)
                if num_conj > 0.5 * (nf * 2 * nfeed):  # 2 for 2 pols
                    if mpiutil.rank0:
                        print '!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!'
                        print '!!!   Detect data should be their conjugate...   !!!'
                        print '!!!   Correct it automatically...                !!!'
                        print '!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!'
                    mpiutil.barrier()
                    # correct vis
                    ts.local_vis[:] = ts.local_vis.conj()
                    # correct G
                    Gain.local_array[:] = Gain.local_array.conj()

                # solve for gain
#                feedplotG = []
#                savenum = 0
                for ii, (fi, pi, di) in enumerate(fpd_linds):
                    y = G_abs.local_array[li:hi, ii]
                    inds = np.where(np.isfinite(y))[0]

                    if len(inds) >= max(4, 0.5 * len(y)):
                        # get the approximate magnitude by averaging the central G_abs
                        mag = np.mean(y[inds])
                        # solve phase by least square fit
                        ui = (feedpos[di] - feedpos[0]) * (
                            1.0e6 * freq[fi]
                        ) / const.c  # position of this feed (relative to the first feed) in unit of wavelength
                        exp_factor = np.exp(2.0J * np.pi * np.dot(n0, ui))
                        ef = exp_factor
                        Gi = Gain.local_array[li:hi, ii]

                        #                        feedplotG += [Gi[10]]
                        #                        print(feedplotG)

                        e_phs = np.dot(ef[inds].conj(),
                                       Gi[inds] / y[inds]) / len(inds)
                        ea = np.abs(e_phs)

                        if np.abs(ea - 1.0) < 0.1:
                            # compute gain for this feed
                            lgain[ii] = mag * e_phs
                            if save_phs_change:
                                lphs[:, ii] = np.angle(
                                    np.exp(-2.0J * np.pi * np.dot(n0t, ui)) *
                                    Gain.local_array[:, ii])
                        else:
                            e_phs_conj = np.dot(ef[inds],
                                                Gi[inds] / y[inds]) / len(inds)
                            eac = np.abs(e_phs_conj)
                            if eac > ea:
                                if np.abs(eac - 1.0) < 0.01:
                                    print 'feedno = %d, fi = %d, pol = %s: may need to be conjugated' % (
                                        feedno[di], fi, gain_pd[pi])
                            else:
                                print 'feedno = %d, fi = %d, pol = %s: maybe wrong abs(e_phs): %s' % (
                                    feedno[di], fi, gain_pd[pi], ea)

#                import matplotlib.pyplot as plt
#                import os.path
#                feedplotG = np.array(feedplotG)
#                if mpiutil.rank0:
#                    if os.path.isfile('feedplotG%d.npy'%savenum):
#                        savenum += 1
#                    np.save('feedplotG%d'%savenum,feedplotG)
#                    plt.plot(feedplotG)
#                    plt.savefig('Gfig%d'%savenum)

# gather local gain
                gain = mpiutil.gather_array(lgain,
                                            axis=0,
                                            root=None,
                                            comm=ts.comm)
                del lgain
                gain = gain.reshape(nf, 2, nfeed)
                if save_phs_change:
                    phs = mpiutil.gather_array(lphs,
                                               axis=1,
                                               root=0,
                                               comm=ts.comm)
                    del lphs
                    if mpiutil.rank0:
                        phs = phs.reshape(nt, nf, 2, nfeed)

                # apply gain to vis
                if apply_gain:
                    for fi in range(nf):
                        for pi in [pol.index('xx'), pol.index('yy')]:
                            pi_ = gain_pd[pol[pi]]
                            for bi, (fd1, fd2) in enumerate(
                                    ts['blorder'].local_data):
                                g1 = gain[fi, pi_, feedno.index(fd1)]
                                g2 = gain[fi, pi_, feedno.index(fd2)]
                                if np.isfinite(g1) and np.isfinite(g2):
                                    ts.local_vis[:, fi, pi,
                                                 bi] /= (g1 * np.conj(g2))
                                else:
                                    # mask the un-calibrated vis
                                    ts.local_vis_mask[:, fi, pi, bi] = True

                # save gain to file
                if save_gain:
                    if tag_output_iter:
                        gain_file = output_path(gain_file,
                                                iteration=self.iteration)
                    else:
                        gain_file = output_path(gain_file)
                    if mpiutil.rank0:
                        with h5py.File(gain_file, 'w') as f:
                            # allocate space for Gain
                            dset = f.create_dataset('Gain', (nt, nf, 2, nfeed),
                                                    dtype=Gain.dtype)
                            dset.attrs['calibrator'] = calibrator
                            dset.attrs['dim'] = 'time, freq, pol, feed'
                            try:
                                dset.attrs['time'] = ts.time[start_ind:end_ind]
                            except RuntimeError:
                                f.create_dataset(
                                    'time', data=ts.time[start_ind:end_ind])
                                dset.attrs['time'] = '/time'
                            dset.attrs['freq'] = freq
                            dset.attrs['pol'] = np.array(['xx', 'yy'])
                            dset.attrs['feed'] = np.array(feedno)
                            # save gain
                            dset = f.create_dataset('gain', data=gain)
                            dset.attrs['calibrator'] = calibrator
                            dset.attrs['dim'] = 'freq, pol, feed'
                            dset.attrs['freq'] = freq
                            dset.attrs['pol'] = np.array(['xx', 'yy'])
                            dset.attrs['feed'] = np.array(feedno)
                            # save phs

                            if save_phs_change:
                                f.create_dataset('phs', data=phs)
                            # save transit index and transit time for the case do the ps cal first
                            f.attrs['transit_index'] = transit_inds[0]
                            #                            f.attrs['transit_time'] = ts['jul_date'][transit_inds[0]]
                            f.attrs['transit_jul'] = ts['jul_date'][
                                transit_inds[0]]
                            f.attrs['transit_time'] = ts['sec1970'][
                                transit_inds[0]]
                            #                            f.attrs['transit_time'] = ts.attrs['sec1970'] + ts.attrs['inttime']*transit_inds[0] # in sec1970 to improve numeric precision

                            if os.path.exists(output_path(
                                    ts.ns_gain_file)) and not ts.ps_first:
                                with h5py.File(output_path(ts.ns_gain_file),
                                               'r+') as ns_file:
                                    phs_only = not ('ns_cal_amp'
                                                    in ns_file.keys())
                                    #                                    exclude_bad = 'badchn' in ns_file['channo'].attrs.keys()
                                    new_gain = uni_gain(f,
                                                        ns_file,
                                                        phs_only=phs_only)
                                    ns_file.create_dataset('uni_gain',
                                                           data=new_gain)
                                    ns_file['uni_gain'].attrs[
                                        'dim'] = '(time, freq, bl)'

                    mpiutil.barrier()

                    # save Gain
                    for i in range(10):
                        try:
                            # NOTE: if write simultaneously, will loss data with processes distributed in several nodes
                            for ri in xrange(mpiutil.size):
                                if ri == mpiutil.rank:
                                    with h5py.File(gain_file, 'r+') as f:
                                        for ii, (ti, fi,
                                                 pi) in enumerate(tfp_linds):
                                            ti_ = ti - start_ind
                                            pi_ = gain_pd[pol[pi]]
                                            f['Gain'][ti_, fi, pi_] = lGain[ii]
                                mpiutil.barrier()
                            break
                        except IOError:
                            time.sleep(0.5)
                            continue
                    else:
                        raise RuntimeError('Could not open file: %s...' %
                                           gain_file)

                    mpiutil.barrier()

        # convert vis from intensity unit to temperature unit in K
        if temperature_convert:
            if 'unit' in ts.vis.attrs.keys() and ts.vis.attrs['unit'] == 'K':
                if mpiutil.rank0:
                    print 'vis is already in unit K, do nothing...'
            else:
                factor = 1.0e-26 * (const.c**2 / (2 * const.k_B *
                                                  (1.0e6 * freq)**2)
                                    )  # NOTE: 1Jy = 1.0e-26 W m^-2 Hz^-1
                ts.local_vis[:] *= factor[np.newaxis, :, np.newaxis,
                                          np.newaxis]
                ts.vis.attrs['unit'] = 'K'

        return super(PsCal, self).process(ts)
Exemple #13
0
    def process(self, ts):

        tsys = self.params['tsys']
        accuracy_boost = self.params['accuracy_boost']
        l_boost = self.params['l_boost']
        bl_range = self.params['bl_range']
        auto_correlations = self.params['auto_correlations']
        pol = self.params['pol']
        beam_dir = output_path(self.params['beam_dir'])
        noise_weight = self.params['noise_weight']
        ts_dir = output_path(self.params['ts_dir'])
        ts_name = self.params['ts_name']
        no_m_zero = self.params['no_m_zero']

        assert isinstance(
            ts, Timestream
        ), '%s only works for Timestream object' % self.__class__.__name__

        ts.redistribute('time')

        lat = ts.attrs['sitelat']
        # lon = ts.attrs['sitelon']
        lon = 0.0
        # lon = np.degrees(ts['ra_dec'][0, 0]) # the first ra
        local_origin = False
        freqs = ts.freq[:]  # MHz
        nfreq = freqs.shape[0]
        band_width = ts.attrs['freqstep']  # MHz
        try:
            ndays = ts.attrs['ndays']
        except KeyError:
            ndays = 1
        feeds = ts['feedno'][:]
        bls = [tuple(bl) for bl in ts.bl]
        az, alt = ts['az_alt'].local_data[
            0]  # assume fixed az, alt during the observation
        az = np.degrees(az)
        alt = np.degrees(alt)
        pointing = [az, alt, 0.0]
        feedpos = ts['feedpos'][:]

        if ts.is_dish:
            from tlpipe.map.drift.telescope import tl_dish

            dish_width = ts.attrs['dishdiam']
            tel = tl_dish.TlUnpolarisedDishArray(lat, lon, freqs, band_width,
                                                 tsys, ndays, accuracy_boost,
                                                 l_boost, bl_range,
                                                 auto_correlations,
                                                 local_origin, dish_width,
                                                 feedpos, pointing)
        elif ts.is_cylinder:
            from tlpipe.map.drift.telescope import tl_cylinder

            # factor = 1.2 # suppose an illumination efficiency, keep same with that in timestream_common
            factor = 0.79  # for xx
            # factor = 0.88 # for yy
            cyl_width = factor * ts.attrs['cywid']
            tel = tl_cylinder.TlUnpolarisedCylinder(
                lat, lon, freqs, band_width, tsys, ndays, accuracy_boost,
                l_boost, bl_range, auto_correlations, local_origin, cyl_width,
                feedpos)
        else:
            raise RuntimeError('Unknown array type %s' % ts.attrs['telescope'])

        allpairs = tel.allpairs
        redundancy = tel.redundancy
        red_bin = np.cumsum(np.insert(redundancy, 0, 0))  # redundancy bin
        unqpairs = tel.uniquepairs
        nuq = len(unqpairs)  # number of unique pairs

        # to save m-mode
        if mpiutil.rank0:
            # large array only in rank0 to save memory
            mmode = np.zeros((2 * tel.mmax + 1, nfreq, nuq),
                             dtype=np.complex128)
            N = np.zeros((nfreq, nuq),
                         dtype=np.int)  # number of accumulate terms

        # mmode of a specific unique pair
        mmodeqi = np.zeros((2 * tel.mmax + 1, nfreq), dtype=np.complex128)
        Nqi = np.zeros((nfreq), dtype=np.int)  # number of accumulate terms

        start_ra = ts.vis.attrs['start_ra']
        ra = mpiutil.gather_array(ts['ra_dec'].local_data[:, 0], root=None)
        ra = np.unwrap(ra)
        # find the first index that ra closest to start_ra
        ind = np.searchsorted(ra, start_ra)
        if np.abs(ra[ind] - start_ra) > np.abs(ra[ind + 1] - start_ra):
            ind = ind + 1

        # get number of int_time in one sidereal day
        num_int = np.int(np.around(1.0 * const.sday / ts.attrs['inttime']))
        nt = ts.vis.shape[0]
        nt1 = min(num_int, nt - ind)

        inds = np.arange(nt)
        local_inds = mpiutil.scatter_array(inds, root=None)

        local_phi = ts['ra_dec'].local_data[:, 0]
        # the Fourier transfom matrix
        E = np.exp(-1.0J *
                   np.outer(np.arange(-tel.mmax, tel.mmax + 1), local_phi))

        # pols to consider
        pol_str = [ts.pol_dict[p] for p in ts['pol'][:]]  # as string
        if pol == 'xx' or pol == 'yy':
            pis = [pol_str.index(pol)]
        elif pol == 'I':
            pis = [pol_str.index('xx'), pol_str.index('yy')]
        else:
            raise ValueError('Invalid pol: %s' % pol)

        # compute mmodes for each unique pair
        for qi in range(nuq):
            mmodeqi[:] = 0
            Nqi[:] = 0
            this_pairs = allpairs[red_bin[qi]:red_bin[qi + 1]]
            for a1, a2 in this_pairs:
                for pi in pis:
                    try:
                        b_ind = bls.index((feeds[a1], feeds[a2]))
                        V = ts.local_vis[:, :, pi, b_ind]
                    except ValueError:
                        b_ind = bls.index((feeds[a2], feeds[a1]))
                        V = ts.local_vis[:, :, pi, b_ind].conj()
                    M = ts.local_vis_mask[:, :, pi, b_ind]  # mask
                    # mask time points that are outside of this day
                    M[local_inds < ind, :] = True
                    M[local_inds >= ind + nt1, :] = True
                    V = np.where(M, 0, V)  # fill masked values with 0
                    M = M.astype(np.int)
                    # mmode[:, :, qi] += np.dot(E, V)
                    # N[:, qi] += np.sum(M, axis=0)
                    mmodeqi += np.dot(E, V)
                    Nqi += np.sum(M, axis=0)

            mpiutil.barrier()

            # accumulate mmode from all processes by Reduce
            if mpiutil.size > 1:  # more than one processes
                if mpiutil.rank0:
                    # use IN_PLACE to reuse the mmode and N array
                    mpiutil.world.Reduce(mpiutil.IN_PLACE,
                                         mmodeqi,
                                         op=mpiutil.SUM,
                                         root=0)
                    mpiutil.world.Reduce(mpiutil.IN_PLACE,
                                         Nqi,
                                         op=mpiutil.SUM,
                                         root=0)
                else:
                    mpiutil.world.Reduce(mmodeqi,
                                         mmodeqi,
                                         op=mpiutil.SUM,
                                         root=0)
                    mpiutil.world.Reduce(Nqi, Nqi, op=mpiutil.SUM, root=0)

            if mpiutil.rank0:
                mmode[:, :, qi] = mmodeqi
                N[:, qi] = Nqi

        del ts
        del E

        # beamtransfer
        bt = beamtransfer.BeamTransfer(beam_dir, tel, noise_weight, True)
        # timestream
        tstream = timestream.Timestream(ts_dir, ts_name, bt, no_m_zero)

        if mpiutil.rank0:
            # reshape mmode toseparate positive and negative ms
            mmode1 = np.zeros((tel.mmax + 1, nfreq, 2, nuq), dtype=mmode.dtype)
            mmode1[0, :, 0] = mmode[tel.mmax]
            for mi in range(1, tel.mmax + 1):
                mmode1[mi, :, 0] = mmode[tel.mmax + mi]
                mmode1[mi, :, 1] = mmode[tel.mmax - mi].conj()

            del mmode

            # normalize mmode
            # mmode1 /= N[np.newaxis, :, np.newaxis, :]

            # save mmode to file
            mmode_dir = tstream.output_directory + '/mmodes'
            if os.path.exists(mmode_dir + '/COMPLETED_M'):
                # update the already existing mmodes
                for mi in range(tel.mmax + 1):
                    with h5py.File(tstream._mfile(mi), 'r+') as f:
                        f['/mmode'][:] += mmode1[mi]
                with h5py.File(mmode_dir + '/count.hdf5', 'r+') as f:
                    f['count'][:] += N
            else:
                for mi in range(tel.mmax + 1):
                    # make directory for each m-mode
                    if not os.path.exists(tstream._mdir(mi)):
                        os.makedirs(tstream._mdir(mi))

                    # create the m-file and save the result.
                    with h5py.File(tstream._mfile(mi), 'w') as f:
                        f.create_dataset('/mmode', data=mmode1[mi])
                        f.attrs['m'] = mi

                with h5py.File(mmode_dir + '/count.hdf5', 'w') as f:
                    f.create_dataset('count', data=N)

                # Make file marker that the m's have been correctly generated:
                open(mmode_dir + '/COMPLETED_M', 'a').close()

                # save the tstream object
                tstream.save()

        mpiutil.barrier()

        return tstream
Exemple #14
0
    def process(self, rt):

        assert isinstance(rt, RawTimestream), '%s only works for RawTimestream object currently' % self.__class__.__name__

        if not 'ns_on' in rt.iterkeys():
            raise RuntimeError('No noise source info, can not do noise source calibration')

        rt.redistribute('baseline')

        num_mean = self.params['num_mean']
        phs_only = self.params['phs_only']
        save_gain = self.params['save_gain']
        tag_output_iter = self.params['tag_output_iter']
        gain_file = self.params['gain_file']
        bl_incl = self.params['bl_incl']
        bl_excl = self.params['bl_excl']
        freq_incl = self.params['freq_incl']
        freq_excl = self.params['freq_excl']

        nt = rt.local_vis.shape[0]
        if num_mean <= 0:
            raise RuntimeError('Invalid num_mean = %s' % num_mean)
        ns_on = rt['ns_on'][:]
        ns_on = np.where(ns_on, 1, 0)
        diff_ns = np.diff(ns_on)
        inds = np.where(diff_ns==1)[0] # NOTE: these are inds just 1 before the first ON
        if inds[0]-1 < 0: # no off data in the beginning to use
            inds = inds[1:]
        if inds[-1]+2 > nt-1: # no on data in the end to use
            inds = inds[:-1]

        if save_gain:
            num_inds = len(inds)
            shp = (num_inds,)+rt.local_vis.shape[1:]
            dtype = rt.local_vis.real.dtype
            # create dataset to record ns_cal_time_inds
            rt.create_time_ordered_dataset('ns_cal_time_inds', inds)
            # create dataset to record ns_cal_phase
            ns_cal_phase = np.empty(shp, dtype=dtype)
            ns_cal_phase[:] = np.nan
            ns_cal_phase = mpiarray.MPIArray.wrap(ns_cal_phase, axis=2, comm=rt.comm)
            rt.create_freq_and_bl_ordered_dataset('ns_cal_phase', ns_cal_phase, axis_order=(None, 1, 2))
            rt['ns_cal_phase'].attrs['unit'] = 'radians'
            if not phs_only:
                # create dataset to record ns_cal_amp
                ns_cal_amp = np.empty(shp, dtype=dtype)
                ns_cal_amp[:] = np.nan
                ns_cal_amp = mpiarray.MPIArray.wrap(ns_cal_amp, axis=2, comm=rt.comm)
                rt.create_freq_and_bl_ordered_dataset('ns_cal_amp', ns_cal_amp, axis_order=(None, 1, 2))

        if bl_incl == 'all':
            bls_plt = [ tuple(bl) for bl in rt.bl ]
        else:
            bls_plt = [ bl for bl in bl_incl if not bl in bl_excl ]

        if freq_incl == 'all':
            freq_plt = range(rt.freq.shape[0])
        else:
            freq_plt = [ fi for fi in freq_incl if not fi in freq_excl ]

        show_progress = self.params['show_progress']
        progress_step = self.params['progress_step']

        rt.freq_and_bl_data_operate(self.cal, full_data=True, show_progress=show_progress, progress_step=progress_step, keep_dist_axis=False, num_mean=num_mean, inds=inds, bls_plt=bls_plt, freq_plt=freq_plt)

        if save_gain:
            # gather bl_order to rank0
            bl_order = mpiutil.gather_array(rt['blorder'].local_data, axis=0, root=0, comm=rt.comm)
            # gather ns_cal_phase / ns_cal_amp to rank 0
            ns_cal_phase = mpiutil.gather_array(rt['ns_cal_phase'].local_data, axis=2, root=0, comm=rt.comm)
            phs_unit = rt['ns_cal_phase'].attrs['unit']
            rt.delete_a_dataset('ns_cal_phase', reserve_hint=False)
            if not phs_only:
                ns_cal_amp = mpiutil.gather_array(rt['ns_cal_amp'].local_data, axis=2, root=0, comm=rt.comm)
                rt.delete_a_dataset('ns_cal_amp', reserve_hint=False)

            if tag_output_iter:
                gain_file = output_path(gain_file, iteration=self.iteration)
            else:
                gain_file = output_path(gain_file)
            if mpiutil.rank0:
                with h5py.File(gain_file, 'w') as f:
                    # save time
                    f.create_dataset('time', data=rt['jul_date'][:])
                    f['time'].attrs['unit'] = 'Julian date'
                    # save freq
                    f.create_dataset('freq', data=rt['freq'][:])
                    f['freq'].attrs['unit'] = rt['freq'].attrs['unit']
                    # save bl
                    f.create_dataset('bl_order', data=bl_order)
                    # save ns_cal_time_inds
                    f.create_dataset('ns_cal_time_inds', data=rt['ns_cal_time_inds'][:])
                    # save ns_cal_phase
                    f.create_dataset('ns_cal_phase', data=ns_cal_phase)
                    f['ns_cal_phase'].attrs['unit'] = phs_unit
                    f['ns_cal_phase'].attrs['dim'] = '(time, freq, bl)'
                    if not phs_only:
                        # save ns_cal_amp
                        f.create_dataset('ns_cal_amp', data=ns_cal_amp)

            rt.delete_a_dataset('ns_cal_time_inds', reserve_hint=False)

        return super(NsCal, self).process(rt)
Exemple #15
0
    def process(self, ts):

        calibrator = self.params['calibrator']
        catalog = self.params['catalog']
        span = self.params['span']
        save_gain = self.params['save_gain']
        gain_file = self.params['gain_file']

        ts.redistribute('frequency')

        lfreq = ts.local_freq[:]  # local freq

        feedno = ts['feedno'][:].tolist()
        pol = ts['pol'][:].tolist()
        bl = ts.bl[:]
        bls = [tuple(b) for b in bl]
        # # antpointing = np.radians(ts['antpointing'][-1, :, :]) # radians
        # transitsource = ts['transitsource'][:]
        # transit_time = transitsource[-1, 0] # second, sec1970
        # int_time = ts.attrs['inttime'] # second

        # calibrator
        srclist, cutoff, catalogs = a.scripting.parse_srcs(calibrator, catalog)
        cat = a.src.get_catalog(srclist, cutoff, catalogs)
        assert (len(cat) == 1), 'Allow only one calibrator'
        s = cat.values()[0]
        if mpiutil.rank0:
            print 'Calibrating for source %s with' % calibrator,
            print 'strength', s._jys, 'Jy',
            print 'measured at', s.mfreq, 'GHz',
            print 'with index', s.index

        # get transit time of calibrator
        # array
        aa = ts.array
        aa.set_jultime(ts['jul_date'][0])  # the first obs time point
        next_transit = aa.next_transit(s)
        transit_time = a.phs.ephem2juldate(next_transit)  # Julian date
        if transit_time > ts['jul_date'][-1]:
            local_next_transit = ephem.Date(next_transit + 8.0 * ephem.hour)
            raise RuntimeError(
                'Data does not contain local transit time %s of source %s' %
                (local_next_transit, calibrator))

        # the first transit index
        transit_inds = [np.searchsorted(ts['jul_date'][:], transit_time)]
        # find all other transit indices
        aa.set_jultime(ts['jul_date'][0] + 1.0)
        transit_time = a.phs.ephem2juldate(aa.next_transit(s))  # Julian date
        cnt = 2
        while (transit_time <= ts['jul_date'][-1]):
            transit_inds.append(
                np.searchsorted(ts['jul_date'][:], transit_time))
            aa.set_jultime(ts['jul_date'][0] + 1.0 * cnt)
            transit_time = a.phs.ephem2juldate(
                aa.next_transit(s))  # Julian date
            cnt += 1

        print transit_inds

        ### now only use the first transit point to do the cal
        ### may need to improve in the future
        transit_ind = transit_inds[0]
        int_time = ts.attrs['inttime']  # second
        start_ind = transit_ind - np.int(span / int_time)
        end_ind = transit_ind + np.int(span / int_time)

        nt = end_ind - start_ind
        nfeed = len(feedno)
        eigval = np.empty((nt, nfeed, 2, len(lfreq)), dtype=np.float64)
        eigval[:] = np.nan
        gain = np.empty((nt, nfeed, 2, len(lfreq)), dtype=np.complex128)
        gain[:] = complex(np.nan, np.nan)

        # construct visibility matrix for a single time, pol, freq
        Vmat = np.zeros((nfeed, nfeed), dtype=ts.main_data.dtype)
        for ind, ti in enumerate(range(start_ind, end_ind)):
            # when noise on, just pass
            if 'ns_on' in ts.iterkeys() and ts['ns_on'][ti]:
                continue
            aa.set_jultime(ts['jul_date'][ti])
            s.compute(aa)
            # get fluxes vs. freq of the calibrator
            Sc = s.get_jys()
            # get the topocentric coordinate of the calibrator at the current time
            s_top = s.get_crds('top', ncrd=3)
            aa.sim_cache(cat.get_crds(
                'eq', ncrd=3))  # for compute bm_response and sim
            for pi in [pol.index('xx'), pol.index('yy')]:  # xx, yy
                aa.set_active_pol(pol[pi])
                for fi, freq in enumerate(lfreq):  # mpi among freq
                    for i, ai in enumerate(feedno):
                        for j, aj in enumerate(feedno):
                            # uij = aa.gen_uvw(i, j, src='z').squeeze() # (rj - ri)/lambda
                            uij = aa.gen_uvw(i, j,
                                             src='z')[:,
                                                      0, :]  # (rj - ri)/lambda
                            # bmij = aa.bm_response(i, j).squeeze() # will get error for only one local freq
                            # import pdb
                            # pdb.set_trace()
                            bmij = aa.bm_response(i, j).reshape(-1)
                            try:
                                bi = bls.index((ai, aj))
                                # Vmat[i, j] = ts.local_vis[ti, fi, pi, bi] / (Sc[fi] * bmij[fi] * np.exp(-2.0J * np.pi * np.dot(s_top, uij[:, fi]))) # xx, yy
                                Vmat[i, j] = ts.local_vis[ti, fi, pi, bi] / (
                                    Sc[fi] * bmij[fi] *
                                    np.exp(2.0J * np.pi * np.dot(
                                        s_top, uij[:, fi])))  # xx, yy
                            except ValueError:
                                bi = bls.index((aj, ai))
                                # Vmat[i, j] = np.conj(ts.local_vis[ti, fi, pi, bi] / (Sc[fi] * bmij[fi] * np.exp(-2.0J * np.pi * np.dot(s_top, uij[:, fi])))) # xx, yy
                                Vmat[i, j] = np.conj(
                                    ts.local_vis[ti, fi, pi, bi] /
                                    (Sc[fi] * bmij[fi] *
                                     np.exp(2.0J * np.pi * np.dot(
                                         s_top, uij[:, fi]))))  # xx, yy

                    # Eigen decomposition

                    Vmat = np.where(np.isfinite(Vmat), Vmat, 0)

                    e, U = eigh(Vmat)
                    eigval[ind, :, pi, fi] = e[::-1]  # descending order
                    # max eigen-val
                    lbd = e[-1]  # lambda
                    # the gain vector for this freq
                    gvec = np.sqrt(
                        lbd
                    ) * U[:,
                          -1]  # only eigen-vector corresponding to the maximum eigen-val
                    gain[ind, :, pi, fi] = gvec

        # apply gain to vis
        # get the time mean gain
        tgain = np.ma.mean(np.ma.masked_invalid(gain), axis=0)  # time mean
        tgain = mpiutil.gather_array(tgain, axis=-1, root=None)

        ts.redistribute('baseline')
        ts.pol_and_bl_data_operate(cal, tgain=tgain)

        # save gain if required:
        if save_gain:
            gain_file = output_path(gain_file)
            eigval = mpiarray.MPIArray.wrap(eigval, axis=3)
            gain = mpiarray.MPIArray.wrap(gain, axis=3)
            mem_gain = memh5.MemGroup(distributed=True)
            mem_gain.create_dataset('eigval', data=eigval)
            mem_gain.create_dataset('gain', data=gain)
            # add attris
            mem_gain.attrs['jul_data'] = ts['jul_date'][start_ind:end_ind]
            mem_gain.attrs['feed'] = np.array(feedno)
            mem_gain.attrs['pol'] = np.array(['xx', 'yy'])
            mem_gain.attrs['freq'] = ts.freq[:]  # freq should be common

            # save to file
            mem_gain.to_hdf5(gain_file, hints=False)

        return super(PsCal, self).process(ts)
Exemple #16
0
    def process(self, ts):

        assert isinstance(
            ts, Timestream
        ), '%s only works for Timestream object' % self.__class__.__name__

        calibrator = self.params['calibrator']
        catalog = self.params['catalog']
        span = self.params['span']
        plot_figs = self.params['plot_figs']
        fig_prefix = self.params['fig_name']
        tag_output_iter = self.params['tag_output_iter']
        save_gain = self.params['save_gain']
        gain_file = self.params['gain_file']
        temperature_convert = self.params['temperature_convert']

        ts.redistribute('baseline')

        feedno = ts['feedno'][:].tolist()
        pol = [ts.pol_dict[p] for p in ts['pol'][:]]
        bl = mpiutil.gather_array(ts.local_bl[:], root=None, comm=ts.comm)
        bls = [tuple(b) for b in bl]
        # # antpointing = np.radians(ts['antpointing'][-1, :, :]) # radians
        # transitsource = ts['transitsource'][:]
        # transit_time = transitsource[-1, 0] # second, sec1970
        # int_time = ts.attrs['inttime'] # second

        # calibrator
        srclist, cutoff, catalogs = a.scripting.parse_srcs(calibrator, catalog)
        cat = a.src.get_catalog(srclist, cutoff, catalogs)
        assert (len(cat) == 1), 'Allow only one calibrator'
        s = cat.values()[0]
        if mpiutil.rank0:
            print 'Calibrating for source %s with' % calibrator,
            print 'strength', s._jys, 'Jy',
            print 'measured at', s.mfreq, 'GHz',
            print 'with index', s.index

        # get transit time of calibrator
        # array
        aa = ts.array
        aa.set_jultime(ts['jul_date'][0])  # the first obs time point
        next_transit = aa.next_transit(s)
        transit_time = a.phs.ephem2juldate(next_transit)  # Julian date
        # get time zone
        pattern = '[-+]?\d+'
        tz = re.search(pattern, ts.attrs['timezone']).group()
        tz = int(tz)
        local_next_transit = ephem.Date(
            next_transit + tz * ephem.hour)  # plus 8h to get Beijing time
        if transit_time > ts['jul_date'][-1]:
            raise RuntimeError(
                'Data does not contain local transit time %s of source %s' %
                (local_next_transit, calibrator))

        # the first transit index
        transit_inds = [np.searchsorted(ts['jul_date'][:], transit_time)]
        # find all other transit indices
        aa.set_jultime(ts['jul_date'][0] + 1.0)
        transit_time = a.phs.ephem2juldate(aa.next_transit(s))  # Julian date
        cnt = 2
        while (transit_time <= ts['jul_date'][-1]):
            transit_inds.append(
                np.searchsorted(ts['jul_date'][:], transit_time))
            aa.set_jultime(ts['jul_date'][0] + 1.0 * cnt)
            transit_time = a.phs.ephem2juldate(
                aa.next_transit(s))  # Julian date
            cnt += 1

        if mpiutil.rank0:
            print 'transit ind of %s: %s, time: %s' % (
                calibrator, transit_inds, local_next_transit)

        ### now only use the first transit point to do the cal
        ### may need to improve in the future
        transit_ind = transit_inds[0]
        int_time = ts.attrs['inttime']  # second
        start_ind = transit_ind - np.int(span / int_time)
        end_ind = transit_ind + np.int(
            span / int_time) + 1  # plus 1 to make transit_ind is at the center

        # check if data contain this range
        if start_ind < 0:
            raise RuntimeError('start_ind: %d < 0' % start_ind)
        if end_ind > ts.vis.shape[0]:
            raise RuntimeError('end_ind: %d > %d' % (end_ind, ts.vis.shape[0]))

        ############################################
        # if ts.is_cylinder:
        #     ts.local_vis[:] = ts.local_vis.conj() # now for cylinder array
        ############################################

        nt = end_ind - start_ind
        t_inds = range(start_ind, end_ind)
        freq = ts.freq[:]
        nf = len(freq)
        nlb = len(ts.local_bl[:])
        nfeed = len(feedno)
        tfp_inds = list(
            itertools.product(
                t_inds, range(nf),
                [pol.index('xx'), pol.index('yy')]))  # only for xx and yy
        ns, ss, es = mpiutil.split_all(len(tfp_inds), comm=ts.comm)
        # gather data to make each process to have its own data which has all bls
        for ri, (ni, si, ei) in enumerate(zip(ns, ss, es)):
            lvis = np.zeros((ni, nlb), dtype=ts.vis.dtype)
            lvis_mask = np.zeros((ni, nlb), dtype=ts.vis_mask.dtype)
            for ii, (ti, fi, pi) in enumerate(tfp_inds[si:ei]):
                lvis[ii] = ts.local_vis[ti, fi, pi]
                lvis_mask[ii] = ts.local_vis_mask[ti, fi, pi]
            # gather vis from all process for separate bls
            gvis = mpiutil.gather_array(lvis, axis=1, root=ri, comm=ts.comm)
            gvis_mask = mpiutil.gather_array(lvis_mask,
                                             axis=1,
                                             root=ri,
                                             comm=ts.comm)
            if ri == mpiutil.rank:
                tfp_linds = tfp_inds[si:ei]  # inds for this process
                this_vis = gvis
                this_vis_mask = gvis_mask
        del tfp_inds
        del lvis
        del lvis_mask
        lGain = np.empty((len(tfp_linds), nfeed), dtype=np.complex128)
        lGain[:] = complex(np.nan, np.nan)

        # construct visibility matrix for a single time, freq, pol
        Vmat = np.zeros((nfeed, nfeed), dtype=ts.vis.dtype)
        Sc = s.get_jys()
        for ii, (ti, fi, pi) in enumerate(tfp_linds):
            # when noise on, just pass
            if 'ns_on' in ts.iterkeys() and ts['ns_on'][ti]:
                continue
            # aa.set_jultime(ts['jul_date'][ti])
            # s.compute(aa)
            # get fluxes vs. freq of the calibrator
            # Sc = s.get_jys()
            # get the topocentric coordinate of the calibrator at the current time
            # s_top = s.get_crds('top', ncrd=3)
            # aa.sim_cache(cat.get_crds('eq', ncrd=3)) # for compute bm_response and sim
            mask_cnt = 0
            for i, ai in enumerate(feedno):
                for j, aj in enumerate(feedno):
                    try:
                        bi = bls.index((ai, aj))
                        if this_vis_mask[ii, bi] and not np.isfinite(
                                this_vis[ii, bi]):
                            mask_cnt += 1
                            Vmat[i, j] = 0
                        else:
                            Vmat[i, j] = this_vis[ii, bi] / Sc[fi]  # xx, yy
                    except ValueError:
                        bi = bls.index((aj, ai))
                        if this_vis_mask[ii, bi] and not np.isfinite(
                                this_vis[ii, bi]):
                            mask_cnt += 1
                            Vmat[i, j] = 0
                        else:
                            Vmat[i, j] = np.conj(this_vis[ii, bi] /
                                                 Sc[fi])  # xx, yy

            # if too many masks
            if mask_cnt > 0.3 * nfeed**2:
                continue

            # set invalid val to 0
            # Vmat = np.where(np.isfinite(Vmat), Vmat, 0)

            # initialize the outliers
            med = np.median(Vmat.real) + 1.0J * np.median(Vmat.imag)
            diff = Vmat - med
            S0 = np.where(np.abs(diff) > 3.0 * rpca_decomp.MAD(Vmat), diff, 0)
            # stable PCA decomposition
            V0, S = rpca_decomp.decompose(Vmat,
                                          rank=1,
                                          S=S0,
                                          max_iter=100,
                                          threshold='hard',
                                          tol=1.0e-6,
                                          debug=False)
            # V0, S = rpca_decomp.decompose(Vmat, rank=1, S=S0, max_iter=100, threshold='soft', tol=1.0e-6, debug=False)

            # plot
            if plot_figs:
                ind = ti - start_ind
                # plot Vmat
                plt.figure(figsize=(13, 5))
                plt.subplot(121)
                plt.imshow(Vmat.real,
                           aspect='equal',
                           origin='lower',
                           interpolation='nearest')
                plt.colorbar(shrink=1.0)
                plt.subplot(122)
                plt.imshow(Vmat.imag,
                           aspect='equal',
                           origin='lower',
                           interpolation='nearest')
                plt.colorbar(shrink=1.0)
                fig_name = '%s_V_%d_%d_%s.png' % (fig_prefix, ind, fi,
                                                  ts.pol_dict[pi])
                if tag_output_iter:
                    fig_name = output_path(fig_name, iteration=self.iteration)
                else:
                    fig_name = output_path(fig_name)
                plt.savefig(fig_name)
                plt.close()
                # plot V0
                plt.figure(figsize=(13, 5))
                plt.subplot(121)
                plt.imshow(V0.real,
                           aspect='equal',
                           origin='lower',
                           interpolation='nearest')
                plt.colorbar(shrink=1.0)
                plt.subplot(122)
                plt.imshow(V0.imag,
                           aspect='equal',
                           origin='lower',
                           interpolation='nearest')
                plt.colorbar(shrink=1.0)
                fig_name = '%s_V0_%d_%d_%s.png' % (fig_prefix, ind, fi,
                                                   ts.pol_dict[pi])
                if tag_output_iter:
                    fig_name = output_path(fig_name, iteration=self.iteration)
                else:
                    fig_name = output_path(fig_name)
                plt.savefig(fig_name)
                plt.close()
                # plot S
                plt.figure(figsize=(13, 5))
                plt.subplot(121)
                plt.imshow(S.real,
                           aspect='equal',
                           origin='lower',
                           interpolation='nearest')
                plt.colorbar(shrink=1.0)
                plt.subplot(122)
                plt.imshow(S.imag,
                           aspect='equal',
                           origin='lower',
                           interpolation='nearest')
                plt.colorbar(shrink=1.0)
                fig_name = '%s_S_%d_%d_%s.png' % (fig_prefix, ind, fi,
                                                  ts.pol_dict[pi])
                if tag_output_iter:
                    fig_name = output_path(fig_name, iteration=self.iteration)
                else:
                    fig_name = output_path(fig_name)
                plt.savefig(fig_name)
                plt.close()
                # plot N
                N = Vmat - V0 - S
                plt.figure(figsize=(13, 5))
                plt.subplot(121)
                plt.imshow(N.real,
                           aspect='equal',
                           origin='lower',
                           interpolation='nearest')
                plt.colorbar(shrink=1.0)
                plt.subplot(122)
                plt.imshow(N.imag,
                           aspect='equal',
                           origin='lower',
                           interpolation='nearest')
                plt.colorbar(shrink=1.0)
                fig_name = '%s_N_%d_%d_%s.png' % (fig_prefix, ind, fi,
                                                  ts.pol_dict[pi])
                if tag_output_iter:
                    fig_name = output_path(fig_name, iteration=self.iteration)
                else:
                    fig_name = output_path(fig_name)
                plt.savefig(fig_name)
                plt.close()

            e, U = la.eigh(V0, eigvals=(nfeed - 1, nfeed - 1))
            g = U[:, -1] * e[-1]**0.5
            lGain[ii] = g

            # plot Gain
            if plot_figs:
                plt.figure()
                plt.plot(feedno, g.real, 'b-', label='real')
                plt.plot(feedno, g.real, 'bo')
                plt.plot(feedno, g.imag, 'g-', label='imag')
                plt.plot(feedno, g.imag, 'go')
                plt.plot(feedno, np.abs(g), 'r-', label='abs')
                plt.plot(feedno, np.abs(g), 'ro')
                plt.xlim(feedno[0] - 1, feedno[-1] + 1)
                yl, yh = plt.ylim()
                plt.ylim(yl, yh + (yh - yl) / 5)
                plt.xlabel('Feed number')
                plt.legend()
                fig_name = '%s_ants_%d_%d_%s.png' % (fig_prefix, ind, fi,
                                                     ts.pol_dict[pi])
                if tag_output_iter:
                    fig_name = output_path(fig_name, iteration=self.iteration)
                else:
                    fig_name = output_path(fig_name)
                plt.savefig(fig_name)
                plt.close()

        # gather Gain from each processes
        Gain = mpiutil.gather_array(lGain, axis=0, root=None, comm=ts.comm)
        Gain = Gain.reshape(nt, nf, 2, nfeed)

        # choose data slice near the transit time
        c = nt / 2  # center ind
        li = max(0, c - 100)
        hi = min(nt, c + 100 + 1)
        x = np.arange(li, hi)
        # compute s_top for this time range
        n0 = np.zeros(((hi - li), 3))
        for ti, jt in enumerate(ts.time[start_ind:end_ind][li:hi]):
            aa.set_jultime(jt)
            s.compute(aa)
            n0[ti] = s.get_crds('top', ncrd=3)

        # get the positions of feeds
        feedpos = ts['feedpos'][:]

        # create data to save the solved gain for each feed
        local_fp_inds = mpiutil.mpilist(
            list(itertools.product(range(nf), range(2))))
        lgain = np.zeros((len(local_fp_inds), nfeed),
                         dtype=Gain.dtype)  # gain for each feed
        lgain[:] = complex(np.nan, np.nan)

        for ii, (fi, pi) in enumerate(local_fp_inds):
            data = np.abs(Gain[li:hi, fi, pi, :]).T
            # flag outliers
            median = np.ma.median(data, axis=0)
            abs_diff = np.ma.abs(data - median[np.newaxis, :])
            mad = np.ma.median(abs_diff, axis=0) / 0.6745
            with warnings.catch_warnings():
                warnings.filterwarnings(
                    'ignore', 'invalid value encountered in greater')
                warnings.filterwarnings(
                    'ignore', 'invalid value encountered in greater_equal')
                warnings.filterwarnings(
                    'ignore', 'invalid value encountered in absolute')
                data = np.where(abs_diff > 3.0 * mad[np.newaxis, :], np.nan,
                                data)
            # gaussian/sinc fit
            for idx in range(nfeed):
                y = data[idx]
                inds = np.where(np.isfinite(y))[0]
                if len(inds) > 0.75 * len(y):
                    # get the best estimate of the central val
                    cval = y[inds[np.argmin(np.abs(inds - c))]]
                    try:
                        # gaussian fit
                        # popt, pcov = optimize.curve_fit(fg, x[inds], y[inds], p0=(cval, c, 90, 0))
                        # sinc function seems fit better
                        popt, pcov = optimize.curve_fit(fc,
                                                        x[inds],
                                                        y[inds],
                                                        p0=(cval, c, 1.0e-2,
                                                            0))
                        # print 'popt:', popt
                    except RuntimeError:
                        print 'curve_fit failed for fi = %d, pol = %s, feed = %d' % (
                            fi, ['xx', 'yy'][pi], feedno[idx])
                        continue

                    An = y / fc(popt[1], *popt)  # the beam profile
                    ui = (feedpos[idx] - feedpos[0]) * (
                        1.0e6 * freq[fi]
                    ) / const.c  # position of this feed (relative to the first feed) in unit of wavelength
                    exp_factor = np.exp(2.0J * np.pi * np.dot(n0, ui))
                    Ae = An * exp_factor
                    Gi = Gain[li:hi, fi, pi, idx]
                    # compute gain for this feed
                    lgain[ii,
                          idx] = np.dot(Ae[inds].conj(), Gi[inds]) / np.dot(
                              Ae[inds].conj(), Ae[inds])

        # gather local gain
        gain = mpiutil.gather_array(lgain, axis=0, root=None, comm=ts.comm)
        gain = gain.reshape(nf, 2, nfeed)

        # apply gain to vis
        for fi in range(nf):
            for pi in [pol.index('xx'), pol.index('yy')]:
                for bi, (fd1, fd2) in enumerate(ts['blorder'].local_data):
                    g1 = gain[fi, pi, feedno.index(fd1)]
                    g2 = gain[fi, pi, feedno.index(fd2)]
                    if np.isfinite(g1) and np.isfinite(g2):
                        ts.local_vis[:, fi, pi, bi] /= (g1 * np.conj(g2))
                    else:
                        # mask the un-calibrated vis
                        ts.local_vis_mask[:, fi, pi, bi] = True

        # convert vis from intensity unit to temperature unit in K
        if temperature_convert:
            factor = 1.0e-26 * (const.c**2 / (2 * const.k_B *
                                              (1.0e6 * freq)**2)
                                )  # NOTE: 1Jy = 1.0e-26 W m^-2 Hz^-1
            ts.local_vis[:] *= factor[np.newaxis, :, np.newaxis, np.newaxis]
            ts.vis.attrs['unit'] = 'K'

        # save gain to file
        if mpiutil.rank0 and save_gain:
            if tag_output_iter:
                gain_file = output_path(gain_file, iteration=self.iteration)
            else:
                gain_file = output_path(gain_file)
            with h5py.File(gain_file, 'w') as f:
                # save Gain
                Gain = f.create_dataset('Gain', data=Gain)
                Gain.attrs['dim'] = 'time, freq, pol, feed'
                Gain.attrs['time'] = ts.time[start_ind:end_ind]
                Gain.attrs['freq'] = freq
                Gain.attrs['pol'] = np.array(['xx', 'yy'])
                Gain.attrs['feed'] = np.array(feedno)
                # save gain
                gain = f.create_dataset('gain', data=gain)
                gain.attrs['dim'] = 'freq, pol, feed'
                gain.attrs['freq'] = freq
                gain.attrs['pol'] = np.array(['xx', 'yy'])
                gain.attrs['feed'] = np.array(feedno)

        return super(PsCal, self).process(ts)
Exemple #17
0
    def process(self, ts):

        excl_auto = self.params['excl_auto']
        plot_stats = self.params['plot_stats']
        fig_prefix = self.params['fig_name']
        rotate_xdate = self.params['rotate_xdate']
        tag_output_iter = self.params['tag_output_iter']

        ts.redistribute('baseline')

        if ts.local_vis_mask.ndim == 3:  # RawTimestream
            if excl_auto:
                bl = ts.local_bl
                vis_mask = ts.local_vis_mask[:, :, bl[:, 0] != bl[:, 1]].copy()
            else:
                vis_mask = ts.local_vis_mask.copy()
            nt, nf, lnb = vis_mask.shape
        elif ts.local_vis_mask.ndim == 4:  # Timestream
            # suppose masks are the same for all 4 pols
            if excl_auto:
                bl = ts.local_bl
                vis_mask = ts.local_vis_mask[:, :, 0,
                                             bl[:, 0] != bl[:, 1]].copy()
            else:
                vis_mask = ts.local_vis_mask[:, :, 0].copy()
            nt, nf, lnb = vis_mask.shape
        else:
            raise RuntimeError('Incorrect vis_mask shape %s' %
                               ts.local_vis_mask.shape)

        # total number of bl
        nb = mpiutil.allreduce(lnb, comm=ts.comm)

        # un-mask ns-on positions
        if 'ns_on' in ts.iterkeys():
            vis_mask[ts['ns_on'][:]] = False

        # statistics along time axis
        time_mask = np.sum(vis_mask, axis=(1, 2)).reshape(-1, 1)
        # gather local array to rank0
        time_mask = mpiutil.gather_array(time_mask,
                                         axis=1,
                                         root=0,
                                         comm=ts.comm)
        if mpiutil.rank0:
            time_mask = np.sum(time_mask, axis=1)

        # statistics along time axis
        freq_mask = np.sum(vis_mask, axis=(0, 2)).reshape(-1, 1)
        # gather local array to rank0
        freq_mask = mpiutil.gather_array(freq_mask,
                                         axis=1,
                                         root=0,
                                         comm=ts.comm)
        if mpiutil.rank0:
            freq_mask = np.sum(freq_mask, axis=1)

        if plot_stats and mpiutil.rank0:
            time_fig_name = '%s_%s.png' % (fig_prefix, 'time')
            if tag_output_iter:
                time_fig_name = output_path(time_fig_name,
                                            iteration=self.iteration)
            else:
                time_fig_name = output_path(time_fig_name)

            # plot time_mask
            plt.figure()
            fig, ax = plt.subplots()
            x_vals = np.array([
                (datetime.utcfromtimestamp(s) + timedelta(hours=8))
                for s in ts['sec1970'][:]
            ])
            xlabel = '%s' % x_vals[0].date()
            x_vals = mdates.date2num(x_vals)
            ax.plot(x_vals, 100 * time_mask / np.float(nf * nb))
            ax.xaxis_date()
            date_format = mdates.DateFormatter('%H:%M')
            ax.xaxis.set_major_formatter(date_format)
            if rotate_xdate:
                # set the x-axis tick labels to diagonal so it fits better
                fig.autofmt_xdate()
            else:
                # reduce the number of tick locators
                locator = MaxNLocator(nbins=6)
                ax.xaxis.set_major_locator(locator)
                ax.xaxis.set_minor_locator(AutoMinorLocator(2))

            ax.set_xlabel(xlabel)
            ax.set_ylabel(r'RFI (%)')
            plt.savefig(time_fig_name)
            plt.close()

            freq_fig_name = '%s_%s.png' % (fig_prefix, 'freq')
            if tag_output_iter:
                freq_fig_name = output_path(freq_fig_name,
                                            iteration=self.iteration)
            else:
                freq_fig_name = output_path(freq_fig_name)

            # plot freq_mask
            plt.figure()
            plt.plot(ts.freq[:], 100 * freq_mask / np.float(nt * nb))
            plt.grid(True)
            plt.xlabel(r'$\nu$ / MHz')
            plt.ylabel(r'RFI (%)')
            plt.savefig(freq_fig_name)
            plt.close()

        return super(Stats, self).process(ts)
            rot_map,
            rot=(0, 90, 0),
            xsize=400,
            half_sky=True,
            return_projected_map=True)[72:328,
                                       72:328]  # rot to make NCP at the center
        dataset[li, j, :, :, 0] = rec_img.data  # only data of the masked array

# save dataset of rank0
if mpiutil.rank0:
    if not os.path.isdir('./training_dataset_256x256_aug'):
        os.mkdir('./training_dataset_256x256_aug')
    np.save('./training_dataset_256x256_aug/dataset_rank0.npy', dataset)

# gather dataset to rank 0
dataset = mpiutil.gather_array(dataset, axis=0, root=0)

if mpiutil.rank0:
    dataset = dataset.reshape((N * m, n, n, 1))
    inds = np.arange(N * m)
    np.random.shuffle(inds[1:])  # leave 0 (the true NP) unchanged
    dataset = dataset[inds]  # randomly shuffle the datasets

    print dataset.shape

    train = dataset[:12000]  # train dataset
    val = dataset[12000:18000]  # validation dataset
    test = dataset[18000:]  # test dataset

    if not os.path.isdir('./training_dataset_256x256_aug'):
        os.mkdir('./training_dataset_256x256_aug')
Exemple #19
0
    def process(self, rt):

        assert isinstance(
            rt, RawTimestream
        ), '%s only works for RawTimestream object currently' % self.__class__.__name__

        channel = self.params['channel']
        sigma = self.params['sigma']
        mask_near = max(0, int(self.params['mask_near']))

        rt.redistribute(0)  # make time the dist axis

        auto_inds = np.where(
            rt.bl[:, 0] == rt.bl[:,
                                 1])[0].tolist()  # inds for auto-correlations
        channels = [rt.bl[ai, 0] for ai in auto_inds]  # all chosen channels
        if channel is not None:
            if channel in channels:
                bl_ind = auto_inds[channels.index(channel)]
            else:
                bl_ind = auto_inds[0]
                if mpiutil.rank0:
                    print 'Warning: Required channel %d doen not in the data, use channel %d instead' % (
                        channel, rt.bl[bl_ind, 0])
        else:
            bl_ind = auto_inds[0]
        # move the chosen channel to the first
        auto_inds.remove(bl_ind)
        auto_inds = [bl_ind] + auto_inds

        for bl_ind in auto_inds:
            this_chan = rt.bl[bl_ind, 0]  # channel of this bl_ind
            vis = np.ma.array(rt.local_vis[:, :, bl_ind].real,
                              mask=rt.local_vis_mask[:, :, bl_ind])
            cnt = vis.count()  # number of not masked vals
            total_cnt = mpiutil.allreduce(cnt)
            vis_shp = rt.vis.shape
            ratio = float(total_cnt) / np.prod(
                (vis_shp[0], vis_shp[1]))  # ratio of un-masked vals
            if ratio < 0.5:  # too many masked vals
                if mpiutil.rank0:
                    warnings.warn(
                        'Too many masked values for auto-correlation of Channel: %d, does not use it'
                        % this_chan)
                continue

            tt_mean = mpiutil.gather_array(np.ma.mean(vis, axis=-1).filled(0),
                                           root=None)
            df = np.diff(tt_mean, axis=-1)
            pdf = np.where(df > 0, df, 0)
            pinds = np.where(pdf > pdf.mean() + sigma * pdf.std())[0]
            pinds = pinds + 1
            pinds1 = [pinds[0]]
            for pi in pinds[1:]:
                if pi - pinds1[-1] > 1:
                    pinds1.append(pi)
            pinds = np.array(pinds1)
            pT = Counter(
                np.diff(pinds)).most_common(1)[0][0]  # period of pinds

            ndf = np.where(df < 0, df, 0)
            ninds = np.where(ndf < ndf.mean() - sigma * ndf.std())[0]
            ninds = ninds + 1
            ninds = ninds[::-1]
            ninds1 = [ninds[0]]
            for ni in ninds[1:]:
                if ni - ninds1[-1] < -1:
                    ninds1.append(ni)
            ninds = np.array(ninds1[::-1])
            nT = Counter(
                np.diff(ninds)).most_common(1)[0][0]  # period of ninds

            if pT != nT:  # failed to detect correct period
                if mpiutil.rank0:
                    warnings.warn(
                        'Failed to detect correct period for auto-correlation of Channel: %d, positive T %d != negative T %d, does not use it'
                        % (this_chan, pT, nT))
                continue
            else:
                period = pT

            ninds = ninds.reshape(-1, 1)
            dinds = (ninds - pinds).flatten()
            on_time = Counter(dinds[dinds > 0] % period).most_common(1)[0][0]
            off_time = Counter(-dinds[dinds < 0] % period).most_common(1)[0][0]

            if period != on_time + off_time:  # incorrect detect
                if mpiutil.rank0:
                    warnings.warn(
                        'Incorrect detect for auto-correlation of Channel: %d, period %d != on_time %d + off_time %d, does not use it'
                        % (this_chan, period, on_time, off_time))
                continue
            else:
                if 'noisesource' in rt.iterkeys():
                    if rt['noisesource'].shape[0] == 1:  # only 1 noise source
                        start, stop, cycle = rt['noisesource'][0, :]
                        int_time = rt.attrs['inttime']
                        true_on_time = np.round((stop - start) / int_time)
                        true_period = np.round(cycle / int_time)
                        if on_time != true_on_time and period != true_period:  # inconsistant with the record in the data
                            if mpiutil.rank0:
                                warnings.warn(
                                    'Detected noise source info is inconsistant with the record in the data for auto-correlation of Channel: %d: on_time %d != record_on_time %d, period != record_period %d, does not use it'
                                    % (this_chan, on_time, true_on_time,
                                       period, true_period))
                            continue
                    elif rt['noisesource'].shape[
                            0] >= 2:  # more than 1 noise source
                        if mpiutil.rank0:
                            warnings.warn(
                                'More than 1 noise source, do not know how to deal with this currently'
                            )

                # break if succeed
                break

        else:
            raise RuntimeError('Failed to detect noise source signal')

        if mpiutil.rank0:
            print 'Detected noise source: period = %d, on_time = %d, off_time = %d' % (
                period, on_time, off_time)
        num_period = np.int(np.ceil(len(tt_mean) / np.float(period)))
        tmp_ns_on = np.array(([True] * on_time + [False] * off_time) *
                             num_period)[:len(tt_mean)]
        on_start = Counter(pinds % period).most_common(1)[0][0]
        ns_on = np.roll(tmp_ns_on, on_start)

        # import matplotlib
        # matplotlib.use('Agg')
        # import matplotlib.pyplot as plt
        # plt.figure()
        # plt.plot(np.where(ns_on, np.nan, tt_mean))
        # # plt.plot(pinds, tt_mean[pinds], 'RI')
        # # plt.plot(ninds, tt_mean[ninds], 'go')
        # plt.savefig('df.png')
        # err

        ns_on1 = mpiarray.MPIArray.from_numpy_array(ns_on)

        rt.create_main_time_ordered_dataset('ns_on', ns_on1)
        rt['ns_on'].attrs['period'] = period
        rt['ns_on'].attrs['on_time'] = on_time
        rt['ns_on'].attrs['off_time'] = off_time

        # set vis_mask corresponding to ns_on
        on_inds = np.where(rt['ns_on'].local_data[:])[0]
        rt.local_vis_mask[on_inds] = True

        if mask_near > 0:
            on_inds = np.where(ns_on)[0]
            new_on_inds = on_inds.tolist()
            for i in xrange(1, mask_near + 1):
                new_on_inds = new_on_inds + (on_inds - i).tolist() + (
                    on_inds + i).tolist()
            new_on_inds = np.unique(new_on_inds)

            if rt['vis_mask'].distributed:
                start = rt.vis_mask.local_offset[0]
                end = start + rt.vis_mask.local_shape[0]
            else:
                start = 0
                end = rt.vis_mask.shape[0]
            global_inds = np.arange(start, end).tolist()
            new_on_inds = np.intersect1d(new_on_inds, global_inds)
            local_on_inds = [global_inds.index(i) for i in new_on_inds]
            rt.local_vis_mask[
                local_on_inds] = True  # set mask using global slicing

        return super(Detect, self).process(rt)
    def process(self, rt):

        assert isinstance(
            rt, RawTimestream
        ), '%s only works for RawTimestream object currently' % self.__class__.__name__

        channel = self.params['channel']
        sigma = self.params['sigma']
        mask_near = max(0, int(self.params['mask_near']))
        #================================================
        rt.FRB_cal = self.params['FRB_cal']
        ns_arr_file = self.params['ns_arr_file']
        ns_prop_file = self.params['ns_prop_file']
        num_noise = self.params['num_noise']
        change_reference_tol = self.params['change_reference_tol']
        abnormal_ns_save_file = 'ns_cal/abnormal_ns.npz'
        #================================================

        rt.redistribute(0)  # make time the dist axis

        #        time_span = rt.local_vis.shape[0]
        #        total_time_span = mpiutil.allreduce(time_span)
        total_time_span = rt['sec1970'].shape[0]

        auto_inds = np.where(
            rt.bl[:, 0] == rt.bl[:,
                                 1])[0].tolist()  # inds for auto-correlations
        channels = [rt.bl[ai, 0] for ai in auto_inds]  # all chosen channels
        if channel is not None:
            if channel in channels:
                bl_ind = auto_inds[channels.index(channel)]
            else:
                bl_ind = auto_inds[0]
                if mpiutil.rank0:
                    print 'Warning: Required channel %d doen not in the data, use channel %d instead' % (
                        channel, rt.bl[bl_ind, 0])
        else:
            bl_ind = auto_inds[0]
        # move the chosen channel to the first
        auto_inds.remove(bl_ind)
        auto_inds = [bl_ind] + auto_inds
        if rt.FRB_cal:
            #            ns_arr_end_time = mpiutil.bcast(rt['jul_date'][total_time_span - 1], root = mpiutil.size - 1)
            ns_arr_end_time = rt.attrs['sec1970'][0] + rt.attrs['inttime'] * (
                total_time_span - 1)
            if os.path.exists(output_path(ns_arr_file)) and not os.path.exists(
                    output_path(ns_prop_file)):
                filein = np.load(output_path(ns_arr_file))
                # add 0 to avoid a single float or int become an array
                ns_arr_pinds = filein['on_inds'] + 0
                ns_arr_ninds = filein['off_inds'] + 0
                ns_arr_pinds = list(ns_arr_pinds)
                ns_arr_ninds = list(ns_arr_ninds)
                ns_arr_len = filein['time_len'] + 0
                ns_arr_num = filein['ns_num'] + 0
                ns_arr_bl = filein['auto_chn'] + 0
                ns_arr_start_time = filein['start_time']
                #                overlap_index_span = np.around(np.float128(filein['end_time'] - mpiutil.bcast(rt['jul_date'][0], root = 0))*24.*3600./rt.attrs['inttime']) + 1
                overlap_index_span = np.around(
                    np.float128(filein['end_time'] - rt.attrs['sec1970'][0]) /
                    rt.attrs['inttime']) + 1
                if overlap_index_span > 0:
                    raise OverlapData(
                        'Overlap of data occured when trying to build up noise property file! In julian date, the end time of previous data is %.10f, while the start time of this data is %.10f. The overlap span in index is %d.'
                        % (filein['end_time'], rt.attrs['sec1970'][0],
                           overlap_index_span))
            elif not os.path.exists(output_path(ns_prop_file)):
                ns_arr_pinds = []
                ns_arr_ninds = []
                ns_arr_len = 0
                ns_arr_num = -1  # to distinguish the first process

                #                ns_arr_start_time = mpiutil.bcast(rt['jul_date'][0], root = 0)
                ns_arr_start_time = rt.attrs['sec1970'][0]

        if rt.FRB_cal and os.path.exists(output_path(ns_prop_file)):
            if mpiutil.rank0:
                print('Use existing ns property file %s to do calibration.' %
                      output_path(ns_prop_file))
            ns_prop_data = np.load(output_path(ns_prop_file))
            period = ns_prop_data['period'] + 0
            on_time = ns_prop_data['on_time'] + 0
            off_time = ns_prop_data['off_time'] + 0
            reference_time = ns_prop_data[
                'reference_time'] + 0  # in sec1970, is a start point of noise
            if 'lost_count' in ns_prop_data.files:
                lost_count_before = ns_prop_data['lost_count']
            else:
                lost_count_before = 0
            if 'added_count' in ns_prop_data.files:
                added_count_before = ns_prop_data['added_count']
            else:
                added_count_before = 0

#            this_time_start = mpiutil.bcast(rt['jul_date'][0], root=0)
#            skip_inds = int(np.around((this_time_start - reference_time)*86400.0/rt.attrs['inttime'])) # the number of index between the reference time and the start point
            this_time_start = rt.attrs['sec1970'][0]
            skip_inds = int(
                np.around(
                    (this_time_start - reference_time) / rt.attrs['inttime'])
            )  # the number of index between the reference time and the start point
            on_start = period - skip_inds % period
            #            if total_time_span < period:
            #                raise Exception('Time span of data %d is shorter than period %d!'%(total_time_span, period))
            if total_time_span <= on_start + on_time:
                raise NoNoisePoint(
                    'Calculated from previous data, this data contains no noise point or does not contain complete noise signal!'
                )
            # check whether there are lost points
            # only consider that the case that there are one lost point and only consider the on points
#================================================
            abnormal_count = -1
            lost_one_point = 0
            added_one_point = 0
            abnormal_list = []
            lost_one_list = []
            added_one_list = []
            lost_one_pos = []
            added_one_pos = []
            complete_period_num = (
                total_time_span - on_start - on_time -
                1) // period  # the number of periods in the data
            on_points = [
                on_start + i * period for i in range(complete_period_num + 1)
            ]
            off_points = [
                on_start + on_time + i * period
                for i in range(complete_period_num + 1)
            ]
            for bl_ind in auto_inds:
                this_chan = rt.bl[bl_ind, 0]  # channel of this bl_ind
                vis = np.ma.array(rt.local_vis[:, :, bl_ind].real,
                                  mask=rt.local_vis_mask[:, :, bl_ind])
                cnt = vis.count()  # number of not masked vals
                total_cnt = mpiutil.allreduce(cnt)
                vis_shp = rt.vis.shape
                ratio = float(total_cnt) / np.prod(
                    (vis_shp[0], vis_shp[1]))  # ratio of un-masked vals
                if ratio < 0.5:  # too many masked vals
                    continue

                if abnormal_count < 0:
                    abnormal_count = 0
                tt_mean = mpiutil.gather_array(
                    np.ma.mean(vis, axis=-1).filled(0),
                    root=None)  # mean for all freq, for a specific auto bl
                df = np.diff(tt_mean, axis=-1)
                pdf = np.where(df > 0, df, 0)
                pinds = np.where(pdf > pdf.mean() + sigma * pdf.std())[0]
                pinds = pinds + 1
                #====================================
                if len(pinds) == 0:  # no raise, might be badchn, continue
                    continue
#====================================
                pinds1 = [pinds[0]]
                for pi in pinds[1:]:
                    if pi - pinds1[-1] > 1:
                        pinds1.append(pi)
                pinds = np.array(pinds1)

                ndf = np.where(df < 0, df, 0)
                ninds = np.where(ndf < ndf.mean() - sigma * ndf.std())[0]
                ninds = ninds + 1
                ninds = ninds[::-1]
                #====================================
                if len(ninds) == 0:  # no raise, might be badchn, continue
                    continue
#====================================
                ninds1 = [ninds[0]]
                for ni in ninds[1:]:
                    if ni - ninds1[-1] < -1:
                        ninds1.append(ni)
                ninds = np.array(ninds1[::-1])
                cmp_signal, cmp_res = cmp_cd(on_points, off_points, pinds,
                                             ninds, period, 2. / 3)
                if cmp_signal == 'ok':  # normal
                    continue
                elif cmp_signal == 'abnormal':  # abnormal
                    abnormal_count += 1
                    abnormal_list += [this_chan]
                elif cmp_signal == 'lost':  # lost point
                    lost_one_point += 1
                    lost_one_list += [this_chan]
                    lost_one_pos += [cmp_res]
                    continue
                elif cmp_signal == 'add':  # added point
                    added_one_point += 1
                    added_one_list += [this_chan]
                    added_one_pos += [cmp_res]
                    continue
                else:
                    raise Exception('Unknown comparison signal!')

#                if on_start in pinds or on_start + on_time in ninds:
#                    # to avoid the effect of interference
#                    continue
#                elif on_start - 1 in pinds or on_start + on_time - 1 in ninds:
#                    lost_one_point += 1
#                    lost_one_list += [this_chan]
#                    continue
#                elif on_start - 1 < 0 and on_start + period - 1 in pinds:
#                    lost_one_point += 1
#                    lost_one_list += [this_chan]
#                    continue
#                else:
#                    abnormal_count += 1
#                    abnormal_list += [this_chan]

            if abnormal_count < 0:
                raise NoNoisePoint(
                    'No noise points are detected from this data or the data contains too many masked points!'
                )
            elif abnormal_count > len(auto_inds) / 3:
                if mpiutil.rank0:
                    np.savez(output_path(abnormal_ns_save_file),
                             on_inds=pinds,
                             off_inds=ninds,
                             on_start=on_start,
                             period=period,
                             on_time=on_time)
                mpiutil.barrier()
                raise AbnormalPoints(
                    'Something rather than one lost point happened. The expected start point is %d, period %d, on_time %d, but the pinds and ninds are: '
                    % (on_start, period, on_time), pinds, ninds)
            elif lost_one_point > 2 * len(auto_inds) / 3:
                uniques, counts = np.unique(lost_one_pos, return_counts=True)
                maxcount = np.argmax(counts)
                if counts[maxcount] > 2 * len(auto_inds) / 3:
                    lost_position = uniques[maxcount]
                else:
                    raise AbnormalPoints(
                        'More than 2/3 baselines have detected lost points but do not have a universal lost position, some unexpected error happened!\nChannels that probably lost one point and the position: %s'
                        % str(zip(lost_one_list, lost_one_pos)))
                if mpiutil.rank0:
                    if lost_position == 0:
                        warnings.warn(
                            'One lost point before the data is detected!',
                            LostPointBegin)
                    else:
                        warnings.warn(
                            'One lost point before index %d is detected!' %
                            on_points[lost_position], LostPointMiddle)
#                on_start = on_start - 1
                if lost_position == 0:
                    lost_count_before += 1
                elif lost_count_before == 0:
                    lost_count_before += 1
                on_points = np.array(on_points)
                on_points[lost_position:] = on_points[lost_position:] - 1
                off_points = np.array(off_points)
                off_points[lost_position:] = off_points[lost_position:] - 1
                if on_points[0] < 0:
                    on_points = on_points[1:]
                    off_points = off_points[1:]
                if mpiutil.rank0:
                    if lost_count_before >= change_reference_tol:
                        #                        reference_time -= 1/86400.0*rt.attrs['inttime']
                        reference_time -= rt.attrs['inttime']
                        warnings.warn(
                            'Move the reference time one index earlier to compensate!',
                            ChangeReferenceTime)
                        np.savez(output_path(ns_prop_file),
                                 period=period,
                                 on_time=on_time,
                                 off_time=off_time,
                                 reference_time=reference_time,
                                 lost_count=0,
                                 added_count=0)
                    else:
                        warnings.warn(
                            'The number of recorded lost points was %d while tolerance is %d. Do not change the reference time.'
                            % (lost_count_before, change_reference_tol))
                        np.savez(output_path(ns_prop_file),
                                 period=period,
                                 on_time=on_time,
                                 off_time=off_time,
                                 reference_time=reference_time,
                                 lost_count=lost_count_before,
                                 added_count=0)
#                mpiutil.barrier()
            elif added_one_point > 2 * len(auto_inds) / 3:
                uniques, counts = np.unique(added_one_pos, return_counts=True)
                maxcount = np.argmax(counts)
                if counts[maxcount] > 2 * len(auto_inds) / 3:
                    added_position = uniques[maxcount]
                else:
                    raise AbnormalPoints(
                        'More than 2/3 baselines have detected additional points but do not have a universal adding position, some unexpected error happened!\nChannels that probably added one point and the position: %s'
                        % str(zip(added_one_list, added_one_pos)))
                if mpiutil.rank0:
                    if added_position == 0:
                        warnings.warn(
                            'One additional point before the data is detected!',
                            AdditionalPointBegin)
                    else:
                        warnings.warn(
                            'One additional point before index %d is detected!'
                            % on_points[added_position], AdditionalPointMiddle)
                if added_position == 0:
                    added_count_before += 1
                elif added_count_before == 0:
                    added_count_before += 1
#                on_start = on_start - 1
                on_points = np.array(on_points)
                on_points[added_position:] = on_points[added_position:] + 1
                off_points = np.array(off_points)
                off_points[added_position:] = off_points[added_position:] + 1
                if off_points[-1] >= total_time_span:
                    on_points = on_points[:-1]
                    off_points = off_points[:-1]
                if mpiutil.rank0:
                    if added_count_before >= change_reference_tol:
                        warnings.warn(
                            'Move the reference time one index later to compensate!',
                            ChangeReferenceTime)
                        #                        reference_time += 1/86400.0*rt.attrs['inttime']
                        reference_time += rt.attrs['inttime']
                        np.savez(output_path(ns_prop_file),
                                 period=period,
                                 on_time=on_time,
                                 off_time=off_time,
                                 reference_time=reference_time,
                                 lost_count=0,
                                 added_count=0)
                    else:
                        warnings.warn(
                            'The number of recorded added points was %d while tolerance is %d. Do not change the reference time.'
                            % (added_count_before, change_reference_tol))
                        np.savez(output_path(ns_prop_file),
                                 period=period,
                                 on_time=on_time,
                                 off_time=off_time,
                                 reference_time=reference_time,
                                 lost_count=0,
                                 added_count=added_count_before)
#                mpiutil.barrier()
            elif lost_one_point > 0 or abnormal_count > 0 or added_one_point > 0:
                if mpiutil.rank0:
                    np.savez(output_path(ns_prop_file),
                             period=period,
                             on_time=on_time,
                             off_time=off_time,
                             reference_time=reference_time,
                             lost_count=0,
                             added_count=0)
                    warnings.warn(
                        'Abnormal points are detected for some channel, number of abnormal bl is %d, number of channel that probably lost one point is %d, number of channel that probably added one point is %d'
                        % (abnormal_count, lost_one_point, added_one_point),
                        DetectedLostAddedAbnormal)
                    warnings.warn(
                        'Abnomal channels: %s\nChannels that probably lost one point: %s\nChannels that probably added one point: %s'
                        % (str(abnormal_list), str(lost_one_list),
                           str(added_one_list)), DetectedLostAddedAbnormal)
            else:
                if mpiutil.rank0:
                    np.savez(output_path(ns_prop_file),
                             period=period,
                             on_time=on_time,
                             off_time=off_time,
                             reference_time=reference_time,
                             lost_count=0,
                             added_count=0)

#================================================
            if mpiutil.rank0:
                print 'Noise source: period = %d, on_time = %d, off_time = %d' % (
                    period, on_time, off_time)
#            num_period = np.int(np.ceil(total_time_span / np.float(period)))
#            ns_on = np.array([False] * on_start + ([True] * on_time + [False] * off_time) * num_period)[:total_time_span]
            ns_on = np.array([False] * total_time_span)
            for i, j in zip(on_points, off_points):
                ns_on[i:j] = True

#            if mpiutil.rank0:
#                np.save('ns_on', ns_on)
        elif not rt.FRB_cal or ns_arr_num < num_noise:

            #            min_inds = 0 # min_inds = min(len(pinds), len(ninds), min_inds) if min_inds != 0 else min(len(pinds), len(ninds))
            ns_num_add = []
            for ns_arr_index, bl_ind in enumerate(auto_inds):
                this_chan = rt.bl[bl_ind, 0]  # channel of this bl_ind
                vis = np.ma.array(rt.local_vis[:, :, bl_ind].real,
                                  mask=rt.local_vis_mask[:, :, bl_ind])
                cnt = vis.count()  # number of not masked vals
                total_cnt = mpiutil.allreduce(cnt)
                vis_shp = rt.vis.shape
                ratio = float(total_cnt) / np.prod(
                    (vis_shp[0], vis_shp[1]))  # ratio of un-masked vals
                if ratio < 0.5:  # too many masked vals
                    if rt.FRB_cal and ns_arr_num == -1:
                        ns_arr_pinds += [np.array([])]
                        ns_arr_ninds += [np.array([])]
                    if mpiutil.rank0:
                        warnings.warn(
                            'Too many masked values for auto-correlation of Channel: %d, does not use it'
                            % this_chan)
                    continue
                tt_mean = mpiutil.gather_array(
                    np.ma.mean(vis, axis=-1).filled(0),
                    root=None)  # mean for all freq, for a specific auto bl
                df = np.diff(tt_mean, axis=-1)
                pdf = np.where(df > 0, df, 0)
                pinds = np.where(pdf > pdf.mean() + sigma * pdf.std())[0]
                #====================================
                if len(pinds) == 0:  # no raise, might be badchn, continue
                    if rt.FRB_cal and ns_arr_num == -1:
                        ns_arr_pinds += [np.array([])]
                        ns_arr_ninds += [np.array([])]
                    if mpiutil.rank0:
                        warnings.warn(
                            'No noise on signal is detected for Channel %d, it may be bad channel.'
                            % this_chan)
                    continue
#====================================
                pinds = pinds + 1
                pinds1 = [pinds[0]]
                for pi in pinds[1:]:
                    if pi - pinds1[-1] > 1:
                        pinds1.append(pi)
                pinds = np.array(pinds1)
                pT = Counter(
                    np.diff(pinds)).most_common(1)[0][0]  # period of pinds

                ndf = np.where(df < 0, df, 0)
                ninds = np.where(ndf < ndf.mean() - sigma * ndf.std())[0]
                ninds = ninds + 1
                ninds = ninds[::-1]
                #====================================
                if len(ninds) == 0:  # no raise, might be badchn, continue
                    if rt.FRB_cal and ns_arr_num == -1:
                        ns_arr_pinds += [np.array([])]
                        ns_arr_ninds += [np.array([])]
                    if mpiutil.rank0:
                        warnings.warn(
                            'No noise off signal is detected for Channel %d, it may be bad channel.'
                            % this_chan)
                    continue
#====================================
                ninds1 = [ninds[0]]
                for ni in ninds[1:]:
                    if ni - ninds1[-1] < -1:
                        ninds1.append(ni)
                ninds = np.array(ninds1[::-1])
                nT = Counter(
                    np.diff(ninds)).most_common(1)[0][0]  # period of ninds

                ns_num_add += [min(len(pinds), len(ninds))]
                if rt.FRB_cal:
                    if ns_arr_num == -1:
                        ns_arr_pinds += [pinds]
                        ns_arr_ninds += [ninds]
                    else:
                        ns_arr_pinds[ns_arr_index] = np.concatenate(
                            [ns_arr_pinds[ns_arr_index], pinds + ns_arr_len])
                        ns_arr_ninds[ns_arr_index] = np.concatenate(
                            [ns_arr_ninds[ns_arr_index], ninds + ns_arr_len])
#==============================================
# continue for non-FRB case
                if pT != nT:  # failed to detect correct period
                    if mpiutil.rank0:
                        warnings.warn(
                            'Failed to detect correct period for auto-correlation of Channel: %d, positive T %d != negative T %d, does not use it'
                            % (this_chan, pT, nT))
                    continue
                else:
                    period = pT

                ninds = ninds.reshape(-1, 1)
                dinds = (ninds - pinds).flatten()
                on_time = Counter(dinds[dinds > 0] %
                                  period).most_common(1)[0][0]
                off_time = Counter(-dinds[dinds < 0] %
                                   period).most_common(1)[0][0]

                if period != on_time + off_time:  # incorrect detect
                    if mpiutil.rank0:
                        warnings.warn(
                            'Incorrect detect for auto-correlation of Channel: %d, period %d != on_time %d + off_time %d, does not use it'
                            % (this_chan, period, on_time, off_time))
                    continue
                else:
                    if 'noisesource' in rt.iterkeys():
                        if rt['noisesource'].shape[
                                0] == 1:  # only 1 noise source
                            start, stop, cycle = rt['noisesource'][0, :]
                            int_time = rt.attrs['inttime']
                            true_on_time = np.round((stop - start) / int_time)
                            true_period = np.round(cycle / int_time)
                            if on_time != true_on_time and period != true_period:  # inconsistant with the record in the data
                                if mpiutil.rank0:
                                    warnings.warn(
                                        'Detected noise source info is inconsistant with the record in the data for auto-correlation of Channel: %d: on_time %d != record_on_time %d, period != record_period %d, does not use it'
                                        % (this_chan, on_time, true_on_time,
                                           period, true_period))
                                continue
                        elif rt['noisesource'].shape[
                                0] >= 2:  # more than 1 noise source
                            if mpiutil.rank0:
                                warnings.warn(
                                    'More than 1 noise source, do not know how to deal with this currently'
                                )

                    # break if succeed
                    if not rt.FRB_cal:  # for FRB case, record all baseline
                        break

            else:
                if not rt.FRB_cal:
                    raise DetectNoiseFailure(
                        'Failed to detect noise source signal')

            if mpiutil.rank0:
                print 'Detected noise source: period = %d, on_time = %d, off_time = %d' % (
                    period, on_time, off_time)
            on_start = Counter(pinds % period).most_common(1)[0][0]
            num_period = np.int(np.ceil(len(tt_mean) / np.float(period)))
            ns_on = np.array([False] * on_start +
                             ([True] * on_time + [False] * off_time) *
                             num_period)[:len(tt_mean)]
            #==============================================
            if rt.FRB_cal:
                ns_arr_len += total_time_span
                #                ns_arr_from_time = np.float128(ns_arr_end_time - ns_arr_start_time)*24.*3600./rt.attrs['inttime'] + 1
                ns_arr_from_time = np.around(
                    np.float128(ns_arr_end_time - ns_arr_start_time) /
                    rt.attrs['inttime']) + 1
                if ns_arr_len == ns_arr_from_time:
                    pass
                else:
                    raise IncontinuousData(
                        'Incontinuous data. Index span calculated from time is %.2f, while the sum of array length is %d! Can not deal with incontinuous data at present!'
                        % (ns_arr_from_time, ns_arr_len))
#                    print('Detected incontinuous data, use index span %d calculated from time instead of the sum of array length %d!'%(ns_arr_from_time, ns_arr_len))
#                    ns_arr_len = ns_arr_from_time
                if ns_arr_num < 0:
                    ns_arr_num = 0
#                ns_arr_num += min_inds
                ns_arr_num += np.around(np.average(ns_num_add))
                ns_arr_bl = rt.bl[auto_inds, 0]
            if mpiutil.rank0 and rt.FRB_cal:
                np.savez(output_path(ns_arr_file),
                         on_inds=ns_arr_pinds,
                         off_inds=ns_arr_ninds,
                         time_len=ns_arr_len,
                         ns_num=ns_arr_num,
                         auto_chn=ns_arr_bl,
                         start_time=ns_arr_start_time,
                         end_time=ns_arr_end_time)
                if ns_arr_num < num_noise:
                    raise NoiseNotEnough(
                        'Number of noise points %d is not enough for calibration(need %d), wait for next file!'
                        % (ns_arr_num, num_noise))
#            mpiutil.barrier()
#=================================================================
        if rt.FRB_cal and (not os.path.exists(
                output_path(ns_prop_file))) and ns_arr_num >= num_noise:
            if mpiutil.rank0:
                print(
                    'Got %d noise points (need %d) to build up noise property file!'
                    % (ns_arr_num, num_noise))
            for ns_arr_index, (pinds, ninds) in enumerate(
                    zip(ns_arr_pinds, ns_arr_ninds)):
                if len(pinds) < num_noise * 2. / 3. or len(
                        ninds) < num_noise * 2. / 3.:
                    print(
                        'Channel %d does not have enough noise points(%d, need %d) for calibration. Do not use it.'
                        % (ns_arr_bl[ns_arr_index], len(pinds),
                           int(2 * num_noise / 3.)))
                    continue
                pT = Counter(
                    np.diff(pinds)).most_common(1)[0][0]  # period of pinds
                nT = Counter(
                    np.diff(ninds)).most_common(1)[0][0]  # period of ninds

                #=================================================================
                if pT != nT:  # failed to detect correct period
                    if mpiutil.rank0:
                        warnings.warn(
                            'Failed to detect correct period for auto-correlation of Channel: %d, positive T %d != negative T %d, does not use it'
                            % (ns_arr_bl[ns_arr_index], pT, nT))
                    continue
                else:
                    period = pT

                ninds = ninds.reshape(-1, 1)
                dinds = (ninds - pinds).flatten()
                on_time = Counter(dinds[dinds > 0] %
                                  period).most_common(1)[0][0]
                off_time = Counter(-dinds[dinds < 0] %
                                   period).most_common(1)[0][0]

                if period != on_time + off_time:  # incorrect detect
                    if mpiutil.rank0:
                        warnings.warn(
                            'Incorrect detect for auto-correlation of Channel: %d, period %d != on_time %d + off_time %d, does not use it'
                            % (ns_arr_bl[ns_arr_index], period, on_time,
                               off_time))
                    continue
                else:
                    if 'noisesource' in rt.iterkeys():
                        if rt['noisesource'].shape[
                                0] == 1:  # only 1 noise source
                            start, stop, cycle = rt['noisesource'][0, :]
                            int_time = rt.attrs['inttime']
                            true_on_time = np.round((stop - start) / int_time)
                            true_period = np.round(cycle / int_time)
                            if on_time != true_on_time and period != true_period:  # inconsistant with the record in the data
                                if mpiutil.rank0:
                                    warnings.warn(
                                        'Detected noise source info is inconsistant with the record in the data for auto-correlation of Channel: %d: on_time %d != record_on_time %d, period != record_period %d, does not use it'
                                        % (this_chan, on_time, true_on_time,
                                           period, true_period))
                                continue
                        elif rt['noisesource'].shape[
                                0] >= 2:  # more than 1 noise source
                            if mpiutil.rank0:
                                warnings.warn(
                                    'More than 1 noise source, do not know how to deal with this currently'
                                )

                    # break if succeed

                    break

            else:
                raise DetectNoiseFailure(
                    'Failed to detect noise source signal')

            if mpiutil.rank0:
                print 'Detected noise source: period = %d, on_time = %d, off_time = %d' % (
                    period, on_time, off_time)
            on_start = Counter(pinds % period).most_common(1)[0][0]
            this_time_len = total_time_span
            first_ind = ns_arr_len - this_time_len
            skip_inds = first_ind - on_start
            on_start = period - skip_inds % period
            num_period = np.int(np.ceil(total_time_span / np.float(period)))
            ns_on = np.array([False] * on_start +
                             ([True] * on_time + [False] * off_time) *
                             num_period)[:total_time_span]

        # import matplotlib
        # matplotlib.use('Agg')
        # import matplotlib.pyplot as plt
        # plt.figure()
        # plt.plot(np.where(ns_on, np.nan, tt_mean))
        # # plt.plot(pinds, tt_mean[pinds], 'RI')
        # # plt.plot(ninds, tt_mean[ninds], 'go')
        # plt.savefig('df.png')
        # err

        ns_on1 = mpiarray.MPIArray.from_numpy_array(ns_on)

        rt.create_main_time_ordered_dataset('ns_on', ns_on1)
        rt['ns_on'].attrs['period'] = period
        rt['ns_on'].attrs['on_time'] = on_time
        rt['ns_on'].attrs['off_time'] = off_time

        if (not rt['jul_date'][on_start] is None) and not os.path.exists(
                output_path(ns_prop_file)):
            #            np.savez(output_path(ns_prop_file), period = period, on_time = on_time, off_time = off_time, reference_time = np.float128(rt['jul_date'][on_start]))
            np.savez(
                output_path(ns_prop_file),
                period=period,
                on_time=on_time,
                off_time=off_time,
                reference_time=np.float128(rt.attrs['sec1970'][0] +
                                           on_start * rt.attrs['inttime']))
            #            mpiutil.barrier()
            if mpiutil.rank0:
                print('Save noise property file to %s' %
                      output_path(ns_prop_file))

        # set vis_mask corresponding to ns_on
        on_inds = np.where(rt['ns_on'].local_data[:])[0]
        rt.local_vis_mask[on_inds] = True

        if mask_near > 0:
            on_inds = np.where(ns_on)[0]
            new_on_inds = on_inds.tolist()
            for i in xrange(1, mask_near + 1):
                new_on_inds = new_on_inds + (on_inds - i).tolist() + (
                    on_inds + i).tolist()
            new_on_inds = np.unique(new_on_inds)

            if rt['vis_mask'].distributed:
                start = rt.vis_mask.local_offset[0]
                end = start + rt.vis_mask.local_shape[0]
            else:
                start = 0
                end = rt.vis_mask.shape[0]
            global_inds = np.arange(start, end).tolist()
            new_on_inds = np.intersect1d(new_on_inds, global_inds)
            local_on_inds = [global_inds.index(i) for i in new_on_inds]
            rt.local_vis_mask[
                local_on_inds] = True  # set mask using global slicing

        return super(Detect, self).process(rt)
Exemple #21
0
    def process(self, ts):

        assert isinstance(
            ts, Timestream
        ), '%s only works for Timestream object' % self.__class__.__name__

        calibrator = self.params['calibrator']
        catalog = self.params['catalog']
        vis_conj = self.params['vis_conj']
        zero_diag = self.params['zero_diag']
        span = self.params['span']
        reserve_high_gain = self.params['reserve_high_gain']
        plot_figs = self.params['plot_figs']
        fig_prefix = self.params['fig_name']
        tag_output_iter = self.params['tag_output_iter']
        save_src_vis = self.params['save_src_vis']
        src_vis_file = self.params['src_vis_file']
        subtract_src = self.params['subtract_src']
        replace_with_src = self.params['replace_with_src']
        apply_gain = self.params['apply_gain']
        save_gain = self.params['save_gain']
        save_phs_change = self.params['save_phs_change']
        gain_file = self.params['gain_file']
        # temperature_convert = self.params['temperature_convert']
        show_progress = self.params['show_progress']
        progress_step = self.params['progress_step']

        if save_src_vis or subtract_src or apply_gain or save_gain:
            pol_type = ts['pol'].attrs['pol_type']
            if pol_type != 'linear':
                raise RuntimeError('Can not do ps_cal for pol_type: %s' %
                                   pol_type)

            ts.redistribute('baseline')

            feedno = ts['feedno'][:].tolist()
            pol = [ts.pol_dict[p] for p in ts['pol'][:]]  # as string
            gain_pd = {
                'xx': 0,
                'yy': 1,
                0: 'xx',
                1: 'yy'
            }  # for gain related op
            bls = mpiutil.gather_array(ts.local_bl[:], root=None, comm=ts.comm)
            # # antpointing = np.radians(ts['antpointing'][-1, :, :]) # radians
            # transitsource = ts['transitsource'][:]
            # transit_time = transitsource[-1, 0] # second, sec1970
            # int_time = ts.attrs['inttime'] # second

            # get the calibrator
            try:
                s = calibrators.get_src(calibrator)
            except KeyError:
                if mpiutil.rank0:
                    print 'Calibrator %s is unavailable, available calibrators are:'
                    for key, d in calibrators.src_data.items():
                        print '%8s  ->  %12s' % (key, d[0])
                raise RuntimeError('Calibrator %s is unavailable')
            if mpiutil.rank0:
                print 'Try to calibrate with %s...' % s.src_name

            # get transit time of calibrator
            # array
            aa = ts.array
            aa.set_jultime(ts['jul_date'][0])  # the first obs time point
            next_transit = aa.next_transit(s)
            transit_time = a.phs.ephem2juldate(next_transit)  # Julian date
            # get time zone
            pattern = '[-+]?\d+'
            tz = re.search(pattern, ts.attrs['timezone']).group()
            tz = int(tz)
            local_next_transit = ephem.Date(
                next_transit + tz * ephem.hour)  # plus 8h to get Beijing time
            # if transit_time > ts['jul_date'][-1]:
            if transit_time > max(ts['jul_date'][-1], ts['jul_date'][:].max()):
                raise RuntimeError(
                    'Data does not contain local transit time %s of source %s'
                    % (local_next_transit, calibrator))

            # the first transit index
            transit_inds = [np.searchsorted(ts['jul_date'][:], transit_time)]
            # find all other transit indices
            aa.set_jultime(ts['jul_date'][0] + 1.0)
            transit_time = a.phs.ephem2juldate(
                aa.next_transit(s))  # Julian date
            cnt = 2
            while (transit_time <= ts['jul_date'][-1]):
                transit_inds.append(
                    np.searchsorted(ts['jul_date'][:], transit_time))
                aa.set_jultime(ts['jul_date'][0] + 1.0 * cnt)
                transit_time = a.phs.ephem2juldate(
                    aa.next_transit(s))  # Julian date
                cnt += 1

            if mpiutil.rank0:
                print 'transit ind of %s: %s, time: %s' % (
                    s.src_name, transit_inds, local_next_transit)

            ### now only use the first transit point to do the cal
            ### may need to improve in the future
            transit_ind = transit_inds[0]
            int_time = ts.attrs['inttime']  # second
            start_ind = transit_ind - np.int(span / int_time)
            end_ind = transit_ind + np.int(
                span /
                int_time) + 1  # plus 1 to make transit_ind is at the center

            start_ind = max(0, start_ind)
            end_ind = min(end_ind, ts.vis.shape[0])

            if vis_conj:
                ts.local_vis[:] = ts.local_vis.conj()

            nt = end_ind - start_ind
            t_inds = range(start_ind, end_ind)
            freq = ts.freq[:]  # MHz
            nf = len(freq)
            nlb = len(ts.local_bl[:])
            nfeed = len(feedno)
            tfp_inds = list(
                itertools.product(
                    t_inds, range(nf),
                    [pol.index('xx'), pol.index('yy')]))  # only for xx and yy
            ns, ss, es = mpiutil.split_all(len(tfp_inds), comm=ts.comm)
            # gather data to make each process to have its own data which has all bls
            for ri, (ni, si, ei) in enumerate(zip(ns, ss, es)):
                lvis = np.zeros((ni, nlb), dtype=ts.vis.dtype)
                lvis_mask = np.zeros((ni, nlb), dtype=ts.vis_mask.dtype)
                for ii, (ti, fi, pi) in enumerate(tfp_inds[si:ei]):
                    lvis[ii] = ts.local_vis[ti, fi, pi]
                    lvis_mask[ii] = ts.local_vis_mask[ti, fi, pi]
                # gather vis from all process for separate bls
                gvis = mpiutil.gather_array(lvis,
                                            axis=1,
                                            root=ri,
                                            comm=ts.comm)
                gvis_mask = mpiutil.gather_array(lvis_mask,
                                                 axis=1,
                                                 root=ri,
                                                 comm=ts.comm)
                if ri == mpiutil.rank:
                    tfp_linds = tfp_inds[si:ei]  # inds for this process
                    this_vis = gvis
                    this_vis_mask = gvis_mask
            del tfp_inds
            del lvis
            del lvis_mask
            tfp_len = len(tfp_linds)

            # lotl_mask = np.zeros((tfp_len, nfeed, nfeed), dtype=bool)
            cnan = complex(np.nan, np.nan)  # complex nan
            if save_src_vis or subtract_src:
                # save calibrator src vis
                lsrc_vis = np.full((tfp_len, nfeed, nfeed),
                                   cnan,
                                   dtype=ts.vis.dtype)
                if save_src_vis:
                    # save sky vis
                    lsky_vis = np.full((tfp_len, nfeed, nfeed),
                                       cnan,
                                       dtype=ts.vis.dtype)
                    # save outlier vis
                    lotl_vis = np.full((tfp_len, nfeed, nfeed),
                                       cnan,
                                       dtype=ts.vis.dtype)

            if apply_gain or save_gain:
                lGain = np.full((tfp_len, nfeed), cnan, dtype=ts.vis.dtype)

            # find indices mapping between Vmat and vis
            # bis = range(nbl)
            bis_conj = []  # indices that shold be conj
            mis = [
            ]  # indices in the nfeed x nfeed matrix by flatten it to a vector
            mis_conj = [
            ]  # indices (of conj vis) in the nfeed x nfeed matrix by flatten it to a vector
            for bi, (fdi, fdj) in enumerate(bls):
                ai, aj = feedno.index(fdi), feedno.index(fdj)
                mis.append(ai * nfeed + aj)
                if ai != aj:
                    bis_conj.append(bi)
                    mis_conj.append(aj * nfeed + ai)

            # construct visibility matrix for a single time, freq, pol
            Vmat = np.full((nfeed, nfeed), cnan, dtype=ts.vis.dtype)
            # get flus of the calibrator in the observing frequencies
            if show_progress and mpiutil.rank0:
                pg = progress.Progress(tfp_len, step=progress_step)
            for ii, (ti, fi, pi) in enumerate(tfp_linds):
                if show_progress and mpiutil.rank0:
                    pg.show(ii)
                # when noise on, just pass
                if 'ns_on' in ts.iterkeys() and ts['ns_on'][ti]:
                    continue
                # aa.set_jultime(ts['jul_date'][ti])
                # s.compute(aa)
                # get the topocentric coordinate of the calibrator at the current time
                # s_top = s.get_crds('top', ncrd=3)
                # aa.sim_cache(cat.get_crds('eq', ncrd=3)) # for compute bm_response and sim
                Vmat.flat[mis] = np.ma.array(
                    this_vis[ii], mask=this_vis_mask[ii]).filled(cnan)
                Vmat.flat[mis_conj] = np.ma.array(
                    this_vis[ii, bis_conj],
                    mask=this_vis_mask[ii, bis_conj]).conj().filled(cnan)

                if save_src_vis:
                    lsky_vis[ii] = Vmat

                # set invalid val to 0
                invalid = ~np.isfinite(Vmat)  # a bool array
                # if too many masks
                if np.where(invalid)[0].shape[0] > 0.3 * nfeed**2:
                    continue
                Vmat[invalid] = 0
                # if all are zeros
                if np.allclose(Vmat, 0.0):
                    continue

                # fill diagonal of Vmat to 0
                if zero_diag:
                    np.fill_diagonal(Vmat, 0)

                # initialize the outliers
                med = np.median(Vmat.real) + 1.0J * np.median(Vmat.imag)
                diff = Vmat - med
                S0 = np.where(
                    np.abs(diff) > 3.0 * rpca_decomp.MAD(Vmat), diff, 0)
                # stable PCA decomposition
                V0, S = rpca_decomp.decompose(Vmat,
                                              rank=1,
                                              S=S0,
                                              max_iter=200,
                                              threshold='hard',
                                              tol=1.0e-6,
                                              debug=False)

                # # find abnormal values in S
                # # first check diagonal elements
                # import pdb; pdb.set_trace()
                # svals = np.diag(S)
                # smed = np.median(svals.real) + 1.0J * np.median(svals.imag)
                # smad = rpca_decomp.MAD(svals)
                # # abnormal indices
                # abis =  np.where(np.abs(svals - smed) > 3.0 * smad)[0]
                # for abi in abis:
                #     lotl_mask[ii, abi, abi] = True
                # # then check non-diagonal elements
                # for rii in range(nfeed):
                #     for cii in range(nfeed):
                #         if rii == cii:
                #             continue
                #         rli = max(0, rii-2)
                #         rhi = min(nfeed, rii+3)
                #         cli = max(0, cii-2)
                #         chi = min(nfeed, cii+3)
                #         svals = np.array([ S[xi, yi] for xi in range(rli, rhi) for yi in range(cli, chi) if xi != yi ])
                #         smed = np.median(svals.real) + 1.0J * np.median(svals.imag)
                #         smad = rpca_decomp.MAD(svals)
                #         if np.abs(S[rii, cii] - smed) > 3.0 * smad:
                #             lotl_mask[ii, rii, cii] = True

                if save_src_vis or subtract_src:
                    lsrc_vis[ii] = V0
                    if save_src_vis:
                        lotl_vis[ii] = S

                # plot
                if plot_figs:
                    ind = ti - start_ind
                    # plot Vmat
                    plt.figure(figsize=(13, 5))
                    plt.subplot(121)
                    plt.imshow(Vmat.real,
                               aspect='equal',
                               origin='lower',
                               interpolation='nearest')
                    plt.colorbar(shrink=1.0)
                    plt.subplot(122)
                    plt.imshow(Vmat.imag,
                               aspect='equal',
                               origin='lower',
                               interpolation='nearest')
                    plt.colorbar(shrink=1.0)
                    fig_name = '%s_V_%d_%d_%s.png' % (fig_prefix, ind, fi,
                                                      pol[pi])
                    if tag_output_iter:
                        fig_name = output_path(fig_name,
                                               iteration=self.iteration)
                    else:
                        fig_name = output_path(fig_name)
                    plt.savefig(fig_name)
                    plt.close()
                    # plot V0
                    plt.figure(figsize=(13, 5))
                    plt.subplot(121)
                    plt.imshow(V0.real,
                               aspect='equal',
                               origin='lower',
                               interpolation='nearest')
                    plt.colorbar(shrink=1.0)
                    plt.subplot(122)
                    plt.imshow(V0.imag,
                               aspect='equal',
                               origin='lower',
                               interpolation='nearest')
                    plt.colorbar(shrink=1.0)
                    fig_name = '%s_V0_%d_%d_%s.png' % (fig_prefix, ind, fi,
                                                       pol[pi])
                    if tag_output_iter:
                        fig_name = output_path(fig_name,
                                               iteration=self.iteration)
                    else:
                        fig_name = output_path(fig_name)
                    plt.savefig(fig_name)
                    plt.close()
                    # plot S
                    plt.figure(figsize=(13, 5))
                    plt.subplot(121)
                    plt.imshow(S.real,
                               aspect='equal',
                               origin='lower',
                               interpolation='nearest')
                    plt.colorbar(shrink=1.0)
                    plt.subplot(122)
                    plt.imshow(S.imag,
                               aspect='equal',
                               origin='lower',
                               interpolation='nearest')
                    plt.colorbar(shrink=1.0)
                    fig_name = '%s_S_%d_%d_%s.png' % (fig_prefix, ind, fi,
                                                      pol[pi])
                    if tag_output_iter:
                        fig_name = output_path(fig_name,
                                               iteration=self.iteration)
                    else:
                        fig_name = output_path(fig_name)
                    plt.savefig(fig_name)
                    plt.close()
                    # plot N
                    N = Vmat - V0 - S
                    plt.figure(figsize=(13, 5))
                    plt.subplot(121)
                    plt.imshow(N.real,
                               aspect='equal',
                               origin='lower',
                               interpolation='nearest')
                    plt.colorbar(shrink=1.0)
                    plt.subplot(122)
                    plt.imshow(N.imag,
                               aspect='equal',
                               origin='lower',
                               interpolation='nearest')
                    plt.colorbar(shrink=1.0)
                    fig_name = '%s_N_%d_%d_%s.png' % (fig_prefix, ind, fi,
                                                      pol[pi])
                    if tag_output_iter:
                        fig_name = output_path(fig_name,
                                               iteration=self.iteration)
                    else:
                        fig_name = output_path(fig_name)
                    plt.savefig(fig_name)
                    plt.close()

                if apply_gain or save_gain:
                    # use v_ij = gi gj^* \int Ai Aj^* e^(2\pi i n \cdot uij) T(x) d^2n
                    # precisely, we shold have
                    # V0 = (lambda^2 * Sc / (2 k_B)) * gi gj^* Ai Aj^* e^(2\pi i n0 \cdot uij)
                    e, U = la.eigh(V0, eigvals=(nfeed - 1, nfeed - 1))
                    g = U[:, -1] * e[
                        -1]**0.5  # = \sqrt(lambda^2 * Sc / (2 k_B)) * gi Ai * e^(2\pi i n0 \cdot ui)
                    if g[0].real < 0:
                        g *= -1.0  # make all g[0] phase 0, instead of pi
                    lGain[ii] = g

                    # plot Gain
                    if plot_figs:
                        plt.figure()
                        plt.plot(feedno, g.real, 'b-', label='real')
                        plt.plot(feedno, g.real, 'bo')
                        plt.plot(feedno, g.imag, 'g-', label='imag')
                        plt.plot(feedno, g.imag, 'go')
                        plt.plot(feedno, np.abs(g), 'r-', label='abs')
                        plt.plot(feedno, np.abs(g), 'ro')
                        plt.xlim(feedno[0] - 1, feedno[-1] + 1)
                        yl, yh = plt.ylim()
                        plt.ylim(yl, yh + (yh - yl) / 5)
                        plt.xlabel('Feed number')
                        plt.legend()
                        fig_name = '%s_ants_%d_%d_%s.png' % (fig_prefix, ind,
                                                             fi, pol[pi])
                        if tag_output_iter:
                            fig_name = output_path(fig_name,
                                                   iteration=self.iteration)
                        else:
                            fig_name = output_path(fig_name)
                        plt.savefig(fig_name)
                        plt.close()

            # # apply outlier mask
            # nbl = len(bls)
            # lom = np.zeros((lotl_mask.shape[0], nbl), dtype=lotl_mask.dtype)
            # for bi, (fd1, fd2) in enumerate(bls):
            #     b1, b2 = feedno.index(fd1), feedno.index(fd2)
            #     lom[:, bi] = lotl_mask[:, b1, b2]
            # lom = mpiarray.MPIArray.wrap(lom, axis=0, comm=ts.comm)
            # lom = lom.redistribute(axis=1).local_array.reshape(nt, nf, 2, -1)
            # ts.local_vis_mask[start_ind:end_ind, :, pol.index('xx')] |= lom[:, :, 0]
            # ts.local_vis_mask[start_ind:end_ind, :, pol.index('yy')] |= lom[:, :, 1]

            # subtract the vis of calibrator from self.vis
            if subtract_src:
                nbl = len(bls)
                lv = np.zeros((lsrc_vis.shape[0], nbl), dtype=lsrc_vis.dtype)
                for bi, (fd1, fd2) in enumerate(bls):
                    b1, b2 = feedno.index(fd1), feedno.index(fd2)
                    lv[:, bi] = lsrc_vis[:, b1, b2]
                lv = mpiarray.MPIArray.wrap(lv, axis=0, comm=ts.comm)
                lv = lv.redistribute(axis=1).local_array.reshape(nt, nf, 2, -1)
                if replace_with_src:
                    ts.local_vis[start_ind:end_ind, :,
                                 pol.index('xx')] = lv[:, :, 0]
                    ts.local_vis[start_ind:end_ind, :,
                                 pol.index('yy')] = lv[:, :, 1]
                else:
                    if 'ns_on' in ts.iterkeys():
                        lv[ts['ns_on']
                           [start_ind:
                            end_ind]] = 0  # avoid ns_on signal to become nan
                    ts.local_vis[start_ind:end_ind, :,
                                 pol.index('xx')] -= lv[:, :, 0]
                    ts.local_vis[start_ind:end_ind, :,
                                 pol.index('yy')] -= lv[:, :, 1]

                del lv

            if not save_src_vis:
                if subtract_src:
                    del lsrc_vis
            else:
                if tag_output_iter:
                    src_vis_file = output_path(src_vis_file,
                                               iteration=self.iteration)
                else:
                    src_vis_file = output_path(src_vis_file)
                # create file and allocate space first by rank0
                if mpiutil.rank0:
                    with h5py.File(src_vis_file, 'w') as f:
                        # allocate space
                        shp = (nt, nf, 2, nfeed, nfeed)
                        f.create_dataset('sky_vis', shp, dtype=lsky_vis.dtype)
                        f.create_dataset('src_vis', shp, dtype=lsrc_vis.dtype)
                        f.create_dataset('outlier_vis',
                                         shp,
                                         dtype=lotl_vis.dtype)
                        # f.create_dataset('outlier_mask', shp, dtype=lotl_mask.dtype)
                        f.attrs['calibrator'] = calibrator
                        f.attrs['dim'] = 'time, freq, pol, feed, feed'
                        try:
                            f.attrs['time'] = ts.time[start_ind:end_ind]
                        except RuntimeError:
                            f.create_dataset('time',
                                             data=ts.time[start_ind:end_ind])
                            f.attrs['time'] = '/time'
                        f.attrs['freq'] = freq
                        f.attrs['pol'] = np.array(['xx', 'yy'])
                        f.attrs['feed'] = np.array(feedno)

                mpiutil.barrier()

                # write data to file
                for i in range(10):
                    try:
                        # NOTE: if write simultaneously, will loss data with processes distributed in several nodes
                        for ri in xrange(mpiutil.size):
                            if ri == mpiutil.rank:
                                with h5py.File(src_vis_file, 'r+') as f:
                                    for ii, (ti, fi,
                                             pi) in enumerate(tfp_linds):
                                        ti_ = ti - start_ind
                                        pi_ = gain_pd[pol[pi]]
                                        f['sky_vis'][ti_, fi,
                                                     pi_] = lsky_vis[ii]
                                        f['src_vis'][ti_, fi,
                                                     pi_] = lsrc_vis[ii]
                                        f['outlier_vis'][ti_, fi,
                                                         pi_] = lotl_vis[ii]
                                        # f['outlier_mask'][ti_, fi, pi_] = lotl_mask[ii]
                            mpiutil.barrier()
                        break
                    except IOError:
                        time.sleep(0.5)
                        continue
                else:
                    raise RuntimeError('Could not open file: %s...' %
                                       src_vis_file)

                del lsrc_vis
                del lsky_vis
                del lotl_vis
                # del lotl_mask

                mpiutil.barrier()

            if apply_gain or save_gain:
                # flag outliers in lGain along each feed
                lG_abs = np.full_like(lGain, np.nan, dtype=lGain.real.dtype)
                for i in range(lGain.shape[0]):
                    valid_inds = np.where(np.isfinite(lGain[i]))[0]
                    if len(valid_inds) > 3:
                        vabs = np.abs(lGain[i, valid_inds])
                        vmed = np.median(vabs)
                        vabs_diff = np.abs(vabs - vmed)
                        vmad = np.median(vabs_diff) / 0.6745
                        if reserve_high_gain:
                            # reserve significantly higher ones, flag only significantly lower ones
                            lG_abs[i, valid_inds] = np.where(
                                vmed - vabs > 3.0 * vmad, np.nan, vabs)
                        else:
                            # flag both significantly higher and lower ones
                            lG_abs[i, valid_inds] = np.where(
                                vabs_diff > 3.0 * vmad, np.nan, vabs)

                # choose data slice near the transit time
                li = max(start_ind, transit_ind - 10) - start_ind
                hi = min(end_ind, transit_ind + 10 + 1) - start_ind
                ci = transit_ind - start_ind  # center index for transit_ind
                # compute s_top for this time range
                n0 = np.zeros(((hi - li), 3))
                for ti, jt in enumerate(ts.time[start_ind:end_ind][li:hi]):
                    aa.set_jultime(jt)
                    s.compute(aa)
                    n0[ti] = s.get_crds('top', ncrd=3)
                if save_phs_change:
                    n0t = np.zeros((nt, 3))
                    for ti, jt in enumerate(ts.time[start_ind:end_ind]):
                        aa.set_jultime(jt)
                        s.compute(aa)
                        n0t[ti] = s.get_crds('top', ncrd=3)

                # get the positions of feeds
                feedpos = ts['feedpos'][:]

                # wrap and redistribute Gain and flagged G_abs
                Gain = mpiarray.MPIArray.wrap(lGain, axis=0, comm=ts.comm)
                Gain = Gain.redistribute(axis=1).reshape(
                    nt, nf, 2, None).redistribute(axis=0).reshape(
                        None, nf * 2 * nfeed).redistribute(axis=1)
                G_abs = mpiarray.MPIArray.wrap(lG_abs, axis=0, comm=ts.comm)
                G_abs = G_abs.redistribute(axis=1).reshape(
                    nt, nf, 2, None).redistribute(axis=0).reshape(
                        None, nf * 2 * nfeed).redistribute(axis=1)

                fpd_inds = list(
                    itertools.product(range(nf), range(2),
                                      range(nfeed)))  # only for xx and yy
                fpd_linds = mpiutil.mpilist(fpd_inds,
                                            method='con',
                                            comm=ts.comm)
                del fpd_inds
                # create data to save the solved gain for each feed
                lgain = np.full((len(fpd_linds), ), cnan,
                                dtype=Gain.dtype)  # gain for each feed
                if save_phs_change:
                    lphs = np.full((nt, len(fpd_linds)),
                                   np.nan,
                                   dtype=Gain.real.dtype
                                   )  # phase change with time for each feed

                # check for conj
                num_conj = 0
                for ii, (fi, pi, di) in enumerate(fpd_linds):
                    y = G_abs.local_array[li:hi, ii]
                    inds = np.where(np.isfinite(y))[0]
                    if len(inds) >= max(4, 0.5 * len(y)):
                        # get the approximate magnitude by averaging the central G_abs
                        # solve phase by least square fit
                        ui = (feedpos[di] - feedpos[0]) * (
                            1.0e6 * freq[fi]
                        ) / const.c  # position of this feed (relative to the first feed) in unit of wavelength
                        exp_factor = np.exp(2.0J * np.pi * np.dot(n0, ui))
                        ef = exp_factor
                        Gi = Gain.local_array[li:hi, ii]
                        e_phs = np.dot(ef[inds].conj(),
                                       Gi[inds] / y[inds]) / len(inds)
                        ea = np.abs(e_phs)
                        e_phs_conj = np.dot(ef[inds],
                                            Gi[inds] / y[inds]) / len(inds)
                        eac = np.abs(e_phs_conj)
                        if eac > ea:
                            num_conj += 1
                # reduce num_conj from all processes
                num_conj = mpiutil.allreduce(num_conj, comm=ts.comm)
                if num_conj > 0.5 * (nf * 2 * nfeed):  # 2 for 2 pols
                    if mpiutil.rank0:
                        print '!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!'
                        print '!!!   Detect data should be their conjugate...   !!!'
                        print '!!!   Correct it automatically...                !!!'
                        print '!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!'
                    mpiutil.barrier()
                    # correct vis
                    ts.local_vis[:] = ts.local_vis.conj()
                    # correct G
                    Gain.local_array[:] = Gain.local_array.conj()

                # solve for gain
                for ii, (fi, pi, di) in enumerate(fpd_linds):
                    y = G_abs.local_array[li:hi, ii]
                    inds = np.where(np.isfinite(y))[0]
                    if len(inds) >= max(4, 0.5 * len(y)):
                        # get the approximate magnitude by averaging the central G_abs
                        mag = np.mean(
                            y[inds]
                        )  # = \sqrt(lambda^2 * Sc / (2 k_B)) * |gi| Ai
                        # solve phase by least square fit
                        ui = (feedpos[di] - feedpos[0]) * (
                            1.0e6 * freq[fi]
                        ) / const.c  # position of this feed (relative to the first feed) in unit of wavelength
                        exp_factor = np.exp(2.0J * np.pi * np.dot(n0, ui))
                        ef = exp_factor
                        Gi = Gain.local_array[li:hi, ii]
                        e_phs = np.dot(ef[inds].conj(), Gi[inds] /
                                       y[inds]) / len(inds)  # the phase of gi
                        ea = np.abs(e_phs)
                        if np.abs(ea - 1.0) < 0.1:
                            # compute gain for this feed
                            lgain[
                                ii] = mag * e_phs  # \sqrt(lambda^2 * Sc / (2 k_B)) * gi Ai
                            if save_phs_change:
                                lphs[:, ii] = np.angle(
                                    np.exp(-2.0J * np.pi * np.dot(n0t, ui)) *
                                    Gain.local_array[:, ii])
                        else:
                            e_phs_conj = np.dot(ef[inds],
                                                Gi[inds] / y[inds]) / len(inds)
                            eac = np.abs(e_phs_conj)
                            if eac > ea:
                                if np.abs(eac - 1.0) < 0.01:
                                    print 'feedno = %d, fi = %d, pol = %s: may need to be conjugated' % (
                                        feedno[di], fi, gain_pd[pi])
                            else:
                                print 'feedno = %d, fi = %d, pol = %s: maybe wrong abs(e_phs): %s' % (
                                    feedno[di], fi, gain_pd[pi], ea)

                # gather local gain
                gain = mpiutil.gather_array(lgain,
                                            axis=0,
                                            root=None,
                                            comm=ts.comm)
                del lgain
                gain = gain.reshape(nf, 2, nfeed)
                if save_phs_change:
                    phs = mpiutil.gather_array(lphs,
                                               axis=1,
                                               root=0,
                                               comm=ts.comm)
                    del lphs
                    if mpiutil.rank0:
                        phs = phs.reshape(nt, nf, 2, nfeed)

                # normalize to get the exact gain
                Sc = s.get_jys(1.0e-3 * freq)
                # Omega = aa.ants[0].beam.Omega ### TODO: implement Omega for dish
                Ai = aa.ants[0].beam.response(n0[ci - li])
                lmd = const.c / (1.0e6 * freq)
                factor = np.sqrt(
                    (lmd**2 * 1.0e-26 * Sc) /
                    (2 * const.k_B)) * Ai  # NOTE: 1Jy = 1.0e-26 W m^-2 Hz^-1
                gain /= factor[:, np.newaxis, np.newaxis]

                # apply gain to vis
                if apply_gain:
                    for fi in range(nf):
                        for pi in [pol.index('xx'), pol.index('yy')]:
                            pi_ = gain_pd[pol[pi]]
                            for bi, (fd1, fd2) in enumerate(
                                    ts['blorder'].local_data):
                                g1 = gain[fi, pi_, feedno.index(fd1)]
                                g2 = gain[fi, pi_, feedno.index(fd2)]
                                if np.isfinite(g1) and np.isfinite(g2):
                                    if fd1 == fd2:
                                        # auto-correlation should be real
                                        ts.local_vis[:, fi, pi,
                                                     bi] /= (g1 *
                                                             np.conj(g2)).real
                                    else:
                                        ts.local_vis[:, fi, pi,
                                                     bi] /= (g1 * np.conj(g2))
                                else:
                                    # mask the un-calibrated vis
                                    ts.local_vis_mask[:, fi, pi, bi] = True

                    # in unit K after the calibration
                    ts.vis.attrs['unit'] = 'K'

                # save gain to file
                if save_gain:
                    if tag_output_iter:
                        gain_file = output_path(gain_file,
                                                iteration=self.iteration)
                    else:
                        gain_file = output_path(gain_file)
                    if mpiutil.rank0:
                        with h5py.File(gain_file, 'w') as f:
                            # allocate space for Gain
                            dset = f.create_dataset('Gain', (nt, nf, 2, nfeed),
                                                    dtype=Gain.dtype)
                            dset.attrs['calibrator'] = calibrator
                            dset.attrs['dim'] = 'time, freq, pol, feed'
                            try:
                                dset.attrs['time'] = ts.time[start_ind:end_ind]
                            except RuntimeError:
                                f.create_dataset(
                                    'time', data=ts.time[start_ind:end_ind])
                                dset.attrs['time'] = '/time'
                            dset.attrs['freq'] = freq
                            dset.attrs['pol'] = np.array(['xx', 'yy'])
                            dset.attrs['feed'] = np.array(feedno)
                            dset.attrs['transit_ind'] = transit_ind
                            # save gain
                            dset = f.create_dataset('gain', data=gain)
                            dset.attrs['calibrator'] = calibrator
                            dset.attrs['dim'] = 'freq, pol, feed'
                            dset.attrs['freq'] = freq
                            dset.attrs['pol'] = np.array(['xx', 'yy'])
                            dset.attrs['feed'] = np.array(feedno)
                            # save phs
                            if save_phs_change:
                                f.create_dataset('phs', data=phs)

                    mpiutil.barrier()

                    # save Gain
                    for i in range(10):
                        try:
                            # NOTE: if write simultaneously, will loss data with processes distributed in several nodes
                            for ri in xrange(mpiutil.size):
                                if ri == mpiutil.rank:
                                    with h5py.File(gain_file, 'r+') as f:
                                        for ii, (ti, fi,
                                                 pi) in enumerate(tfp_linds):
                                            ti_ = ti - start_ind
                                            pi_ = gain_pd[pol[pi]]
                                            f['Gain'][ti_, fi, pi_] = lGain[ii]
                                mpiutil.barrier()
                            break
                        except IOError:
                            time.sleep(0.5)
                            continue
                    else:
                        raise RuntimeError('Could not open file: %s...' %
                                           gain_file)

                    mpiutil.barrier()

        return super(PsCal, self).process(ts)
    def load_tod_excl_main_data(self):
        """Load time ordered attributes and datasets (exclude the main data) from all files."""

        super(TimestreamCommon, self).load_tod_excl_main_data()

        if 'sec1970' not in self.iterkeys():
            # generate sec1970
            int_time = self.infiles[0].attrs['inttime']
            sec1970s = []
            nts = []
            for fh in mpiutil.mpilist(self.infiles, method='con', comm=self.comm):
                sec1970s.append(fh.attrs['sec1970'])
                nts.append(fh[self.main_data_name].shape[0])
            sec1970 = np.zeros(sum(nts), dtype=np.float64) # precision float32 is not enough
            cum_nts = np.cumsum([0] + nts)
            for idx, (nt, sec) in enumerate(zip(nts, sec1970s)):
                sec1970[cum_nts[idx]:cum_nts[idx+1]] = np.array([ sec + i*int_time for i in xrange(nt)], dtype=np.float64) # precision float32 is not enough
            # gather local sec1970
            sec1970 = mpiutil.gather_array(sec1970, root=None, comm=self.comm)
            # select the corresponding section
            sec1970 = sec1970[self.main_data_start:self.main_data_stop][self.main_data_select[0]]

            # if time is just the distributed axis, load sec1970 distributed
            if 'time' == self.main_data_axes[self.main_data_dist_axis]:
                sec1970 = mpiarray.MPIArray.from_numpy_array(sec1970)
            self.create_main_time_ordered_dataset('sec1970', data=sec1970)
            # create attrs of this dset
            self['sec1970'].attrs["unit"] = 'second'
            # determine if it is continuous in time
            sec_diff = np.diff(sec1970)
            break_inds = np.where(sec_diff>1.5*int_time)[0]
            if len(break_inds) > 0:
                self['sec1970'].attrs["continuous"] = False
                self['sec1970'].attrs["break_inds"] = break_inds + 1
            else:
                self['sec1970'].attrs["continuous"] = True

            # generate julian date
            jul_date = np.array([ date_util.get_juldate(datetime.fromtimestamp(s), tzone=self.infiles[0].attrs['timezone']) for s in sec1970 ], dtype=np.float64) # precision float32 is not enough
            if 'time' == self.main_data_axes[self.main_data_dist_axis]:
                jul_date = mpiarray.MPIArray.wrap(jul_date, axis=0)
            # if time is just the distributed axis, load jul_date distributed
            self.create_main_time_ordered_dataset('jul_date', data=jul_date)
            # create attrs of this dset
            self['jul_date'].attrs["unit"] = 'day'

            # generate local time in hour from 0 to 24.0
            def _hour(t):
                return t.hour + t.minute/60.0 + t.second/3600.0 + t.microsecond/3.6e8
            local_hour = np.array([ _hour(datetime.fromtimestamp(s).time()) for s in sec1970 ], dtype=np.float64)
            if 'time' == self.main_data_axes[self.main_data_dist_axis]:
                local_hour = mpiarray.MPIArray.wrap(local_hour, axis=0)
            # if time is just the distributed axis, load local_hour distributed
            self.create_main_time_ordered_dataset('local_hour', data=local_hour)
            # create attrs of this dset
            self['local_hour'].attrs["unit"] = 'hour'

            # generate az, alt
            az_alt = np.zeros((self['sec1970'].local_data.shape[0], 2), dtype=np.float32) # radians
            if self.is_dish:
                # antpointing = rt['antpointing'][-1, :, :] # degree
                # pointingtime = rt['pointingtime'][-1, :, :] # degree
                az_alt[:, 0] = 0.0 # az
                az_alt[:, 1] = np.pi/2 # alt
            elif self.is_cylinder:
                az_alt[:, 0] = np.pi/2 # az
                az_alt[:, 1] = np.pi/2 # alt
            else:
                raise RuntimeError('Unknown antenna type %s' % self.attrs['telescope'])

            # generate ra, dec of the antenna pointing
            aa = self.array
            ra_dec = np.zeros_like(az_alt) # radians
            for ti in xrange(az_alt.shape[0]):
                az, alt = az_alt[ti]
                az, alt = ephem.degrees(az), ephem.degrees(alt)
                aa.set_jultime(self['jul_date'].local_data[ti])
                ra_dec[ti] = aa.radec_of(az, alt) # in radians, a point in the sky above the observer

            if self.main_data_dist_axis == 0:
                az_alt = mpiarray.MPIArray.wrap(az_alt, axis=0)
                ra_dec = mpiarray.MPIArray.wrap(ra_dec, axis=0)
            # if time is just the distributed axis, create distributed datasets
            self.create_main_time_ordered_dataset('az_alt', data=az_alt)
            self['az_alt'].attrs['unit'] = 'radian'
            self.create_main_time_ordered_dataset('ra_dec', data=ra_dec)
            self['ra_dec'].attrs['unit'] = 'radian'

            # determin if it is the same pointing
            if self.main_data_dist_axis == 0:
                az_alt = az_alt.local_array
                ra_dec = ra_dec.local_array
            # gather local az_alt
            az_alt = mpiutil.gather_array(az_alt, root=None, comm=self.comm)
            if np.allclose(az_alt[:, 0], az_alt[0, 0]) and np.allclose(az_alt[:, 1], az_alt[0, 1]):
                self['az_alt'].attrs['same_pointing'] = True
            else:
                self['az_alt'].attrs['same_pointing'] = False
            # determin if it is the same dec
            # gather local ra_dec
            ra_dec = mpiutil.gather_array(ra_dec, root=None, comm=self.comm)
            if np.allclose(ra_dec[:, 1], ra_dec[0, 1]):
                self['ra_dec'].attrs['same_dec'] = True
            else:
                self['ra_dec'].attrs['same_dec'] = False