Exemple #1
0
    def process(self, ts):

        assert isinstance(ts, Timestream), '%s only works for Timestream object' % self.__class__.__name__

        num_mean = self.params['num_mean']
        save_ns_vis = self.params['save_ns_vis']
        ns_vis_file = self.params['ns_vis_file']
        apply_gain = self.params['apply_gain']
        save_gain = self.params['save_gain']
        gain_file = self.params['gain_file']
        show_progress = self.params['show_progress']
        progress_step = self.params['progress_step']
        tag_output_iter = self.params['tag_output_iter']

        if save_ns_vis or apply_gain or save_gain:
            pol_type = ts['pol'].attrs['pol_type']
            if pol_type != 'linear':
                raise RuntimeError('Can not do ns_eigcal for pol_type: %s' % pol_type)

            ts.redistribute('baseline')

            nt = ts.local_vis.shape[0]
            freq = ts.freq[:]
            pol = [ ts.pol_dict[p] for p in ts['pol'][:] ] # as string
            bls = mpiutil.gather_array(ts.local_bl[:], root=None, comm=ts.comm)
            feedno = ts['feedno'][:].tolist()

            nf = ts.local_freq.shape[0]
            npol = ts.local_pol.shape[0]
            nlb = ts.local_bl.shape[0]
            nfeed = len(feedno)

            if num_mean <= 0:
                raise RuntimeError('Invalid num_mean = %s' % num_mean)
            ns_on = ts['ns_on'][:]
            ns_on = np.where(ns_on, 1, 0)
            diff_ns = np.diff(ns_on)
            on_si = np.where(diff_ns==1)[0] + 1 # start inds of ON
            on_ei = np.where(diff_ns==-1)[0] + 1 # (end inds + 1) of ON
            if on_ei[0] < on_si[0]:
                on_ei = on_ei[1:]
            if on_si[-1] > on_ei[-1]:
                on_si = on_si[:-1]

            if on_si[0] < num_mean+1: # not enough off data in the beginning to use
                on_si = on_si[1:]
                on_ei = on_ei[1:]

            if len(on_si) != len(on_ei):
                raise RuntimeError('len(on_si) != len(on_ei)')
            num_on = len(on_si)
            cal_inds = (on_si + on_ei) / 2 # cal inds are the center inds on ON


            # find indices mapping between Vmat and vis
            # bis = range(nbl)
            bis_conj = [] # indices that shold be conj
            mis = [] # indices in the nfeed x nfeed matrix by flatten it to a vector
            mis_conj = [] # indices (of conj vis) in the nfeed x nfeed matrix by flatten it to a vector
            for bi, (fdi, fdj) in enumerate(bls):
                ai, aj = feedno.index(fdi), feedno.index(fdj)
                mis.append(ai * nfeed + aj)
                if ai != aj:
                    bis_conj.append(bi)
                    mis_conj.append(aj * nfeed + ai)


            tfp_inds = list(itertools.product(range(num_on), range(nf), range(npol)))
            ns, ss, es = mpiutil.split_all(len(tfp_inds), comm=ts.comm)
            # gather data to make each process to have its own data which has all bls
            for ri, (ni, si, ei) in enumerate(zip(ns, ss, es)):
                lon_off = np.zeros((ni, nlb), dtype=ts.vis.dtype)
                for ii, (ti, fi, pi) in enumerate(tfp_inds[si:ei]):
                    si_on, ei_on = on_si[ti], on_ei[ti]
                    # mean of ON - mean of OFF
                    if ei_on - si_on > 3:
                        # does not use the two ends if there are more than three ONs
                        lon_off[ii] = np.mean(ts.local_vis[si_on+1:ei_on-1, fi, pi], axis=0) - np.ma.mean(np.ma.array(ts.local_vis[si_on-num_mean-1:si_on-1, fi, pi], mask=ts.local_vis_mask[si_on-num_mean-1:si_on-1, fi, pi]), axis=0)
                    else:
                        lon_off[ii] = np.mean(ts.local_vis[si_on:ei_on, fi, pi], axis=0) - np.ma.mean(np.ma.array(ts.local_vis[si_on-num_mean-1:si_on-1, fi, pi], mask=ts.local_vis_mask[si_on-num_mean-1:si_on-1, fi, pi]), axis=0)

                # gather on_off from all process for separate bls
                on_off = mpiutil.gather_array(lon_off, axis=1, root=ri, comm=ts.comm)
                if ri == mpiutil.rank:
                    tfp_linds = tfp_inds[si:ei] # inds for this process
                    this_on_off = on_off
            del tfp_inds
            del lon_off
            tfp_len = len(tfp_linds)


            cnan = complex(np.nan, np.nan) # complex nan
            if save_ns_vis:
                # save the extracted noise source vis
                lsrc_vis = np.full((tfp_len, nfeed, nfeed), cnan, dtype=ts.vis.dtype)
                # save sky vis
                lsky_vis = np.full((tfp_len, nfeed, nfeed), cnan, dtype=ts.vis.dtype)
                # save outlier vis
                lotl_vis = np.full((tfp_len, nfeed, nfeed), cnan, dtype=ts.vis.dtype)

            if apply_gain or save_gain:
                lgain = np.zeros((tfp_len, nfeed), dtype=ts.vis.dtype)
                lgain_mask = np.zeros((tfp_len, nfeed), dtype=bool)

            # construct visibility matrix for a single time, freq, pol
            Vmat = np.full((nfeed, nfeed), cnan, dtype=ts.vis.dtype)

            if show_progress and mpiutil.rank0:
                pg = progress.Progress(len(tfp_linds), step=progress_step)

            for ii, (ti, fi, pi) in enumerate(tfp_linds):
                if show_progress and mpiutil.rank0:
                    pg.show(ii)

                Vmat.flat[mis] = this_on_off[ii]
                Vmat.flat[mis_conj] = this_on_off[ii, bis_conj].conj()

                if save_ns_vis:
                    lsky_vis[ii] = Vmat

                # initialize the outliers
                med = np.median(Vmat.real) + 1.0J * np.median(Vmat.imag)
                diff = Vmat - med
                S0 = np.where(np.abs(diff)>3.0*rpca_decomp.MAD(Vmat), diff, 0)
                # stable PCA decomposition
                V0, S = rpca_decomp.decompose(Vmat, rank=1, S=S0, max_iter=100, threshold='hard', tol=1.0e-6, debug=False)
                if save_ns_vis:
                    lsrc_vis[ii] = V0
                    lotl_vis[ii] = S

                if apply_gain or save_gain:
                    e, U = la.eigh(V0, eigvals=(nfeed-1, nfeed-1))
                    g = U[:, -1] * e[-1]**0.5
                    # g = U[:, -1] * nfeed**0.5 # to make g_i g_j^* ~ 1
                    if g[0].real < 0:
                        g *= -1.0 # make all g[0] phase 0, instead of pi
                    lgain[ii] = g
                    ### maybe does not flag abnormal values here to simplify the programming, the flag can be down in ps_cal
                    gabs = np.abs(g)
                    gmed = np.median(gabs)
                    gabs_diff = np.abs(gabs - gmed)
                    gmad = np.median(gabs_diff) / 0.6745
                    lgain_mask[ii, np.where(gabs_diff>3.0*gmad)[0]] = True # mask invalid feeds

            if save_ns_vis:
                if tag_output_iter:
                    ns_vis_file = output_path(ns_vis_file, iteration=self.iteration)
                else:
                    ns_vis_file = output_path(ns_vis_file)
                # create file and allocate space first by rank0
                if mpiutil.rank0:
                    with h5py.File(ns_vis_file, 'w') as f:
                        # allocate space
                        shp = (num_on, nf, npol, nfeed, nfeed)
                        f.create_dataset('sky_vis', shp, dtype=lsky_vis.dtype)
                        f.create_dataset('src_vis', shp, dtype=lsrc_vis.dtype)
                        f.create_dataset('outlier_vis', shp, dtype=lotl_vis.dtype)
                        f.attrs['dim'] = 'time, freq, pol, feed, feed'
                        try:
                            f.attrs['time_inds'] = (on_si + on_ei) / 2
                        except RuntimeError:
                            f.create_dataset('time_inds', data=(on_si + on_ei)/2)
                            f.attrs['time_inds'] = '/time_inds'
                        f.attrs['freq'] = ts.freq
                        f.attrs['pol'] = ts.pol
                        f.attrs['feed'] = np.array(feedno)

                mpiutil.barrier()

                # write data to file
                for i in range(10):
                    try:
                        # NOTE: if write simultaneously, will loss data with processes distributed in several nodes
                        for ri in xrange(mpiutil.size):
                            if ri == mpiutil.rank:
                                with h5py.File(ns_vis_file, 'r+') as f:
                                    for ii, (ti, fi, pi) in enumerate(tfp_linds):
                                        f['sky_vis'][ti, fi, pi] = lsky_vis[ii]
                                        f['src_vis'][ti, fi, pi] = lsrc_vis[ii]
                                        f['outlier_vis'][ti, fi, pi] = lotl_vis[ii]
                            mpiutil.barrier()
                        break
                    except IOError:
                        time.sleep(0.5)
                        continue
                else:
                    raise RuntimeError('Could not open file: %s...' % src_vis_file)

                del lsrc_vis
                del lsky_vis
                del lotl_vis

                mpiutil.barrier()

            if apply_gain or save_gain:
                gain = mpiutil.gather_array(lgain, axis=0, root=None, comm=ts.comm)
                gain_mask = mpiutil.gather_array(lgain_mask, axis=0, root=None, comm=ts.comm)
                del lgain
                del lgain_mask
                gain = gain.reshape(num_on, nf, npol, nfeed)
                gain_mask = gain_mask.reshape(num_on, nf, npol, nfeed)

                # normalize gain to make its amp ~ 1
                gain_med = np.ma.median(np.ma.array(np.abs(gain), mask=gain_mask))
                gain /= gain_med

                # phi = np.angle(gain)

                # delta_phi = np.zeros((num_on, nf, npol))

                # # get phase change
                # for ti in range(1, num_on):
                #     delta_phi[ti] = np.ma.mean(np.ma.array(phi[ti], mask=gain_mask[ti]) - np.ma.array(phi[ti-1], mask=gain_mask[ti-1]), axis=2)

                # # save original gain
                # gain_original = gain.copy()
                # # compensate phase changes
                # gain *= np.exp(1.0J * delta_phi[:, :, :, np.newaxis])

                gain_alltimes = np.full((nt, nf, npol, nfeed), cnan, dtype=gain.dtype)
                gain_alltimes_mask = np.zeros((nt, nf, npol, nfeed), dtype=bool)

                # interpolate to all time points
                for fi in range(nf):
                    for pi in range(npol):
                        for di in range(nfeed):
                            valid_inds = np.where(np.logical_not(gain_mask[:, fi, pi, di]))[0]
                            if len(valid_inds) < 0.75 * num_on:
                                # no enough points to do good interpolation
                                gain_alltimes_mask[:, fi, pi, di] = True
                            else:
                                # gain_alltimes[:, fi, pi, di] = InterpolatedUnivariateSpline(cal_inds[valid_inds], gain[valid_inds, fi, pi, di].real)(np.arange(nt)) + 1.0J * InterpolatedUnivariateSpline(cal_inds[valid_inds], gain[valid_inds, fi, pi, di].imag)(np.arange(nt))
                                # interpolate amp and phase to avoid abrupt changes
                                amp = InterpolatedUnivariateSpline(cal_inds[valid_inds], np.abs(gain[valid_inds, fi, pi, di]))(np.arange(nt))
                                phs = InterpolatedUnivariateSpline(cal_inds[valid_inds], np.unwrap(np.angle(gain[valid_inds, fi, pi, di])))(np.arange(nt))
                                gain_alltimes[:, fi, pi, di] = amp * np.exp(1.0J * phs)

                # apply gain to vis
                for fi in range(nf):
                    for pi in range(npol):
                        for bi, (fd1, fd2) in enumerate(ts['blorder'].local_data):
                            g1 = gain_alltimes[:, fi, pi, feedno.index(fd1)]
                            g1_mask = gain_alltimes_mask[:, fi, pi, feedno.index(fd1)]
                            g2 = gain_alltimes[:, fi, pi, feedno.index(fd2)]
                            g2_mask = gain_alltimes_mask[:, fi, pi, feedno.index(fd2)]
                            g12 = g1 * np.conj(g2)
                            g12_mask = np.logical_or(g1_mask, g2_mask)

                            if fd1 == fd2:
                                # auto-correlation should be real
                                ts.local_vis[:, fi, pi, bi] /= g12.real
                            else:
                                ts.local_vis[:, fi, pi, bi] /= g12
                            ts.local_vis_mask[:, fi, pi, bi] = np.logical_or(ts.local_vis_mask[:, fi, pi, bi], g12_mask)


                if save_gain:
                    if tag_output_iter:
                        gain_file = output_path(gain_file, iteration=self.iteration)
                    else:
                        gain_file = output_path(gain_file)
                    if mpiutil.rank0:
                        with h5py.File(gain_file, 'w') as f:
                            # allocate space for Gain
                            # dset = f.create_dataset('gain', data=gain_original) # gain without phase compensation
                            dset = f.create_dataset('gain', data=gain) # gain without phase compensation
                            # f.create_dataset('delta_phi', data=delta_phi)
                            f.create_dataset('gain_mask', data=gain_mask)
                            dset.attrs['dim'] = 'time, freq, pol, feed'
                            try:
                                dset.attrs['time_inds'] = cal_inds
                            except RuntimeError:
                                f.create_dataset('time_inds', data=cal_inds)
                                dset.attrs['time_inds'] = '/time_inds'
                            dset.attrs['freq'] = ts.freq
                            dset.attrs['pol'] = ts.pol
                            dset.attrs['feed'] = np.array(feedno)
                            dset.attrs['gain_med'] = gain_med # record the normalization factor
                            # save gain_alltimes
                            dset = f.create_dataset('gain_alltimes', data=gain_alltimes)
                            f.create_dataset('gain_alltimes_mask', data=gain_alltimes_mask)
                            dset.attrs['dim'] = 'time, freq, pol, feed'
                            try:
                                dset.attrs['time_inds'] = np.arange(nt)
                            except RuntimeError:
                                f.create_dataset('all_time_inds', data=np.arange(nt))
                                dset.attrs['time_inds'] = '/all_time_inds'
                            dset.attrs['freq'] = ts.freq
                            dset.attrs['pol'] = ts.pol
                            dset.attrs['feed'] = np.array(feedno)

                            f.create_dataset('time', data=ts.local_time)

                    mpiutil.barrier()


        return super(NsCal, self).process(ts)
Exemple #2
0
    def process(self, ts):

        assert isinstance(
            ts, Timestream
        ), '%s only works for Timestream object' % self.__class__.__name__

        calibrator = self.params['calibrator']
        catalog = self.params['catalog']
        vis_conj = self.params['vis_conj']
        zero_diag = self.params['zero_diag']
        span = self.params['span']
        reserve_high_gain = self.params['reserve_high_gain']
        plot_figs = self.params['plot_figs']
        fig_prefix = self.params['fig_name']
        tag_output_iter = self.params['tag_output_iter']
        save_src_vis = self.params['save_src_vis']
        src_vis_file = self.params['src_vis_file']
        subtract_src = self.params['subtract_src']
        replace_with_src = self.params['replace_with_src']
        apply_gain = self.params['apply_gain']
        save_gain = self.params['save_gain']
        save_phs_change = self.params['save_phs_change']
        gain_file = self.params['gain_file']
        # temperature_convert = self.params['temperature_convert']
        show_progress = self.params['show_progress']
        progress_step = self.params['progress_step']

        if save_src_vis or subtract_src or apply_gain or save_gain:
            pol_type = ts['pol'].attrs['pol_type']
            if pol_type != 'linear':
                raise RuntimeError('Can not do ps_cal for pol_type: %s' %
                                   pol_type)

            ts.redistribute('baseline')

            feedno = ts['feedno'][:].tolist()
            pol = [ts.pol_dict[p] for p in ts['pol'][:]]  # as string
            gain_pd = {
                'xx': 0,
                'yy': 1,
                0: 'xx',
                1: 'yy'
            }  # for gain related op
            bls = mpiutil.gather_array(ts.local_bl[:], root=None, comm=ts.comm)
            # # antpointing = np.radians(ts['antpointing'][-1, :, :]) # radians
            # transitsource = ts['transitsource'][:]
            # transit_time = transitsource[-1, 0] # second, sec1970
            # int_time = ts.attrs['inttime'] # second

            # get the calibrator
            try:
                s = calibrators.get_src(calibrator)
            except KeyError:
                if mpiutil.rank0:
                    print 'Calibrator %s is unavailable, available calibrators are:'
                    for key, d in calibrators.src_data.items():
                        print '%8s  ->  %12s' % (key, d[0])
                raise RuntimeError('Calibrator %s is unavailable')
            if mpiutil.rank0:
                print 'Try to calibrate with %s...' % s.src_name

            # get transit time of calibrator
            # array
            aa = ts.array
            aa.set_jultime(ts['jul_date'][0])  # the first obs time point
            next_transit = aa.next_transit(s)
            transit_time = a.phs.ephem2juldate(next_transit)  # Julian date
            # get time zone
            pattern = '[-+]?\d+'
            tz = re.search(pattern, ts.attrs['timezone']).group()
            tz = int(tz)
            local_next_transit = ephem.Date(
                next_transit + tz * ephem.hour)  # plus 8h to get Beijing time
            # if transit_time > ts['jul_date'][-1]:
            if transit_time > max(ts['jul_date'][-1], ts['jul_date'][:].max()):
                raise RuntimeError(
                    'Data does not contain local transit time %s of source %s'
                    % (local_next_transit, calibrator))

            # the first transit index
            transit_inds = [np.searchsorted(ts['jul_date'][:], transit_time)]
            # find all other transit indices
            aa.set_jultime(ts['jul_date'][0] + 1.0)
            transit_time = a.phs.ephem2juldate(
                aa.next_transit(s))  # Julian date
            cnt = 2
            while (transit_time <= ts['jul_date'][-1]):
                transit_inds.append(
                    np.searchsorted(ts['jul_date'][:], transit_time))
                aa.set_jultime(ts['jul_date'][0] + 1.0 * cnt)
                transit_time = a.phs.ephem2juldate(
                    aa.next_transit(s))  # Julian date
                cnt += 1

            if mpiutil.rank0:
                print 'transit ind of %s: %s, time: %s' % (
                    s.src_name, transit_inds, local_next_transit)

            ### now only use the first transit point to do the cal
            ### may need to improve in the future
            transit_ind = transit_inds[0]
            int_time = ts.attrs['inttime']  # second
            start_ind = transit_ind - np.int(span / int_time)
            end_ind = transit_ind + np.int(
                span /
                int_time) + 1  # plus 1 to make transit_ind is at the center

            start_ind = max(0, start_ind)
            end_ind = min(end_ind, ts.vis.shape[0])

            if vis_conj:
                ts.local_vis[:] = ts.local_vis.conj()

            nt = end_ind - start_ind
            t_inds = range(start_ind, end_ind)
            freq = ts.freq[:]  # MHz
            nf = len(freq)
            nlb = len(ts.local_bl[:])
            nfeed = len(feedno)
            tfp_inds = list(
                itertools.product(
                    t_inds, range(nf),
                    [pol.index('xx'), pol.index('yy')]))  # only for xx and yy
            ns, ss, es = mpiutil.split_all(len(tfp_inds), comm=ts.comm)
            # gather data to make each process to have its own data which has all bls
            for ri, (ni, si, ei) in enumerate(zip(ns, ss, es)):
                lvis = np.zeros((ni, nlb), dtype=ts.vis.dtype)
                lvis_mask = np.zeros((ni, nlb), dtype=ts.vis_mask.dtype)
                for ii, (ti, fi, pi) in enumerate(tfp_inds[si:ei]):
                    lvis[ii] = ts.local_vis[ti, fi, pi]
                    lvis_mask[ii] = ts.local_vis_mask[ti, fi, pi]
                # gather vis from all process for separate bls
                gvis = mpiutil.gather_array(lvis,
                                            axis=1,
                                            root=ri,
                                            comm=ts.comm)
                gvis_mask = mpiutil.gather_array(lvis_mask,
                                                 axis=1,
                                                 root=ri,
                                                 comm=ts.comm)
                if ri == mpiutil.rank:
                    tfp_linds = tfp_inds[si:ei]  # inds for this process
                    this_vis = gvis
                    this_vis_mask = gvis_mask
            del tfp_inds
            del lvis
            del lvis_mask
            tfp_len = len(tfp_linds)

            # lotl_mask = np.zeros((tfp_len, nfeed, nfeed), dtype=bool)
            cnan = complex(np.nan, np.nan)  # complex nan
            if save_src_vis or subtract_src:
                # save calibrator src vis
                lsrc_vis = np.full((tfp_len, nfeed, nfeed),
                                   cnan,
                                   dtype=ts.vis.dtype)
                if save_src_vis:
                    # save sky vis
                    lsky_vis = np.full((tfp_len, nfeed, nfeed),
                                       cnan,
                                       dtype=ts.vis.dtype)
                    # save outlier vis
                    lotl_vis = np.full((tfp_len, nfeed, nfeed),
                                       cnan,
                                       dtype=ts.vis.dtype)

            if apply_gain or save_gain:
                lGain = np.full((tfp_len, nfeed), cnan, dtype=ts.vis.dtype)

            # find indices mapping between Vmat and vis
            # bis = range(nbl)
            bis_conj = []  # indices that shold be conj
            mis = [
            ]  # indices in the nfeed x nfeed matrix by flatten it to a vector
            mis_conj = [
            ]  # indices (of conj vis) in the nfeed x nfeed matrix by flatten it to a vector
            for bi, (fdi, fdj) in enumerate(bls):
                ai, aj = feedno.index(fdi), feedno.index(fdj)
                mis.append(ai * nfeed + aj)
                if ai != aj:
                    bis_conj.append(bi)
                    mis_conj.append(aj * nfeed + ai)

            # construct visibility matrix for a single time, freq, pol
            Vmat = np.full((nfeed, nfeed), cnan, dtype=ts.vis.dtype)
            # get flus of the calibrator in the observing frequencies
            if show_progress and mpiutil.rank0:
                pg = progress.Progress(tfp_len, step=progress_step)
            for ii, (ti, fi, pi) in enumerate(tfp_linds):
                if show_progress and mpiutil.rank0:
                    pg.show(ii)
                # when noise on, just pass
                if 'ns_on' in ts.iterkeys() and ts['ns_on'][ti]:
                    continue
                # aa.set_jultime(ts['jul_date'][ti])
                # s.compute(aa)
                # get the topocentric coordinate of the calibrator at the current time
                # s_top = s.get_crds('top', ncrd=3)
                # aa.sim_cache(cat.get_crds('eq', ncrd=3)) # for compute bm_response and sim
                Vmat.flat[mis] = np.ma.array(
                    this_vis[ii], mask=this_vis_mask[ii]).filled(cnan)
                Vmat.flat[mis_conj] = np.ma.array(
                    this_vis[ii, bis_conj],
                    mask=this_vis_mask[ii, bis_conj]).conj().filled(cnan)

                if save_src_vis:
                    lsky_vis[ii] = Vmat

                # set invalid val to 0
                invalid = ~np.isfinite(Vmat)  # a bool array
                # if too many masks
                if np.where(invalid)[0].shape[0] > 0.3 * nfeed**2:
                    continue
                Vmat[invalid] = 0
                # if all are zeros
                if np.allclose(Vmat, 0.0):
                    continue

                # fill diagonal of Vmat to 0
                if zero_diag:
                    np.fill_diagonal(Vmat, 0)

                # initialize the outliers
                med = np.median(Vmat.real) + 1.0J * np.median(Vmat.imag)
                diff = Vmat - med
                S0 = np.where(
                    np.abs(diff) > 3.0 * rpca_decomp.MAD(Vmat), diff, 0)
                # stable PCA decomposition
                V0, S = rpca_decomp.decompose(Vmat,
                                              rank=1,
                                              S=S0,
                                              max_iter=200,
                                              threshold='hard',
                                              tol=1.0e-6,
                                              debug=False)

                # # find abnormal values in S
                # # first check diagonal elements
                # import pdb; pdb.set_trace()
                # svals = np.diag(S)
                # smed = np.median(svals.real) + 1.0J * np.median(svals.imag)
                # smad = rpca_decomp.MAD(svals)
                # # abnormal indices
                # abis =  np.where(np.abs(svals - smed) > 3.0 * smad)[0]
                # for abi in abis:
                #     lotl_mask[ii, abi, abi] = True
                # # then check non-diagonal elements
                # for rii in range(nfeed):
                #     for cii in range(nfeed):
                #         if rii == cii:
                #             continue
                #         rli = max(0, rii-2)
                #         rhi = min(nfeed, rii+3)
                #         cli = max(0, cii-2)
                #         chi = min(nfeed, cii+3)
                #         svals = np.array([ S[xi, yi] for xi in range(rli, rhi) for yi in range(cli, chi) if xi != yi ])
                #         smed = np.median(svals.real) + 1.0J * np.median(svals.imag)
                #         smad = rpca_decomp.MAD(svals)
                #         if np.abs(S[rii, cii] - smed) > 3.0 * smad:
                #             lotl_mask[ii, rii, cii] = True

                if save_src_vis or subtract_src:
                    lsrc_vis[ii] = V0
                    if save_src_vis:
                        lotl_vis[ii] = S

                # plot
                if plot_figs:
                    ind = ti - start_ind
                    # plot Vmat
                    plt.figure(figsize=(13, 5))
                    plt.subplot(121)
                    plt.imshow(Vmat.real,
                               aspect='equal',
                               origin='lower',
                               interpolation='nearest')
                    plt.colorbar(shrink=1.0)
                    plt.subplot(122)
                    plt.imshow(Vmat.imag,
                               aspect='equal',
                               origin='lower',
                               interpolation='nearest')
                    plt.colorbar(shrink=1.0)
                    fig_name = '%s_V_%d_%d_%s.png' % (fig_prefix, ind, fi,
                                                      pol[pi])
                    if tag_output_iter:
                        fig_name = output_path(fig_name,
                                               iteration=self.iteration)
                    else:
                        fig_name = output_path(fig_name)
                    plt.savefig(fig_name)
                    plt.close()
                    # plot V0
                    plt.figure(figsize=(13, 5))
                    plt.subplot(121)
                    plt.imshow(V0.real,
                               aspect='equal',
                               origin='lower',
                               interpolation='nearest')
                    plt.colorbar(shrink=1.0)
                    plt.subplot(122)
                    plt.imshow(V0.imag,
                               aspect='equal',
                               origin='lower',
                               interpolation='nearest')
                    plt.colorbar(shrink=1.0)
                    fig_name = '%s_V0_%d_%d_%s.png' % (fig_prefix, ind, fi,
                                                       pol[pi])
                    if tag_output_iter:
                        fig_name = output_path(fig_name,
                                               iteration=self.iteration)
                    else:
                        fig_name = output_path(fig_name)
                    plt.savefig(fig_name)
                    plt.close()
                    # plot S
                    plt.figure(figsize=(13, 5))
                    plt.subplot(121)
                    plt.imshow(S.real,
                               aspect='equal',
                               origin='lower',
                               interpolation='nearest')
                    plt.colorbar(shrink=1.0)
                    plt.subplot(122)
                    plt.imshow(S.imag,
                               aspect='equal',
                               origin='lower',
                               interpolation='nearest')
                    plt.colorbar(shrink=1.0)
                    fig_name = '%s_S_%d_%d_%s.png' % (fig_prefix, ind, fi,
                                                      pol[pi])
                    if tag_output_iter:
                        fig_name = output_path(fig_name,
                                               iteration=self.iteration)
                    else:
                        fig_name = output_path(fig_name)
                    plt.savefig(fig_name)
                    plt.close()
                    # plot N
                    N = Vmat - V0 - S
                    plt.figure(figsize=(13, 5))
                    plt.subplot(121)
                    plt.imshow(N.real,
                               aspect='equal',
                               origin='lower',
                               interpolation='nearest')
                    plt.colorbar(shrink=1.0)
                    plt.subplot(122)
                    plt.imshow(N.imag,
                               aspect='equal',
                               origin='lower',
                               interpolation='nearest')
                    plt.colorbar(shrink=1.0)
                    fig_name = '%s_N_%d_%d_%s.png' % (fig_prefix, ind, fi,
                                                      pol[pi])
                    if tag_output_iter:
                        fig_name = output_path(fig_name,
                                               iteration=self.iteration)
                    else:
                        fig_name = output_path(fig_name)
                    plt.savefig(fig_name)
                    plt.close()

                if apply_gain or save_gain:
                    # use v_ij = gi gj^* \int Ai Aj^* e^(2\pi i n \cdot uij) T(x) d^2n
                    # precisely, we shold have
                    # V0 = (lambda^2 * Sc / (2 k_B)) * gi gj^* Ai Aj^* e^(2\pi i n0 \cdot uij)
                    e, U = la.eigh(V0, eigvals=(nfeed - 1, nfeed - 1))
                    g = U[:, -1] * e[
                        -1]**0.5  # = \sqrt(lambda^2 * Sc / (2 k_B)) * gi Ai * e^(2\pi i n0 \cdot ui)
                    if g[0].real < 0:
                        g *= -1.0  # make all g[0] phase 0, instead of pi
                    lGain[ii] = g

                    # plot Gain
                    if plot_figs:
                        plt.figure()
                        plt.plot(feedno, g.real, 'b-', label='real')
                        plt.plot(feedno, g.real, 'bo')
                        plt.plot(feedno, g.imag, 'g-', label='imag')
                        plt.plot(feedno, g.imag, 'go')
                        plt.plot(feedno, np.abs(g), 'r-', label='abs')
                        plt.plot(feedno, np.abs(g), 'ro')
                        plt.xlim(feedno[0] - 1, feedno[-1] + 1)
                        yl, yh = plt.ylim()
                        plt.ylim(yl, yh + (yh - yl) / 5)
                        plt.xlabel('Feed number')
                        plt.legend()
                        fig_name = '%s_ants_%d_%d_%s.png' % (fig_prefix, ind,
                                                             fi, pol[pi])
                        if tag_output_iter:
                            fig_name = output_path(fig_name,
                                                   iteration=self.iteration)
                        else:
                            fig_name = output_path(fig_name)
                        plt.savefig(fig_name)
                        plt.close()

            # # apply outlier mask
            # nbl = len(bls)
            # lom = np.zeros((lotl_mask.shape[0], nbl), dtype=lotl_mask.dtype)
            # for bi, (fd1, fd2) in enumerate(bls):
            #     b1, b2 = feedno.index(fd1), feedno.index(fd2)
            #     lom[:, bi] = lotl_mask[:, b1, b2]
            # lom = mpiarray.MPIArray.wrap(lom, axis=0, comm=ts.comm)
            # lom = lom.redistribute(axis=1).local_array.reshape(nt, nf, 2, -1)
            # ts.local_vis_mask[start_ind:end_ind, :, pol.index('xx')] |= lom[:, :, 0]
            # ts.local_vis_mask[start_ind:end_ind, :, pol.index('yy')] |= lom[:, :, 1]

            # subtract the vis of calibrator from self.vis
            if subtract_src:
                nbl = len(bls)
                lv = np.zeros((lsrc_vis.shape[0], nbl), dtype=lsrc_vis.dtype)
                for bi, (fd1, fd2) in enumerate(bls):
                    b1, b2 = feedno.index(fd1), feedno.index(fd2)
                    lv[:, bi] = lsrc_vis[:, b1, b2]
                lv = mpiarray.MPIArray.wrap(lv, axis=0, comm=ts.comm)
                lv = lv.redistribute(axis=1).local_array.reshape(nt, nf, 2, -1)
                if replace_with_src:
                    ts.local_vis[start_ind:end_ind, :,
                                 pol.index('xx')] = lv[:, :, 0]
                    ts.local_vis[start_ind:end_ind, :,
                                 pol.index('yy')] = lv[:, :, 1]
                else:
                    if 'ns_on' in ts.iterkeys():
                        lv[ts['ns_on']
                           [start_ind:
                            end_ind]] = 0  # avoid ns_on signal to become nan
                    ts.local_vis[start_ind:end_ind, :,
                                 pol.index('xx')] -= lv[:, :, 0]
                    ts.local_vis[start_ind:end_ind, :,
                                 pol.index('yy')] -= lv[:, :, 1]

                del lv

            if not save_src_vis:
                if subtract_src:
                    del lsrc_vis
            else:
                if tag_output_iter:
                    src_vis_file = output_path(src_vis_file,
                                               iteration=self.iteration)
                else:
                    src_vis_file = output_path(src_vis_file)
                # create file and allocate space first by rank0
                if mpiutil.rank0:
                    with h5py.File(src_vis_file, 'w') as f:
                        # allocate space
                        shp = (nt, nf, 2, nfeed, nfeed)
                        f.create_dataset('sky_vis', shp, dtype=lsky_vis.dtype)
                        f.create_dataset('src_vis', shp, dtype=lsrc_vis.dtype)
                        f.create_dataset('outlier_vis',
                                         shp,
                                         dtype=lotl_vis.dtype)
                        # f.create_dataset('outlier_mask', shp, dtype=lotl_mask.dtype)
                        f.attrs['calibrator'] = calibrator
                        f.attrs['dim'] = 'time, freq, pol, feed, feed'
                        try:
                            f.attrs['time'] = ts.time[start_ind:end_ind]
                        except RuntimeError:
                            f.create_dataset('time',
                                             data=ts.time[start_ind:end_ind])
                            f.attrs['time'] = '/time'
                        f.attrs['freq'] = freq
                        f.attrs['pol'] = np.array(['xx', 'yy'])
                        f.attrs['feed'] = np.array(feedno)

                mpiutil.barrier()

                # write data to file
                for i in range(10):
                    try:
                        # NOTE: if write simultaneously, will loss data with processes distributed in several nodes
                        for ri in xrange(mpiutil.size):
                            if ri == mpiutil.rank:
                                with h5py.File(src_vis_file, 'r+') as f:
                                    for ii, (ti, fi,
                                             pi) in enumerate(tfp_linds):
                                        ti_ = ti - start_ind
                                        pi_ = gain_pd[pol[pi]]
                                        f['sky_vis'][ti_, fi,
                                                     pi_] = lsky_vis[ii]
                                        f['src_vis'][ti_, fi,
                                                     pi_] = lsrc_vis[ii]
                                        f['outlier_vis'][ti_, fi,
                                                         pi_] = lotl_vis[ii]
                                        # f['outlier_mask'][ti_, fi, pi_] = lotl_mask[ii]
                            mpiutil.barrier()
                        break
                    except IOError:
                        time.sleep(0.5)
                        continue
                else:
                    raise RuntimeError('Could not open file: %s...' %
                                       src_vis_file)

                del lsrc_vis
                del lsky_vis
                del lotl_vis
                # del lotl_mask

                mpiutil.barrier()

            if apply_gain or save_gain:
                # flag outliers in lGain along each feed
                lG_abs = np.full_like(lGain, np.nan, dtype=lGain.real.dtype)
                for i in range(lGain.shape[0]):
                    valid_inds = np.where(np.isfinite(lGain[i]))[0]
                    if len(valid_inds) > 3:
                        vabs = np.abs(lGain[i, valid_inds])
                        vmed = np.median(vabs)
                        vabs_diff = np.abs(vabs - vmed)
                        vmad = np.median(vabs_diff) / 0.6745
                        if reserve_high_gain:
                            # reserve significantly higher ones, flag only significantly lower ones
                            lG_abs[i, valid_inds] = np.where(
                                vmed - vabs > 3.0 * vmad, np.nan, vabs)
                        else:
                            # flag both significantly higher and lower ones
                            lG_abs[i, valid_inds] = np.where(
                                vabs_diff > 3.0 * vmad, np.nan, vabs)

                # choose data slice near the transit time
                li = max(start_ind, transit_ind - 10) - start_ind
                hi = min(end_ind, transit_ind + 10 + 1) - start_ind
                ci = transit_ind - start_ind  # center index for transit_ind
                # compute s_top for this time range
                n0 = np.zeros(((hi - li), 3))
                for ti, jt in enumerate(ts.time[start_ind:end_ind][li:hi]):
                    aa.set_jultime(jt)
                    s.compute(aa)
                    n0[ti] = s.get_crds('top', ncrd=3)
                if save_phs_change:
                    n0t = np.zeros((nt, 3))
                    for ti, jt in enumerate(ts.time[start_ind:end_ind]):
                        aa.set_jultime(jt)
                        s.compute(aa)
                        n0t[ti] = s.get_crds('top', ncrd=3)

                # get the positions of feeds
                feedpos = ts['feedpos'][:]

                # wrap and redistribute Gain and flagged G_abs
                Gain = mpiarray.MPIArray.wrap(lGain, axis=0, comm=ts.comm)
                Gain = Gain.redistribute(axis=1).reshape(
                    nt, nf, 2, None).redistribute(axis=0).reshape(
                        None, nf * 2 * nfeed).redistribute(axis=1)
                G_abs = mpiarray.MPIArray.wrap(lG_abs, axis=0, comm=ts.comm)
                G_abs = G_abs.redistribute(axis=1).reshape(
                    nt, nf, 2, None).redistribute(axis=0).reshape(
                        None, nf * 2 * nfeed).redistribute(axis=1)

                fpd_inds = list(
                    itertools.product(range(nf), range(2),
                                      range(nfeed)))  # only for xx and yy
                fpd_linds = mpiutil.mpilist(fpd_inds,
                                            method='con',
                                            comm=ts.comm)
                del fpd_inds
                # create data to save the solved gain for each feed
                lgain = np.full((len(fpd_linds), ), cnan,
                                dtype=Gain.dtype)  # gain for each feed
                if save_phs_change:
                    lphs = np.full((nt, len(fpd_linds)),
                                   np.nan,
                                   dtype=Gain.real.dtype
                                   )  # phase change with time for each feed

                # check for conj
                num_conj = 0
                for ii, (fi, pi, di) in enumerate(fpd_linds):
                    y = G_abs.local_array[li:hi, ii]
                    inds = np.where(np.isfinite(y))[0]
                    if len(inds) >= max(4, 0.5 * len(y)):
                        # get the approximate magnitude by averaging the central G_abs
                        # solve phase by least square fit
                        ui = (feedpos[di] - feedpos[0]) * (
                            1.0e6 * freq[fi]
                        ) / const.c  # position of this feed (relative to the first feed) in unit of wavelength
                        exp_factor = np.exp(2.0J * np.pi * np.dot(n0, ui))
                        ef = exp_factor
                        Gi = Gain.local_array[li:hi, ii]
                        e_phs = np.dot(ef[inds].conj(),
                                       Gi[inds] / y[inds]) / len(inds)
                        ea = np.abs(e_phs)
                        e_phs_conj = np.dot(ef[inds],
                                            Gi[inds] / y[inds]) / len(inds)
                        eac = np.abs(e_phs_conj)
                        if eac > ea:
                            num_conj += 1
                # reduce num_conj from all processes
                num_conj = mpiutil.allreduce(num_conj, comm=ts.comm)
                if num_conj > 0.5 * (nf * 2 * nfeed):  # 2 for 2 pols
                    if mpiutil.rank0:
                        print '!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!'
                        print '!!!   Detect data should be their conjugate...   !!!'
                        print '!!!   Correct it automatically...                !!!'
                        print '!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!'
                    mpiutil.barrier()
                    # correct vis
                    ts.local_vis[:] = ts.local_vis.conj()
                    # correct G
                    Gain.local_array[:] = Gain.local_array.conj()

                # solve for gain
                for ii, (fi, pi, di) in enumerate(fpd_linds):
                    y = G_abs.local_array[li:hi, ii]
                    inds = np.where(np.isfinite(y))[0]
                    if len(inds) >= max(4, 0.5 * len(y)):
                        # get the approximate magnitude by averaging the central G_abs
                        mag = np.mean(
                            y[inds]
                        )  # = \sqrt(lambda^2 * Sc / (2 k_B)) * |gi| Ai
                        # solve phase by least square fit
                        ui = (feedpos[di] - feedpos[0]) * (
                            1.0e6 * freq[fi]
                        ) / const.c  # position of this feed (relative to the first feed) in unit of wavelength
                        exp_factor = np.exp(2.0J * np.pi * np.dot(n0, ui))
                        ef = exp_factor
                        Gi = Gain.local_array[li:hi, ii]
                        e_phs = np.dot(ef[inds].conj(), Gi[inds] /
                                       y[inds]) / len(inds)  # the phase of gi
                        ea = np.abs(e_phs)
                        if np.abs(ea - 1.0) < 0.1:
                            # compute gain for this feed
                            lgain[
                                ii] = mag * e_phs  # \sqrt(lambda^2 * Sc / (2 k_B)) * gi Ai
                            if save_phs_change:
                                lphs[:, ii] = np.angle(
                                    np.exp(-2.0J * np.pi * np.dot(n0t, ui)) *
                                    Gain.local_array[:, ii])
                        else:
                            e_phs_conj = np.dot(ef[inds],
                                                Gi[inds] / y[inds]) / len(inds)
                            eac = np.abs(e_phs_conj)
                            if eac > ea:
                                if np.abs(eac - 1.0) < 0.01:
                                    print 'feedno = %d, fi = %d, pol = %s: may need to be conjugated' % (
                                        feedno[di], fi, gain_pd[pi])
                            else:
                                print 'feedno = %d, fi = %d, pol = %s: maybe wrong abs(e_phs): %s' % (
                                    feedno[di], fi, gain_pd[pi], ea)

                # gather local gain
                gain = mpiutil.gather_array(lgain,
                                            axis=0,
                                            root=None,
                                            comm=ts.comm)
                del lgain
                gain = gain.reshape(nf, 2, nfeed)
                if save_phs_change:
                    phs = mpiutil.gather_array(lphs,
                                               axis=1,
                                               root=0,
                                               comm=ts.comm)
                    del lphs
                    if mpiutil.rank0:
                        phs = phs.reshape(nt, nf, 2, nfeed)

                # normalize to get the exact gain
                Sc = s.get_jys(1.0e-3 * freq)
                # Omega = aa.ants[0].beam.Omega ### TODO: implement Omega for dish
                Ai = aa.ants[0].beam.response(n0[ci - li])
                lmd = const.c / (1.0e6 * freq)
                factor = np.sqrt(
                    (lmd**2 * 1.0e-26 * Sc) /
                    (2 * const.k_B)) * Ai  # NOTE: 1Jy = 1.0e-26 W m^-2 Hz^-1
                gain /= factor[:, np.newaxis, np.newaxis]

                # apply gain to vis
                if apply_gain:
                    for fi in range(nf):
                        for pi in [pol.index('xx'), pol.index('yy')]:
                            pi_ = gain_pd[pol[pi]]
                            for bi, (fd1, fd2) in enumerate(
                                    ts['blorder'].local_data):
                                g1 = gain[fi, pi_, feedno.index(fd1)]
                                g2 = gain[fi, pi_, feedno.index(fd2)]
                                if np.isfinite(g1) and np.isfinite(g2):
                                    if fd1 == fd2:
                                        # auto-correlation should be real
                                        ts.local_vis[:, fi, pi,
                                                     bi] /= (g1 *
                                                             np.conj(g2)).real
                                    else:
                                        ts.local_vis[:, fi, pi,
                                                     bi] /= (g1 * np.conj(g2))
                                else:
                                    # mask the un-calibrated vis
                                    ts.local_vis_mask[:, fi, pi, bi] = True

                    # in unit K after the calibration
                    ts.vis.attrs['unit'] = 'K'

                # save gain to file
                if save_gain:
                    if tag_output_iter:
                        gain_file = output_path(gain_file,
                                                iteration=self.iteration)
                    else:
                        gain_file = output_path(gain_file)
                    if mpiutil.rank0:
                        with h5py.File(gain_file, 'w') as f:
                            # allocate space for Gain
                            dset = f.create_dataset('Gain', (nt, nf, 2, nfeed),
                                                    dtype=Gain.dtype)
                            dset.attrs['calibrator'] = calibrator
                            dset.attrs['dim'] = 'time, freq, pol, feed'
                            try:
                                dset.attrs['time'] = ts.time[start_ind:end_ind]
                            except RuntimeError:
                                f.create_dataset(
                                    'time', data=ts.time[start_ind:end_ind])
                                dset.attrs['time'] = '/time'
                            dset.attrs['freq'] = freq
                            dset.attrs['pol'] = np.array(['xx', 'yy'])
                            dset.attrs['feed'] = np.array(feedno)
                            dset.attrs['transit_ind'] = transit_ind
                            # save gain
                            dset = f.create_dataset('gain', data=gain)
                            dset.attrs['calibrator'] = calibrator
                            dset.attrs['dim'] = 'freq, pol, feed'
                            dset.attrs['freq'] = freq
                            dset.attrs['pol'] = np.array(['xx', 'yy'])
                            dset.attrs['feed'] = np.array(feedno)
                            # save phs
                            if save_phs_change:
                                f.create_dataset('phs', data=phs)

                    mpiutil.barrier()

                    # save Gain
                    for i in range(10):
                        try:
                            # NOTE: if write simultaneously, will loss data with processes distributed in several nodes
                            for ri in xrange(mpiutil.size):
                                if ri == mpiutil.rank:
                                    with h5py.File(gain_file, 'r+') as f:
                                        for ii, (ti, fi,
                                                 pi) in enumerate(tfp_linds):
                                            ti_ = ti - start_ind
                                            pi_ = gain_pd[pol[pi]]
                                            f['Gain'][ti_, fi, pi_] = lGain[ii]
                                mpiutil.barrier()
                            break
                        except IOError:
                            time.sleep(0.5)
                            continue
                    else:
                        raise RuntimeError('Could not open file: %s...' %
                                           gain_file)

                    mpiutil.barrier()

        return super(PsCal, self).process(ts)
Exemple #3
0
    def process(self, ts):

        assert isinstance(
            ts, Timestream
        ), '%s only works for Timestream object' % self.__class__.__name__

        #        if mpiutil.rank0:
        #            saveVmat = []

        calibrator = self.params['calibrator']
        catalog = self.params['catalog']
        vis_conj = self.params['vis_conj']
        zero_diag = self.params['zero_diag']
        span = self.params['span']
        reserve_high_gain = self.params['reserve_high_gain']
        plot_figs = self.params['plot_figs']
        fig_prefix = self.params['fig_name']
        tag_output_iter = self.params['tag_output_iter']
        save_src_vis = self.params['save_src_vis']
        src_vis_file = self.params['src_vis_file']
        subtract_src = self.params['subtract_src']
        apply_gain = self.params['apply_gain']
        save_gain = self.params['save_gain']
        save_phs_change = self.params['save_phs_change']
        gain_file = self.params['gain_file']
        temperature_convert = self.params['temperature_convert']
        show_progress = self.params['show_progress']
        progress_step = self.params['progress_step']
        srcdict = self.params['srcdict']
        max_iter = self.params['max_iter']

        #        if mpiutil.rank0:
        #            print(ts.keys())
        #            print(ts.attrs.keys())
        #
        #        import sys
        #        sys.exit()

        if save_src_vis or subtract_src or apply_gain or save_gain:
            pol_type = ts['pol'].attrs['pol_type']
            if pol_type != 'linear':
                raise RuntimeError('Can not do ps_cal for pol_type: %s' %
                                   pol_type)

            ts.redistribute('baseline')

            feedno = ts['feedno'][:].tolist()
            pol = [ts.pol_dict[p] for p in ts['pol'][:]]  # as string
            gain_pd = {
                'xx': 0,
                'yy': 1,
                0: 'xx',
                1: 'yy'
            }  # for gain related op
            bls = mpiutil.gather_array(ts.local_bl[:], root=None, comm=ts.comm)
            # # antpointing = np.radians(ts['antpointing'][-1, :, :]) # radians
            # transitsource = ts['transitsource'][:]
            # transit_time = transitsource[-1, 0] # second, sec1970
            # int_time = ts.attrs['inttime'] # second

            # calibrator
            #see whether the source should be observed
            #only for single calibrator
            if ts.is_dish:
                obsd_sources = ts['transitsource'].attrs['srcname']
                obsd_sources = obsd_sources.split(',')
                #                srcdict = {'cas':'CassiopeiaA'}
                src_index = 0.1
                for src_short, src_long in srcdict.items():
                    #                    print(src_short,src_long,obsd_sources,calibrator)
                    if (src_short == calibrator) and (src_long
                                                      in obsd_sources):
                        src_index = obsd_sources.index(src_long)
                        break
                else:
                    raise Exception(
                        'the transit source is not included in the observation plan!'
                    )
                obs_transit_time = ts['transitsource'][src_index][0]

#            srclist, cutoff, catalogs = a.scripting.parse_srcs(calibrator, catalog)
#            cat = a.src.get_catalog(srclist, cutoff, catalogs)
#            assert(len(cat) == 1), 'Allow only one calibrator'
#            s = cat.values()[0]

# get the calibrator
            try:
                s = calibrators.get_src(calibrator)
            except KeyError:
                if mpiutil.rank0:
                    print 'Calibrator %s is unavailable, available calibrators are:'
                    for key, d in calibrators.src_data.items():
                        print '%8s  ->  %12s' % (key, d[0])
                raise RuntimeError('Calibrator %s is unavailable')
            if mpiutil.rank0:
                print 'Try to calibrate with %s...' % s.src_name
#            if mpiutil.rank0:
#                print 'Calibrating for source %s with' % calibrator,
#                print 'strength', s._jys, 'Jy',
#                print 'measured at', s.mfreq, 'GHz',
#                print 'with index', s.index

# get transit time of calibrator
# array
            aa = ts.array
            aa.set_jultime(ts['jul_date'][0])  # the first obs time point
            next_transit = aa.next_transit(s)
            next_transit = aa.next_transit(
                s) if not ts.is_dish else ephem.date(
                    datetime.utcfromtimestamp(obs_transit_time))
            transit_time = a.phs.ephem2juldate(next_transit)  # Julian date
            # get time zone
            pattern = '[-+]?\d+'
            tz = re.search(pattern, ts.attrs['timezone']).group()
            tz = int(tz)

            local_next_transit = ephem.Date(
                next_transit + tz * ephem.hour)  # plus 8h to get Beijing time
            # if transit_time > ts['jul_date'][-1]:
            #========================================================
            if ts.is_dish:
                if (transit_time >
                        max(ts['jul_date'][-1],
                            ts['jul_date'][:].max())) or (transit_time < min(
                                ts['jul_date'][0], ts['jul_date'][:].min())):
                    #                    raise RuntimeError('dish data does not contain local transit time %s of source %s' % (local_next_transit, calibrator))
                    raise NoTransit(
                        'Dish data does not contain local transit time %s of source %s'
                        % (local_next_transit, calibrator))
                transit_inds = [
                    np.searchsorted(ts['jul_date'][:], transit_time)
                ]
                peak_span = int(np.around(10. / ts.attrs['inttime']))
                peak_start = ts['jul_date'][transit_inds[0] - peak_span]
                peak_end = ts['jul_date'][transit_inds[0] + peak_span]
                tspan = (peak_end - peak_start) * 86400.
                if mpiutil.rank0:
                    print('transit peak: %s' % local_next_transit)
                    print('%d point (should be about 10s) previous: %s' %
                          (peak_span,
                           ephem.date(
                               a.phs.juldate2ephem(peak_start) +
                               tz * ephem.hour)))
                    print(
                        '%d point (should be about 10s) later: %s' %
                        (peak_span,
                         ephem.date(
                             a.phs.juldate2ephem(peak_end) + tz * ephem.hour)))
                if tspan > 30.:
                    warnings.warn(
                        'Peak of transit is not continuous in the data! May lead to poor performance!'
                    )
            else:
                if transit_time > max(ts['jul_date'][-1],
                                      ts['jul_date'][:].max()):
                    #                    raise RuntimeError('Cylinder data does not contain local transit time %s of source %s' % (local_next_transit, calibrator))
                    raise NoTransit(
                        'Cylinder data does not contain local transit time %s of source %s'
                        % (local_next_transit, calibrator))

                # the first transit index
                transit_inds = [
                    np.searchsorted(ts['jul_date'][:], transit_time)
                ]
                # find all other transit indices
                aa.set_jultime(ts['jul_date'][0] + 1.0)
                transit_time = a.phs.ephem2juldate(
                    aa.next_transit(s))  # Julian date
                cnt = 2
                while (transit_time <= ts['jul_date'][-1]):
                    transit_inds.append(
                        np.searchsorted(ts['jul_date'][:], transit_time))
                    aa.set_jultime(ts['jul_date'][0] + 1.0 * cnt)
                    # Julian date
                    #                transit_time = a.phs.ephem2juldate(aa.next_transit(s) if not ts.is_dish else ephem.date(datetime.utcfromtimestamp(obs_transit_time)))
                    transit_time = a.phs.ephem2juldate(aa.next_transit(s))
                    cnt += 1
#========================================================

            if mpiutil.rank0:
                print 'transit ind of %s: %s, time: %s' % (
                    s.src_name, transit_inds, local_next_transit)

            if (not ts.ps_first) and ts.interp_all_masked:
                raise NotEnoughPointToInterpolateError(
                    'More than 80% of the data was masked due to shortage of noise points for interpolation(need at least 4 to perform cubic spline)! The pointsource calibration may not be done due to too many masked points!'
                )
            ### now only use the first transit point to do the cal
            ### may need to improve in the future
            transit_ind = transit_inds[0]
            int_time = ts.attrs['inttime']  # second
            start_ind = transit_ind - np.int(span / int_time)
            end_ind = transit_ind + np.int(
                span /
                int_time) + 1  # plus 1 to make transit_ind is at the center

            start_ind = max(0, start_ind)
            end_ind = min(end_ind, ts.vis.shape[0])

            if vis_conj:
                ts.local_vis[:] = ts.local_vis.conj()

            nt = end_ind - start_ind
            t_inds = range(start_ind, end_ind)
            freq = ts.freq[:]
            nf = len(freq)
            nlb = len(ts.local_bl[:])
            nfeed = len(feedno)
            tfp_inds = list(
                itertools.product(
                    t_inds, range(nf),
                    [pol.index('xx'), pol.index('yy')]))  # only for xx and yy
            ns, ss, es = mpiutil.split_all(len(tfp_inds), comm=ts.comm)
            # gather data to make each process to have its own data which has all bls
            for ri, (ni, si, ei) in enumerate(zip(ns, ss, es)):
                lvis = np.zeros((ni, nlb), dtype=ts.vis.dtype)
                lvis_mask = np.zeros((ni, nlb), dtype=ts.vis_mask.dtype)
                for ii, (ti, fi, pi) in enumerate(tfp_inds[si:ei]):
                    lvis[ii] = ts.local_vis[ti, fi, pi]
                    lvis_mask[ii] = ts.local_vis_mask[ti, fi, pi]
                # gather vis from all process for separate bls
                gvis = mpiutil.gather_array(lvis,
                                            axis=1,
                                            root=ri,
                                            comm=ts.comm)
                gvis_mask = mpiutil.gather_array(lvis_mask,
                                                 axis=1,
                                                 root=ri,
                                                 comm=ts.comm)
                if ri == mpiutil.rank:
                    tfp_linds = tfp_inds[si:ei]  # inds for this process
                    this_vis = gvis
                    this_vis_mask = gvis_mask
            del tfp_inds
            del lvis
            del lvis_mask
            tfp_len = len(tfp_linds)

            cnan = complex(np.nan, np.nan)  # complex nan
            if save_src_vis or subtract_src:
                # save calibrator src vis
                lsrc_vis = np.full((tfp_len, nfeed, nfeed),
                                   cnan,
                                   dtype=ts.vis.dtype)
                if save_src_vis:
                    # save sky vis
                    lsky_vis = np.full((tfp_len, nfeed, nfeed),
                                       cnan,
                                       dtype=ts.vis.dtype)
                    # save outlier vis
                    lotl_vis = np.full((tfp_len, nfeed, nfeed),
                                       cnan,
                                       dtype=ts.vis.dtype)

            if apply_gain or save_gain:
                lGain = np.full((tfp_len, nfeed), cnan, dtype=ts.vis.dtype)

            # find indices mapping between Vmat and vis
            # bis = range(nbl)
            bis_conj = []  # indices that shold be conj
            mis = [
            ]  # indices in the nfeed x nfeed matrix by flatten it to a vector
            mis_conj = [
            ]  # indices (of conj vis) in the nfeed x nfeed matrix by flatten it to a vector
            for bi, (fdi, fdj) in enumerate(bls):
                ai, aj = feedno.index(fdi), feedno.index(fdj)
                mis.append(ai * nfeed + aj)
                if ai != aj:
                    bis_conj.append(bi)
                    mis_conj.append(aj * nfeed + ai)

            # construct visibility matrix for a single time, freq, pol
            Vmat = np.full((nfeed, nfeed), cnan, dtype=ts.vis.dtype)
            # get flus of the calibrator in the observing frequencies
            Sc = s.get_jys(freq)
            if show_progress and mpiutil.rank0:
                pg = progress.Progress(tfp_len, step=progress_step)
            for ii, (ti, fi, pi) in enumerate(tfp_linds):
                if show_progress and mpiutil.rank0:
                    pg.show(ii)
                # when noise on, just pass
                if 'ns_on' in ts.iterkeys() and ts['ns_on'][ti]:
                    continue
                # aa.set_jultime(ts['jul_date'][ti])
                # s.compute(aa)
                # get the topocentric coordinate of the calibrator at the current time
                # s_top = s.get_crds('top', ncrd=3)
                # aa.sim_cache(cat.get_crds('eq', ncrd=3)) # for compute bm_response and sim
                Vmat.flat[mis] = np.ma.array(
                    this_vis[ii], mask=this_vis_mask[ii]).filled(cnan)
                Vmat.flat[mis_conj] = np.ma.array(
                    this_vis[ii, bis_conj],
                    mask=this_vis_mask[ii, bis_conj]).conj().filled(cnan)

                #                if mpiutil.rank0:
                #                    saveVmat += [Vmat.copy()]

                if save_src_vis:
                    lsky_vis[ii] = Vmat

                # set invalid val to 0
                invalid = ~np.isfinite(Vmat)  # a bool array
                # if too many masks

                if np.where(invalid)[0].shape[0] > 0.3 * nfeed**2:
                    continue
                Vmat[invalid] = 0
                #                # if all are zeros
                if np.allclose(Vmat, 0.0):
                    continue

                # fill diagonal of Vmat to 0
                if zero_diag:
                    np.fill_diagonal(Vmat, 0)

                # initialize the outliers
                med = np.median(Vmat.real) + 1.0J * np.median(Vmat.imag)
                diff = Vmat - med
                S0 = np.where(
                    np.abs(diff) > 3.0 * rpca_decomp.MAD(Vmat), diff, 0)
                # stable PCA decomposition
                #                V0, S = rpca_decomp.decompose(Vmat, rank=1, S=S0, max_iter=100, threshold='hard', tol=1.0e-6, debug=False)
                V0, S = rpca_decomp.decompose(Vmat,
                                              rank=1,
                                              S=S0,
                                              max_iter=max_iter,
                                              threshold='hard',
                                              tol=1.0e-6,
                                              debug=False)
                #                V0, S = rpca_decomp.decompose(Vmat, rank=1, S=S0, max_iter=1, threshold='hard', tol=1., debug=False)
                if save_src_vis or subtract_src:
                    lsrc_vis[ii] = V0
                    if save_src_vis:
                        lotl_vis[ii] = S

                # plot
                if plot_figs:
                    ind = ti - start_ind
                    # plot Vmat
                    plt.figure(figsize=(13, 5))
                    plt.subplot(121)
                    plt.imshow(Vmat.real,
                               aspect='equal',
                               origin='lower',
                               interpolation='nearest')
                    plt.colorbar(shrink=1.0)
                    plt.subplot(122)
                    plt.imshow(Vmat.imag,
                               aspect='equal',
                               origin='lower',
                               interpolation='nearest')
                    plt.colorbar(shrink=1.0)
                    fig_name = '%s_V_%d_%d_%s.png' % (fig_prefix, ind, fi,
                                                      pol[pi])
                    if tag_output_iter:
                        fig_name = output_path(fig_name,
                                               iteration=self.iteration)
                    else:
                        fig_name = output_path(fig_name)
                    plt.savefig(fig_name)
                    plt.close()
                    # plot V0
                    plt.figure(figsize=(13, 5))
                    plt.subplot(121)
                    plt.imshow(V0.real,
                               aspect='equal',
                               origin='lower',
                               interpolation='nearest')
                    plt.colorbar(shrink=1.0)
                    plt.subplot(122)
                    plt.imshow(V0.imag,
                               aspect='equal',
                               origin='lower',
                               interpolation='nearest')
                    plt.colorbar(shrink=1.0)
                    fig_name = '%s_V0_%d_%d_%s.png' % (fig_prefix, ind, fi,
                                                       pol[pi])
                    if tag_output_iter:
                        fig_name = output_path(fig_name,
                                               iteration=self.iteration)
                    else:
                        fig_name = output_path(fig_name)
                    plt.savefig(fig_name)
                    plt.close()
                    # plot S
                    plt.figure(figsize=(13, 5))
                    plt.subplot(121)
                    plt.imshow(S.real,
                               aspect='equal',
                               origin='lower',
                               interpolation='nearest')
                    plt.colorbar(shrink=1.0)
                    plt.subplot(122)
                    plt.imshow(S.imag,
                               aspect='equal',
                               origin='lower',
                               interpolation='nearest')
                    plt.colorbar(shrink=1.0)
                    fig_name = '%s_S_%d_%d_%s.png' % (fig_prefix, ind, fi,
                                                      pol[pi])
                    if tag_output_iter:
                        fig_name = output_path(fig_name,
                                               iteration=self.iteration)
                    else:
                        fig_name = output_path(fig_name)
                    plt.savefig(fig_name)
                    plt.close()
                    # plot N
                    N = Vmat - V0 - S
                    plt.figure(figsize=(13, 5))
                    plt.subplot(121)
                    plt.imshow(N.real,
                               aspect='equal',
                               origin='lower',
                               interpolation='nearest')
                    plt.colorbar(shrink=1.0)
                    plt.subplot(122)
                    plt.imshow(N.imag,
                               aspect='equal',
                               origin='lower',
                               interpolation='nearest')
                    plt.colorbar(shrink=1.0)
                    fig_name = '%s_N_%d_%d_%s.png' % (fig_prefix, ind, fi,
                                                      pol[pi])
                    if tag_output_iter:
                        fig_name = output_path(fig_name,
                                               iteration=self.iteration)
                    else:
                        fig_name = output_path(fig_name)
                    plt.savefig(fig_name)
                    plt.close()

                if apply_gain or save_gain:
                    #                    if mpiutil.rank0:
                    #                        np.save('Vmat',saveVmat)
                    e, U = la.eigh(V0 / Sc[fi], eigvals=(nfeed - 1, nfeed - 1))
                    g = U[:, -1] * e[-1]**0.5
                    if g[0].real < 0:
                        g *= -1.0  # make all g[0] phase 0, instead of pi
                    lGain[ii] = g

                    # plot Gain
                    liplot = max(start_ind, transit_ind - 1)
                    hiplot = min(end_ind, transit_ind + 1 + 1)
                    if plot_figs and ti >= liplot and ti <= hiplot:
                        #                    if ti >= liplot and ti <= hiplot:
                        ind = ti - start_ind
                        plt.figure()
                        plt.plot(feedno, g.real, 'b-', label='real')
                        plt.plot(feedno, g.real, 'bo')
                        plt.plot(feedno, g.imag, 'g-', label='imag')
                        plt.plot(feedno, g.imag, 'go')
                        plt.plot(feedno, np.abs(g), 'r-', label='abs')
                        plt.plot(feedno, np.abs(g), 'ro')
                        plt.xlim(feedno[0] - 1, feedno[-1] + 1)
                        yl, yh = plt.ylim()
                        plt.ylim(yl, yh + (yh - yl) / 5)
                        plt.xlabel('Feed number')
                        plt.legend()
                        fig_name = '%s_ants_%d_%d_%s.png' % (fig_prefix, ind,
                                                             fi, pol[pi])
                        print('plot %s' % fig_name)
                        if tag_output_iter:
                            fig_name = output_path(fig_name,
                                                   iteration=self.iteration)
                        else:
                            fig_name = output_path(fig_name)
                        plt.savefig(fig_name)
                        plt.close()

            # subtract the vis of calibrator from self.vis
            if subtract_src:
                nbl = len(bls)
                lv = np.zeros((lsrc_vis.shape[0], nbl), dtype=lsrc_vis.dtype)
                for bi, (fd1, fd2) in enumerate(bls):
                    b1, b2 = feedno.index(fd1), feedno.index(fd2)
                    lv[:, bi] = lsrc_vis[:, b1, b2]
                lv = mpiarray.MPIArray.wrap(lv, axis=0, comm=ts.comm)
                lv = lv.redistribute(axis=1).local_array.reshape(nt, nf, 2, -1)
                if 'ns_on' in ts.iterkeys():
                    lv[ts['ns_on']
                       [start_ind:
                        end_ind]] = 0  # avoid ns_on signal to become nan
                ts.local_vis[start_ind:end_ind, :,
                             pol.index('xx')] -= lv[:, :, 0]
                ts.local_vis[start_ind:end_ind, :,
                             pol.index('yy')] -= lv[:, :, 1]

                del lv

            if not save_src_vis:
                if subtract_src:
                    del lsrc_vis
            else:
                if tag_output_iter:
                    src_vis_file = output_path(src_vis_file,
                                               iteration=self.iteration)
                else:
                    src_vis_file = output_path(src_vis_file)
                # create file and allocate space first by rank0
                if mpiutil.rank0:
                    with h5py.File(src_vis_file, 'w') as f:
                        # allocate space
                        shp = (nt, nf, 2, nfeed, nfeed)
                        f.create_dataset('sky_vis', shp, dtype=lsky_vis.dtype)
                        f.create_dataset('src_vis', shp, dtype=lsrc_vis.dtype)
                        f.create_dataset('outlier_vis',
                                         shp,
                                         dtype=lotl_vis.dtype)
                        f.attrs['calibrator'] = calibrator
                        f.attrs['dim'] = 'time, freq, pol, feed, feed'
                        try:
                            f.attrs['time'] = ts.time[start_ind:end_ind]
                        except RuntimeError:
                            f.create_dataset('time',
                                             data=ts.time[start_ind:end_ind])
                            f.attrs['time'] = '/time'
                        f.attrs['freq'] = freq
                        f.attrs['pol'] = np.array(['xx', 'yy'])
                        f.attrs['feed'] = np.array(feedno)

                mpiutil.barrier()

                # write data to file
                for i in range(10):
                    try:
                        # NOTE: if write simultaneously, will loss data with processes distributed in several nodes
                        for ri in xrange(mpiutil.size):
                            if ri == mpiutil.rank:
                                with h5py.File(src_vis_file, 'r+') as f:
                                    for ii, (ti, fi,
                                             pi) in enumerate(tfp_linds):
                                        ti_ = ti - start_ind
                                        pi_ = gain_pd[pol[pi]]
                                        f['sky_vis'][ti_, fi,
                                                     pi_] = lsky_vis[ii]
                                        f['src_vis'][ti_, fi,
                                                     pi_] = lsrc_vis[ii]
                                        f['outlier_vis'][ti_, fi,
                                                         pi_] = lotl_vis[ii]
                            mpiutil.barrier()
                        break
                    except IOError:
                        time.sleep(0.5)
                        continue
                else:
                    raise RuntimeError('Could not open file: %s...' %
                                       src_vis_file)

                del lsrc_vis
                del lsky_vis
                del lotl_vis

                mpiutil.barrier()

            if apply_gain or save_gain:
                # flag outliers in lGain along each feed
                lG_abs = np.full_like(lGain, np.nan, dtype=lGain.real.dtype)
                for i in range(lGain.shape[0]):
                    valid_inds = np.where(np.isfinite(lGain[i]))[0]
                    if len(valid_inds) > 3:
                        vabs = np.abs(lGain[i, valid_inds])
                        vmed = np.median(vabs)
                        vabs_diff = np.abs(vabs - vmed)
                        vmad = np.median(vabs_diff) / 0.6745
                        if reserve_high_gain:
                            # reserve significantly higher ones, flag only significantly lower ones
                            lG_abs[i, valid_inds] = np.where(
                                vmed - vabs > 3.0 * vmad, np.nan, vabs)
                        else:
                            # flag both significantly higher and lower ones
                            lG_abs[i, valid_inds] = np.where(
                                vabs_diff > 3.0 * vmad, np.nan, vabs)

                # choose data slice near the transit time
                li = max(start_ind, transit_ind - 10) - start_ind
                hi = min(end_ind, transit_ind + 10 + 1) - start_ind
                # compute s_top for this time range
                n0 = np.zeros(((hi - li), 3))
                for ti, jt in enumerate(ts.time[start_ind:end_ind][li:hi]):
                    aa.set_jultime(jt)
                    s.compute(aa)
                    n0[ti] = s.get_crds('top', ncrd=3)
                if save_phs_change:
                    n0t = np.zeros((nt, 3))
                    for ti, jt in enumerate(ts.time[start_ind:end_ind]):
                        aa.set_jultime(jt)
                        s.compute(aa)
                        n0t[ti] = s.get_crds('top', ncrd=3)

                # get the positions of feeds
                feedpos = ts['feedpos'][:]

                # wrap and redistribute Gain and flagged G_abs
                Gain = mpiarray.MPIArray.wrap(lGain, axis=0, comm=ts.comm)
                Gain = Gain.redistribute(axis=1).reshape(
                    nt, nf, 2, None).redistribute(axis=0).reshape(
                        None, nf * 2 * nfeed).redistribute(axis=1)
                G_abs = mpiarray.MPIArray.wrap(lG_abs, axis=0, comm=ts.comm)
                G_abs = G_abs.redistribute(axis=1).reshape(
                    nt, nf, 2, None).redistribute(axis=0).reshape(
                        None, nf * 2 * nfeed).redistribute(axis=1)

                fpd_inds = list(
                    itertools.product(range(nf), range(2),
                                      range(nfeed)))  # only for xx and yy
                fpd_linds = mpiutil.mpilist(fpd_inds,
                                            method='con',
                                            comm=ts.comm)
                del fpd_inds
                # create data to save the solved gain for each feed
                lgain = np.full((len(fpd_linds), ), cnan,
                                dtype=Gain.dtype)  # gain for each feed
                if save_phs_change:
                    lphs = np.full((nt, len(fpd_linds)),
                                   np.nan,
                                   dtype=Gain.real.dtype
                                   )  # phase change with time for each feed

                # check for conj
                num_conj = 0
                for ii, (fi, pi, di) in enumerate(fpd_linds):
                    y = G_abs.local_array[li:hi, ii]
                    inds = np.where(np.isfinite(y))[0]
                    if len(inds) >= max(4, 0.5 * len(y)):
                        # get the approximate magnitude by averaging the central G_abs
                        # solve phase by least square fit
                        ui = (feedpos[di] - feedpos[0]) * (
                            1.0e6 * freq[fi]
                        ) / const.c  # position of this feed (relative to the first feed) in unit of wavelength
                        exp_factor = np.exp(2.0J * np.pi * np.dot(n0, ui))
                        ef = exp_factor
                        Gi = Gain.local_array[li:hi, ii]
                        e_phs = np.dot(ef[inds].conj(),
                                       Gi[inds] / y[inds]) / len(inds)
                        ea = np.abs(e_phs)
                        e_phs_conj = np.dot(ef[inds],
                                            Gi[inds] / y[inds]) / len(inds)
                        eac = np.abs(e_phs_conj)
                        if eac > ea:
                            num_conj += 1
                # reduce num_conj from all processes
                num_conj = mpiutil.allreduce(num_conj, comm=ts.comm)
                if num_conj > 0.5 * (nf * 2 * nfeed):  # 2 for 2 pols
                    if mpiutil.rank0:
                        print '!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!'
                        print '!!!   Detect data should be their conjugate...   !!!'
                        print '!!!   Correct it automatically...                !!!'
                        print '!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!'
                    mpiutil.barrier()
                    # correct vis
                    ts.local_vis[:] = ts.local_vis.conj()
                    # correct G
                    Gain.local_array[:] = Gain.local_array.conj()

                # solve for gain
#                feedplotG = []
#                savenum = 0
                for ii, (fi, pi, di) in enumerate(fpd_linds):
                    y = G_abs.local_array[li:hi, ii]
                    inds = np.where(np.isfinite(y))[0]

                    if len(inds) >= max(4, 0.5 * len(y)):
                        # get the approximate magnitude by averaging the central G_abs
                        mag = np.mean(y[inds])
                        # solve phase by least square fit
                        ui = (feedpos[di] - feedpos[0]) * (
                            1.0e6 * freq[fi]
                        ) / const.c  # position of this feed (relative to the first feed) in unit of wavelength
                        exp_factor = np.exp(2.0J * np.pi * np.dot(n0, ui))
                        ef = exp_factor
                        Gi = Gain.local_array[li:hi, ii]

                        #                        feedplotG += [Gi[10]]
                        #                        print(feedplotG)

                        e_phs = np.dot(ef[inds].conj(),
                                       Gi[inds] / y[inds]) / len(inds)
                        ea = np.abs(e_phs)

                        if np.abs(ea - 1.0) < 0.1:
                            # compute gain for this feed
                            lgain[ii] = mag * e_phs
                            if save_phs_change:
                                lphs[:, ii] = np.angle(
                                    np.exp(-2.0J * np.pi * np.dot(n0t, ui)) *
                                    Gain.local_array[:, ii])
                        else:
                            e_phs_conj = np.dot(ef[inds],
                                                Gi[inds] / y[inds]) / len(inds)
                            eac = np.abs(e_phs_conj)
                            if eac > ea:
                                if np.abs(eac - 1.0) < 0.01:
                                    print 'feedno = %d, fi = %d, pol = %s: may need to be conjugated' % (
                                        feedno[di], fi, gain_pd[pi])
                            else:
                                print 'feedno = %d, fi = %d, pol = %s: maybe wrong abs(e_phs): %s' % (
                                    feedno[di], fi, gain_pd[pi], ea)

#                import matplotlib.pyplot as plt
#                import os.path
#                feedplotG = np.array(feedplotG)
#                if mpiutil.rank0:
#                    if os.path.isfile('feedplotG%d.npy'%savenum):
#                        savenum += 1
#                    np.save('feedplotG%d'%savenum,feedplotG)
#                    plt.plot(feedplotG)
#                    plt.savefig('Gfig%d'%savenum)

# gather local gain
                gain = mpiutil.gather_array(lgain,
                                            axis=0,
                                            root=None,
                                            comm=ts.comm)
                del lgain
                gain = gain.reshape(nf, 2, nfeed)
                if save_phs_change:
                    phs = mpiutil.gather_array(lphs,
                                               axis=1,
                                               root=0,
                                               comm=ts.comm)
                    del lphs
                    if mpiutil.rank0:
                        phs = phs.reshape(nt, nf, 2, nfeed)

                # apply gain to vis
                if apply_gain:
                    for fi in range(nf):
                        for pi in [pol.index('xx'), pol.index('yy')]:
                            pi_ = gain_pd[pol[pi]]
                            for bi, (fd1, fd2) in enumerate(
                                    ts['blorder'].local_data):
                                g1 = gain[fi, pi_, feedno.index(fd1)]
                                g2 = gain[fi, pi_, feedno.index(fd2)]
                                if np.isfinite(g1) and np.isfinite(g2):
                                    ts.local_vis[:, fi, pi,
                                                 bi] /= (g1 * np.conj(g2))
                                else:
                                    # mask the un-calibrated vis
                                    ts.local_vis_mask[:, fi, pi, bi] = True

                # save gain to file
                if save_gain:
                    if tag_output_iter:
                        gain_file = output_path(gain_file,
                                                iteration=self.iteration)
                    else:
                        gain_file = output_path(gain_file)
                    if mpiutil.rank0:
                        with h5py.File(gain_file, 'w') as f:
                            # allocate space for Gain
                            dset = f.create_dataset('Gain', (nt, nf, 2, nfeed),
                                                    dtype=Gain.dtype)
                            dset.attrs['calibrator'] = calibrator
                            dset.attrs['dim'] = 'time, freq, pol, feed'
                            try:
                                dset.attrs['time'] = ts.time[start_ind:end_ind]
                            except RuntimeError:
                                f.create_dataset(
                                    'time', data=ts.time[start_ind:end_ind])
                                dset.attrs['time'] = '/time'
                            dset.attrs['freq'] = freq
                            dset.attrs['pol'] = np.array(['xx', 'yy'])
                            dset.attrs['feed'] = np.array(feedno)
                            # save gain
                            dset = f.create_dataset('gain', data=gain)
                            dset.attrs['calibrator'] = calibrator
                            dset.attrs['dim'] = 'freq, pol, feed'
                            dset.attrs['freq'] = freq
                            dset.attrs['pol'] = np.array(['xx', 'yy'])
                            dset.attrs['feed'] = np.array(feedno)
                            # save phs

                            if save_phs_change:
                                f.create_dataset('phs', data=phs)
                            # save transit index and transit time for the case do the ps cal first
                            f.attrs['transit_index'] = transit_inds[0]
                            #                            f.attrs['transit_time'] = ts['jul_date'][transit_inds[0]]
                            f.attrs['transit_jul'] = ts['jul_date'][
                                transit_inds[0]]
                            f.attrs['transit_time'] = ts['sec1970'][
                                transit_inds[0]]
                            #                            f.attrs['transit_time'] = ts.attrs['sec1970'] + ts.attrs['inttime']*transit_inds[0] # in sec1970 to improve numeric precision

                            if os.path.exists(output_path(
                                    ts.ns_gain_file)) and not ts.ps_first:
                                with h5py.File(output_path(ts.ns_gain_file),
                                               'r+') as ns_file:
                                    phs_only = not ('ns_cal_amp'
                                                    in ns_file.keys())
                                    #                                    exclude_bad = 'badchn' in ns_file['channo'].attrs.keys()
                                    new_gain = uni_gain(f,
                                                        ns_file,
                                                        phs_only=phs_only)
                                    ns_file.create_dataset('uni_gain',
                                                           data=new_gain)
                                    ns_file['uni_gain'].attrs[
                                        'dim'] = '(time, freq, bl)'

                    mpiutil.barrier()

                    # save Gain
                    for i in range(10):
                        try:
                            # NOTE: if write simultaneously, will loss data with processes distributed in several nodes
                            for ri in xrange(mpiutil.size):
                                if ri == mpiutil.rank:
                                    with h5py.File(gain_file, 'r+') as f:
                                        for ii, (ti, fi,
                                                 pi) in enumerate(tfp_linds):
                                            ti_ = ti - start_ind
                                            pi_ = gain_pd[pol[pi]]
                                            f['Gain'][ti_, fi, pi_] = lGain[ii]
                                mpiutil.barrier()
                            break
                        except IOError:
                            time.sleep(0.5)
                            continue
                    else:
                        raise RuntimeError('Could not open file: %s...' %
                                           gain_file)

                    mpiutil.barrier()

        # convert vis from intensity unit to temperature unit in K
        if temperature_convert:
            if 'unit' in ts.vis.attrs.keys() and ts.vis.attrs['unit'] == 'K':
                if mpiutil.rank0:
                    print 'vis is already in unit K, do nothing...'
            else:
                factor = 1.0e-26 * (const.c**2 / (2 * const.k_B *
                                                  (1.0e6 * freq)**2)
                                    )  # NOTE: 1Jy = 1.0e-26 W m^-2 Hz^-1
                ts.local_vis[:] *= factor[np.newaxis, :, np.newaxis,
                                          np.newaxis]
                ts.vis.attrs['unit'] = 'K'

        return super(PsCal, self).process(ts)