Пример #1
0
    def wrap(cls, array, axis, comm=None):
        """Turn a set of numpy arrays into a distributed MPIArray object.

        This is needed for functions such as `np.fft.fft` which always return
        an `np.ndarray`.

        Parameters
        ----------
        array : np.ndarray
            Array to wrap.
        axis : integer
            Axis over which the array is distributed. The lengths are checked
            to try and ensure this is correct.
        comm : MPI.Comm, optional
            The communicator over which the array is distributed. If `None`
            (default), use `MPI.COMM_WORLD`.

        Returns
        -------
        dist_array : MPIArray
            An MPIArray view of the input.
        """

        # from mpi4py import MPI

        if comm is None:
            comm = mpiutil.world

        # Get axis length, both locally, and globally
        axlen = array.shape[axis]
        totallen = mpiutil.allreduce(axlen, comm=comm)

        # Figure out what the distributed layout should be
        local_num, local_start, local_end = mpiutil.split_local(totallen, comm=comm)

        # Check the local layout is consistent with what we expect, and send
        # result to all ranks
        layout_issue = mpiutil.allreduce(axlen != local_num, op=mpiutil.MAX, comm=comm)

        if layout_issue:
            raise Exception("Cannot wrap, distributed axis local length is incorrect.")

        # Set shape and offset
        lshape = array.shape
        global_shape = list(lshape)
        global_shape[axis] = totallen

        loffset = [0] * len(lshape)
        loffset[axis] = local_start

        # Setup attributes of class
        dist_arr = array.view(cls)
        dist_arr._global_shape = tuple(global_shape)
        dist_arr._axis = axis
        dist_arr._local_shape = tuple(lshape)
        dist_arr._local_offset = tuple(loffset)
        dist_arr._comm = comm

        return dist_arr
Пример #2
0
    def process(self, ts):

        assert isinstance(
            ts, Timestream
        ), '%s only works for Timestream object' % self.__class__.__name__

        calibrator = self.params['calibrator']
        catalog = self.params['catalog']
        vis_conj = self.params['vis_conj']
        zero_diag = self.params['zero_diag']
        span = self.params['span']
        reserve_high_gain = self.params['reserve_high_gain']
        plot_figs = self.params['plot_figs']
        fig_prefix = self.params['fig_name']
        tag_output_iter = self.params['tag_output_iter']
        save_src_vis = self.params['save_src_vis']
        src_vis_file = self.params['src_vis_file']
        subtract_src = self.params['subtract_src']
        replace_with_src = self.params['replace_with_src']
        apply_gain = self.params['apply_gain']
        save_gain = self.params['save_gain']
        save_phs_change = self.params['save_phs_change']
        gain_file = self.params['gain_file']
        # temperature_convert = self.params['temperature_convert']
        show_progress = self.params['show_progress']
        progress_step = self.params['progress_step']

        if save_src_vis or subtract_src or apply_gain or save_gain:
            pol_type = ts['pol'].attrs['pol_type']
            if pol_type != 'linear':
                raise RuntimeError('Can not do ps_cal for pol_type: %s' %
                                   pol_type)

            ts.redistribute('baseline')

            feedno = ts['feedno'][:].tolist()
            pol = [ts.pol_dict[p] for p in ts['pol'][:]]  # as string
            gain_pd = {
                'xx': 0,
                'yy': 1,
                0: 'xx',
                1: 'yy'
            }  # for gain related op
            bls = mpiutil.gather_array(ts.local_bl[:], root=None, comm=ts.comm)
            # # antpointing = np.radians(ts['antpointing'][-1, :, :]) # radians
            # transitsource = ts['transitsource'][:]
            # transit_time = transitsource[-1, 0] # second, sec1970
            # int_time = ts.attrs['inttime'] # second

            # get the calibrator
            try:
                s = calibrators.get_src(calibrator)
            except KeyError:
                if mpiutil.rank0:
                    print 'Calibrator %s is unavailable, available calibrators are:'
                    for key, d in calibrators.src_data.items():
                        print '%8s  ->  %12s' % (key, d[0])
                raise RuntimeError('Calibrator %s is unavailable')
            if mpiutil.rank0:
                print 'Try to calibrate with %s...' % s.src_name

            # get transit time of calibrator
            # array
            aa = ts.array
            aa.set_jultime(ts['jul_date'][0])  # the first obs time point
            next_transit = aa.next_transit(s)
            transit_time = a.phs.ephem2juldate(next_transit)  # Julian date
            # get time zone
            pattern = '[-+]?\d+'
            tz = re.search(pattern, ts.attrs['timezone']).group()
            tz = int(tz)
            local_next_transit = ephem.Date(
                next_transit + tz * ephem.hour)  # plus 8h to get Beijing time
            # if transit_time > ts['jul_date'][-1]:
            if transit_time > max(ts['jul_date'][-1], ts['jul_date'][:].max()):
                raise RuntimeError(
                    'Data does not contain local transit time %s of source %s'
                    % (local_next_transit, calibrator))

            # the first transit index
            transit_inds = [np.searchsorted(ts['jul_date'][:], transit_time)]
            # find all other transit indices
            aa.set_jultime(ts['jul_date'][0] + 1.0)
            transit_time = a.phs.ephem2juldate(
                aa.next_transit(s))  # Julian date
            cnt = 2
            while (transit_time <= ts['jul_date'][-1]):
                transit_inds.append(
                    np.searchsorted(ts['jul_date'][:], transit_time))
                aa.set_jultime(ts['jul_date'][0] + 1.0 * cnt)
                transit_time = a.phs.ephem2juldate(
                    aa.next_transit(s))  # Julian date
                cnt += 1

            if mpiutil.rank0:
                print 'transit ind of %s: %s, time: %s' % (
                    s.src_name, transit_inds, local_next_transit)

            ### now only use the first transit point to do the cal
            ### may need to improve in the future
            transit_ind = transit_inds[0]
            int_time = ts.attrs['inttime']  # second
            start_ind = transit_ind - np.int(span / int_time)
            end_ind = transit_ind + np.int(
                span /
                int_time) + 1  # plus 1 to make transit_ind is at the center

            start_ind = max(0, start_ind)
            end_ind = min(end_ind, ts.vis.shape[0])

            if vis_conj:
                ts.local_vis[:] = ts.local_vis.conj()

            nt = end_ind - start_ind
            t_inds = range(start_ind, end_ind)
            freq = ts.freq[:]  # MHz
            nf = len(freq)
            nlb = len(ts.local_bl[:])
            nfeed = len(feedno)
            tfp_inds = list(
                itertools.product(
                    t_inds, range(nf),
                    [pol.index('xx'), pol.index('yy')]))  # only for xx and yy
            ns, ss, es = mpiutil.split_all(len(tfp_inds), comm=ts.comm)
            # gather data to make each process to have its own data which has all bls
            for ri, (ni, si, ei) in enumerate(zip(ns, ss, es)):
                lvis = np.zeros((ni, nlb), dtype=ts.vis.dtype)
                lvis_mask = np.zeros((ni, nlb), dtype=ts.vis_mask.dtype)
                for ii, (ti, fi, pi) in enumerate(tfp_inds[si:ei]):
                    lvis[ii] = ts.local_vis[ti, fi, pi]
                    lvis_mask[ii] = ts.local_vis_mask[ti, fi, pi]
                # gather vis from all process for separate bls
                gvis = mpiutil.gather_array(lvis,
                                            axis=1,
                                            root=ri,
                                            comm=ts.comm)
                gvis_mask = mpiutil.gather_array(lvis_mask,
                                                 axis=1,
                                                 root=ri,
                                                 comm=ts.comm)
                if ri == mpiutil.rank:
                    tfp_linds = tfp_inds[si:ei]  # inds for this process
                    this_vis = gvis
                    this_vis_mask = gvis_mask
            del tfp_inds
            del lvis
            del lvis_mask
            tfp_len = len(tfp_linds)

            # lotl_mask = np.zeros((tfp_len, nfeed, nfeed), dtype=bool)
            cnan = complex(np.nan, np.nan)  # complex nan
            if save_src_vis or subtract_src:
                # save calibrator src vis
                lsrc_vis = np.full((tfp_len, nfeed, nfeed),
                                   cnan,
                                   dtype=ts.vis.dtype)
                if save_src_vis:
                    # save sky vis
                    lsky_vis = np.full((tfp_len, nfeed, nfeed),
                                       cnan,
                                       dtype=ts.vis.dtype)
                    # save outlier vis
                    lotl_vis = np.full((tfp_len, nfeed, nfeed),
                                       cnan,
                                       dtype=ts.vis.dtype)

            if apply_gain or save_gain:
                lGain = np.full((tfp_len, nfeed), cnan, dtype=ts.vis.dtype)

            # find indices mapping between Vmat and vis
            # bis = range(nbl)
            bis_conj = []  # indices that shold be conj
            mis = [
            ]  # indices in the nfeed x nfeed matrix by flatten it to a vector
            mis_conj = [
            ]  # indices (of conj vis) in the nfeed x nfeed matrix by flatten it to a vector
            for bi, (fdi, fdj) in enumerate(bls):
                ai, aj = feedno.index(fdi), feedno.index(fdj)
                mis.append(ai * nfeed + aj)
                if ai != aj:
                    bis_conj.append(bi)
                    mis_conj.append(aj * nfeed + ai)

            # construct visibility matrix for a single time, freq, pol
            Vmat = np.full((nfeed, nfeed), cnan, dtype=ts.vis.dtype)
            # get flus of the calibrator in the observing frequencies
            if show_progress and mpiutil.rank0:
                pg = progress.Progress(tfp_len, step=progress_step)
            for ii, (ti, fi, pi) in enumerate(tfp_linds):
                if show_progress and mpiutil.rank0:
                    pg.show(ii)
                # when noise on, just pass
                if 'ns_on' in ts.iterkeys() and ts['ns_on'][ti]:
                    continue
                # aa.set_jultime(ts['jul_date'][ti])
                # s.compute(aa)
                # get the topocentric coordinate of the calibrator at the current time
                # s_top = s.get_crds('top', ncrd=3)
                # aa.sim_cache(cat.get_crds('eq', ncrd=3)) # for compute bm_response and sim
                Vmat.flat[mis] = np.ma.array(
                    this_vis[ii], mask=this_vis_mask[ii]).filled(cnan)
                Vmat.flat[mis_conj] = np.ma.array(
                    this_vis[ii, bis_conj],
                    mask=this_vis_mask[ii, bis_conj]).conj().filled(cnan)

                if save_src_vis:
                    lsky_vis[ii] = Vmat

                # set invalid val to 0
                invalid = ~np.isfinite(Vmat)  # a bool array
                # if too many masks
                if np.where(invalid)[0].shape[0] > 0.3 * nfeed**2:
                    continue
                Vmat[invalid] = 0
                # if all are zeros
                if np.allclose(Vmat, 0.0):
                    continue

                # fill diagonal of Vmat to 0
                if zero_diag:
                    np.fill_diagonal(Vmat, 0)

                # initialize the outliers
                med = np.median(Vmat.real) + 1.0J * np.median(Vmat.imag)
                diff = Vmat - med
                S0 = np.where(
                    np.abs(diff) > 3.0 * rpca_decomp.MAD(Vmat), diff, 0)
                # stable PCA decomposition
                V0, S = rpca_decomp.decompose(Vmat,
                                              rank=1,
                                              S=S0,
                                              max_iter=200,
                                              threshold='hard',
                                              tol=1.0e-6,
                                              debug=False)

                # # find abnormal values in S
                # # first check diagonal elements
                # import pdb; pdb.set_trace()
                # svals = np.diag(S)
                # smed = np.median(svals.real) + 1.0J * np.median(svals.imag)
                # smad = rpca_decomp.MAD(svals)
                # # abnormal indices
                # abis =  np.where(np.abs(svals - smed) > 3.0 * smad)[0]
                # for abi in abis:
                #     lotl_mask[ii, abi, abi] = True
                # # then check non-diagonal elements
                # for rii in range(nfeed):
                #     for cii in range(nfeed):
                #         if rii == cii:
                #             continue
                #         rli = max(0, rii-2)
                #         rhi = min(nfeed, rii+3)
                #         cli = max(0, cii-2)
                #         chi = min(nfeed, cii+3)
                #         svals = np.array([ S[xi, yi] for xi in range(rli, rhi) for yi in range(cli, chi) if xi != yi ])
                #         smed = np.median(svals.real) + 1.0J * np.median(svals.imag)
                #         smad = rpca_decomp.MAD(svals)
                #         if np.abs(S[rii, cii] - smed) > 3.0 * smad:
                #             lotl_mask[ii, rii, cii] = True

                if save_src_vis or subtract_src:
                    lsrc_vis[ii] = V0
                    if save_src_vis:
                        lotl_vis[ii] = S

                # plot
                if plot_figs:
                    ind = ti - start_ind
                    # plot Vmat
                    plt.figure(figsize=(13, 5))
                    plt.subplot(121)
                    plt.imshow(Vmat.real,
                               aspect='equal',
                               origin='lower',
                               interpolation='nearest')
                    plt.colorbar(shrink=1.0)
                    plt.subplot(122)
                    plt.imshow(Vmat.imag,
                               aspect='equal',
                               origin='lower',
                               interpolation='nearest')
                    plt.colorbar(shrink=1.0)
                    fig_name = '%s_V_%d_%d_%s.png' % (fig_prefix, ind, fi,
                                                      pol[pi])
                    if tag_output_iter:
                        fig_name = output_path(fig_name,
                                               iteration=self.iteration)
                    else:
                        fig_name = output_path(fig_name)
                    plt.savefig(fig_name)
                    plt.close()
                    # plot V0
                    plt.figure(figsize=(13, 5))
                    plt.subplot(121)
                    plt.imshow(V0.real,
                               aspect='equal',
                               origin='lower',
                               interpolation='nearest')
                    plt.colorbar(shrink=1.0)
                    plt.subplot(122)
                    plt.imshow(V0.imag,
                               aspect='equal',
                               origin='lower',
                               interpolation='nearest')
                    plt.colorbar(shrink=1.0)
                    fig_name = '%s_V0_%d_%d_%s.png' % (fig_prefix, ind, fi,
                                                       pol[pi])
                    if tag_output_iter:
                        fig_name = output_path(fig_name,
                                               iteration=self.iteration)
                    else:
                        fig_name = output_path(fig_name)
                    plt.savefig(fig_name)
                    plt.close()
                    # plot S
                    plt.figure(figsize=(13, 5))
                    plt.subplot(121)
                    plt.imshow(S.real,
                               aspect='equal',
                               origin='lower',
                               interpolation='nearest')
                    plt.colorbar(shrink=1.0)
                    plt.subplot(122)
                    plt.imshow(S.imag,
                               aspect='equal',
                               origin='lower',
                               interpolation='nearest')
                    plt.colorbar(shrink=1.0)
                    fig_name = '%s_S_%d_%d_%s.png' % (fig_prefix, ind, fi,
                                                      pol[pi])
                    if tag_output_iter:
                        fig_name = output_path(fig_name,
                                               iteration=self.iteration)
                    else:
                        fig_name = output_path(fig_name)
                    plt.savefig(fig_name)
                    plt.close()
                    # plot N
                    N = Vmat - V0 - S
                    plt.figure(figsize=(13, 5))
                    plt.subplot(121)
                    plt.imshow(N.real,
                               aspect='equal',
                               origin='lower',
                               interpolation='nearest')
                    plt.colorbar(shrink=1.0)
                    plt.subplot(122)
                    plt.imshow(N.imag,
                               aspect='equal',
                               origin='lower',
                               interpolation='nearest')
                    plt.colorbar(shrink=1.0)
                    fig_name = '%s_N_%d_%d_%s.png' % (fig_prefix, ind, fi,
                                                      pol[pi])
                    if tag_output_iter:
                        fig_name = output_path(fig_name,
                                               iteration=self.iteration)
                    else:
                        fig_name = output_path(fig_name)
                    plt.savefig(fig_name)
                    plt.close()

                if apply_gain or save_gain:
                    # use v_ij = gi gj^* \int Ai Aj^* e^(2\pi i n \cdot uij) T(x) d^2n
                    # precisely, we shold have
                    # V0 = (lambda^2 * Sc / (2 k_B)) * gi gj^* Ai Aj^* e^(2\pi i n0 \cdot uij)
                    e, U = la.eigh(V0, eigvals=(nfeed - 1, nfeed - 1))
                    g = U[:, -1] * e[
                        -1]**0.5  # = \sqrt(lambda^2 * Sc / (2 k_B)) * gi Ai * e^(2\pi i n0 \cdot ui)
                    if g[0].real < 0:
                        g *= -1.0  # make all g[0] phase 0, instead of pi
                    lGain[ii] = g

                    # plot Gain
                    if plot_figs:
                        plt.figure()
                        plt.plot(feedno, g.real, 'b-', label='real')
                        plt.plot(feedno, g.real, 'bo')
                        plt.plot(feedno, g.imag, 'g-', label='imag')
                        plt.plot(feedno, g.imag, 'go')
                        plt.plot(feedno, np.abs(g), 'r-', label='abs')
                        plt.plot(feedno, np.abs(g), 'ro')
                        plt.xlim(feedno[0] - 1, feedno[-1] + 1)
                        yl, yh = plt.ylim()
                        plt.ylim(yl, yh + (yh - yl) / 5)
                        plt.xlabel('Feed number')
                        plt.legend()
                        fig_name = '%s_ants_%d_%d_%s.png' % (fig_prefix, ind,
                                                             fi, pol[pi])
                        if tag_output_iter:
                            fig_name = output_path(fig_name,
                                                   iteration=self.iteration)
                        else:
                            fig_name = output_path(fig_name)
                        plt.savefig(fig_name)
                        plt.close()

            # # apply outlier mask
            # nbl = len(bls)
            # lom = np.zeros((lotl_mask.shape[0], nbl), dtype=lotl_mask.dtype)
            # for bi, (fd1, fd2) in enumerate(bls):
            #     b1, b2 = feedno.index(fd1), feedno.index(fd2)
            #     lom[:, bi] = lotl_mask[:, b1, b2]
            # lom = mpiarray.MPIArray.wrap(lom, axis=0, comm=ts.comm)
            # lom = lom.redistribute(axis=1).local_array.reshape(nt, nf, 2, -1)
            # ts.local_vis_mask[start_ind:end_ind, :, pol.index('xx')] |= lom[:, :, 0]
            # ts.local_vis_mask[start_ind:end_ind, :, pol.index('yy')] |= lom[:, :, 1]

            # subtract the vis of calibrator from self.vis
            if subtract_src:
                nbl = len(bls)
                lv = np.zeros((lsrc_vis.shape[0], nbl), dtype=lsrc_vis.dtype)
                for bi, (fd1, fd2) in enumerate(bls):
                    b1, b2 = feedno.index(fd1), feedno.index(fd2)
                    lv[:, bi] = lsrc_vis[:, b1, b2]
                lv = mpiarray.MPIArray.wrap(lv, axis=0, comm=ts.comm)
                lv = lv.redistribute(axis=1).local_array.reshape(nt, nf, 2, -1)
                if replace_with_src:
                    ts.local_vis[start_ind:end_ind, :,
                                 pol.index('xx')] = lv[:, :, 0]
                    ts.local_vis[start_ind:end_ind, :,
                                 pol.index('yy')] = lv[:, :, 1]
                else:
                    if 'ns_on' in ts.iterkeys():
                        lv[ts['ns_on']
                           [start_ind:
                            end_ind]] = 0  # avoid ns_on signal to become nan
                    ts.local_vis[start_ind:end_ind, :,
                                 pol.index('xx')] -= lv[:, :, 0]
                    ts.local_vis[start_ind:end_ind, :,
                                 pol.index('yy')] -= lv[:, :, 1]

                del lv

            if not save_src_vis:
                if subtract_src:
                    del lsrc_vis
            else:
                if tag_output_iter:
                    src_vis_file = output_path(src_vis_file,
                                               iteration=self.iteration)
                else:
                    src_vis_file = output_path(src_vis_file)
                # create file and allocate space first by rank0
                if mpiutil.rank0:
                    with h5py.File(src_vis_file, 'w') as f:
                        # allocate space
                        shp = (nt, nf, 2, nfeed, nfeed)
                        f.create_dataset('sky_vis', shp, dtype=lsky_vis.dtype)
                        f.create_dataset('src_vis', shp, dtype=lsrc_vis.dtype)
                        f.create_dataset('outlier_vis',
                                         shp,
                                         dtype=lotl_vis.dtype)
                        # f.create_dataset('outlier_mask', shp, dtype=lotl_mask.dtype)
                        f.attrs['calibrator'] = calibrator
                        f.attrs['dim'] = 'time, freq, pol, feed, feed'
                        try:
                            f.attrs['time'] = ts.time[start_ind:end_ind]
                        except RuntimeError:
                            f.create_dataset('time',
                                             data=ts.time[start_ind:end_ind])
                            f.attrs['time'] = '/time'
                        f.attrs['freq'] = freq
                        f.attrs['pol'] = np.array(['xx', 'yy'])
                        f.attrs['feed'] = np.array(feedno)

                mpiutil.barrier()

                # write data to file
                for i in range(10):
                    try:
                        # NOTE: if write simultaneously, will loss data with processes distributed in several nodes
                        for ri in xrange(mpiutil.size):
                            if ri == mpiutil.rank:
                                with h5py.File(src_vis_file, 'r+') as f:
                                    for ii, (ti, fi,
                                             pi) in enumerate(tfp_linds):
                                        ti_ = ti - start_ind
                                        pi_ = gain_pd[pol[pi]]
                                        f['sky_vis'][ti_, fi,
                                                     pi_] = lsky_vis[ii]
                                        f['src_vis'][ti_, fi,
                                                     pi_] = lsrc_vis[ii]
                                        f['outlier_vis'][ti_, fi,
                                                         pi_] = lotl_vis[ii]
                                        # f['outlier_mask'][ti_, fi, pi_] = lotl_mask[ii]
                            mpiutil.barrier()
                        break
                    except IOError:
                        time.sleep(0.5)
                        continue
                else:
                    raise RuntimeError('Could not open file: %s...' %
                                       src_vis_file)

                del lsrc_vis
                del lsky_vis
                del lotl_vis
                # del lotl_mask

                mpiutil.barrier()

            if apply_gain or save_gain:
                # flag outliers in lGain along each feed
                lG_abs = np.full_like(lGain, np.nan, dtype=lGain.real.dtype)
                for i in range(lGain.shape[0]):
                    valid_inds = np.where(np.isfinite(lGain[i]))[0]
                    if len(valid_inds) > 3:
                        vabs = np.abs(lGain[i, valid_inds])
                        vmed = np.median(vabs)
                        vabs_diff = np.abs(vabs - vmed)
                        vmad = np.median(vabs_diff) / 0.6745
                        if reserve_high_gain:
                            # reserve significantly higher ones, flag only significantly lower ones
                            lG_abs[i, valid_inds] = np.where(
                                vmed - vabs > 3.0 * vmad, np.nan, vabs)
                        else:
                            # flag both significantly higher and lower ones
                            lG_abs[i, valid_inds] = np.where(
                                vabs_diff > 3.0 * vmad, np.nan, vabs)

                # choose data slice near the transit time
                li = max(start_ind, transit_ind - 10) - start_ind
                hi = min(end_ind, transit_ind + 10 + 1) - start_ind
                ci = transit_ind - start_ind  # center index for transit_ind
                # compute s_top for this time range
                n0 = np.zeros(((hi - li), 3))
                for ti, jt in enumerate(ts.time[start_ind:end_ind][li:hi]):
                    aa.set_jultime(jt)
                    s.compute(aa)
                    n0[ti] = s.get_crds('top', ncrd=3)
                if save_phs_change:
                    n0t = np.zeros((nt, 3))
                    for ti, jt in enumerate(ts.time[start_ind:end_ind]):
                        aa.set_jultime(jt)
                        s.compute(aa)
                        n0t[ti] = s.get_crds('top', ncrd=3)

                # get the positions of feeds
                feedpos = ts['feedpos'][:]

                # wrap and redistribute Gain and flagged G_abs
                Gain = mpiarray.MPIArray.wrap(lGain, axis=0, comm=ts.comm)
                Gain = Gain.redistribute(axis=1).reshape(
                    nt, nf, 2, None).redistribute(axis=0).reshape(
                        None, nf * 2 * nfeed).redistribute(axis=1)
                G_abs = mpiarray.MPIArray.wrap(lG_abs, axis=0, comm=ts.comm)
                G_abs = G_abs.redistribute(axis=1).reshape(
                    nt, nf, 2, None).redistribute(axis=0).reshape(
                        None, nf * 2 * nfeed).redistribute(axis=1)

                fpd_inds = list(
                    itertools.product(range(nf), range(2),
                                      range(nfeed)))  # only for xx and yy
                fpd_linds = mpiutil.mpilist(fpd_inds,
                                            method='con',
                                            comm=ts.comm)
                del fpd_inds
                # create data to save the solved gain for each feed
                lgain = np.full((len(fpd_linds), ), cnan,
                                dtype=Gain.dtype)  # gain for each feed
                if save_phs_change:
                    lphs = np.full((nt, len(fpd_linds)),
                                   np.nan,
                                   dtype=Gain.real.dtype
                                   )  # phase change with time for each feed

                # check for conj
                num_conj = 0
                for ii, (fi, pi, di) in enumerate(fpd_linds):
                    y = G_abs.local_array[li:hi, ii]
                    inds = np.where(np.isfinite(y))[0]
                    if len(inds) >= max(4, 0.5 * len(y)):
                        # get the approximate magnitude by averaging the central G_abs
                        # solve phase by least square fit
                        ui = (feedpos[di] - feedpos[0]) * (
                            1.0e6 * freq[fi]
                        ) / const.c  # position of this feed (relative to the first feed) in unit of wavelength
                        exp_factor = np.exp(2.0J * np.pi * np.dot(n0, ui))
                        ef = exp_factor
                        Gi = Gain.local_array[li:hi, ii]
                        e_phs = np.dot(ef[inds].conj(),
                                       Gi[inds] / y[inds]) / len(inds)
                        ea = np.abs(e_phs)
                        e_phs_conj = np.dot(ef[inds],
                                            Gi[inds] / y[inds]) / len(inds)
                        eac = np.abs(e_phs_conj)
                        if eac > ea:
                            num_conj += 1
                # reduce num_conj from all processes
                num_conj = mpiutil.allreduce(num_conj, comm=ts.comm)
                if num_conj > 0.5 * (nf * 2 * nfeed):  # 2 for 2 pols
                    if mpiutil.rank0:
                        print '!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!'
                        print '!!!   Detect data should be their conjugate...   !!!'
                        print '!!!   Correct it automatically...                !!!'
                        print '!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!'
                    mpiutil.barrier()
                    # correct vis
                    ts.local_vis[:] = ts.local_vis.conj()
                    # correct G
                    Gain.local_array[:] = Gain.local_array.conj()

                # solve for gain
                for ii, (fi, pi, di) in enumerate(fpd_linds):
                    y = G_abs.local_array[li:hi, ii]
                    inds = np.where(np.isfinite(y))[0]
                    if len(inds) >= max(4, 0.5 * len(y)):
                        # get the approximate magnitude by averaging the central G_abs
                        mag = np.mean(
                            y[inds]
                        )  # = \sqrt(lambda^2 * Sc / (2 k_B)) * |gi| Ai
                        # solve phase by least square fit
                        ui = (feedpos[di] - feedpos[0]) * (
                            1.0e6 * freq[fi]
                        ) / const.c  # position of this feed (relative to the first feed) in unit of wavelength
                        exp_factor = np.exp(2.0J * np.pi * np.dot(n0, ui))
                        ef = exp_factor
                        Gi = Gain.local_array[li:hi, ii]
                        e_phs = np.dot(ef[inds].conj(), Gi[inds] /
                                       y[inds]) / len(inds)  # the phase of gi
                        ea = np.abs(e_phs)
                        if np.abs(ea - 1.0) < 0.1:
                            # compute gain for this feed
                            lgain[
                                ii] = mag * e_phs  # \sqrt(lambda^2 * Sc / (2 k_B)) * gi Ai
                            if save_phs_change:
                                lphs[:, ii] = np.angle(
                                    np.exp(-2.0J * np.pi * np.dot(n0t, ui)) *
                                    Gain.local_array[:, ii])
                        else:
                            e_phs_conj = np.dot(ef[inds],
                                                Gi[inds] / y[inds]) / len(inds)
                            eac = np.abs(e_phs_conj)
                            if eac > ea:
                                if np.abs(eac - 1.0) < 0.01:
                                    print 'feedno = %d, fi = %d, pol = %s: may need to be conjugated' % (
                                        feedno[di], fi, gain_pd[pi])
                            else:
                                print 'feedno = %d, fi = %d, pol = %s: maybe wrong abs(e_phs): %s' % (
                                    feedno[di], fi, gain_pd[pi], ea)

                # gather local gain
                gain = mpiutil.gather_array(lgain,
                                            axis=0,
                                            root=None,
                                            comm=ts.comm)
                del lgain
                gain = gain.reshape(nf, 2, nfeed)
                if save_phs_change:
                    phs = mpiutil.gather_array(lphs,
                                               axis=1,
                                               root=0,
                                               comm=ts.comm)
                    del lphs
                    if mpiutil.rank0:
                        phs = phs.reshape(nt, nf, 2, nfeed)

                # normalize to get the exact gain
                Sc = s.get_jys(1.0e-3 * freq)
                # Omega = aa.ants[0].beam.Omega ### TODO: implement Omega for dish
                Ai = aa.ants[0].beam.response(n0[ci - li])
                lmd = const.c / (1.0e6 * freq)
                factor = np.sqrt(
                    (lmd**2 * 1.0e-26 * Sc) /
                    (2 * const.k_B)) * Ai  # NOTE: 1Jy = 1.0e-26 W m^-2 Hz^-1
                gain /= factor[:, np.newaxis, np.newaxis]

                # apply gain to vis
                if apply_gain:
                    for fi in range(nf):
                        for pi in [pol.index('xx'), pol.index('yy')]:
                            pi_ = gain_pd[pol[pi]]
                            for bi, (fd1, fd2) in enumerate(
                                    ts['blorder'].local_data):
                                g1 = gain[fi, pi_, feedno.index(fd1)]
                                g2 = gain[fi, pi_, feedno.index(fd2)]
                                if np.isfinite(g1) and np.isfinite(g2):
                                    if fd1 == fd2:
                                        # auto-correlation should be real
                                        ts.local_vis[:, fi, pi,
                                                     bi] /= (g1 *
                                                             np.conj(g2)).real
                                    else:
                                        ts.local_vis[:, fi, pi,
                                                     bi] /= (g1 * np.conj(g2))
                                else:
                                    # mask the un-calibrated vis
                                    ts.local_vis_mask[:, fi, pi, bi] = True

                    # in unit K after the calibration
                    ts.vis.attrs['unit'] = 'K'

                # save gain to file
                if save_gain:
                    if tag_output_iter:
                        gain_file = output_path(gain_file,
                                                iteration=self.iteration)
                    else:
                        gain_file = output_path(gain_file)
                    if mpiutil.rank0:
                        with h5py.File(gain_file, 'w') as f:
                            # allocate space for Gain
                            dset = f.create_dataset('Gain', (nt, nf, 2, nfeed),
                                                    dtype=Gain.dtype)
                            dset.attrs['calibrator'] = calibrator
                            dset.attrs['dim'] = 'time, freq, pol, feed'
                            try:
                                dset.attrs['time'] = ts.time[start_ind:end_ind]
                            except RuntimeError:
                                f.create_dataset(
                                    'time', data=ts.time[start_ind:end_ind])
                                dset.attrs['time'] = '/time'
                            dset.attrs['freq'] = freq
                            dset.attrs['pol'] = np.array(['xx', 'yy'])
                            dset.attrs['feed'] = np.array(feedno)
                            dset.attrs['transit_ind'] = transit_ind
                            # save gain
                            dset = f.create_dataset('gain', data=gain)
                            dset.attrs['calibrator'] = calibrator
                            dset.attrs['dim'] = 'freq, pol, feed'
                            dset.attrs['freq'] = freq
                            dset.attrs['pol'] = np.array(['xx', 'yy'])
                            dset.attrs['feed'] = np.array(feedno)
                            # save phs
                            if save_phs_change:
                                f.create_dataset('phs', data=phs)

                    mpiutil.barrier()

                    # save Gain
                    for i in range(10):
                        try:
                            # NOTE: if write simultaneously, will loss data with processes distributed in several nodes
                            for ri in xrange(mpiutil.size):
                                if ri == mpiutil.rank:
                                    with h5py.File(gain_file, 'r+') as f:
                                        for ii, (ti, fi,
                                                 pi) in enumerate(tfp_linds):
                                            ti_ = ti - start_ind
                                            pi_ = gain_pd[pol[pi]]
                                            f['Gain'][ti_, fi, pi_] = lGain[ii]
                                mpiutil.barrier()
                            break
                        except IOError:
                            time.sleep(0.5)
                            continue
                    else:
                        raise RuntimeError('Could not open file: %s...' %
                                           gain_file)

                    mpiutil.barrier()

        return super(PsCal, self).process(ts)
Пример #3
0
    def process(self, rt):

        assert isinstance(
            rt, RawTimestream
        ), '%s only works for RawTimestream object currently' % self.__class__.__name__

        channel = self.params['channel']
        sigma = self.params['sigma']
        mask_near = max(0, int(self.params['mask_near']))

        rt.redistribute(0)  # make time the dist axis

        auto_inds = np.where(
            rt.bl[:, 0] == rt.bl[:,
                                 1])[0].tolist()  # inds for auto-correlations
        channels = [rt.bl[ai, 0] for ai in auto_inds]  # all chosen channels
        if channel is not None:
            if channel in channels:
                bl_ind = auto_inds[channels.index(channel)]
            else:
                bl_ind = auto_inds[0]
                if mpiutil.rank0:
                    print 'Warning: Required channel %d doen not in the data, use channel %d instead' % (
                        channel, rt.bl[bl_ind, 0])
        else:
            bl_ind = auto_inds[0]
        # move the chosen channel to the first
        auto_inds.remove(bl_ind)
        auto_inds = [bl_ind] + auto_inds

        for bl_ind in auto_inds:
            this_chan = rt.bl[bl_ind, 0]  # channel of this bl_ind
            vis = np.ma.array(rt.local_vis[:, :, bl_ind].real,
                              mask=rt.local_vis_mask[:, :, bl_ind])
            cnt = vis.count()  # number of not masked vals
            total_cnt = mpiutil.allreduce(cnt)
            vis_shp = rt.vis.shape
            ratio = float(total_cnt) / np.prod(
                (vis_shp[0], vis_shp[1]))  # ratio of un-masked vals
            if ratio < 0.5:  # too many masked vals
                if mpiutil.rank0:
                    warnings.warn(
                        'Too many masked values for auto-correlation of Channel: %d, does not use it'
                        % this_chan)
                continue

            tt_mean = mpiutil.gather_array(np.ma.mean(vis, axis=-1).filled(0),
                                           root=None)
            df = np.diff(tt_mean, axis=-1)
            pdf = np.where(df > 0, df, 0)
            pinds = np.where(pdf > pdf.mean() + sigma * pdf.std())[0]
            pinds = pinds + 1
            pinds1 = [pinds[0]]
            for pi in pinds[1:]:
                if pi - pinds1[-1] > 1:
                    pinds1.append(pi)
            pinds = np.array(pinds1)
            pT = Counter(
                np.diff(pinds)).most_common(1)[0][0]  # period of pinds

            ndf = np.where(df < 0, df, 0)
            ninds = np.where(ndf < ndf.mean() - sigma * ndf.std())[0]
            ninds = ninds + 1
            ninds = ninds[::-1]
            ninds1 = [ninds[0]]
            for ni in ninds[1:]:
                if ni - ninds1[-1] < -1:
                    ninds1.append(ni)
            ninds = np.array(ninds1[::-1])
            nT = Counter(
                np.diff(ninds)).most_common(1)[0][0]  # period of ninds

            if pT != nT:  # failed to detect correct period
                if mpiutil.rank0:
                    warnings.warn(
                        'Failed to detect correct period for auto-correlation of Channel: %d, positive T %d != negative T %d, does not use it'
                        % (this_chan, pT, nT))
                continue
            else:
                period = pT

            ninds = ninds.reshape(-1, 1)
            dinds = (ninds - pinds).flatten()
            on_time = Counter(dinds[dinds > 0] % period).most_common(1)[0][0]
            off_time = Counter(-dinds[dinds < 0] % period).most_common(1)[0][0]

            if period != on_time + off_time:  # incorrect detect
                if mpiutil.rank0:
                    warnings.warn(
                        'Incorrect detect for auto-correlation of Channel: %d, period %d != on_time %d + off_time %d, does not use it'
                        % (this_chan, period, on_time, off_time))
                continue
            else:
                if 'noisesource' in rt.iterkeys():
                    if rt['noisesource'].shape[0] == 1:  # only 1 noise source
                        start, stop, cycle = rt['noisesource'][0, :]
                        int_time = rt.attrs['inttime']
                        true_on_time = np.round((stop - start) / int_time)
                        true_period = np.round(cycle / int_time)
                        if on_time != true_on_time and period != true_period:  # inconsistant with the record in the data
                            if mpiutil.rank0:
                                warnings.warn(
                                    'Detected noise source info is inconsistant with the record in the data for auto-correlation of Channel: %d: on_time %d != record_on_time %d, period != record_period %d, does not use it'
                                    % (this_chan, on_time, true_on_time,
                                       period, true_period))
                            continue
                    elif rt['noisesource'].shape[
                            0] >= 2:  # more than 1 noise source
                        if mpiutil.rank0:
                            warnings.warn(
                                'More than 1 noise source, do not know how to deal with this currently'
                            )

                # break if succeed
                break

        else:
            raise RuntimeError('Failed to detect noise source signal')

        if mpiutil.rank0:
            print 'Detected noise source: period = %d, on_time = %d, off_time = %d' % (
                period, on_time, off_time)
        num_period = np.int(np.ceil(len(tt_mean) / np.float(period)))
        tmp_ns_on = np.array(([True] * on_time + [False] * off_time) *
                             num_period)[:len(tt_mean)]
        on_start = Counter(pinds % period).most_common(1)[0][0]
        ns_on = np.roll(tmp_ns_on, on_start)

        # import matplotlib
        # matplotlib.use('Agg')
        # import matplotlib.pyplot as plt
        # plt.figure()
        # plt.plot(np.where(ns_on, np.nan, tt_mean))
        # # plt.plot(pinds, tt_mean[pinds], 'RI')
        # # plt.plot(ninds, tt_mean[ninds], 'go')
        # plt.savefig('df.png')
        # err

        ns_on1 = mpiarray.MPIArray.from_numpy_array(ns_on)

        rt.create_main_time_ordered_dataset('ns_on', ns_on1)
        rt['ns_on'].attrs['period'] = period
        rt['ns_on'].attrs['on_time'] = on_time
        rt['ns_on'].attrs['off_time'] = off_time

        # set vis_mask corresponding to ns_on
        on_inds = np.where(rt['ns_on'].local_data[:])[0]
        rt.local_vis_mask[on_inds] = True

        if mask_near > 0:
            on_inds = np.where(ns_on)[0]
            new_on_inds = on_inds.tolist()
            for i in xrange(1, mask_near + 1):
                new_on_inds = new_on_inds + (on_inds - i).tolist() + (
                    on_inds + i).tolist()
            new_on_inds = np.unique(new_on_inds)

            if rt['vis_mask'].distributed:
                start = rt.vis_mask.local_offset[0]
                end = start + rt.vis_mask.local_shape[0]
            else:
                start = 0
                end = rt.vis_mask.shape[0]
            global_inds = np.arange(start, end).tolist()
            new_on_inds = np.intersect1d(new_on_inds, global_inds)
            local_on_inds = [global_inds.index(i) for i in new_on_inds]
            rt.local_vis_mask[
                local_on_inds] = True  # set mask using global slicing

        return super(Detect, self).process(rt)
Пример #4
0
    def process(self, ts):

        excl_auto = self.params['excl_auto']
        plot_stats = self.params['plot_stats']
        fig_prefix = self.params['fig_name']
        rotate_xdate = self.params['rotate_xdate']
        tag_output_iter = self.params['tag_output_iter']

        ts.redistribute('baseline')

        if ts.local_vis_mask.ndim == 3:  # RawTimestream
            if excl_auto:
                bl = ts.local_bl
                vis_mask = ts.local_vis_mask[:, :, bl[:, 0] != bl[:, 1]].copy()
            else:
                vis_mask = ts.local_vis_mask.copy()
            nt, nf, lnb = vis_mask.shape
        elif ts.local_vis_mask.ndim == 4:  # Timestream
            # suppose masks are the same for all 4 pols
            if excl_auto:
                bl = ts.local_bl
                vis_mask = ts.local_vis_mask[:, :, 0,
                                             bl[:, 0] != bl[:, 1]].copy()
            else:
                vis_mask = ts.local_vis_mask[:, :, 0].copy()
            nt, nf, lnb = vis_mask.shape
        else:
            raise RuntimeError('Incorrect vis_mask shape %s' %
                               ts.local_vis_mask.shape)

        # total number of bl
        nb = mpiutil.allreduce(lnb, comm=ts.comm)

        # un-mask ns-on positions
        if 'ns_on' in ts.iterkeys():
            vis_mask[ts['ns_on'][:]] = False

        # statistics along time axis
        time_mask = np.sum(vis_mask, axis=(1, 2)).reshape(-1, 1)
        # gather local array to rank0
        time_mask = mpiutil.gather_array(time_mask,
                                         axis=1,
                                         root=0,
                                         comm=ts.comm)
        if mpiutil.rank0:
            time_mask = np.sum(time_mask, axis=1)

        # statistics along time axis
        freq_mask = np.sum(vis_mask, axis=(0, 2)).reshape(-1, 1)
        # gather local array to rank0
        freq_mask = mpiutil.gather_array(freq_mask,
                                         axis=1,
                                         root=0,
                                         comm=ts.comm)
        if mpiutil.rank0:
            freq_mask = np.sum(freq_mask, axis=1)

        if plot_stats and mpiutil.rank0:
            time_fig_name = '%s_%s.png' % (fig_prefix, 'time')
            if tag_output_iter:
                time_fig_name = output_path(time_fig_name,
                                            iteration=self.iteration)
            else:
                time_fig_name = output_path(time_fig_name)

            # plot time_mask
            plt.figure()
            fig, ax = plt.subplots()
            x_vals = np.array([
                (datetime.utcfromtimestamp(s) + timedelta(hours=8))
                for s in ts['sec1970'][:]
            ])
            xlabel = '%s' % x_vals[0].date()
            x_vals = mdates.date2num(x_vals)
            ax.plot(x_vals, 100 * time_mask / np.float(nf * nb))
            ax.xaxis_date()
            date_format = mdates.DateFormatter('%H:%M')
            ax.xaxis.set_major_formatter(date_format)
            if rotate_xdate:
                # set the x-axis tick labels to diagonal so it fits better
                fig.autofmt_xdate()
            else:
                # reduce the number of tick locators
                locator = MaxNLocator(nbins=6)
                ax.xaxis.set_major_locator(locator)
                ax.xaxis.set_minor_locator(AutoMinorLocator(2))

            ax.set_xlabel(xlabel)
            ax.set_ylabel(r'RFI (%)')
            plt.savefig(time_fig_name)
            plt.close()

            freq_fig_name = '%s_%s.png' % (fig_prefix, 'freq')
            if tag_output_iter:
                freq_fig_name = output_path(freq_fig_name,
                                            iteration=self.iteration)
            else:
                freq_fig_name = output_path(freq_fig_name)

            # plot freq_mask
            plt.figure()
            plt.plot(ts.freq[:], 100 * freq_mask / np.float(nt * nb))
            plt.grid(True)
            plt.xlabel(r'$\nu$ / MHz')
            plt.ylabel(r'RFI (%)')
            plt.savefig(freq_fig_name)
            plt.close()

        return super(Stats, self).process(ts)
Пример #5
0
    def process(self, data):
        """Regrid visibility data onto a regular grid in hour angle.

        Parameters
        ----------
        data : TimeStream
            Time-ordered data.

        Returns
        -------
        new_data : SiderealStream
            The regridded data centered on the source RA.
        """

        # Redistribute if needed
        data.redistribute("freq")

        # View of data
        weight = data.weight[:].view(np.ndarray)
        vis_data = data.vis[:].view(np.ndarray)

        # Get apparent source RA, including precession effects
        ra, _ = ephem.object_coords(self.src, data.time[0], deg=True, obs=self.observer)
        # Get catalogue RA for reference
        ra_icrs, _ = ephem.object_coords(self.src, deg=True, obs=self.observer)

        # Convert input times to hour angle
        lha = unwrap_lha(self.observer.unix_to_lsa(data.time), ra)

        # perform regridding
        success = 1
        try:
            new_grid, new_vis, ni = self._regrid(vis_data, weight, lha)
        except np.linalg.LinAlgError as e:
            self.log.error(str(e))
            success = 0
        except ValueError as e:
            self.log.error(str(e))
            success = 0
        # Check other ranks have completed
        success = mpiutil.allreduce(success)
        if success != mpiutil.size:
            self.log.warning("Regridding failed. Skipping transit.")
            return None

        # mask out regions beyond bounds of this transit
        grid_mask = np.ones_like(new_grid)
        grid_mask[new_grid < lha.min()] = 0.0
        grid_mask[new_grid > lha.max()] = 0.0
        new_vis *= grid_mask
        ni *= grid_mask

        # Wrap to produce MPIArray
        if data.distributed:
            new_vis = mpiarray.MPIArray.wrap(new_vis, axis=data.vis.distributed_axis)
            ni = mpiarray.MPIArray.wrap(ni, axis=data.vis.distributed_axis)

        # Create new container for output
        ra_grid = (new_grid + ra) % 360.0
        new_data = SiderealStream(
            axes_from=data, attrs_from=data, ra=ra_grid, comm=data.comm
        )
        new_data.redistribute("freq")
        new_data.vis[:] = new_vis
        new_data.weight[:] = ni
        new_data.attrs["cirs_ra"] = ra
        new_data.attrs["icrs_ra"] = ra_icrs

        return new_data
Пример #6
0
    def process(self, rt):

        assert isinstance(
            rt, RawTimestream
        ), '%s only works for RawTimestream object currently' % self.__class__.__name__

        channel = self.params['channel']
        sigma = self.params['sigma']
        mask_near = max(0, int(self.params['mask_near']))
        #================================================
        rt.FRB_cal = self.params['FRB_cal']
        ns_arr_file = self.params['ns_arr_file']
        ns_prop_file = self.params['ns_prop_file']
        num_noise = self.params['num_noise']
        change_reference_tol = self.params['change_reference_tol']
        abnormal_ns_save_file = 'ns_cal/abnormal_ns.npz'
        #================================================

        rt.redistribute(0)  # make time the dist axis

        #        time_span = rt.local_vis.shape[0]
        #        total_time_span = mpiutil.allreduce(time_span)
        total_time_span = rt['sec1970'].shape[0]

        auto_inds = np.where(
            rt.bl[:, 0] == rt.bl[:,
                                 1])[0].tolist()  # inds for auto-correlations
        channels = [rt.bl[ai, 0] for ai in auto_inds]  # all chosen channels
        if channel is not None:
            if channel in channels:
                bl_ind = auto_inds[channels.index(channel)]
            else:
                bl_ind = auto_inds[0]
                if mpiutil.rank0:
                    print 'Warning: Required channel %d doen not in the data, use channel %d instead' % (
                        channel, rt.bl[bl_ind, 0])
        else:
            bl_ind = auto_inds[0]
        # move the chosen channel to the first
        auto_inds.remove(bl_ind)
        auto_inds = [bl_ind] + auto_inds
        if rt.FRB_cal:
            #            ns_arr_end_time = mpiutil.bcast(rt['jul_date'][total_time_span - 1], root = mpiutil.size - 1)
            ns_arr_end_time = rt.attrs['sec1970'][0] + rt.attrs['inttime'] * (
                total_time_span - 1)
            if os.path.exists(output_path(ns_arr_file)) and not os.path.exists(
                    output_path(ns_prop_file)):
                filein = np.load(output_path(ns_arr_file))
                # add 0 to avoid a single float or int become an array
                ns_arr_pinds = filein['on_inds'] + 0
                ns_arr_ninds = filein['off_inds'] + 0
                ns_arr_pinds = list(ns_arr_pinds)
                ns_arr_ninds = list(ns_arr_ninds)
                ns_arr_len = filein['time_len'] + 0
                ns_arr_num = filein['ns_num'] + 0
                ns_arr_bl = filein['auto_chn'] + 0
                ns_arr_start_time = filein['start_time']
                #                overlap_index_span = np.around(np.float128(filein['end_time'] - mpiutil.bcast(rt['jul_date'][0], root = 0))*24.*3600./rt.attrs['inttime']) + 1
                overlap_index_span = np.around(
                    np.float128(filein['end_time'] - rt.attrs['sec1970'][0]) /
                    rt.attrs['inttime']) + 1
                if overlap_index_span > 0:
                    raise OverlapData(
                        'Overlap of data occured when trying to build up noise property file! In julian date, the end time of previous data is %.10f, while the start time of this data is %.10f. The overlap span in index is %d.'
                        % (filein['end_time'], rt.attrs['sec1970'][0],
                           overlap_index_span))
            elif not os.path.exists(output_path(ns_prop_file)):
                ns_arr_pinds = []
                ns_arr_ninds = []
                ns_arr_len = 0
                ns_arr_num = -1  # to distinguish the first process

                #                ns_arr_start_time = mpiutil.bcast(rt['jul_date'][0], root = 0)
                ns_arr_start_time = rt.attrs['sec1970'][0]

        if rt.FRB_cal and os.path.exists(output_path(ns_prop_file)):
            if mpiutil.rank0:
                print('Use existing ns property file %s to do calibration.' %
                      output_path(ns_prop_file))
            ns_prop_data = np.load(output_path(ns_prop_file))
            period = ns_prop_data['period'] + 0
            on_time = ns_prop_data['on_time'] + 0
            off_time = ns_prop_data['off_time'] + 0
            reference_time = ns_prop_data[
                'reference_time'] + 0  # in sec1970, is a start point of noise
            if 'lost_count' in ns_prop_data.files:
                lost_count_before = ns_prop_data['lost_count']
            else:
                lost_count_before = 0
            if 'added_count' in ns_prop_data.files:
                added_count_before = ns_prop_data['added_count']
            else:
                added_count_before = 0

#            this_time_start = mpiutil.bcast(rt['jul_date'][0], root=0)
#            skip_inds = int(np.around((this_time_start - reference_time)*86400.0/rt.attrs['inttime'])) # the number of index between the reference time and the start point
            this_time_start = rt.attrs['sec1970'][0]
            skip_inds = int(
                np.around(
                    (this_time_start - reference_time) / rt.attrs['inttime'])
            )  # the number of index between the reference time and the start point
            on_start = period - skip_inds % period
            #            if total_time_span < period:
            #                raise Exception('Time span of data %d is shorter than period %d!'%(total_time_span, period))
            if total_time_span <= on_start + on_time:
                raise NoNoisePoint(
                    'Calculated from previous data, this data contains no noise point or does not contain complete noise signal!'
                )
            # check whether there are lost points
            # only consider that the case that there are one lost point and only consider the on points
#================================================
            abnormal_count = -1
            lost_one_point = 0
            added_one_point = 0
            abnormal_list = []
            lost_one_list = []
            added_one_list = []
            lost_one_pos = []
            added_one_pos = []
            complete_period_num = (
                total_time_span - on_start - on_time -
                1) // period  # the number of periods in the data
            on_points = [
                on_start + i * period for i in range(complete_period_num + 1)
            ]
            off_points = [
                on_start + on_time + i * period
                for i in range(complete_period_num + 1)
            ]
            for bl_ind in auto_inds:
                this_chan = rt.bl[bl_ind, 0]  # channel of this bl_ind
                vis = np.ma.array(rt.local_vis[:, :, bl_ind].real,
                                  mask=rt.local_vis_mask[:, :, bl_ind])
                cnt = vis.count()  # number of not masked vals
                total_cnt = mpiutil.allreduce(cnt)
                vis_shp = rt.vis.shape
                ratio = float(total_cnt) / np.prod(
                    (vis_shp[0], vis_shp[1]))  # ratio of un-masked vals
                if ratio < 0.5:  # too many masked vals
                    continue

                if abnormal_count < 0:
                    abnormal_count = 0
                tt_mean = mpiutil.gather_array(
                    np.ma.mean(vis, axis=-1).filled(0),
                    root=None)  # mean for all freq, for a specific auto bl
                df = np.diff(tt_mean, axis=-1)
                pdf = np.where(df > 0, df, 0)
                pinds = np.where(pdf > pdf.mean() + sigma * pdf.std())[0]
                pinds = pinds + 1
                #====================================
                if len(pinds) == 0:  # no raise, might be badchn, continue
                    continue
#====================================
                pinds1 = [pinds[0]]
                for pi in pinds[1:]:
                    if pi - pinds1[-1] > 1:
                        pinds1.append(pi)
                pinds = np.array(pinds1)

                ndf = np.where(df < 0, df, 0)
                ninds = np.where(ndf < ndf.mean() - sigma * ndf.std())[0]
                ninds = ninds + 1
                ninds = ninds[::-1]
                #====================================
                if len(ninds) == 0:  # no raise, might be badchn, continue
                    continue
#====================================
                ninds1 = [ninds[0]]
                for ni in ninds[1:]:
                    if ni - ninds1[-1] < -1:
                        ninds1.append(ni)
                ninds = np.array(ninds1[::-1])
                cmp_signal, cmp_res = cmp_cd(on_points, off_points, pinds,
                                             ninds, period, 2. / 3)
                if cmp_signal == 'ok':  # normal
                    continue
                elif cmp_signal == 'abnormal':  # abnormal
                    abnormal_count += 1
                    abnormal_list += [this_chan]
                elif cmp_signal == 'lost':  # lost point
                    lost_one_point += 1
                    lost_one_list += [this_chan]
                    lost_one_pos += [cmp_res]
                    continue
                elif cmp_signal == 'add':  # added point
                    added_one_point += 1
                    added_one_list += [this_chan]
                    added_one_pos += [cmp_res]
                    continue
                else:
                    raise Exception('Unknown comparison signal!')

#                if on_start in pinds or on_start + on_time in ninds:
#                    # to avoid the effect of interference
#                    continue
#                elif on_start - 1 in pinds or on_start + on_time - 1 in ninds:
#                    lost_one_point += 1
#                    lost_one_list += [this_chan]
#                    continue
#                elif on_start - 1 < 0 and on_start + period - 1 in pinds:
#                    lost_one_point += 1
#                    lost_one_list += [this_chan]
#                    continue
#                else:
#                    abnormal_count += 1
#                    abnormal_list += [this_chan]

            if abnormal_count < 0:
                raise NoNoisePoint(
                    'No noise points are detected from this data or the data contains too many masked points!'
                )
            elif abnormal_count > len(auto_inds) / 3:
                if mpiutil.rank0:
                    np.savez(output_path(abnormal_ns_save_file),
                             on_inds=pinds,
                             off_inds=ninds,
                             on_start=on_start,
                             period=period,
                             on_time=on_time)
                mpiutil.barrier()
                raise AbnormalPoints(
                    'Something rather than one lost point happened. The expected start point is %d, period %d, on_time %d, but the pinds and ninds are: '
                    % (on_start, period, on_time), pinds, ninds)
            elif lost_one_point > 2 * len(auto_inds) / 3:
                uniques, counts = np.unique(lost_one_pos, return_counts=True)
                maxcount = np.argmax(counts)
                if counts[maxcount] > 2 * len(auto_inds) / 3:
                    lost_position = uniques[maxcount]
                else:
                    raise AbnormalPoints(
                        'More than 2/3 baselines have detected lost points but do not have a universal lost position, some unexpected error happened!\nChannels that probably lost one point and the position: %s'
                        % str(zip(lost_one_list, lost_one_pos)))
                if mpiutil.rank0:
                    if lost_position == 0:
                        warnings.warn(
                            'One lost point before the data is detected!',
                            LostPointBegin)
                    else:
                        warnings.warn(
                            'One lost point before index %d is detected!' %
                            on_points[lost_position], LostPointMiddle)
#                on_start = on_start - 1
                if lost_position == 0:
                    lost_count_before += 1
                elif lost_count_before == 0:
                    lost_count_before += 1
                on_points = np.array(on_points)
                on_points[lost_position:] = on_points[lost_position:] - 1
                off_points = np.array(off_points)
                off_points[lost_position:] = off_points[lost_position:] - 1
                if on_points[0] < 0:
                    on_points = on_points[1:]
                    off_points = off_points[1:]
                if mpiutil.rank0:
                    if lost_count_before >= change_reference_tol:
                        #                        reference_time -= 1/86400.0*rt.attrs['inttime']
                        reference_time -= rt.attrs['inttime']
                        warnings.warn(
                            'Move the reference time one index earlier to compensate!',
                            ChangeReferenceTime)
                        np.savez(output_path(ns_prop_file),
                                 period=period,
                                 on_time=on_time,
                                 off_time=off_time,
                                 reference_time=reference_time,
                                 lost_count=0,
                                 added_count=0)
                    else:
                        warnings.warn(
                            'The number of recorded lost points was %d while tolerance is %d. Do not change the reference time.'
                            % (lost_count_before, change_reference_tol))
                        np.savez(output_path(ns_prop_file),
                                 period=period,
                                 on_time=on_time,
                                 off_time=off_time,
                                 reference_time=reference_time,
                                 lost_count=lost_count_before,
                                 added_count=0)
#                mpiutil.barrier()
            elif added_one_point > 2 * len(auto_inds) / 3:
                uniques, counts = np.unique(added_one_pos, return_counts=True)
                maxcount = np.argmax(counts)
                if counts[maxcount] > 2 * len(auto_inds) / 3:
                    added_position = uniques[maxcount]
                else:
                    raise AbnormalPoints(
                        'More than 2/3 baselines have detected additional points but do not have a universal adding position, some unexpected error happened!\nChannels that probably added one point and the position: %s'
                        % str(zip(added_one_list, added_one_pos)))
                if mpiutil.rank0:
                    if added_position == 0:
                        warnings.warn(
                            'One additional point before the data is detected!',
                            AdditionalPointBegin)
                    else:
                        warnings.warn(
                            'One additional point before index %d is detected!'
                            % on_points[added_position], AdditionalPointMiddle)
                if added_position == 0:
                    added_count_before += 1
                elif added_count_before == 0:
                    added_count_before += 1
#                on_start = on_start - 1
                on_points = np.array(on_points)
                on_points[added_position:] = on_points[added_position:] + 1
                off_points = np.array(off_points)
                off_points[added_position:] = off_points[added_position:] + 1
                if off_points[-1] >= total_time_span:
                    on_points = on_points[:-1]
                    off_points = off_points[:-1]
                if mpiutil.rank0:
                    if added_count_before >= change_reference_tol:
                        warnings.warn(
                            'Move the reference time one index later to compensate!',
                            ChangeReferenceTime)
                        #                        reference_time += 1/86400.0*rt.attrs['inttime']
                        reference_time += rt.attrs['inttime']
                        np.savez(output_path(ns_prop_file),
                                 period=period,
                                 on_time=on_time,
                                 off_time=off_time,
                                 reference_time=reference_time,
                                 lost_count=0,
                                 added_count=0)
                    else:
                        warnings.warn(
                            'The number of recorded added points was %d while tolerance is %d. Do not change the reference time.'
                            % (added_count_before, change_reference_tol))
                        np.savez(output_path(ns_prop_file),
                                 period=period,
                                 on_time=on_time,
                                 off_time=off_time,
                                 reference_time=reference_time,
                                 lost_count=0,
                                 added_count=added_count_before)
#                mpiutil.barrier()
            elif lost_one_point > 0 or abnormal_count > 0 or added_one_point > 0:
                if mpiutil.rank0:
                    np.savez(output_path(ns_prop_file),
                             period=period,
                             on_time=on_time,
                             off_time=off_time,
                             reference_time=reference_time,
                             lost_count=0,
                             added_count=0)
                    warnings.warn(
                        'Abnormal points are detected for some channel, number of abnormal bl is %d, number of channel that probably lost one point is %d, number of channel that probably added one point is %d'
                        % (abnormal_count, lost_one_point, added_one_point),
                        DetectedLostAddedAbnormal)
                    warnings.warn(
                        'Abnomal channels: %s\nChannels that probably lost one point: %s\nChannels that probably added one point: %s'
                        % (str(abnormal_list), str(lost_one_list),
                           str(added_one_list)), DetectedLostAddedAbnormal)
            else:
                if mpiutil.rank0:
                    np.savez(output_path(ns_prop_file),
                             period=period,
                             on_time=on_time,
                             off_time=off_time,
                             reference_time=reference_time,
                             lost_count=0,
                             added_count=0)

#================================================
            if mpiutil.rank0:
                print 'Noise source: period = %d, on_time = %d, off_time = %d' % (
                    period, on_time, off_time)
#            num_period = np.int(np.ceil(total_time_span / np.float(period)))
#            ns_on = np.array([False] * on_start + ([True] * on_time + [False] * off_time) * num_period)[:total_time_span]
            ns_on = np.array([False] * total_time_span)
            for i, j in zip(on_points, off_points):
                ns_on[i:j] = True

#            if mpiutil.rank0:
#                np.save('ns_on', ns_on)
        elif not rt.FRB_cal or ns_arr_num < num_noise:

            #            min_inds = 0 # min_inds = min(len(pinds), len(ninds), min_inds) if min_inds != 0 else min(len(pinds), len(ninds))
            ns_num_add = []
            for ns_arr_index, bl_ind in enumerate(auto_inds):
                this_chan = rt.bl[bl_ind, 0]  # channel of this bl_ind
                vis = np.ma.array(rt.local_vis[:, :, bl_ind].real,
                                  mask=rt.local_vis_mask[:, :, bl_ind])
                cnt = vis.count()  # number of not masked vals
                total_cnt = mpiutil.allreduce(cnt)
                vis_shp = rt.vis.shape
                ratio = float(total_cnt) / np.prod(
                    (vis_shp[0], vis_shp[1]))  # ratio of un-masked vals
                if ratio < 0.5:  # too many masked vals
                    if rt.FRB_cal and ns_arr_num == -1:
                        ns_arr_pinds += [np.array([])]
                        ns_arr_ninds += [np.array([])]
                    if mpiutil.rank0:
                        warnings.warn(
                            'Too many masked values for auto-correlation of Channel: %d, does not use it'
                            % this_chan)
                    continue
                tt_mean = mpiutil.gather_array(
                    np.ma.mean(vis, axis=-1).filled(0),
                    root=None)  # mean for all freq, for a specific auto bl
                df = np.diff(tt_mean, axis=-1)
                pdf = np.where(df > 0, df, 0)
                pinds = np.where(pdf > pdf.mean() + sigma * pdf.std())[0]
                #====================================
                if len(pinds) == 0:  # no raise, might be badchn, continue
                    if rt.FRB_cal and ns_arr_num == -1:
                        ns_arr_pinds += [np.array([])]
                        ns_arr_ninds += [np.array([])]
                    if mpiutil.rank0:
                        warnings.warn(
                            'No noise on signal is detected for Channel %d, it may be bad channel.'
                            % this_chan)
                    continue
#====================================
                pinds = pinds + 1
                pinds1 = [pinds[0]]
                for pi in pinds[1:]:
                    if pi - pinds1[-1] > 1:
                        pinds1.append(pi)
                pinds = np.array(pinds1)
                pT = Counter(
                    np.diff(pinds)).most_common(1)[0][0]  # period of pinds

                ndf = np.where(df < 0, df, 0)
                ninds = np.where(ndf < ndf.mean() - sigma * ndf.std())[0]
                ninds = ninds + 1
                ninds = ninds[::-1]
                #====================================
                if len(ninds) == 0:  # no raise, might be badchn, continue
                    if rt.FRB_cal and ns_arr_num == -1:
                        ns_arr_pinds += [np.array([])]
                        ns_arr_ninds += [np.array([])]
                    if mpiutil.rank0:
                        warnings.warn(
                            'No noise off signal is detected for Channel %d, it may be bad channel.'
                            % this_chan)
                    continue
#====================================
                ninds1 = [ninds[0]]
                for ni in ninds[1:]:
                    if ni - ninds1[-1] < -1:
                        ninds1.append(ni)
                ninds = np.array(ninds1[::-1])
                nT = Counter(
                    np.diff(ninds)).most_common(1)[0][0]  # period of ninds

                ns_num_add += [min(len(pinds), len(ninds))]
                if rt.FRB_cal:
                    if ns_arr_num == -1:
                        ns_arr_pinds += [pinds]
                        ns_arr_ninds += [ninds]
                    else:
                        ns_arr_pinds[ns_arr_index] = np.concatenate(
                            [ns_arr_pinds[ns_arr_index], pinds + ns_arr_len])
                        ns_arr_ninds[ns_arr_index] = np.concatenate(
                            [ns_arr_ninds[ns_arr_index], ninds + ns_arr_len])
#==============================================
# continue for non-FRB case
                if pT != nT:  # failed to detect correct period
                    if mpiutil.rank0:
                        warnings.warn(
                            'Failed to detect correct period for auto-correlation of Channel: %d, positive T %d != negative T %d, does not use it'
                            % (this_chan, pT, nT))
                    continue
                else:
                    period = pT

                ninds = ninds.reshape(-1, 1)
                dinds = (ninds - pinds).flatten()
                on_time = Counter(dinds[dinds > 0] %
                                  period).most_common(1)[0][0]
                off_time = Counter(-dinds[dinds < 0] %
                                   period).most_common(1)[0][0]

                if period != on_time + off_time:  # incorrect detect
                    if mpiutil.rank0:
                        warnings.warn(
                            'Incorrect detect for auto-correlation of Channel: %d, period %d != on_time %d + off_time %d, does not use it'
                            % (this_chan, period, on_time, off_time))
                    continue
                else:
                    if 'noisesource' in rt.iterkeys():
                        if rt['noisesource'].shape[
                                0] == 1:  # only 1 noise source
                            start, stop, cycle = rt['noisesource'][0, :]
                            int_time = rt.attrs['inttime']
                            true_on_time = np.round((stop - start) / int_time)
                            true_period = np.round(cycle / int_time)
                            if on_time != true_on_time and period != true_period:  # inconsistant with the record in the data
                                if mpiutil.rank0:
                                    warnings.warn(
                                        'Detected noise source info is inconsistant with the record in the data for auto-correlation of Channel: %d: on_time %d != record_on_time %d, period != record_period %d, does not use it'
                                        % (this_chan, on_time, true_on_time,
                                           period, true_period))
                                continue
                        elif rt['noisesource'].shape[
                                0] >= 2:  # more than 1 noise source
                            if mpiutil.rank0:
                                warnings.warn(
                                    'More than 1 noise source, do not know how to deal with this currently'
                                )

                    # break if succeed
                    if not rt.FRB_cal:  # for FRB case, record all baseline
                        break

            else:
                if not rt.FRB_cal:
                    raise DetectNoiseFailure(
                        'Failed to detect noise source signal')

            if mpiutil.rank0:
                print 'Detected noise source: period = %d, on_time = %d, off_time = %d' % (
                    period, on_time, off_time)
            on_start = Counter(pinds % period).most_common(1)[0][0]
            num_period = np.int(np.ceil(len(tt_mean) / np.float(period)))
            ns_on = np.array([False] * on_start +
                             ([True] * on_time + [False] * off_time) *
                             num_period)[:len(tt_mean)]
            #==============================================
            if rt.FRB_cal:
                ns_arr_len += total_time_span
                #                ns_arr_from_time = np.float128(ns_arr_end_time - ns_arr_start_time)*24.*3600./rt.attrs['inttime'] + 1
                ns_arr_from_time = np.around(
                    np.float128(ns_arr_end_time - ns_arr_start_time) /
                    rt.attrs['inttime']) + 1
                if ns_arr_len == ns_arr_from_time:
                    pass
                else:
                    raise IncontinuousData(
                        'Incontinuous data. Index span calculated from time is %.2f, while the sum of array length is %d! Can not deal with incontinuous data at present!'
                        % (ns_arr_from_time, ns_arr_len))
#                    print('Detected incontinuous data, use index span %d calculated from time instead of the sum of array length %d!'%(ns_arr_from_time, ns_arr_len))
#                    ns_arr_len = ns_arr_from_time
                if ns_arr_num < 0:
                    ns_arr_num = 0
#                ns_arr_num += min_inds
                ns_arr_num += np.around(np.average(ns_num_add))
                ns_arr_bl = rt.bl[auto_inds, 0]
            if mpiutil.rank0 and rt.FRB_cal:
                np.savez(output_path(ns_arr_file),
                         on_inds=ns_arr_pinds,
                         off_inds=ns_arr_ninds,
                         time_len=ns_arr_len,
                         ns_num=ns_arr_num,
                         auto_chn=ns_arr_bl,
                         start_time=ns_arr_start_time,
                         end_time=ns_arr_end_time)
                if ns_arr_num < num_noise:
                    raise NoiseNotEnough(
                        'Number of noise points %d is not enough for calibration(need %d), wait for next file!'
                        % (ns_arr_num, num_noise))
#            mpiutil.barrier()
#=================================================================
        if rt.FRB_cal and (not os.path.exists(
                output_path(ns_prop_file))) and ns_arr_num >= num_noise:
            if mpiutil.rank0:
                print(
                    'Got %d noise points (need %d) to build up noise property file!'
                    % (ns_arr_num, num_noise))
            for ns_arr_index, (pinds, ninds) in enumerate(
                    zip(ns_arr_pinds, ns_arr_ninds)):
                if len(pinds) < num_noise * 2. / 3. or len(
                        ninds) < num_noise * 2. / 3.:
                    print(
                        'Channel %d does not have enough noise points(%d, need %d) for calibration. Do not use it.'
                        % (ns_arr_bl[ns_arr_index], len(pinds),
                           int(2 * num_noise / 3.)))
                    continue
                pT = Counter(
                    np.diff(pinds)).most_common(1)[0][0]  # period of pinds
                nT = Counter(
                    np.diff(ninds)).most_common(1)[0][0]  # period of ninds

                #=================================================================
                if pT != nT:  # failed to detect correct period
                    if mpiutil.rank0:
                        warnings.warn(
                            'Failed to detect correct period for auto-correlation of Channel: %d, positive T %d != negative T %d, does not use it'
                            % (ns_arr_bl[ns_arr_index], pT, nT))
                    continue
                else:
                    period = pT

                ninds = ninds.reshape(-1, 1)
                dinds = (ninds - pinds).flatten()
                on_time = Counter(dinds[dinds > 0] %
                                  period).most_common(1)[0][0]
                off_time = Counter(-dinds[dinds < 0] %
                                   period).most_common(1)[0][0]

                if period != on_time + off_time:  # incorrect detect
                    if mpiutil.rank0:
                        warnings.warn(
                            'Incorrect detect for auto-correlation of Channel: %d, period %d != on_time %d + off_time %d, does not use it'
                            % (ns_arr_bl[ns_arr_index], period, on_time,
                               off_time))
                    continue
                else:
                    if 'noisesource' in rt.iterkeys():
                        if rt['noisesource'].shape[
                                0] == 1:  # only 1 noise source
                            start, stop, cycle = rt['noisesource'][0, :]
                            int_time = rt.attrs['inttime']
                            true_on_time = np.round((stop - start) / int_time)
                            true_period = np.round(cycle / int_time)
                            if on_time != true_on_time and period != true_period:  # inconsistant with the record in the data
                                if mpiutil.rank0:
                                    warnings.warn(
                                        'Detected noise source info is inconsistant with the record in the data for auto-correlation of Channel: %d: on_time %d != record_on_time %d, period != record_period %d, does not use it'
                                        % (this_chan, on_time, true_on_time,
                                           period, true_period))
                                continue
                        elif rt['noisesource'].shape[
                                0] >= 2:  # more than 1 noise source
                            if mpiutil.rank0:
                                warnings.warn(
                                    'More than 1 noise source, do not know how to deal with this currently'
                                )

                    # break if succeed

                    break

            else:
                raise DetectNoiseFailure(
                    'Failed to detect noise source signal')

            if mpiutil.rank0:
                print 'Detected noise source: period = %d, on_time = %d, off_time = %d' % (
                    period, on_time, off_time)
            on_start = Counter(pinds % period).most_common(1)[0][0]
            this_time_len = total_time_span
            first_ind = ns_arr_len - this_time_len
            skip_inds = first_ind - on_start
            on_start = period - skip_inds % period
            num_period = np.int(np.ceil(total_time_span / np.float(period)))
            ns_on = np.array([False] * on_start +
                             ([True] * on_time + [False] * off_time) *
                             num_period)[:total_time_span]

        # import matplotlib
        # matplotlib.use('Agg')
        # import matplotlib.pyplot as plt
        # plt.figure()
        # plt.plot(np.where(ns_on, np.nan, tt_mean))
        # # plt.plot(pinds, tt_mean[pinds], 'RI')
        # # plt.plot(ninds, tt_mean[ninds], 'go')
        # plt.savefig('df.png')
        # err

        ns_on1 = mpiarray.MPIArray.from_numpy_array(ns_on)

        rt.create_main_time_ordered_dataset('ns_on', ns_on1)
        rt['ns_on'].attrs['period'] = period
        rt['ns_on'].attrs['on_time'] = on_time
        rt['ns_on'].attrs['off_time'] = off_time

        if (not rt['jul_date'][on_start] is None) and not os.path.exists(
                output_path(ns_prop_file)):
            #            np.savez(output_path(ns_prop_file), period = period, on_time = on_time, off_time = off_time, reference_time = np.float128(rt['jul_date'][on_start]))
            np.savez(
                output_path(ns_prop_file),
                period=period,
                on_time=on_time,
                off_time=off_time,
                reference_time=np.float128(rt.attrs['sec1970'][0] +
                                           on_start * rt.attrs['inttime']))
            #            mpiutil.barrier()
            if mpiutil.rank0:
                print('Save noise property file to %s' %
                      output_path(ns_prop_file))

        # set vis_mask corresponding to ns_on
        on_inds = np.where(rt['ns_on'].local_data[:])[0]
        rt.local_vis_mask[on_inds] = True

        if mask_near > 0:
            on_inds = np.where(ns_on)[0]
            new_on_inds = on_inds.tolist()
            for i in xrange(1, mask_near + 1):
                new_on_inds = new_on_inds + (on_inds - i).tolist() + (
                    on_inds + i).tolist()
            new_on_inds = np.unique(new_on_inds)

            if rt['vis_mask'].distributed:
                start = rt.vis_mask.local_offset[0]
                end = start + rt.vis_mask.local_shape[0]
            else:
                start = 0
                end = rt.vis_mask.shape[0]
            global_inds = np.arange(start, end).tolist()
            new_on_inds = np.intersect1d(new_on_inds, global_inds)
            local_on_inds = [global_inds.index(i) for i in new_on_inds]
            rt.local_vis_mask[
                local_on_inds] = True  # set mask using global slicing

        return super(Detect, self).process(rt)
Пример #7
0
    def wrap(cls, array, axis, comm=None):
        """Turn a set of numpy arrays into a distributed MPIArray object.

        This is needed for functions such as `np.fft.fft` which always return
        an `np.ndarray`.

        Parameters
        ----------
        array : np.ndarray
            Array to wrap.
        axis : integer
            Axis over which the array is distributed. The lengths are checked
            to try and ensure this is correct.
        comm : MPI.Comm, optional
            The communicator over which the array is distributed. If `None`
            (default), use `MPI.COMM_WORLD`.

        Returns
        -------
        dist_array : MPIArray
            An MPIArray view of the input.
        """

        # from mpi4py import MPI

        if comm is None:
            comm = mpiutil.world

        # Get axis length, both locally, and globally
        axlen = array.shape[axis]
        totallen = mpiutil.allreduce(axlen, comm=comm)

        # Figure out what the distributed layout should be
        local_num, local_start, local_end = mpiutil.split_local(totallen,
                                                                comm=comm)

        # Check the local layout is consistent with what we expect, and send
        # result to all ranks
        layout_issue = mpiutil.allreduce(axlen != local_num,
                                         op=mpiutil.MAX,
                                         comm=comm)

        if layout_issue:
            raise Exception(
                "Cannot wrap, distributed axis local length is incorrect.")

        # Set shape and offset
        lshape = array.shape
        global_shape = list(lshape)
        global_shape[axis] = totallen

        loffset = [0] * len(lshape)
        loffset[axis] = local_start

        # Setup attributes of class
        dist_arr = array.view(cls)
        dist_arr._global_shape = tuple(global_shape)
        dist_arr._axis = axis
        dist_arr._local_shape = tuple(lshape)
        dist_arr._local_offset = tuple(loffset)
        dist_arr._comm = comm

        return dist_arr
Пример #8
0
    def process(self, rt):

        assert isinstance(rt, RawTimestream), '%s only works for RawTimestream object currently' % self.__class__.__name__

        if not 'ns_on' in rt.iterkeys():
            raise RuntimeError('No noise source info, can not do noise source calibration')

        local_bl_size, local_bl_start, local_bl_end = mpiutil.split_local(len(rt['blorder']))

        rt.redistribute('baseline')

        num_mean = self.params['num_mean']
        phs_only = self.params['phs_only']
        save_gain = self.params['save_gain']
        tag_output_iter = self.params['tag_output_iter']
        gain_file = self.params['gain_file']
        bl_incl = self.params['bl_incl']
        bl_excl = self.params['bl_excl']
        freq_incl = self.params['freq_incl']
        freq_excl = self.params['freq_excl']
        ns_stable = self.params['ns_stable']
        last_transit2stable = self.params['last_transit2stable']

#===================================
        normal_gain_file = self.params['normal_gain_file']
        rt.gain_file = self.params['gain_file']
        absolute_gain_filename = self.params['absolute_gain_filename']
        use_center_data = self.params['use_center_data']
        if use_center_data and rt['ns_on'].attrs['on_time'] <= 2:
            warnings.warn('The period %d <= 2, cannot get rid of the beginning and ending points. Use the whole average automatically!')
            use_center_data = False
#===================================
        # apply abs gain to data if ps_first
        if rt.ps_first:
            if absolute_gain_filename is None:
                raise NoPsGainFile('Absent parameter absolute_gain_file. In ps_first mode, absolute_gain_filename is required to process!')
            if not os.path.exists(output_path(absolute_gain_filename)):
                raise NoPsGainFile('No absolute gain file %s, do the ps calibration first!'%output_path(absolute_gain_filename))
            with h5py.File(output_path(absolute_gain_filename, 'r')) as agfilein:
                if not ns_stable:
#                    try: # if there is transit in this batch of data, build up transit amp and phs
#                        rt.normal_index = np.where(np.abs(rt['jul_date'] - agfilein.attrs['transit_time']) < 2.e-6)[0][0] # to avoid numeric error. 1.e-6 julian date is about 0.1s
                    if rt['sec1970'][0] <= agfilein.attrs['transit_time'] and rt['sec1970'][-1] >= agfilein.attrs['transit_time']:
                        rt.normal_index = np.int64(np.around((agfilein.attrs['transit_time'] - rt['sec1970'][0])/rt.attrs['inttime']))
                        if mpiutil.rank0:
                            print('Detected transit, time index %d, build up transit normalization gain file!'%rt.normal_index)
                        build_normal_file = True
                        rt.normal_phs = np.zeros((rt['vis'].shape[1], local_bl_size))
                        rt.normal_phs.fill(np.nan)
                        if not phs_only:
                            rt.normal_amp = np.zeros((rt['vis'].shape[1], local_bl_size))
                            rt.normal_amp.fill(np.nan)
#                    except IndexError: # no transit in this batch of data, load transit amp and phs from file
                    else: 
                        rt.normal_index = None # no transit flag
                        build_normal_file = False
                        if not os.path.exists(output_path(normal_gain_file)):
                            time_info = (aipy.phs.juldate2ephem(agfilein.attrs['transit_jul'] + 8./24.), aipy.phs.juldate2ephem(rt['jul_date'][0] + 8./24.), aipy.phs.juldate2ephem(rt['jul_date'][-1] + 8./24.))
                            raise TransitGainNotRecorded('The transit %s is not in time range %s to %s and no transit normalization gain was recorded!'%time_info)
                        else:
                            if mpiutil.rank0:
                                print('No transit, use existing transit normalization gain file!')
                            with h5py.File(output_path(normal_gain_file), 'r') as filein:
                                rt.normal_phs = filein['phs'][:, local_bl_start:local_bl_end]
                                if not phs_only:
                                    rt.normal_amp = filein['amp'][:, local_bl_start:local_bl_end]
                else:
                    rt.normal_index = None # no transit flag
                    rt.normal_phs = np.zeros((rt['vis'].shape[1], local_bl_size))
                    rt.normal_phs.fill(np.nan)
                    if not phs_only:
                        rt.normal_amp = np.zeros((rt['vis'].shape[1], local_bl_size))
                        rt.normal_amp.fill(np.nan)
                    try:
                        stable_time = rt['pointingtime'][-1,-1]
                    except KeyError:
                        stable_time = rt['transitsource'][-1, 0]
                    if stable_time > 0: # the time when the antenna will be pointing at the target region was recorded, use it
                        stable_time += 300 # plus 5 min to ensure stable
                    else:
                        stable_time = rt['transitsource'][-2, 0] + last_transit2stable
                    stable_time_jul = datetime.utcfromtimestamp(stable_time)
                    stable_time_jul = ephem.julian_date(stable_time_jul) # convert to julian date, convenient to display
                    if not os.path.exists(output_path(normal_gain_file)):
                        if mpiutil.rank0:
                            print('Normalization gain file has not been built yet. Try to build it!')
                            print('Recorded transit Time: %s'%aipy.phs.juldate2ephem(agfilein.attrs['transit_jul'] + 8./24.))
                            print('Last transit Time: %s'%aipy.phs.juldate2ephem(ephem.julian_date(datetime.utcfromtimestamp(rt['transitsource'][-2,0])) + 8./24.))
                            print('First time point of this data: %s'%aipy.phs.juldate2ephem(rt['jul_date'][0] + 8./24.))
                        build_normal_file = True
#                        if rt['jul_date'][0] < stable_time:
                        if rt.attrs['sec1970'][0] < stable_time:
                            raise BeforeStableTime('The beginning time point is %s, but only after %s, will the noise source be stable. Abort the noise calibration!'%(aipy.phs.juldate2ephem(rt['jul_date'][0] + 8./24.), aipy.phs.juldate2ephem(stable_time_jul + 8./24.)))
                    else:
                        if mpiutil.rank0:
                            print('Use existing normalization gain file!')
                        build_normal_file = False
                        with h5py.File(output_path(normal_gain_file), 'r') as filein:
                            rt.normal_phs = filein['phs'][:, local_bl_start:local_bl_end]
                            if not phs_only:
                                rt.normal_amp = filein['amp'][:, local_bl_start:local_bl_end]
                # apply absolute gain
                absgain = agfilein['gain'][:]
                polfeed, _ = bl2pol_feed_inds(rt.local_bl, agfilein['gain'].attrs['feed'][:], agfilein['gain'].attrs['pol'][:])
                for ii, (ipol, ifeed, jpol, jfeed) in enumerate(polfeed):
#                    rt.local_vis[:,:,ii] /= (absgain[None,:,ipol,ifeed-1] * absgain[None,:,jpol,jfeed-1].conj())
                    rt.local_vis[:,:,ii] /= (absgain[None,:,ipol,ifeed] * absgain[None,:,jpol,jfeed].conj())
#===================================
        nt = rt.local_vis.shape[0]
        if num_mean <= 0:
            raise RuntimeError('Invalid num_mean = %s' % num_mean)
        ns_on = rt['ns_on'][:]
        ns_on = np.where(ns_on, 1, 0)
        diff_ns = np.diff(ns_on)
        inds = np.where(diff_ns==1)[0] # NOTE: these are inds just 1 before the first ON
        if not rt.FRB_cal: # for FRB there might be just one noise point, avoid waste
            if inds[0]-1 < 0: # no off data in the beginning to use
                inds = inds[1:]
            if inds[-1]+2 > nt-1: # no on data in the end to use
                inds = inds[:-1]

        if save_gain:
            num_inds = len(inds)
            shp = (num_inds,)+rt.local_vis.shape[1:]
            dtype = rt.local_vis.real.dtype
            # create dataset to record ns_cal_time_inds
            rt.create_time_ordered_dataset('ns_cal_time_inds', inds)
            # create dataset to record ns_cal_phase
            ns_cal_phase = np.empty(shp, dtype=dtype)
            ns_cal_phase[:] = np.nan
            ns_cal_phase = mpiarray.MPIArray.wrap(ns_cal_phase, axis=2, comm=rt.comm)
            rt.create_freq_and_bl_ordered_dataset('ns_cal_phase', ns_cal_phase, axis_order=(None, 1, 2))
            rt['ns_cal_phase'].attrs['unit'] = 'radians'
            if not phs_only:
                # create dataset to record ns_cal_amp
                ns_cal_amp = np.empty(shp, dtype=dtype)
                ns_cal_amp[:] = np.nan
                ns_cal_amp = mpiarray.MPIArray.wrap(ns_cal_amp, axis=2, comm=rt.comm)
                rt.create_freq_and_bl_ordered_dataset('ns_cal_amp', ns_cal_amp, axis_order=(None, 1, 2))

        if bl_incl == 'all':
            bls_plt = [ tuple(bl) for bl in rt.bl ]
        else:
            bls_plt = [ bl for bl in bl_incl if not bl in bl_excl ]

        if freq_incl == 'all':
            freq_plt = range(rt.freq.shape[0])
        else:
            freq_plt = [ fi for fi in freq_incl if not fi in freq_excl ]

        show_progress = self.params['show_progress']
        progress_step = self.params['progress_step']

        if rt.ps_first:
            rt.freq_and_bl_data_operate(self.cal, full_data=True, show_progress=show_progress, progress_step=progress_step, keep_dist_axis=False, num_mean=num_mean, inds=inds, bls_plt=bls_plt, freq_plt=freq_plt, build_normal_file = build_normal_file)
        else:
            rt.freq_and_bl_data_operate(self.cal, full_data=True, show_progress=show_progress, progress_step=progress_step, keep_dist_axis=False, num_mean=num_mean, inds=inds, bls_plt=bls_plt, freq_plt=freq_plt)

        if save_gain:
            interp_mask_ratio = mpiutil.allreduce(np.sum(rt.interp_mask_count))/1./mpiutil.allreduce(np.size(rt.interp_mask_count)) * 100.
            if interp_mask_ratio > 50. and rt.normal_index is not None:
                warnings.warn('%.1f%% of the data was masked due to shortage of noise points for interpolation(need at least 4 to perform cubic spline)! The pointsource calibration may not be done due to too many masked points!'%interp_mask_ratio, NotEnoughPointToInterpolateWarning)
            if interp_mask_ratio > 80.:
                rt.interp_all_masked = True
            # gather bl_order to rank0
            bl_order = mpiutil.gather_array(rt['blorder'].local_data, axis=0, root=0, comm=rt.comm)
            # gather ns_cal_phase / ns_cal_amp to rank 0
            ns_cal_phase = mpiutil.gather_array(rt['ns_cal_phase'].local_data, axis=2, root=0, comm=rt.comm)
            phs_unit = rt['ns_cal_phase'].attrs['unit']
            rt.delete_a_dataset('ns_cal_phase', reserve_hint=False)
            if not phs_only:
                ns_cal_amp = mpiutil.gather_array(rt['ns_cal_amp'].local_data, axis=2, root=0, comm=rt.comm)
                rt.delete_a_dataset('ns_cal_amp', reserve_hint=False)

            if tag_output_iter:
                gain_file = output_path(gain_file, iteration=self.iteration)
            else:
                gain_file = output_path(gain_file)

            if rt.ps_first:
                phase_finite_count = mpiutil.allreduce(np.isfinite(rt.normal_phs).sum())
                if not phs_only:
                    amp_finite_count = mpiutil.allreduce(np.isfinite(rt.normal_amp).sum())
            if mpiutil.rank0:
                if rt.ps_first and build_normal_file:
                    if phase_finite_count == 0:
                        raise AllMasked('All values are masked when calculating phase!')
                    if not phs_only:
                        if amp_finite_count == 0:
                            raise AllMasked('All values are masked when calculating amplitude!')
                    with h5py.File(output_path(normal_gain_file), 'w') as tapfilein:
                        tapfilein.create_dataset('amp',(rt['vis'].shape[1], rt['vis'].shape[2]))
                        tapfilein.create_dataset('phs',(rt['vis'].shape[1], rt['vis'].shape[2]))
                with h5py.File(gain_file, 'w') as f:
                    # save time
                    f.create_dataset('time', data=rt['jul_date'][:])
                    f['time'].attrs['unit'] = 'Julian date'
                    # save freq
                    f.create_dataset('freq', data=rt['freq'][:])
                    f['freq'].attrs['unit'] = rt['freq'].attrs['unit']
                    # save bl
                    f.create_dataset('bl_order', data=bl_order)
                    # save ns_cal_time_inds
                    f.create_dataset('ns_cal_time_inds', data=rt['ns_cal_time_inds'][:])
                    # save ns_cal_phase
                    f.create_dataset('ns_cal_phase', data=ns_cal_phase)
                    f['ns_cal_phase'].attrs['unit'] = phs_unit
                    f['ns_cal_phase'].attrs['dim'] = '(time, freq, bl)'
                    if not phs_only:
                        # save ns_cal_amp
                        f.create_dataset('ns_cal_amp', data=ns_cal_amp)

                    # save channo
                    f.create_dataset('channo', data=rt['channo'][:])
                    f['channo'].attrs['dim'] = rt['channo'].attrs['dimname']
                    if rt.exclude_bad:
                        f['channo'].attrs['badchn'] = rt['channo'].attrs['badchn']

                    if not (absolute_gain_filename is None):
                        if not os.path.exists(output_path(absolute_gain_filename)):
                            raise NoPsGainFile('No absolute gain file %s, do the ps calibration first!'%output_path(absolute_gain_filename))
                        with h5py.File(output_path(absolute_gain_filename,'r')) as abs_gain:
                            new_gain = uni_gain(abs_gain, f, phs_only = phs_only)
                            f.create_dataset('uni_gain', data = new_gain)
                            f['uni_gain'].attrs['dim'] = '(time, freq, bl)'

            if rt.ps_first and build_normal_file:
                if mpiutil.rank0:
                    print('Start write normalization gain into %s'%output_path(normal_gain_file))
                for i in range(10):
                    try:
                        for ii in range(mpiutil.size):
                            if ii == mpiutil.rank:
                                with h5py.File(output_path(normal_gain_file), 'r+') as fileout:
                                    fileout['phs'][:,local_bl_start:local_bl_end] = rt.normal_phs[:,:]
                                    if not phs_only:
                                        fileout['amp'][:,local_bl_start:local_bl_end] = rt.normal_amp[:,:]
                            mpiutil.barrier()
                        break
                    except IOError:
                        time.sleep(0.5)
                        continue
            rt.delete_a_dataset('ns_cal_time_inds', reserve_hint=False)

        return super(NsCal, self).process(rt)
Пример #9
0
    def generate(self, regen=False):
        """Calculate the total Fisher matrix and bias and save to a file.

        Parameters
        ----------
        regen : boolean, optional
            Force regeneration if products already exist (default `False`).
        """

        if mpiutil.rank0:
            st = time.time()
            print("======== Starting PS calculation ========")

        ffile = self.psdir + "/fisher.hdf5"

        if os.path.exists(ffile) and not regen:
            print("Fisher matrix file: %s exists. Skipping..." % ffile)
            return

        mpiutil.barrier()

        # Pre-compute all the angular power spectra for the bands
        self.genbands()

        # Calculate Fisher and bias for each m
        # Pair up each list item with its position.
        zlist = list(enumerate(range(self.telescope.mmax + 1)))
        # Partition list based on MPI rank
        llist = mpiutil.partition_list_mpi(zlist)
        # Operate on sublist
        fisher_bias_list = [self.fisher_bias_m(item) for ind, item in llist]

        # Unpack into separate lists of the Fisher matrix and bias
        fisher_loc, bias_loc = zip(*fisher_bias_list)

        # Sum over all local m-modes to get the over all Fisher and bias pe process
        fisher_loc = np.sum(np.array(fisher_loc),
                            axis=0).real  # Be careful of the .real here
        bias_loc = np.sum(np.array(bias_loc),
                          axis=0).real  # Be careful of the .real here

        self.fisher = mpiutil.allreduce(fisher_loc, op=MPI.SUM)
        self.bias = mpiutil.allreduce(bias_loc, op=MPI.SUM)

        # Write out all the PS estimation products
        if mpiutil.rank0:
            et = time.time()
            print("======== Ending PS calculation (time=%f) ========" %
                  (et - st))

            # Check to see ensure that Fisher matrix isn't all zeros.
            if not (self.fisher == 0).all():
                # Generate derived quantities (covariance, errors..)
                cv = la.pinv(self.fisher, rcond=1e-8)
                err = cv.diagonal()**0.5
                cr = cv / np.outer(err, err)
            else:
                cv = np.zeros_like(self.fisher)
                err = cv.diagonal()
                cr = np.zeros_like(self.fisher)

            f = h5py.File(self.psdir + "/fisher.hdf5", "w")
            f.attrs["bandtype"] = np.string_(
                self.bandtype)  # HDF5 string issues

            f.create_dataset("fisher/", data=self.fisher)
            f.create_dataset("bias/", data=self.bias)
            f.create_dataset("covariance/", data=cv)
            f.create_dataset("errors/", data=err)
            f.create_dataset("correlation/", data=cr)

            f.create_dataset("band_power/", data=self.band_power)

            if self.bandtype == "polar":
                f.create_dataset("k_start/", data=self.k_start)
                f.create_dataset("k_end/", data=self.k_end)
                f.create_dataset("k_center/", data=self.k_center)

                f.create_dataset("theta_start/", data=self.theta_start)
                f.create_dataset("theta_end/", data=self.theta_end)
                f.create_dataset("theta_center/", data=self.theta_center)

                f.create_dataset("k_bands", data=self.k_bands)
                f.create_dataset("theta_bands", data=self.theta_bands)

            elif self.bandtype == "cartesian":

                f.create_dataset("kpar_start/", data=self.kpar_start)
                f.create_dataset("kpar_end/", data=self.kpar_end)
                f.create_dataset("kpar_center/", data=self.kpar_center)

                f.create_dataset("kperp_start/", data=self.kperp_start)
                f.create_dataset("kperp_end/", data=self.kperp_end)
                f.create_dataset("kperp_center/", data=self.kperp_center)

                f.create_dataset("kpar_bands", data=self.kpar_bands)
                f.create_dataset("kperp_bands", data=self.kperp_bands)

            f.close()
Пример #10
0
    def process(self, rt):

        assert isinstance(
            rt, RawTimestream
        ), '%s only works for RawTimestream object currently' % self.__class__.__name__

        channel = self.params['channel']
        mask_near = max(0, int(self.params['mask_near']))

        rt.redistribute(0)  # make time the dist axis

        auto_inds = np.where(
            rt.bl[:, 0] == rt.bl[:,
                                 1])[0].tolist()  # inds for auto-correlations
        channels = [rt.bl[ai, 0] for ai in auto_inds]  # all chosen channels
        if channel is not None:
            if channel in channels:
                bl_ind = auto_inds[channels.index(channel)]
            else:
                bl_ind = auto_inds[0]
                if mpiutil.rank0:
                    print 'Warning: Required channel %d doen not in the data, use channel %d instead' % (
                        channel, rt.bl[bl_ind, 0])
        else:
            bl_ind = auto_inds[0]
        # move the chosen channel to the first
        auto_inds.remove(bl_ind)
        auto_inds = [bl_ind] + auto_inds

        for bl_ind in auto_inds:
            this_chan = rt.bl[bl_ind, 0]  # channel of this bl_ind
            vis = np.ma.array(rt.local_vis[:, :, bl_ind].real,
                              mask=rt.local_vis_mask[:, :, bl_ind])
            cnt = vis.count()  # number of not masked vals
            total_cnt = mpiutil.allreduce(cnt)
            vis_shp = rt.vis.shape
            ratio = float(total_cnt) / np.prod(
                (vis_shp[0], vis_shp[1]))  # ratio of un-masked vals
            if ratio < 0.5:  # too many masked vals
                if mpiutil.rank0:
                    warnings.warn(
                        'Too many masked values for auto-correlation of Channel: %d, does not use it'
                        % this_chan)
                continue

            tt_mean = mpiutil.gather_array(np.ma.mean(vis, axis=-1).filled(0),
                                           root=None)
            tt_mean_sort = np.sort(tt_mean)
            tt_mean_sort = tt_mean_sort[tt_mean_sort > 0]
            ttms_div = tt_mean_sort[1:] / tt_mean_sort[:-1]
            ind = np.argmax(ttms_div)
            if ind > 0 and ttms_div[ind] > 2.0 * max(ttms_div[ind - 1],
                                                     ttms_div[ind + 1]):
                sep = np.sqrt(tt_mean_sort[ind] * tt_mean_sort[ind + 1])
                break

            # ttms_diff = np.diff(tt_mean_sort)
            # ind = np.argmax(ttms_diff)
            # if ind > 0 and ttms_diff[ind] > 2.0 * max(ttms_diff[ind-1], ttms_diff[ind+1]):
            #     sep = 0.5 * (tt_mean_sort[ind] + tt_mean_sort[ind+1])
            #     break
        else:
            raise RuntimeError(
                'Failed to get the threshold to separate ns signal out')

        ns_on = np.where(tt_mean > sep, True, False)
        nTs = []
        nFs = []
        nT = 1 if ns_on[0] else 0
        nF = 1 if not ns_on[0] else 0
        for i in range(1, len(ns_on)):
            if ns_on[i]:
                if ns_on[i] == ns_on[i - 1]:
                    nT += 1
                else:
                    nT = 1
                    nFs.append(nF)
            else:
                if ns_on[i] == ns_on[i - 1]:
                    nF += 1
                else:
                    nF = 1
                    nTs.append(nT)
        on_time = Counter(nTs).most_common(1)[0][0]
        off_time = Counter(nFs).most_common(1)[0][0]
        period = on_time + off_time

        if 'noisesource' in rt.iterkeys():
            if rt['noisesource'].shape[0] == 1:  # only 1 noise source
                start, stop, cycle = rt['noisesource'][0, :]
                int_time = rt.attrs['inttime']
                true_on_time = np.round((stop - start) / int_time)
                true_period = np.round(cycle / int_time)
                if on_time != true_on_time and period != true_period:  # inconsistant with the record in the data
                    if mpiutil.rank0:
                        warnings.warn(
                            'Detected noise source info is inconsistant with the record in the data for auto-correlation of Channel: %d: on_time %d != record_on_time %d, period != record_period %d, does not use it'
                            % (this_chan, on_time, true_on_time, period,
                               true_period))
            elif rt['noisesource'].shape[0] >= 2:  # more than 1 noise source
                if mpiutil.rank0:
                    warnings.warn(
                        'More than 1 noise source, do not know how to deal with this currently'
                    )

        if mpiutil.rank0:
            print 'Detected noise source: period = %d, on_time = %d, off_time = %d' % (
                period, on_time, off_time)

        ns_on1 = mpiarray.MPIArray.from_numpy_array(ns_on)

        rt.create_main_time_ordered_dataset('ns_on', ns_on1)
        rt['ns_on'].attrs['period'] = period
        rt['ns_on'].attrs['on_time'] = on_time
        rt['ns_on'].attrs['off_time'] = off_time

        # set vis_mask corresponding to ns_on
        on_inds = np.where(rt['ns_on'].local_data[:])[0]
        rt.local_vis_mask[on_inds] = True

        if mask_near > 0:
            on_inds = np.where(ns_on)[0]
            new_on_inds = on_inds.tolist()
            for i in xrange(1, mask_near + 1):
                new_on_inds = new_on_inds + (on_inds - i).tolist() + (
                    on_inds + i).tolist()
            new_on_inds = np.unique(new_on_inds)

            if rt['vis_mask'].distributed:
                start = rt.vis_mask.local_offset[0]
                end = start + rt.vis_mask.local_shape[0]
            else:
                start = 0
                end = rt.vis_mask.shape[0]
            global_inds = np.arange(start, end).tolist()
            new_on_inds = np.intersect1d(new_on_inds, global_inds)
            local_on_inds = [global_inds.index(i) for i in new_on_inds]
            rt.local_vis_mask[
                local_on_inds] = True  # set mask using global slicing

        return super(Detect, self).process(rt)
Пример #11
0
    def read_input(self):
        """Method for (maybe iteratively) reading data from input data files."""

        days = self.params['days']
        extra_inttime = self.params['extra_inttime']
        drop_days = self.params['drop_days']
        mode = self.params['mode']
        dist_axis = self.params['dist_axis']

        ngrp = len(self.input_grps)

        if self.next_grp and ngrp > 1:
            if mpiutil.rank0:
                print 'Start file group %d of %d...' % (self.grp_cnt, ngrp)
            self.restart_iteration()  # re-start iteration for each group
            self.next_grp = False
            self.abs_start = None
            self.abs_stop = None

        input_files = self.input_grps[self.grp_cnt]
        start = self.start[self.grp_cnt]
        stop = self.stop[self.grp_cnt]

        if self.int_time is None:
            # NOTE: here assume all files have the same int_time
            with h5py.File(self.input_files[0], 'r') as f:
                self.int_time = f.attrs['inttime']

        if self.abs_start is None or self.abs_stop is None:
            tmp_tod = self._Tod_class(input_files, mode, start, stop,
                                      dist_axis)
            self.abs_start = tmp_tod.main_data_start
            self.abs_stop = tmp_tod.main_data_stop
            del tmp_tod

        iteration = self.iteration if self.iterable else 0
        this_start = self.abs_start + np.int(
            np.around(iteration * days * const.sday / self.int_time))
        this_stop = min(
            self.abs_stop, self.abs_start + np.int(
                np.around(
                    (iteration + 1) * days * const.sday / self.int_time)) +
            2 * extra_inttime)
        if this_stop >= self.abs_stop:
            self.next_grp = True
            self.grp_cnt += 1
        if this_start >= this_stop:
            self.next_grp = True
            self.grp_cnt += 1
            return None

        this_span = self.int_time * (this_stop - this_start)  # in unit second
        if this_span < drop_days * const.sday:
            if mpiutil.rank0:
                print 'Not enough span time, drop it...'
            return None
        elif (this_stop - this_start) <= extra_inttime:  # use int comparision
            if mpiutil.rank0:
                print 'Not enough span time (less than `extra_inttime`), drop it...'
            return None

        tod = self._Tod_class(input_files, mode, this_start, this_stop,
                              dist_axis)

        tod, _ = self.data_select(tod)

        tod.load_all()  # load in all data

        if self.start_ra is None:  # the first iteration
            if 'time' == tod.main_data_axes[tod.main_data_dist_axis]:
                # ra_dec is distributed among processes
                # find the point of ra_dec[extra_inttime, 0] of the global array
                local_offset = tod['ra_dec'].local_offset[0]
                local_shape = tod['ra_dec'].local_shape[0]
                if local_offset <= extra_inttime and extra_inttime < local_offset + local_shape:
                    in_this = 1
                    start_ra = tod['ra_dec'].local_data[extra_inttime -
                                                        local_offset, 0]
                else:
                    in_this = 0
                    start_ra = None

                # get the rank
                max_val, in_rank = mpiutil.allreduce((in_this, tod.rank),
                                                     op=mpiutil.MAXLOC,
                                                     comm=tod.comm)
                # bcast from this rank
                start_ra = mpiutil.bcast(start_ra, root=in_rank, comm=tod.comm)
                self.start_ra = start_ra
            else:
                self.start_ra = ra_dec[extra_inttime, 0]

        tod.vis.attrs['start_ra'] = self.start_ra  # used for re_order

        return tod
Пример #12
0
    def process(self, ts):

        assert isinstance(
            ts, Timestream
        ), '%s only works for Timestream object' % self.__class__.__name__

        #        if mpiutil.rank0:
        #            saveVmat = []

        calibrator = self.params['calibrator']
        catalog = self.params['catalog']
        vis_conj = self.params['vis_conj']
        zero_diag = self.params['zero_diag']
        span = self.params['span']
        reserve_high_gain = self.params['reserve_high_gain']
        plot_figs = self.params['plot_figs']
        fig_prefix = self.params['fig_name']
        tag_output_iter = self.params['tag_output_iter']
        save_src_vis = self.params['save_src_vis']
        src_vis_file = self.params['src_vis_file']
        subtract_src = self.params['subtract_src']
        apply_gain = self.params['apply_gain']
        save_gain = self.params['save_gain']
        save_phs_change = self.params['save_phs_change']
        gain_file = self.params['gain_file']
        temperature_convert = self.params['temperature_convert']
        show_progress = self.params['show_progress']
        progress_step = self.params['progress_step']
        srcdict = self.params['srcdict']
        max_iter = self.params['max_iter']

        #        if mpiutil.rank0:
        #            print(ts.keys())
        #            print(ts.attrs.keys())
        #
        #        import sys
        #        sys.exit()

        if save_src_vis or subtract_src or apply_gain or save_gain:
            pol_type = ts['pol'].attrs['pol_type']
            if pol_type != 'linear':
                raise RuntimeError('Can not do ps_cal for pol_type: %s' %
                                   pol_type)

            ts.redistribute('baseline')

            feedno = ts['feedno'][:].tolist()
            pol = [ts.pol_dict[p] for p in ts['pol'][:]]  # as string
            gain_pd = {
                'xx': 0,
                'yy': 1,
                0: 'xx',
                1: 'yy'
            }  # for gain related op
            bls = mpiutil.gather_array(ts.local_bl[:], root=None, comm=ts.comm)
            # # antpointing = np.radians(ts['antpointing'][-1, :, :]) # radians
            # transitsource = ts['transitsource'][:]
            # transit_time = transitsource[-1, 0] # second, sec1970
            # int_time = ts.attrs['inttime'] # second

            # calibrator
            #see whether the source should be observed
            #only for single calibrator
            if ts.is_dish:
                obsd_sources = ts['transitsource'].attrs['srcname']
                obsd_sources = obsd_sources.split(',')
                #                srcdict = {'cas':'CassiopeiaA'}
                src_index = 0.1
                for src_short, src_long in srcdict.items():
                    #                    print(src_short,src_long,obsd_sources,calibrator)
                    if (src_short == calibrator) and (src_long
                                                      in obsd_sources):
                        src_index = obsd_sources.index(src_long)
                        break
                else:
                    raise Exception(
                        'the transit source is not included in the observation plan!'
                    )
                obs_transit_time = ts['transitsource'][src_index][0]

#            srclist, cutoff, catalogs = a.scripting.parse_srcs(calibrator, catalog)
#            cat = a.src.get_catalog(srclist, cutoff, catalogs)
#            assert(len(cat) == 1), 'Allow only one calibrator'
#            s = cat.values()[0]

# get the calibrator
            try:
                s = calibrators.get_src(calibrator)
            except KeyError:
                if mpiutil.rank0:
                    print 'Calibrator %s is unavailable, available calibrators are:'
                    for key, d in calibrators.src_data.items():
                        print '%8s  ->  %12s' % (key, d[0])
                raise RuntimeError('Calibrator %s is unavailable')
            if mpiutil.rank0:
                print 'Try to calibrate with %s...' % s.src_name
#            if mpiutil.rank0:
#                print 'Calibrating for source %s with' % calibrator,
#                print 'strength', s._jys, 'Jy',
#                print 'measured at', s.mfreq, 'GHz',
#                print 'with index', s.index

# get transit time of calibrator
# array
            aa = ts.array
            aa.set_jultime(ts['jul_date'][0])  # the first obs time point
            next_transit = aa.next_transit(s)
            next_transit = aa.next_transit(
                s) if not ts.is_dish else ephem.date(
                    datetime.utcfromtimestamp(obs_transit_time))
            transit_time = a.phs.ephem2juldate(next_transit)  # Julian date
            # get time zone
            pattern = '[-+]?\d+'
            tz = re.search(pattern, ts.attrs['timezone']).group()
            tz = int(tz)

            local_next_transit = ephem.Date(
                next_transit + tz * ephem.hour)  # plus 8h to get Beijing time
            # if transit_time > ts['jul_date'][-1]:
            #========================================================
            if ts.is_dish:
                if (transit_time >
                        max(ts['jul_date'][-1],
                            ts['jul_date'][:].max())) or (transit_time < min(
                                ts['jul_date'][0], ts['jul_date'][:].min())):
                    #                    raise RuntimeError('dish data does not contain local transit time %s of source %s' % (local_next_transit, calibrator))
                    raise NoTransit(
                        'Dish data does not contain local transit time %s of source %s'
                        % (local_next_transit, calibrator))
                transit_inds = [
                    np.searchsorted(ts['jul_date'][:], transit_time)
                ]
                peak_span = int(np.around(10. / ts.attrs['inttime']))
                peak_start = ts['jul_date'][transit_inds[0] - peak_span]
                peak_end = ts['jul_date'][transit_inds[0] + peak_span]
                tspan = (peak_end - peak_start) * 86400.
                if mpiutil.rank0:
                    print('transit peak: %s' % local_next_transit)
                    print('%d point (should be about 10s) previous: %s' %
                          (peak_span,
                           ephem.date(
                               a.phs.juldate2ephem(peak_start) +
                               tz * ephem.hour)))
                    print(
                        '%d point (should be about 10s) later: %s' %
                        (peak_span,
                         ephem.date(
                             a.phs.juldate2ephem(peak_end) + tz * ephem.hour)))
                if tspan > 30.:
                    warnings.warn(
                        'Peak of transit is not continuous in the data! May lead to poor performance!'
                    )
            else:
                if transit_time > max(ts['jul_date'][-1],
                                      ts['jul_date'][:].max()):
                    #                    raise RuntimeError('Cylinder data does not contain local transit time %s of source %s' % (local_next_transit, calibrator))
                    raise NoTransit(
                        'Cylinder data does not contain local transit time %s of source %s'
                        % (local_next_transit, calibrator))

                # the first transit index
                transit_inds = [
                    np.searchsorted(ts['jul_date'][:], transit_time)
                ]
                # find all other transit indices
                aa.set_jultime(ts['jul_date'][0] + 1.0)
                transit_time = a.phs.ephem2juldate(
                    aa.next_transit(s))  # Julian date
                cnt = 2
                while (transit_time <= ts['jul_date'][-1]):
                    transit_inds.append(
                        np.searchsorted(ts['jul_date'][:], transit_time))
                    aa.set_jultime(ts['jul_date'][0] + 1.0 * cnt)
                    # Julian date
                    #                transit_time = a.phs.ephem2juldate(aa.next_transit(s) if not ts.is_dish else ephem.date(datetime.utcfromtimestamp(obs_transit_time)))
                    transit_time = a.phs.ephem2juldate(aa.next_transit(s))
                    cnt += 1
#========================================================

            if mpiutil.rank0:
                print 'transit ind of %s: %s, time: %s' % (
                    s.src_name, transit_inds, local_next_transit)

            if (not ts.ps_first) and ts.interp_all_masked:
                raise NotEnoughPointToInterpolateError(
                    'More than 80% of the data was masked due to shortage of noise points for interpolation(need at least 4 to perform cubic spline)! The pointsource calibration may not be done due to too many masked points!'
                )
            ### now only use the first transit point to do the cal
            ### may need to improve in the future
            transit_ind = transit_inds[0]
            int_time = ts.attrs['inttime']  # second
            start_ind = transit_ind - np.int(span / int_time)
            end_ind = transit_ind + np.int(
                span /
                int_time) + 1  # plus 1 to make transit_ind is at the center

            start_ind = max(0, start_ind)
            end_ind = min(end_ind, ts.vis.shape[0])

            if vis_conj:
                ts.local_vis[:] = ts.local_vis.conj()

            nt = end_ind - start_ind
            t_inds = range(start_ind, end_ind)
            freq = ts.freq[:]
            nf = len(freq)
            nlb = len(ts.local_bl[:])
            nfeed = len(feedno)
            tfp_inds = list(
                itertools.product(
                    t_inds, range(nf),
                    [pol.index('xx'), pol.index('yy')]))  # only for xx and yy
            ns, ss, es = mpiutil.split_all(len(tfp_inds), comm=ts.comm)
            # gather data to make each process to have its own data which has all bls
            for ri, (ni, si, ei) in enumerate(zip(ns, ss, es)):
                lvis = np.zeros((ni, nlb), dtype=ts.vis.dtype)
                lvis_mask = np.zeros((ni, nlb), dtype=ts.vis_mask.dtype)
                for ii, (ti, fi, pi) in enumerate(tfp_inds[si:ei]):
                    lvis[ii] = ts.local_vis[ti, fi, pi]
                    lvis_mask[ii] = ts.local_vis_mask[ti, fi, pi]
                # gather vis from all process for separate bls
                gvis = mpiutil.gather_array(lvis,
                                            axis=1,
                                            root=ri,
                                            comm=ts.comm)
                gvis_mask = mpiutil.gather_array(lvis_mask,
                                                 axis=1,
                                                 root=ri,
                                                 comm=ts.comm)
                if ri == mpiutil.rank:
                    tfp_linds = tfp_inds[si:ei]  # inds for this process
                    this_vis = gvis
                    this_vis_mask = gvis_mask
            del tfp_inds
            del lvis
            del lvis_mask
            tfp_len = len(tfp_linds)

            cnan = complex(np.nan, np.nan)  # complex nan
            if save_src_vis or subtract_src:
                # save calibrator src vis
                lsrc_vis = np.full((tfp_len, nfeed, nfeed),
                                   cnan,
                                   dtype=ts.vis.dtype)
                if save_src_vis:
                    # save sky vis
                    lsky_vis = np.full((tfp_len, nfeed, nfeed),
                                       cnan,
                                       dtype=ts.vis.dtype)
                    # save outlier vis
                    lotl_vis = np.full((tfp_len, nfeed, nfeed),
                                       cnan,
                                       dtype=ts.vis.dtype)

            if apply_gain or save_gain:
                lGain = np.full((tfp_len, nfeed), cnan, dtype=ts.vis.dtype)

            # find indices mapping between Vmat and vis
            # bis = range(nbl)
            bis_conj = []  # indices that shold be conj
            mis = [
            ]  # indices in the nfeed x nfeed matrix by flatten it to a vector
            mis_conj = [
            ]  # indices (of conj vis) in the nfeed x nfeed matrix by flatten it to a vector
            for bi, (fdi, fdj) in enumerate(bls):
                ai, aj = feedno.index(fdi), feedno.index(fdj)
                mis.append(ai * nfeed + aj)
                if ai != aj:
                    bis_conj.append(bi)
                    mis_conj.append(aj * nfeed + ai)

            # construct visibility matrix for a single time, freq, pol
            Vmat = np.full((nfeed, nfeed), cnan, dtype=ts.vis.dtype)
            # get flus of the calibrator in the observing frequencies
            Sc = s.get_jys(freq)
            if show_progress and mpiutil.rank0:
                pg = progress.Progress(tfp_len, step=progress_step)
            for ii, (ti, fi, pi) in enumerate(tfp_linds):
                if show_progress and mpiutil.rank0:
                    pg.show(ii)
                # when noise on, just pass
                if 'ns_on' in ts.iterkeys() and ts['ns_on'][ti]:
                    continue
                # aa.set_jultime(ts['jul_date'][ti])
                # s.compute(aa)
                # get the topocentric coordinate of the calibrator at the current time
                # s_top = s.get_crds('top', ncrd=3)
                # aa.sim_cache(cat.get_crds('eq', ncrd=3)) # for compute bm_response and sim
                Vmat.flat[mis] = np.ma.array(
                    this_vis[ii], mask=this_vis_mask[ii]).filled(cnan)
                Vmat.flat[mis_conj] = np.ma.array(
                    this_vis[ii, bis_conj],
                    mask=this_vis_mask[ii, bis_conj]).conj().filled(cnan)

                #                if mpiutil.rank0:
                #                    saveVmat += [Vmat.copy()]

                if save_src_vis:
                    lsky_vis[ii] = Vmat

                # set invalid val to 0
                invalid = ~np.isfinite(Vmat)  # a bool array
                # if too many masks

                if np.where(invalid)[0].shape[0] > 0.3 * nfeed**2:
                    continue
                Vmat[invalid] = 0
                #                # if all are zeros
                if np.allclose(Vmat, 0.0):
                    continue

                # fill diagonal of Vmat to 0
                if zero_diag:
                    np.fill_diagonal(Vmat, 0)

                # initialize the outliers
                med = np.median(Vmat.real) + 1.0J * np.median(Vmat.imag)
                diff = Vmat - med
                S0 = np.where(
                    np.abs(diff) > 3.0 * rpca_decomp.MAD(Vmat), diff, 0)
                # stable PCA decomposition
                #                V0, S = rpca_decomp.decompose(Vmat, rank=1, S=S0, max_iter=100, threshold='hard', tol=1.0e-6, debug=False)
                V0, S = rpca_decomp.decompose(Vmat,
                                              rank=1,
                                              S=S0,
                                              max_iter=max_iter,
                                              threshold='hard',
                                              tol=1.0e-6,
                                              debug=False)
                #                V0, S = rpca_decomp.decompose(Vmat, rank=1, S=S0, max_iter=1, threshold='hard', tol=1., debug=False)
                if save_src_vis or subtract_src:
                    lsrc_vis[ii] = V0
                    if save_src_vis:
                        lotl_vis[ii] = S

                # plot
                if plot_figs:
                    ind = ti - start_ind
                    # plot Vmat
                    plt.figure(figsize=(13, 5))
                    plt.subplot(121)
                    plt.imshow(Vmat.real,
                               aspect='equal',
                               origin='lower',
                               interpolation='nearest')
                    plt.colorbar(shrink=1.0)
                    plt.subplot(122)
                    plt.imshow(Vmat.imag,
                               aspect='equal',
                               origin='lower',
                               interpolation='nearest')
                    plt.colorbar(shrink=1.0)
                    fig_name = '%s_V_%d_%d_%s.png' % (fig_prefix, ind, fi,
                                                      pol[pi])
                    if tag_output_iter:
                        fig_name = output_path(fig_name,
                                               iteration=self.iteration)
                    else:
                        fig_name = output_path(fig_name)
                    plt.savefig(fig_name)
                    plt.close()
                    # plot V0
                    plt.figure(figsize=(13, 5))
                    plt.subplot(121)
                    plt.imshow(V0.real,
                               aspect='equal',
                               origin='lower',
                               interpolation='nearest')
                    plt.colorbar(shrink=1.0)
                    plt.subplot(122)
                    plt.imshow(V0.imag,
                               aspect='equal',
                               origin='lower',
                               interpolation='nearest')
                    plt.colorbar(shrink=1.0)
                    fig_name = '%s_V0_%d_%d_%s.png' % (fig_prefix, ind, fi,
                                                       pol[pi])
                    if tag_output_iter:
                        fig_name = output_path(fig_name,
                                               iteration=self.iteration)
                    else:
                        fig_name = output_path(fig_name)
                    plt.savefig(fig_name)
                    plt.close()
                    # plot S
                    plt.figure(figsize=(13, 5))
                    plt.subplot(121)
                    plt.imshow(S.real,
                               aspect='equal',
                               origin='lower',
                               interpolation='nearest')
                    plt.colorbar(shrink=1.0)
                    plt.subplot(122)
                    plt.imshow(S.imag,
                               aspect='equal',
                               origin='lower',
                               interpolation='nearest')
                    plt.colorbar(shrink=1.0)
                    fig_name = '%s_S_%d_%d_%s.png' % (fig_prefix, ind, fi,
                                                      pol[pi])
                    if tag_output_iter:
                        fig_name = output_path(fig_name,
                                               iteration=self.iteration)
                    else:
                        fig_name = output_path(fig_name)
                    plt.savefig(fig_name)
                    plt.close()
                    # plot N
                    N = Vmat - V0 - S
                    plt.figure(figsize=(13, 5))
                    plt.subplot(121)
                    plt.imshow(N.real,
                               aspect='equal',
                               origin='lower',
                               interpolation='nearest')
                    plt.colorbar(shrink=1.0)
                    plt.subplot(122)
                    plt.imshow(N.imag,
                               aspect='equal',
                               origin='lower',
                               interpolation='nearest')
                    plt.colorbar(shrink=1.0)
                    fig_name = '%s_N_%d_%d_%s.png' % (fig_prefix, ind, fi,
                                                      pol[pi])
                    if tag_output_iter:
                        fig_name = output_path(fig_name,
                                               iteration=self.iteration)
                    else:
                        fig_name = output_path(fig_name)
                    plt.savefig(fig_name)
                    plt.close()

                if apply_gain or save_gain:
                    #                    if mpiutil.rank0:
                    #                        np.save('Vmat',saveVmat)
                    e, U = la.eigh(V0 / Sc[fi], eigvals=(nfeed - 1, nfeed - 1))
                    g = U[:, -1] * e[-1]**0.5
                    if g[0].real < 0:
                        g *= -1.0  # make all g[0] phase 0, instead of pi
                    lGain[ii] = g

                    # plot Gain
                    liplot = max(start_ind, transit_ind - 1)
                    hiplot = min(end_ind, transit_ind + 1 + 1)
                    if plot_figs and ti >= liplot and ti <= hiplot:
                        #                    if ti >= liplot and ti <= hiplot:
                        ind = ti - start_ind
                        plt.figure()
                        plt.plot(feedno, g.real, 'b-', label='real')
                        plt.plot(feedno, g.real, 'bo')
                        plt.plot(feedno, g.imag, 'g-', label='imag')
                        plt.plot(feedno, g.imag, 'go')
                        plt.plot(feedno, np.abs(g), 'r-', label='abs')
                        plt.plot(feedno, np.abs(g), 'ro')
                        plt.xlim(feedno[0] - 1, feedno[-1] + 1)
                        yl, yh = plt.ylim()
                        plt.ylim(yl, yh + (yh - yl) / 5)
                        plt.xlabel('Feed number')
                        plt.legend()
                        fig_name = '%s_ants_%d_%d_%s.png' % (fig_prefix, ind,
                                                             fi, pol[pi])
                        print('plot %s' % fig_name)
                        if tag_output_iter:
                            fig_name = output_path(fig_name,
                                                   iteration=self.iteration)
                        else:
                            fig_name = output_path(fig_name)
                        plt.savefig(fig_name)
                        plt.close()

            # subtract the vis of calibrator from self.vis
            if subtract_src:
                nbl = len(bls)
                lv = np.zeros((lsrc_vis.shape[0], nbl), dtype=lsrc_vis.dtype)
                for bi, (fd1, fd2) in enumerate(bls):
                    b1, b2 = feedno.index(fd1), feedno.index(fd2)
                    lv[:, bi] = lsrc_vis[:, b1, b2]
                lv = mpiarray.MPIArray.wrap(lv, axis=0, comm=ts.comm)
                lv = lv.redistribute(axis=1).local_array.reshape(nt, nf, 2, -1)
                if 'ns_on' in ts.iterkeys():
                    lv[ts['ns_on']
                       [start_ind:
                        end_ind]] = 0  # avoid ns_on signal to become nan
                ts.local_vis[start_ind:end_ind, :,
                             pol.index('xx')] -= lv[:, :, 0]
                ts.local_vis[start_ind:end_ind, :,
                             pol.index('yy')] -= lv[:, :, 1]

                del lv

            if not save_src_vis:
                if subtract_src:
                    del lsrc_vis
            else:
                if tag_output_iter:
                    src_vis_file = output_path(src_vis_file,
                                               iteration=self.iteration)
                else:
                    src_vis_file = output_path(src_vis_file)
                # create file and allocate space first by rank0
                if mpiutil.rank0:
                    with h5py.File(src_vis_file, 'w') as f:
                        # allocate space
                        shp = (nt, nf, 2, nfeed, nfeed)
                        f.create_dataset('sky_vis', shp, dtype=lsky_vis.dtype)
                        f.create_dataset('src_vis', shp, dtype=lsrc_vis.dtype)
                        f.create_dataset('outlier_vis',
                                         shp,
                                         dtype=lotl_vis.dtype)
                        f.attrs['calibrator'] = calibrator
                        f.attrs['dim'] = 'time, freq, pol, feed, feed'
                        try:
                            f.attrs['time'] = ts.time[start_ind:end_ind]
                        except RuntimeError:
                            f.create_dataset('time',
                                             data=ts.time[start_ind:end_ind])
                            f.attrs['time'] = '/time'
                        f.attrs['freq'] = freq
                        f.attrs['pol'] = np.array(['xx', 'yy'])
                        f.attrs['feed'] = np.array(feedno)

                mpiutil.barrier()

                # write data to file
                for i in range(10):
                    try:
                        # NOTE: if write simultaneously, will loss data with processes distributed in several nodes
                        for ri in xrange(mpiutil.size):
                            if ri == mpiutil.rank:
                                with h5py.File(src_vis_file, 'r+') as f:
                                    for ii, (ti, fi,
                                             pi) in enumerate(tfp_linds):
                                        ti_ = ti - start_ind
                                        pi_ = gain_pd[pol[pi]]
                                        f['sky_vis'][ti_, fi,
                                                     pi_] = lsky_vis[ii]
                                        f['src_vis'][ti_, fi,
                                                     pi_] = lsrc_vis[ii]
                                        f['outlier_vis'][ti_, fi,
                                                         pi_] = lotl_vis[ii]
                            mpiutil.barrier()
                        break
                    except IOError:
                        time.sleep(0.5)
                        continue
                else:
                    raise RuntimeError('Could not open file: %s...' %
                                       src_vis_file)

                del lsrc_vis
                del lsky_vis
                del lotl_vis

                mpiutil.barrier()

            if apply_gain or save_gain:
                # flag outliers in lGain along each feed
                lG_abs = np.full_like(lGain, np.nan, dtype=lGain.real.dtype)
                for i in range(lGain.shape[0]):
                    valid_inds = np.where(np.isfinite(lGain[i]))[0]
                    if len(valid_inds) > 3:
                        vabs = np.abs(lGain[i, valid_inds])
                        vmed = np.median(vabs)
                        vabs_diff = np.abs(vabs - vmed)
                        vmad = np.median(vabs_diff) / 0.6745
                        if reserve_high_gain:
                            # reserve significantly higher ones, flag only significantly lower ones
                            lG_abs[i, valid_inds] = np.where(
                                vmed - vabs > 3.0 * vmad, np.nan, vabs)
                        else:
                            # flag both significantly higher and lower ones
                            lG_abs[i, valid_inds] = np.where(
                                vabs_diff > 3.0 * vmad, np.nan, vabs)

                # choose data slice near the transit time
                li = max(start_ind, transit_ind - 10) - start_ind
                hi = min(end_ind, transit_ind + 10 + 1) - start_ind
                # compute s_top for this time range
                n0 = np.zeros(((hi - li), 3))
                for ti, jt in enumerate(ts.time[start_ind:end_ind][li:hi]):
                    aa.set_jultime(jt)
                    s.compute(aa)
                    n0[ti] = s.get_crds('top', ncrd=3)
                if save_phs_change:
                    n0t = np.zeros((nt, 3))
                    for ti, jt in enumerate(ts.time[start_ind:end_ind]):
                        aa.set_jultime(jt)
                        s.compute(aa)
                        n0t[ti] = s.get_crds('top', ncrd=3)

                # get the positions of feeds
                feedpos = ts['feedpos'][:]

                # wrap and redistribute Gain and flagged G_abs
                Gain = mpiarray.MPIArray.wrap(lGain, axis=0, comm=ts.comm)
                Gain = Gain.redistribute(axis=1).reshape(
                    nt, nf, 2, None).redistribute(axis=0).reshape(
                        None, nf * 2 * nfeed).redistribute(axis=1)
                G_abs = mpiarray.MPIArray.wrap(lG_abs, axis=0, comm=ts.comm)
                G_abs = G_abs.redistribute(axis=1).reshape(
                    nt, nf, 2, None).redistribute(axis=0).reshape(
                        None, nf * 2 * nfeed).redistribute(axis=1)

                fpd_inds = list(
                    itertools.product(range(nf), range(2),
                                      range(nfeed)))  # only for xx and yy
                fpd_linds = mpiutil.mpilist(fpd_inds,
                                            method='con',
                                            comm=ts.comm)
                del fpd_inds
                # create data to save the solved gain for each feed
                lgain = np.full((len(fpd_linds), ), cnan,
                                dtype=Gain.dtype)  # gain for each feed
                if save_phs_change:
                    lphs = np.full((nt, len(fpd_linds)),
                                   np.nan,
                                   dtype=Gain.real.dtype
                                   )  # phase change with time for each feed

                # check for conj
                num_conj = 0
                for ii, (fi, pi, di) in enumerate(fpd_linds):
                    y = G_abs.local_array[li:hi, ii]
                    inds = np.where(np.isfinite(y))[0]
                    if len(inds) >= max(4, 0.5 * len(y)):
                        # get the approximate magnitude by averaging the central G_abs
                        # solve phase by least square fit
                        ui = (feedpos[di] - feedpos[0]) * (
                            1.0e6 * freq[fi]
                        ) / const.c  # position of this feed (relative to the first feed) in unit of wavelength
                        exp_factor = np.exp(2.0J * np.pi * np.dot(n0, ui))
                        ef = exp_factor
                        Gi = Gain.local_array[li:hi, ii]
                        e_phs = np.dot(ef[inds].conj(),
                                       Gi[inds] / y[inds]) / len(inds)
                        ea = np.abs(e_phs)
                        e_phs_conj = np.dot(ef[inds],
                                            Gi[inds] / y[inds]) / len(inds)
                        eac = np.abs(e_phs_conj)
                        if eac > ea:
                            num_conj += 1
                # reduce num_conj from all processes
                num_conj = mpiutil.allreduce(num_conj, comm=ts.comm)
                if num_conj > 0.5 * (nf * 2 * nfeed):  # 2 for 2 pols
                    if mpiutil.rank0:
                        print '!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!'
                        print '!!!   Detect data should be their conjugate...   !!!'
                        print '!!!   Correct it automatically...                !!!'
                        print '!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!'
                    mpiutil.barrier()
                    # correct vis
                    ts.local_vis[:] = ts.local_vis.conj()
                    # correct G
                    Gain.local_array[:] = Gain.local_array.conj()

                # solve for gain
#                feedplotG = []
#                savenum = 0
                for ii, (fi, pi, di) in enumerate(fpd_linds):
                    y = G_abs.local_array[li:hi, ii]
                    inds = np.where(np.isfinite(y))[0]

                    if len(inds) >= max(4, 0.5 * len(y)):
                        # get the approximate magnitude by averaging the central G_abs
                        mag = np.mean(y[inds])
                        # solve phase by least square fit
                        ui = (feedpos[di] - feedpos[0]) * (
                            1.0e6 * freq[fi]
                        ) / const.c  # position of this feed (relative to the first feed) in unit of wavelength
                        exp_factor = np.exp(2.0J * np.pi * np.dot(n0, ui))
                        ef = exp_factor
                        Gi = Gain.local_array[li:hi, ii]

                        #                        feedplotG += [Gi[10]]
                        #                        print(feedplotG)

                        e_phs = np.dot(ef[inds].conj(),
                                       Gi[inds] / y[inds]) / len(inds)
                        ea = np.abs(e_phs)

                        if np.abs(ea - 1.0) < 0.1:
                            # compute gain for this feed
                            lgain[ii] = mag * e_phs
                            if save_phs_change:
                                lphs[:, ii] = np.angle(
                                    np.exp(-2.0J * np.pi * np.dot(n0t, ui)) *
                                    Gain.local_array[:, ii])
                        else:
                            e_phs_conj = np.dot(ef[inds],
                                                Gi[inds] / y[inds]) / len(inds)
                            eac = np.abs(e_phs_conj)
                            if eac > ea:
                                if np.abs(eac - 1.0) < 0.01:
                                    print 'feedno = %d, fi = %d, pol = %s: may need to be conjugated' % (
                                        feedno[di], fi, gain_pd[pi])
                            else:
                                print 'feedno = %d, fi = %d, pol = %s: maybe wrong abs(e_phs): %s' % (
                                    feedno[di], fi, gain_pd[pi], ea)

#                import matplotlib.pyplot as plt
#                import os.path
#                feedplotG = np.array(feedplotG)
#                if mpiutil.rank0:
#                    if os.path.isfile('feedplotG%d.npy'%savenum):
#                        savenum += 1
#                    np.save('feedplotG%d'%savenum,feedplotG)
#                    plt.plot(feedplotG)
#                    plt.savefig('Gfig%d'%savenum)

# gather local gain
                gain = mpiutil.gather_array(lgain,
                                            axis=0,
                                            root=None,
                                            comm=ts.comm)
                del lgain
                gain = gain.reshape(nf, 2, nfeed)
                if save_phs_change:
                    phs = mpiutil.gather_array(lphs,
                                               axis=1,
                                               root=0,
                                               comm=ts.comm)
                    del lphs
                    if mpiutil.rank0:
                        phs = phs.reshape(nt, nf, 2, nfeed)

                # apply gain to vis
                if apply_gain:
                    for fi in range(nf):
                        for pi in [pol.index('xx'), pol.index('yy')]:
                            pi_ = gain_pd[pol[pi]]
                            for bi, (fd1, fd2) in enumerate(
                                    ts['blorder'].local_data):
                                g1 = gain[fi, pi_, feedno.index(fd1)]
                                g2 = gain[fi, pi_, feedno.index(fd2)]
                                if np.isfinite(g1) and np.isfinite(g2):
                                    ts.local_vis[:, fi, pi,
                                                 bi] /= (g1 * np.conj(g2))
                                else:
                                    # mask the un-calibrated vis
                                    ts.local_vis_mask[:, fi, pi, bi] = True

                # save gain to file
                if save_gain:
                    if tag_output_iter:
                        gain_file = output_path(gain_file,
                                                iteration=self.iteration)
                    else:
                        gain_file = output_path(gain_file)
                    if mpiutil.rank0:
                        with h5py.File(gain_file, 'w') as f:
                            # allocate space for Gain
                            dset = f.create_dataset('Gain', (nt, nf, 2, nfeed),
                                                    dtype=Gain.dtype)
                            dset.attrs['calibrator'] = calibrator
                            dset.attrs['dim'] = 'time, freq, pol, feed'
                            try:
                                dset.attrs['time'] = ts.time[start_ind:end_ind]
                            except RuntimeError:
                                f.create_dataset(
                                    'time', data=ts.time[start_ind:end_ind])
                                dset.attrs['time'] = '/time'
                            dset.attrs['freq'] = freq
                            dset.attrs['pol'] = np.array(['xx', 'yy'])
                            dset.attrs['feed'] = np.array(feedno)
                            # save gain
                            dset = f.create_dataset('gain', data=gain)
                            dset.attrs['calibrator'] = calibrator
                            dset.attrs['dim'] = 'freq, pol, feed'
                            dset.attrs['freq'] = freq
                            dset.attrs['pol'] = np.array(['xx', 'yy'])
                            dset.attrs['feed'] = np.array(feedno)
                            # save phs

                            if save_phs_change:
                                f.create_dataset('phs', data=phs)
                            # save transit index and transit time for the case do the ps cal first
                            f.attrs['transit_index'] = transit_inds[0]
                            #                            f.attrs['transit_time'] = ts['jul_date'][transit_inds[0]]
                            f.attrs['transit_jul'] = ts['jul_date'][
                                transit_inds[0]]
                            f.attrs['transit_time'] = ts['sec1970'][
                                transit_inds[0]]
                            #                            f.attrs['transit_time'] = ts.attrs['sec1970'] + ts.attrs['inttime']*transit_inds[0] # in sec1970 to improve numeric precision

                            if os.path.exists(output_path(
                                    ts.ns_gain_file)) and not ts.ps_first:
                                with h5py.File(output_path(ts.ns_gain_file),
                                               'r+') as ns_file:
                                    phs_only = not ('ns_cal_amp'
                                                    in ns_file.keys())
                                    #                                    exclude_bad = 'badchn' in ns_file['channo'].attrs.keys()
                                    new_gain = uni_gain(f,
                                                        ns_file,
                                                        phs_only=phs_only)
                                    ns_file.create_dataset('uni_gain',
                                                           data=new_gain)
                                    ns_file['uni_gain'].attrs[
                                        'dim'] = '(time, freq, bl)'

                    mpiutil.barrier()

                    # save Gain
                    for i in range(10):
                        try:
                            # NOTE: if write simultaneously, will loss data with processes distributed in several nodes
                            for ri in xrange(mpiutil.size):
                                if ri == mpiutil.rank:
                                    with h5py.File(gain_file, 'r+') as f:
                                        for ii, (ti, fi,
                                                 pi) in enumerate(tfp_linds):
                                            ti_ = ti - start_ind
                                            pi_ = gain_pd[pol[pi]]
                                            f['Gain'][ti_, fi, pi_] = lGain[ii]
                                mpiutil.barrier()
                            break
                        except IOError:
                            time.sleep(0.5)
                            continue
                    else:
                        raise RuntimeError('Could not open file: %s...' %
                                           gain_file)

                    mpiutil.barrier()

        # convert vis from intensity unit to temperature unit in K
        if temperature_convert:
            if 'unit' in ts.vis.attrs.keys() and ts.vis.attrs['unit'] == 'K':
                if mpiutil.rank0:
                    print 'vis is already in unit K, do nothing...'
            else:
                factor = 1.0e-26 * (const.c**2 / (2 * const.k_B *
                                                  (1.0e6 * freq)**2)
                                    )  # NOTE: 1Jy = 1.0e-26 W m^-2 Hz^-1
                ts.local_vis[:] *= factor[np.newaxis, :, np.newaxis,
                                          np.newaxis]
                ts.vis.attrs['unit'] = 'K'

        return super(PsCal, self).process(ts)