def wrap(cls, array, axis, comm=None): """Turn a set of numpy arrays into a distributed MPIArray object. This is needed for functions such as `np.fft.fft` which always return an `np.ndarray`. Parameters ---------- array : np.ndarray Array to wrap. axis : integer Axis over which the array is distributed. The lengths are checked to try and ensure this is correct. comm : MPI.Comm, optional The communicator over which the array is distributed. If `None` (default), use `MPI.COMM_WORLD`. Returns ------- dist_array : MPIArray An MPIArray view of the input. """ # from mpi4py import MPI if comm is None: comm = mpiutil.world # Get axis length, both locally, and globally axlen = array.shape[axis] totallen = mpiutil.allreduce(axlen, comm=comm) # Figure out what the distributed layout should be local_num, local_start, local_end = mpiutil.split_local(totallen, comm=comm) # Check the local layout is consistent with what we expect, and send # result to all ranks layout_issue = mpiutil.allreduce(axlen != local_num, op=mpiutil.MAX, comm=comm) if layout_issue: raise Exception("Cannot wrap, distributed axis local length is incorrect.") # Set shape and offset lshape = array.shape global_shape = list(lshape) global_shape[axis] = totallen loffset = [0] * len(lshape) loffset[axis] = local_start # Setup attributes of class dist_arr = array.view(cls) dist_arr._global_shape = tuple(global_shape) dist_arr._axis = axis dist_arr._local_shape = tuple(lshape) dist_arr._local_offset = tuple(loffset) dist_arr._comm = comm return dist_arr
def process(self, ts): assert isinstance( ts, Timestream ), '%s only works for Timestream object' % self.__class__.__name__ calibrator = self.params['calibrator'] catalog = self.params['catalog'] vis_conj = self.params['vis_conj'] zero_diag = self.params['zero_diag'] span = self.params['span'] reserve_high_gain = self.params['reserve_high_gain'] plot_figs = self.params['plot_figs'] fig_prefix = self.params['fig_name'] tag_output_iter = self.params['tag_output_iter'] save_src_vis = self.params['save_src_vis'] src_vis_file = self.params['src_vis_file'] subtract_src = self.params['subtract_src'] replace_with_src = self.params['replace_with_src'] apply_gain = self.params['apply_gain'] save_gain = self.params['save_gain'] save_phs_change = self.params['save_phs_change'] gain_file = self.params['gain_file'] # temperature_convert = self.params['temperature_convert'] show_progress = self.params['show_progress'] progress_step = self.params['progress_step'] if save_src_vis or subtract_src or apply_gain or save_gain: pol_type = ts['pol'].attrs['pol_type'] if pol_type != 'linear': raise RuntimeError('Can not do ps_cal for pol_type: %s' % pol_type) ts.redistribute('baseline') feedno = ts['feedno'][:].tolist() pol = [ts.pol_dict[p] for p in ts['pol'][:]] # as string gain_pd = { 'xx': 0, 'yy': 1, 0: 'xx', 1: 'yy' } # for gain related op bls = mpiutil.gather_array(ts.local_bl[:], root=None, comm=ts.comm) # # antpointing = np.radians(ts['antpointing'][-1, :, :]) # radians # transitsource = ts['transitsource'][:] # transit_time = transitsource[-1, 0] # second, sec1970 # int_time = ts.attrs['inttime'] # second # get the calibrator try: s = calibrators.get_src(calibrator) except KeyError: if mpiutil.rank0: print 'Calibrator %s is unavailable, available calibrators are:' for key, d in calibrators.src_data.items(): print '%8s -> %12s' % (key, d[0]) raise RuntimeError('Calibrator %s is unavailable') if mpiutil.rank0: print 'Try to calibrate with %s...' % s.src_name # get transit time of calibrator # array aa = ts.array aa.set_jultime(ts['jul_date'][0]) # the first obs time point next_transit = aa.next_transit(s) transit_time = a.phs.ephem2juldate(next_transit) # Julian date # get time zone pattern = '[-+]?\d+' tz = re.search(pattern, ts.attrs['timezone']).group() tz = int(tz) local_next_transit = ephem.Date( next_transit + tz * ephem.hour) # plus 8h to get Beijing time # if transit_time > ts['jul_date'][-1]: if transit_time > max(ts['jul_date'][-1], ts['jul_date'][:].max()): raise RuntimeError( 'Data does not contain local transit time %s of source %s' % (local_next_transit, calibrator)) # the first transit index transit_inds = [np.searchsorted(ts['jul_date'][:], transit_time)] # find all other transit indices aa.set_jultime(ts['jul_date'][0] + 1.0) transit_time = a.phs.ephem2juldate( aa.next_transit(s)) # Julian date cnt = 2 while (transit_time <= ts['jul_date'][-1]): transit_inds.append( np.searchsorted(ts['jul_date'][:], transit_time)) aa.set_jultime(ts['jul_date'][0] + 1.0 * cnt) transit_time = a.phs.ephem2juldate( aa.next_transit(s)) # Julian date cnt += 1 if mpiutil.rank0: print 'transit ind of %s: %s, time: %s' % ( s.src_name, transit_inds, local_next_transit) ### now only use the first transit point to do the cal ### may need to improve in the future transit_ind = transit_inds[0] int_time = ts.attrs['inttime'] # second start_ind = transit_ind - np.int(span / int_time) end_ind = transit_ind + np.int( span / int_time) + 1 # plus 1 to make transit_ind is at the center start_ind = max(0, start_ind) end_ind = min(end_ind, ts.vis.shape[0]) if vis_conj: ts.local_vis[:] = ts.local_vis.conj() nt = end_ind - start_ind t_inds = range(start_ind, end_ind) freq = ts.freq[:] # MHz nf = len(freq) nlb = len(ts.local_bl[:]) nfeed = len(feedno) tfp_inds = list( itertools.product( t_inds, range(nf), [pol.index('xx'), pol.index('yy')])) # only for xx and yy ns, ss, es = mpiutil.split_all(len(tfp_inds), comm=ts.comm) # gather data to make each process to have its own data which has all bls for ri, (ni, si, ei) in enumerate(zip(ns, ss, es)): lvis = np.zeros((ni, nlb), dtype=ts.vis.dtype) lvis_mask = np.zeros((ni, nlb), dtype=ts.vis_mask.dtype) for ii, (ti, fi, pi) in enumerate(tfp_inds[si:ei]): lvis[ii] = ts.local_vis[ti, fi, pi] lvis_mask[ii] = ts.local_vis_mask[ti, fi, pi] # gather vis from all process for separate bls gvis = mpiutil.gather_array(lvis, axis=1, root=ri, comm=ts.comm) gvis_mask = mpiutil.gather_array(lvis_mask, axis=1, root=ri, comm=ts.comm) if ri == mpiutil.rank: tfp_linds = tfp_inds[si:ei] # inds for this process this_vis = gvis this_vis_mask = gvis_mask del tfp_inds del lvis del lvis_mask tfp_len = len(tfp_linds) # lotl_mask = np.zeros((tfp_len, nfeed, nfeed), dtype=bool) cnan = complex(np.nan, np.nan) # complex nan if save_src_vis or subtract_src: # save calibrator src vis lsrc_vis = np.full((tfp_len, nfeed, nfeed), cnan, dtype=ts.vis.dtype) if save_src_vis: # save sky vis lsky_vis = np.full((tfp_len, nfeed, nfeed), cnan, dtype=ts.vis.dtype) # save outlier vis lotl_vis = np.full((tfp_len, nfeed, nfeed), cnan, dtype=ts.vis.dtype) if apply_gain or save_gain: lGain = np.full((tfp_len, nfeed), cnan, dtype=ts.vis.dtype) # find indices mapping between Vmat and vis # bis = range(nbl) bis_conj = [] # indices that shold be conj mis = [ ] # indices in the nfeed x nfeed matrix by flatten it to a vector mis_conj = [ ] # indices (of conj vis) in the nfeed x nfeed matrix by flatten it to a vector for bi, (fdi, fdj) in enumerate(bls): ai, aj = feedno.index(fdi), feedno.index(fdj) mis.append(ai * nfeed + aj) if ai != aj: bis_conj.append(bi) mis_conj.append(aj * nfeed + ai) # construct visibility matrix for a single time, freq, pol Vmat = np.full((nfeed, nfeed), cnan, dtype=ts.vis.dtype) # get flus of the calibrator in the observing frequencies if show_progress and mpiutil.rank0: pg = progress.Progress(tfp_len, step=progress_step) for ii, (ti, fi, pi) in enumerate(tfp_linds): if show_progress and mpiutil.rank0: pg.show(ii) # when noise on, just pass if 'ns_on' in ts.iterkeys() and ts['ns_on'][ti]: continue # aa.set_jultime(ts['jul_date'][ti]) # s.compute(aa) # get the topocentric coordinate of the calibrator at the current time # s_top = s.get_crds('top', ncrd=3) # aa.sim_cache(cat.get_crds('eq', ncrd=3)) # for compute bm_response and sim Vmat.flat[mis] = np.ma.array( this_vis[ii], mask=this_vis_mask[ii]).filled(cnan) Vmat.flat[mis_conj] = np.ma.array( this_vis[ii, bis_conj], mask=this_vis_mask[ii, bis_conj]).conj().filled(cnan) if save_src_vis: lsky_vis[ii] = Vmat # set invalid val to 0 invalid = ~np.isfinite(Vmat) # a bool array # if too many masks if np.where(invalid)[0].shape[0] > 0.3 * nfeed**2: continue Vmat[invalid] = 0 # if all are zeros if np.allclose(Vmat, 0.0): continue # fill diagonal of Vmat to 0 if zero_diag: np.fill_diagonal(Vmat, 0) # initialize the outliers med = np.median(Vmat.real) + 1.0J * np.median(Vmat.imag) diff = Vmat - med S0 = np.where( np.abs(diff) > 3.0 * rpca_decomp.MAD(Vmat), diff, 0) # stable PCA decomposition V0, S = rpca_decomp.decompose(Vmat, rank=1, S=S0, max_iter=200, threshold='hard', tol=1.0e-6, debug=False) # # find abnormal values in S # # first check diagonal elements # import pdb; pdb.set_trace() # svals = np.diag(S) # smed = np.median(svals.real) + 1.0J * np.median(svals.imag) # smad = rpca_decomp.MAD(svals) # # abnormal indices # abis = np.where(np.abs(svals - smed) > 3.0 * smad)[0] # for abi in abis: # lotl_mask[ii, abi, abi] = True # # then check non-diagonal elements # for rii in range(nfeed): # for cii in range(nfeed): # if rii == cii: # continue # rli = max(0, rii-2) # rhi = min(nfeed, rii+3) # cli = max(0, cii-2) # chi = min(nfeed, cii+3) # svals = np.array([ S[xi, yi] for xi in range(rli, rhi) for yi in range(cli, chi) if xi != yi ]) # smed = np.median(svals.real) + 1.0J * np.median(svals.imag) # smad = rpca_decomp.MAD(svals) # if np.abs(S[rii, cii] - smed) > 3.0 * smad: # lotl_mask[ii, rii, cii] = True if save_src_vis or subtract_src: lsrc_vis[ii] = V0 if save_src_vis: lotl_vis[ii] = S # plot if plot_figs: ind = ti - start_ind # plot Vmat plt.figure(figsize=(13, 5)) plt.subplot(121) plt.imshow(Vmat.real, aspect='equal', origin='lower', interpolation='nearest') plt.colorbar(shrink=1.0) plt.subplot(122) plt.imshow(Vmat.imag, aspect='equal', origin='lower', interpolation='nearest') plt.colorbar(shrink=1.0) fig_name = '%s_V_%d_%d_%s.png' % (fig_prefix, ind, fi, pol[pi]) if tag_output_iter: fig_name = output_path(fig_name, iteration=self.iteration) else: fig_name = output_path(fig_name) plt.savefig(fig_name) plt.close() # plot V0 plt.figure(figsize=(13, 5)) plt.subplot(121) plt.imshow(V0.real, aspect='equal', origin='lower', interpolation='nearest') plt.colorbar(shrink=1.0) plt.subplot(122) plt.imshow(V0.imag, aspect='equal', origin='lower', interpolation='nearest') plt.colorbar(shrink=1.0) fig_name = '%s_V0_%d_%d_%s.png' % (fig_prefix, ind, fi, pol[pi]) if tag_output_iter: fig_name = output_path(fig_name, iteration=self.iteration) else: fig_name = output_path(fig_name) plt.savefig(fig_name) plt.close() # plot S plt.figure(figsize=(13, 5)) plt.subplot(121) plt.imshow(S.real, aspect='equal', origin='lower', interpolation='nearest') plt.colorbar(shrink=1.0) plt.subplot(122) plt.imshow(S.imag, aspect='equal', origin='lower', interpolation='nearest') plt.colorbar(shrink=1.0) fig_name = '%s_S_%d_%d_%s.png' % (fig_prefix, ind, fi, pol[pi]) if tag_output_iter: fig_name = output_path(fig_name, iteration=self.iteration) else: fig_name = output_path(fig_name) plt.savefig(fig_name) plt.close() # plot N N = Vmat - V0 - S plt.figure(figsize=(13, 5)) plt.subplot(121) plt.imshow(N.real, aspect='equal', origin='lower', interpolation='nearest') plt.colorbar(shrink=1.0) plt.subplot(122) plt.imshow(N.imag, aspect='equal', origin='lower', interpolation='nearest') plt.colorbar(shrink=1.0) fig_name = '%s_N_%d_%d_%s.png' % (fig_prefix, ind, fi, pol[pi]) if tag_output_iter: fig_name = output_path(fig_name, iteration=self.iteration) else: fig_name = output_path(fig_name) plt.savefig(fig_name) plt.close() if apply_gain or save_gain: # use v_ij = gi gj^* \int Ai Aj^* e^(2\pi i n \cdot uij) T(x) d^2n # precisely, we shold have # V0 = (lambda^2 * Sc / (2 k_B)) * gi gj^* Ai Aj^* e^(2\pi i n0 \cdot uij) e, U = la.eigh(V0, eigvals=(nfeed - 1, nfeed - 1)) g = U[:, -1] * e[ -1]**0.5 # = \sqrt(lambda^2 * Sc / (2 k_B)) * gi Ai * e^(2\pi i n0 \cdot ui) if g[0].real < 0: g *= -1.0 # make all g[0] phase 0, instead of pi lGain[ii] = g # plot Gain if plot_figs: plt.figure() plt.plot(feedno, g.real, 'b-', label='real') plt.plot(feedno, g.real, 'bo') plt.plot(feedno, g.imag, 'g-', label='imag') plt.plot(feedno, g.imag, 'go') plt.plot(feedno, np.abs(g), 'r-', label='abs') plt.plot(feedno, np.abs(g), 'ro') plt.xlim(feedno[0] - 1, feedno[-1] + 1) yl, yh = plt.ylim() plt.ylim(yl, yh + (yh - yl) / 5) plt.xlabel('Feed number') plt.legend() fig_name = '%s_ants_%d_%d_%s.png' % (fig_prefix, ind, fi, pol[pi]) if tag_output_iter: fig_name = output_path(fig_name, iteration=self.iteration) else: fig_name = output_path(fig_name) plt.savefig(fig_name) plt.close() # # apply outlier mask # nbl = len(bls) # lom = np.zeros((lotl_mask.shape[0], nbl), dtype=lotl_mask.dtype) # for bi, (fd1, fd2) in enumerate(bls): # b1, b2 = feedno.index(fd1), feedno.index(fd2) # lom[:, bi] = lotl_mask[:, b1, b2] # lom = mpiarray.MPIArray.wrap(lom, axis=0, comm=ts.comm) # lom = lom.redistribute(axis=1).local_array.reshape(nt, nf, 2, -1) # ts.local_vis_mask[start_ind:end_ind, :, pol.index('xx')] |= lom[:, :, 0] # ts.local_vis_mask[start_ind:end_ind, :, pol.index('yy')] |= lom[:, :, 1] # subtract the vis of calibrator from self.vis if subtract_src: nbl = len(bls) lv = np.zeros((lsrc_vis.shape[0], nbl), dtype=lsrc_vis.dtype) for bi, (fd1, fd2) in enumerate(bls): b1, b2 = feedno.index(fd1), feedno.index(fd2) lv[:, bi] = lsrc_vis[:, b1, b2] lv = mpiarray.MPIArray.wrap(lv, axis=0, comm=ts.comm) lv = lv.redistribute(axis=1).local_array.reshape(nt, nf, 2, -1) if replace_with_src: ts.local_vis[start_ind:end_ind, :, pol.index('xx')] = lv[:, :, 0] ts.local_vis[start_ind:end_ind, :, pol.index('yy')] = lv[:, :, 1] else: if 'ns_on' in ts.iterkeys(): lv[ts['ns_on'] [start_ind: end_ind]] = 0 # avoid ns_on signal to become nan ts.local_vis[start_ind:end_ind, :, pol.index('xx')] -= lv[:, :, 0] ts.local_vis[start_ind:end_ind, :, pol.index('yy')] -= lv[:, :, 1] del lv if not save_src_vis: if subtract_src: del lsrc_vis else: if tag_output_iter: src_vis_file = output_path(src_vis_file, iteration=self.iteration) else: src_vis_file = output_path(src_vis_file) # create file and allocate space first by rank0 if mpiutil.rank0: with h5py.File(src_vis_file, 'w') as f: # allocate space shp = (nt, nf, 2, nfeed, nfeed) f.create_dataset('sky_vis', shp, dtype=lsky_vis.dtype) f.create_dataset('src_vis', shp, dtype=lsrc_vis.dtype) f.create_dataset('outlier_vis', shp, dtype=lotl_vis.dtype) # f.create_dataset('outlier_mask', shp, dtype=lotl_mask.dtype) f.attrs['calibrator'] = calibrator f.attrs['dim'] = 'time, freq, pol, feed, feed' try: f.attrs['time'] = ts.time[start_ind:end_ind] except RuntimeError: f.create_dataset('time', data=ts.time[start_ind:end_ind]) f.attrs['time'] = '/time' f.attrs['freq'] = freq f.attrs['pol'] = np.array(['xx', 'yy']) f.attrs['feed'] = np.array(feedno) mpiutil.barrier() # write data to file for i in range(10): try: # NOTE: if write simultaneously, will loss data with processes distributed in several nodes for ri in xrange(mpiutil.size): if ri == mpiutil.rank: with h5py.File(src_vis_file, 'r+') as f: for ii, (ti, fi, pi) in enumerate(tfp_linds): ti_ = ti - start_ind pi_ = gain_pd[pol[pi]] f['sky_vis'][ti_, fi, pi_] = lsky_vis[ii] f['src_vis'][ti_, fi, pi_] = lsrc_vis[ii] f['outlier_vis'][ti_, fi, pi_] = lotl_vis[ii] # f['outlier_mask'][ti_, fi, pi_] = lotl_mask[ii] mpiutil.barrier() break except IOError: time.sleep(0.5) continue else: raise RuntimeError('Could not open file: %s...' % src_vis_file) del lsrc_vis del lsky_vis del lotl_vis # del lotl_mask mpiutil.barrier() if apply_gain or save_gain: # flag outliers in lGain along each feed lG_abs = np.full_like(lGain, np.nan, dtype=lGain.real.dtype) for i in range(lGain.shape[0]): valid_inds = np.where(np.isfinite(lGain[i]))[0] if len(valid_inds) > 3: vabs = np.abs(lGain[i, valid_inds]) vmed = np.median(vabs) vabs_diff = np.abs(vabs - vmed) vmad = np.median(vabs_diff) / 0.6745 if reserve_high_gain: # reserve significantly higher ones, flag only significantly lower ones lG_abs[i, valid_inds] = np.where( vmed - vabs > 3.0 * vmad, np.nan, vabs) else: # flag both significantly higher and lower ones lG_abs[i, valid_inds] = np.where( vabs_diff > 3.0 * vmad, np.nan, vabs) # choose data slice near the transit time li = max(start_ind, transit_ind - 10) - start_ind hi = min(end_ind, transit_ind + 10 + 1) - start_ind ci = transit_ind - start_ind # center index for transit_ind # compute s_top for this time range n0 = np.zeros(((hi - li), 3)) for ti, jt in enumerate(ts.time[start_ind:end_ind][li:hi]): aa.set_jultime(jt) s.compute(aa) n0[ti] = s.get_crds('top', ncrd=3) if save_phs_change: n0t = np.zeros((nt, 3)) for ti, jt in enumerate(ts.time[start_ind:end_ind]): aa.set_jultime(jt) s.compute(aa) n0t[ti] = s.get_crds('top', ncrd=3) # get the positions of feeds feedpos = ts['feedpos'][:] # wrap and redistribute Gain and flagged G_abs Gain = mpiarray.MPIArray.wrap(lGain, axis=0, comm=ts.comm) Gain = Gain.redistribute(axis=1).reshape( nt, nf, 2, None).redistribute(axis=0).reshape( None, nf * 2 * nfeed).redistribute(axis=1) G_abs = mpiarray.MPIArray.wrap(lG_abs, axis=0, comm=ts.comm) G_abs = G_abs.redistribute(axis=1).reshape( nt, nf, 2, None).redistribute(axis=0).reshape( None, nf * 2 * nfeed).redistribute(axis=1) fpd_inds = list( itertools.product(range(nf), range(2), range(nfeed))) # only for xx and yy fpd_linds = mpiutil.mpilist(fpd_inds, method='con', comm=ts.comm) del fpd_inds # create data to save the solved gain for each feed lgain = np.full((len(fpd_linds), ), cnan, dtype=Gain.dtype) # gain for each feed if save_phs_change: lphs = np.full((nt, len(fpd_linds)), np.nan, dtype=Gain.real.dtype ) # phase change with time for each feed # check for conj num_conj = 0 for ii, (fi, pi, di) in enumerate(fpd_linds): y = G_abs.local_array[li:hi, ii] inds = np.where(np.isfinite(y))[0] if len(inds) >= max(4, 0.5 * len(y)): # get the approximate magnitude by averaging the central G_abs # solve phase by least square fit ui = (feedpos[di] - feedpos[0]) * ( 1.0e6 * freq[fi] ) / const.c # position of this feed (relative to the first feed) in unit of wavelength exp_factor = np.exp(2.0J * np.pi * np.dot(n0, ui)) ef = exp_factor Gi = Gain.local_array[li:hi, ii] e_phs = np.dot(ef[inds].conj(), Gi[inds] / y[inds]) / len(inds) ea = np.abs(e_phs) e_phs_conj = np.dot(ef[inds], Gi[inds] / y[inds]) / len(inds) eac = np.abs(e_phs_conj) if eac > ea: num_conj += 1 # reduce num_conj from all processes num_conj = mpiutil.allreduce(num_conj, comm=ts.comm) if num_conj > 0.5 * (nf * 2 * nfeed): # 2 for 2 pols if mpiutil.rank0: print '!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!' print '!!! Detect data should be their conjugate... !!!' print '!!! Correct it automatically... !!!' print '!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!' mpiutil.barrier() # correct vis ts.local_vis[:] = ts.local_vis.conj() # correct G Gain.local_array[:] = Gain.local_array.conj() # solve for gain for ii, (fi, pi, di) in enumerate(fpd_linds): y = G_abs.local_array[li:hi, ii] inds = np.where(np.isfinite(y))[0] if len(inds) >= max(4, 0.5 * len(y)): # get the approximate magnitude by averaging the central G_abs mag = np.mean( y[inds] ) # = \sqrt(lambda^2 * Sc / (2 k_B)) * |gi| Ai # solve phase by least square fit ui = (feedpos[di] - feedpos[0]) * ( 1.0e6 * freq[fi] ) / const.c # position of this feed (relative to the first feed) in unit of wavelength exp_factor = np.exp(2.0J * np.pi * np.dot(n0, ui)) ef = exp_factor Gi = Gain.local_array[li:hi, ii] e_phs = np.dot(ef[inds].conj(), Gi[inds] / y[inds]) / len(inds) # the phase of gi ea = np.abs(e_phs) if np.abs(ea - 1.0) < 0.1: # compute gain for this feed lgain[ ii] = mag * e_phs # \sqrt(lambda^2 * Sc / (2 k_B)) * gi Ai if save_phs_change: lphs[:, ii] = np.angle( np.exp(-2.0J * np.pi * np.dot(n0t, ui)) * Gain.local_array[:, ii]) else: e_phs_conj = np.dot(ef[inds], Gi[inds] / y[inds]) / len(inds) eac = np.abs(e_phs_conj) if eac > ea: if np.abs(eac - 1.0) < 0.01: print 'feedno = %d, fi = %d, pol = %s: may need to be conjugated' % ( feedno[di], fi, gain_pd[pi]) else: print 'feedno = %d, fi = %d, pol = %s: maybe wrong abs(e_phs): %s' % ( feedno[di], fi, gain_pd[pi], ea) # gather local gain gain = mpiutil.gather_array(lgain, axis=0, root=None, comm=ts.comm) del lgain gain = gain.reshape(nf, 2, nfeed) if save_phs_change: phs = mpiutil.gather_array(lphs, axis=1, root=0, comm=ts.comm) del lphs if mpiutil.rank0: phs = phs.reshape(nt, nf, 2, nfeed) # normalize to get the exact gain Sc = s.get_jys(1.0e-3 * freq) # Omega = aa.ants[0].beam.Omega ### TODO: implement Omega for dish Ai = aa.ants[0].beam.response(n0[ci - li]) lmd = const.c / (1.0e6 * freq) factor = np.sqrt( (lmd**2 * 1.0e-26 * Sc) / (2 * const.k_B)) * Ai # NOTE: 1Jy = 1.0e-26 W m^-2 Hz^-1 gain /= factor[:, np.newaxis, np.newaxis] # apply gain to vis if apply_gain: for fi in range(nf): for pi in [pol.index('xx'), pol.index('yy')]: pi_ = gain_pd[pol[pi]] for bi, (fd1, fd2) in enumerate( ts['blorder'].local_data): g1 = gain[fi, pi_, feedno.index(fd1)] g2 = gain[fi, pi_, feedno.index(fd2)] if np.isfinite(g1) and np.isfinite(g2): if fd1 == fd2: # auto-correlation should be real ts.local_vis[:, fi, pi, bi] /= (g1 * np.conj(g2)).real else: ts.local_vis[:, fi, pi, bi] /= (g1 * np.conj(g2)) else: # mask the un-calibrated vis ts.local_vis_mask[:, fi, pi, bi] = True # in unit K after the calibration ts.vis.attrs['unit'] = 'K' # save gain to file if save_gain: if tag_output_iter: gain_file = output_path(gain_file, iteration=self.iteration) else: gain_file = output_path(gain_file) if mpiutil.rank0: with h5py.File(gain_file, 'w') as f: # allocate space for Gain dset = f.create_dataset('Gain', (nt, nf, 2, nfeed), dtype=Gain.dtype) dset.attrs['calibrator'] = calibrator dset.attrs['dim'] = 'time, freq, pol, feed' try: dset.attrs['time'] = ts.time[start_ind:end_ind] except RuntimeError: f.create_dataset( 'time', data=ts.time[start_ind:end_ind]) dset.attrs['time'] = '/time' dset.attrs['freq'] = freq dset.attrs['pol'] = np.array(['xx', 'yy']) dset.attrs['feed'] = np.array(feedno) dset.attrs['transit_ind'] = transit_ind # save gain dset = f.create_dataset('gain', data=gain) dset.attrs['calibrator'] = calibrator dset.attrs['dim'] = 'freq, pol, feed' dset.attrs['freq'] = freq dset.attrs['pol'] = np.array(['xx', 'yy']) dset.attrs['feed'] = np.array(feedno) # save phs if save_phs_change: f.create_dataset('phs', data=phs) mpiutil.barrier() # save Gain for i in range(10): try: # NOTE: if write simultaneously, will loss data with processes distributed in several nodes for ri in xrange(mpiutil.size): if ri == mpiutil.rank: with h5py.File(gain_file, 'r+') as f: for ii, (ti, fi, pi) in enumerate(tfp_linds): ti_ = ti - start_ind pi_ = gain_pd[pol[pi]] f['Gain'][ti_, fi, pi_] = lGain[ii] mpiutil.barrier() break except IOError: time.sleep(0.5) continue else: raise RuntimeError('Could not open file: %s...' % gain_file) mpiutil.barrier() return super(PsCal, self).process(ts)
def process(self, rt): assert isinstance( rt, RawTimestream ), '%s only works for RawTimestream object currently' % self.__class__.__name__ channel = self.params['channel'] sigma = self.params['sigma'] mask_near = max(0, int(self.params['mask_near'])) rt.redistribute(0) # make time the dist axis auto_inds = np.where( rt.bl[:, 0] == rt.bl[:, 1])[0].tolist() # inds for auto-correlations channels = [rt.bl[ai, 0] for ai in auto_inds] # all chosen channels if channel is not None: if channel in channels: bl_ind = auto_inds[channels.index(channel)] else: bl_ind = auto_inds[0] if mpiutil.rank0: print 'Warning: Required channel %d doen not in the data, use channel %d instead' % ( channel, rt.bl[bl_ind, 0]) else: bl_ind = auto_inds[0] # move the chosen channel to the first auto_inds.remove(bl_ind) auto_inds = [bl_ind] + auto_inds for bl_ind in auto_inds: this_chan = rt.bl[bl_ind, 0] # channel of this bl_ind vis = np.ma.array(rt.local_vis[:, :, bl_ind].real, mask=rt.local_vis_mask[:, :, bl_ind]) cnt = vis.count() # number of not masked vals total_cnt = mpiutil.allreduce(cnt) vis_shp = rt.vis.shape ratio = float(total_cnt) / np.prod( (vis_shp[0], vis_shp[1])) # ratio of un-masked vals if ratio < 0.5: # too many masked vals if mpiutil.rank0: warnings.warn( 'Too many masked values for auto-correlation of Channel: %d, does not use it' % this_chan) continue tt_mean = mpiutil.gather_array(np.ma.mean(vis, axis=-1).filled(0), root=None) df = np.diff(tt_mean, axis=-1) pdf = np.where(df > 0, df, 0) pinds = np.where(pdf > pdf.mean() + sigma * pdf.std())[0] pinds = pinds + 1 pinds1 = [pinds[0]] for pi in pinds[1:]: if pi - pinds1[-1] > 1: pinds1.append(pi) pinds = np.array(pinds1) pT = Counter( np.diff(pinds)).most_common(1)[0][0] # period of pinds ndf = np.where(df < 0, df, 0) ninds = np.where(ndf < ndf.mean() - sigma * ndf.std())[0] ninds = ninds + 1 ninds = ninds[::-1] ninds1 = [ninds[0]] for ni in ninds[1:]: if ni - ninds1[-1] < -1: ninds1.append(ni) ninds = np.array(ninds1[::-1]) nT = Counter( np.diff(ninds)).most_common(1)[0][0] # period of ninds if pT != nT: # failed to detect correct period if mpiutil.rank0: warnings.warn( 'Failed to detect correct period for auto-correlation of Channel: %d, positive T %d != negative T %d, does not use it' % (this_chan, pT, nT)) continue else: period = pT ninds = ninds.reshape(-1, 1) dinds = (ninds - pinds).flatten() on_time = Counter(dinds[dinds > 0] % period).most_common(1)[0][0] off_time = Counter(-dinds[dinds < 0] % period).most_common(1)[0][0] if period != on_time + off_time: # incorrect detect if mpiutil.rank0: warnings.warn( 'Incorrect detect for auto-correlation of Channel: %d, period %d != on_time %d + off_time %d, does not use it' % (this_chan, period, on_time, off_time)) continue else: if 'noisesource' in rt.iterkeys(): if rt['noisesource'].shape[0] == 1: # only 1 noise source start, stop, cycle = rt['noisesource'][0, :] int_time = rt.attrs['inttime'] true_on_time = np.round((stop - start) / int_time) true_period = np.round(cycle / int_time) if on_time != true_on_time and period != true_period: # inconsistant with the record in the data if mpiutil.rank0: warnings.warn( 'Detected noise source info is inconsistant with the record in the data for auto-correlation of Channel: %d: on_time %d != record_on_time %d, period != record_period %d, does not use it' % (this_chan, on_time, true_on_time, period, true_period)) continue elif rt['noisesource'].shape[ 0] >= 2: # more than 1 noise source if mpiutil.rank0: warnings.warn( 'More than 1 noise source, do not know how to deal with this currently' ) # break if succeed break else: raise RuntimeError('Failed to detect noise source signal') if mpiutil.rank0: print 'Detected noise source: period = %d, on_time = %d, off_time = %d' % ( period, on_time, off_time) num_period = np.int(np.ceil(len(tt_mean) / np.float(period))) tmp_ns_on = np.array(([True] * on_time + [False] * off_time) * num_period)[:len(tt_mean)] on_start = Counter(pinds % period).most_common(1)[0][0] ns_on = np.roll(tmp_ns_on, on_start) # import matplotlib # matplotlib.use('Agg') # import matplotlib.pyplot as plt # plt.figure() # plt.plot(np.where(ns_on, np.nan, tt_mean)) # # plt.plot(pinds, tt_mean[pinds], 'RI') # # plt.plot(ninds, tt_mean[ninds], 'go') # plt.savefig('df.png') # err ns_on1 = mpiarray.MPIArray.from_numpy_array(ns_on) rt.create_main_time_ordered_dataset('ns_on', ns_on1) rt['ns_on'].attrs['period'] = period rt['ns_on'].attrs['on_time'] = on_time rt['ns_on'].attrs['off_time'] = off_time # set vis_mask corresponding to ns_on on_inds = np.where(rt['ns_on'].local_data[:])[0] rt.local_vis_mask[on_inds] = True if mask_near > 0: on_inds = np.where(ns_on)[0] new_on_inds = on_inds.tolist() for i in xrange(1, mask_near + 1): new_on_inds = new_on_inds + (on_inds - i).tolist() + ( on_inds + i).tolist() new_on_inds = np.unique(new_on_inds) if rt['vis_mask'].distributed: start = rt.vis_mask.local_offset[0] end = start + rt.vis_mask.local_shape[0] else: start = 0 end = rt.vis_mask.shape[0] global_inds = np.arange(start, end).tolist() new_on_inds = np.intersect1d(new_on_inds, global_inds) local_on_inds = [global_inds.index(i) for i in new_on_inds] rt.local_vis_mask[ local_on_inds] = True # set mask using global slicing return super(Detect, self).process(rt)
def process(self, ts): excl_auto = self.params['excl_auto'] plot_stats = self.params['plot_stats'] fig_prefix = self.params['fig_name'] rotate_xdate = self.params['rotate_xdate'] tag_output_iter = self.params['tag_output_iter'] ts.redistribute('baseline') if ts.local_vis_mask.ndim == 3: # RawTimestream if excl_auto: bl = ts.local_bl vis_mask = ts.local_vis_mask[:, :, bl[:, 0] != bl[:, 1]].copy() else: vis_mask = ts.local_vis_mask.copy() nt, nf, lnb = vis_mask.shape elif ts.local_vis_mask.ndim == 4: # Timestream # suppose masks are the same for all 4 pols if excl_auto: bl = ts.local_bl vis_mask = ts.local_vis_mask[:, :, 0, bl[:, 0] != bl[:, 1]].copy() else: vis_mask = ts.local_vis_mask[:, :, 0].copy() nt, nf, lnb = vis_mask.shape else: raise RuntimeError('Incorrect vis_mask shape %s' % ts.local_vis_mask.shape) # total number of bl nb = mpiutil.allreduce(lnb, comm=ts.comm) # un-mask ns-on positions if 'ns_on' in ts.iterkeys(): vis_mask[ts['ns_on'][:]] = False # statistics along time axis time_mask = np.sum(vis_mask, axis=(1, 2)).reshape(-1, 1) # gather local array to rank0 time_mask = mpiutil.gather_array(time_mask, axis=1, root=0, comm=ts.comm) if mpiutil.rank0: time_mask = np.sum(time_mask, axis=1) # statistics along time axis freq_mask = np.sum(vis_mask, axis=(0, 2)).reshape(-1, 1) # gather local array to rank0 freq_mask = mpiutil.gather_array(freq_mask, axis=1, root=0, comm=ts.comm) if mpiutil.rank0: freq_mask = np.sum(freq_mask, axis=1) if plot_stats and mpiutil.rank0: time_fig_name = '%s_%s.png' % (fig_prefix, 'time') if tag_output_iter: time_fig_name = output_path(time_fig_name, iteration=self.iteration) else: time_fig_name = output_path(time_fig_name) # plot time_mask plt.figure() fig, ax = plt.subplots() x_vals = np.array([ (datetime.utcfromtimestamp(s) + timedelta(hours=8)) for s in ts['sec1970'][:] ]) xlabel = '%s' % x_vals[0].date() x_vals = mdates.date2num(x_vals) ax.plot(x_vals, 100 * time_mask / np.float(nf * nb)) ax.xaxis_date() date_format = mdates.DateFormatter('%H:%M') ax.xaxis.set_major_formatter(date_format) if rotate_xdate: # set the x-axis tick labels to diagonal so it fits better fig.autofmt_xdate() else: # reduce the number of tick locators locator = MaxNLocator(nbins=6) ax.xaxis.set_major_locator(locator) ax.xaxis.set_minor_locator(AutoMinorLocator(2)) ax.set_xlabel(xlabel) ax.set_ylabel(r'RFI (%)') plt.savefig(time_fig_name) plt.close() freq_fig_name = '%s_%s.png' % (fig_prefix, 'freq') if tag_output_iter: freq_fig_name = output_path(freq_fig_name, iteration=self.iteration) else: freq_fig_name = output_path(freq_fig_name) # plot freq_mask plt.figure() plt.plot(ts.freq[:], 100 * freq_mask / np.float(nt * nb)) plt.grid(True) plt.xlabel(r'$\nu$ / MHz') plt.ylabel(r'RFI (%)') plt.savefig(freq_fig_name) plt.close() return super(Stats, self).process(ts)
def process(self, data): """Regrid visibility data onto a regular grid in hour angle. Parameters ---------- data : TimeStream Time-ordered data. Returns ------- new_data : SiderealStream The regridded data centered on the source RA. """ # Redistribute if needed data.redistribute("freq") # View of data weight = data.weight[:].view(np.ndarray) vis_data = data.vis[:].view(np.ndarray) # Get apparent source RA, including precession effects ra, _ = ephem.object_coords(self.src, data.time[0], deg=True, obs=self.observer) # Get catalogue RA for reference ra_icrs, _ = ephem.object_coords(self.src, deg=True, obs=self.observer) # Convert input times to hour angle lha = unwrap_lha(self.observer.unix_to_lsa(data.time), ra) # perform regridding success = 1 try: new_grid, new_vis, ni = self._regrid(vis_data, weight, lha) except np.linalg.LinAlgError as e: self.log.error(str(e)) success = 0 except ValueError as e: self.log.error(str(e)) success = 0 # Check other ranks have completed success = mpiutil.allreduce(success) if success != mpiutil.size: self.log.warning("Regridding failed. Skipping transit.") return None # mask out regions beyond bounds of this transit grid_mask = np.ones_like(new_grid) grid_mask[new_grid < lha.min()] = 0.0 grid_mask[new_grid > lha.max()] = 0.0 new_vis *= grid_mask ni *= grid_mask # Wrap to produce MPIArray if data.distributed: new_vis = mpiarray.MPIArray.wrap(new_vis, axis=data.vis.distributed_axis) ni = mpiarray.MPIArray.wrap(ni, axis=data.vis.distributed_axis) # Create new container for output ra_grid = (new_grid + ra) % 360.0 new_data = SiderealStream( axes_from=data, attrs_from=data, ra=ra_grid, comm=data.comm ) new_data.redistribute("freq") new_data.vis[:] = new_vis new_data.weight[:] = ni new_data.attrs["cirs_ra"] = ra new_data.attrs["icrs_ra"] = ra_icrs return new_data
def process(self, rt): assert isinstance( rt, RawTimestream ), '%s only works for RawTimestream object currently' % self.__class__.__name__ channel = self.params['channel'] sigma = self.params['sigma'] mask_near = max(0, int(self.params['mask_near'])) #================================================ rt.FRB_cal = self.params['FRB_cal'] ns_arr_file = self.params['ns_arr_file'] ns_prop_file = self.params['ns_prop_file'] num_noise = self.params['num_noise'] change_reference_tol = self.params['change_reference_tol'] abnormal_ns_save_file = 'ns_cal/abnormal_ns.npz' #================================================ rt.redistribute(0) # make time the dist axis # time_span = rt.local_vis.shape[0] # total_time_span = mpiutil.allreduce(time_span) total_time_span = rt['sec1970'].shape[0] auto_inds = np.where( rt.bl[:, 0] == rt.bl[:, 1])[0].tolist() # inds for auto-correlations channels = [rt.bl[ai, 0] for ai in auto_inds] # all chosen channels if channel is not None: if channel in channels: bl_ind = auto_inds[channels.index(channel)] else: bl_ind = auto_inds[0] if mpiutil.rank0: print 'Warning: Required channel %d doen not in the data, use channel %d instead' % ( channel, rt.bl[bl_ind, 0]) else: bl_ind = auto_inds[0] # move the chosen channel to the first auto_inds.remove(bl_ind) auto_inds = [bl_ind] + auto_inds if rt.FRB_cal: # ns_arr_end_time = mpiutil.bcast(rt['jul_date'][total_time_span - 1], root = mpiutil.size - 1) ns_arr_end_time = rt.attrs['sec1970'][0] + rt.attrs['inttime'] * ( total_time_span - 1) if os.path.exists(output_path(ns_arr_file)) and not os.path.exists( output_path(ns_prop_file)): filein = np.load(output_path(ns_arr_file)) # add 0 to avoid a single float or int become an array ns_arr_pinds = filein['on_inds'] + 0 ns_arr_ninds = filein['off_inds'] + 0 ns_arr_pinds = list(ns_arr_pinds) ns_arr_ninds = list(ns_arr_ninds) ns_arr_len = filein['time_len'] + 0 ns_arr_num = filein['ns_num'] + 0 ns_arr_bl = filein['auto_chn'] + 0 ns_arr_start_time = filein['start_time'] # overlap_index_span = np.around(np.float128(filein['end_time'] - mpiutil.bcast(rt['jul_date'][0], root = 0))*24.*3600./rt.attrs['inttime']) + 1 overlap_index_span = np.around( np.float128(filein['end_time'] - rt.attrs['sec1970'][0]) / rt.attrs['inttime']) + 1 if overlap_index_span > 0: raise OverlapData( 'Overlap of data occured when trying to build up noise property file! In julian date, the end time of previous data is %.10f, while the start time of this data is %.10f. The overlap span in index is %d.' % (filein['end_time'], rt.attrs['sec1970'][0], overlap_index_span)) elif not os.path.exists(output_path(ns_prop_file)): ns_arr_pinds = [] ns_arr_ninds = [] ns_arr_len = 0 ns_arr_num = -1 # to distinguish the first process # ns_arr_start_time = mpiutil.bcast(rt['jul_date'][0], root = 0) ns_arr_start_time = rt.attrs['sec1970'][0] if rt.FRB_cal and os.path.exists(output_path(ns_prop_file)): if mpiutil.rank0: print('Use existing ns property file %s to do calibration.' % output_path(ns_prop_file)) ns_prop_data = np.load(output_path(ns_prop_file)) period = ns_prop_data['period'] + 0 on_time = ns_prop_data['on_time'] + 0 off_time = ns_prop_data['off_time'] + 0 reference_time = ns_prop_data[ 'reference_time'] + 0 # in sec1970, is a start point of noise if 'lost_count' in ns_prop_data.files: lost_count_before = ns_prop_data['lost_count'] else: lost_count_before = 0 if 'added_count' in ns_prop_data.files: added_count_before = ns_prop_data['added_count'] else: added_count_before = 0 # this_time_start = mpiutil.bcast(rt['jul_date'][0], root=0) # skip_inds = int(np.around((this_time_start - reference_time)*86400.0/rt.attrs['inttime'])) # the number of index between the reference time and the start point this_time_start = rt.attrs['sec1970'][0] skip_inds = int( np.around( (this_time_start - reference_time) / rt.attrs['inttime']) ) # the number of index between the reference time and the start point on_start = period - skip_inds % period # if total_time_span < period: # raise Exception('Time span of data %d is shorter than period %d!'%(total_time_span, period)) if total_time_span <= on_start + on_time: raise NoNoisePoint( 'Calculated from previous data, this data contains no noise point or does not contain complete noise signal!' ) # check whether there are lost points # only consider that the case that there are one lost point and only consider the on points #================================================ abnormal_count = -1 lost_one_point = 0 added_one_point = 0 abnormal_list = [] lost_one_list = [] added_one_list = [] lost_one_pos = [] added_one_pos = [] complete_period_num = ( total_time_span - on_start - on_time - 1) // period # the number of periods in the data on_points = [ on_start + i * period for i in range(complete_period_num + 1) ] off_points = [ on_start + on_time + i * period for i in range(complete_period_num + 1) ] for bl_ind in auto_inds: this_chan = rt.bl[bl_ind, 0] # channel of this bl_ind vis = np.ma.array(rt.local_vis[:, :, bl_ind].real, mask=rt.local_vis_mask[:, :, bl_ind]) cnt = vis.count() # number of not masked vals total_cnt = mpiutil.allreduce(cnt) vis_shp = rt.vis.shape ratio = float(total_cnt) / np.prod( (vis_shp[0], vis_shp[1])) # ratio of un-masked vals if ratio < 0.5: # too many masked vals continue if abnormal_count < 0: abnormal_count = 0 tt_mean = mpiutil.gather_array( np.ma.mean(vis, axis=-1).filled(0), root=None) # mean for all freq, for a specific auto bl df = np.diff(tt_mean, axis=-1) pdf = np.where(df > 0, df, 0) pinds = np.where(pdf > pdf.mean() + sigma * pdf.std())[0] pinds = pinds + 1 #==================================== if len(pinds) == 0: # no raise, might be badchn, continue continue #==================================== pinds1 = [pinds[0]] for pi in pinds[1:]: if pi - pinds1[-1] > 1: pinds1.append(pi) pinds = np.array(pinds1) ndf = np.where(df < 0, df, 0) ninds = np.where(ndf < ndf.mean() - sigma * ndf.std())[0] ninds = ninds + 1 ninds = ninds[::-1] #==================================== if len(ninds) == 0: # no raise, might be badchn, continue continue #==================================== ninds1 = [ninds[0]] for ni in ninds[1:]: if ni - ninds1[-1] < -1: ninds1.append(ni) ninds = np.array(ninds1[::-1]) cmp_signal, cmp_res = cmp_cd(on_points, off_points, pinds, ninds, period, 2. / 3) if cmp_signal == 'ok': # normal continue elif cmp_signal == 'abnormal': # abnormal abnormal_count += 1 abnormal_list += [this_chan] elif cmp_signal == 'lost': # lost point lost_one_point += 1 lost_one_list += [this_chan] lost_one_pos += [cmp_res] continue elif cmp_signal == 'add': # added point added_one_point += 1 added_one_list += [this_chan] added_one_pos += [cmp_res] continue else: raise Exception('Unknown comparison signal!') # if on_start in pinds or on_start + on_time in ninds: # # to avoid the effect of interference # continue # elif on_start - 1 in pinds or on_start + on_time - 1 in ninds: # lost_one_point += 1 # lost_one_list += [this_chan] # continue # elif on_start - 1 < 0 and on_start + period - 1 in pinds: # lost_one_point += 1 # lost_one_list += [this_chan] # continue # else: # abnormal_count += 1 # abnormal_list += [this_chan] if abnormal_count < 0: raise NoNoisePoint( 'No noise points are detected from this data or the data contains too many masked points!' ) elif abnormal_count > len(auto_inds) / 3: if mpiutil.rank0: np.savez(output_path(abnormal_ns_save_file), on_inds=pinds, off_inds=ninds, on_start=on_start, period=period, on_time=on_time) mpiutil.barrier() raise AbnormalPoints( 'Something rather than one lost point happened. The expected start point is %d, period %d, on_time %d, but the pinds and ninds are: ' % (on_start, period, on_time), pinds, ninds) elif lost_one_point > 2 * len(auto_inds) / 3: uniques, counts = np.unique(lost_one_pos, return_counts=True) maxcount = np.argmax(counts) if counts[maxcount] > 2 * len(auto_inds) / 3: lost_position = uniques[maxcount] else: raise AbnormalPoints( 'More than 2/3 baselines have detected lost points but do not have a universal lost position, some unexpected error happened!\nChannels that probably lost one point and the position: %s' % str(zip(lost_one_list, lost_one_pos))) if mpiutil.rank0: if lost_position == 0: warnings.warn( 'One lost point before the data is detected!', LostPointBegin) else: warnings.warn( 'One lost point before index %d is detected!' % on_points[lost_position], LostPointMiddle) # on_start = on_start - 1 if lost_position == 0: lost_count_before += 1 elif lost_count_before == 0: lost_count_before += 1 on_points = np.array(on_points) on_points[lost_position:] = on_points[lost_position:] - 1 off_points = np.array(off_points) off_points[lost_position:] = off_points[lost_position:] - 1 if on_points[0] < 0: on_points = on_points[1:] off_points = off_points[1:] if mpiutil.rank0: if lost_count_before >= change_reference_tol: # reference_time -= 1/86400.0*rt.attrs['inttime'] reference_time -= rt.attrs['inttime'] warnings.warn( 'Move the reference time one index earlier to compensate!', ChangeReferenceTime) np.savez(output_path(ns_prop_file), period=period, on_time=on_time, off_time=off_time, reference_time=reference_time, lost_count=0, added_count=0) else: warnings.warn( 'The number of recorded lost points was %d while tolerance is %d. Do not change the reference time.' % (lost_count_before, change_reference_tol)) np.savez(output_path(ns_prop_file), period=period, on_time=on_time, off_time=off_time, reference_time=reference_time, lost_count=lost_count_before, added_count=0) # mpiutil.barrier() elif added_one_point > 2 * len(auto_inds) / 3: uniques, counts = np.unique(added_one_pos, return_counts=True) maxcount = np.argmax(counts) if counts[maxcount] > 2 * len(auto_inds) / 3: added_position = uniques[maxcount] else: raise AbnormalPoints( 'More than 2/3 baselines have detected additional points but do not have a universal adding position, some unexpected error happened!\nChannels that probably added one point and the position: %s' % str(zip(added_one_list, added_one_pos))) if mpiutil.rank0: if added_position == 0: warnings.warn( 'One additional point before the data is detected!', AdditionalPointBegin) else: warnings.warn( 'One additional point before index %d is detected!' % on_points[added_position], AdditionalPointMiddle) if added_position == 0: added_count_before += 1 elif added_count_before == 0: added_count_before += 1 # on_start = on_start - 1 on_points = np.array(on_points) on_points[added_position:] = on_points[added_position:] + 1 off_points = np.array(off_points) off_points[added_position:] = off_points[added_position:] + 1 if off_points[-1] >= total_time_span: on_points = on_points[:-1] off_points = off_points[:-1] if mpiutil.rank0: if added_count_before >= change_reference_tol: warnings.warn( 'Move the reference time one index later to compensate!', ChangeReferenceTime) # reference_time += 1/86400.0*rt.attrs['inttime'] reference_time += rt.attrs['inttime'] np.savez(output_path(ns_prop_file), period=period, on_time=on_time, off_time=off_time, reference_time=reference_time, lost_count=0, added_count=0) else: warnings.warn( 'The number of recorded added points was %d while tolerance is %d. Do not change the reference time.' % (added_count_before, change_reference_tol)) np.savez(output_path(ns_prop_file), period=period, on_time=on_time, off_time=off_time, reference_time=reference_time, lost_count=0, added_count=added_count_before) # mpiutil.barrier() elif lost_one_point > 0 or abnormal_count > 0 or added_one_point > 0: if mpiutil.rank0: np.savez(output_path(ns_prop_file), period=period, on_time=on_time, off_time=off_time, reference_time=reference_time, lost_count=0, added_count=0) warnings.warn( 'Abnormal points are detected for some channel, number of abnormal bl is %d, number of channel that probably lost one point is %d, number of channel that probably added one point is %d' % (abnormal_count, lost_one_point, added_one_point), DetectedLostAddedAbnormal) warnings.warn( 'Abnomal channels: %s\nChannels that probably lost one point: %s\nChannels that probably added one point: %s' % (str(abnormal_list), str(lost_one_list), str(added_one_list)), DetectedLostAddedAbnormal) else: if mpiutil.rank0: np.savez(output_path(ns_prop_file), period=period, on_time=on_time, off_time=off_time, reference_time=reference_time, lost_count=0, added_count=0) #================================================ if mpiutil.rank0: print 'Noise source: period = %d, on_time = %d, off_time = %d' % ( period, on_time, off_time) # num_period = np.int(np.ceil(total_time_span / np.float(period))) # ns_on = np.array([False] * on_start + ([True] * on_time + [False] * off_time) * num_period)[:total_time_span] ns_on = np.array([False] * total_time_span) for i, j in zip(on_points, off_points): ns_on[i:j] = True # if mpiutil.rank0: # np.save('ns_on', ns_on) elif not rt.FRB_cal or ns_arr_num < num_noise: # min_inds = 0 # min_inds = min(len(pinds), len(ninds), min_inds) if min_inds != 0 else min(len(pinds), len(ninds)) ns_num_add = [] for ns_arr_index, bl_ind in enumerate(auto_inds): this_chan = rt.bl[bl_ind, 0] # channel of this bl_ind vis = np.ma.array(rt.local_vis[:, :, bl_ind].real, mask=rt.local_vis_mask[:, :, bl_ind]) cnt = vis.count() # number of not masked vals total_cnt = mpiutil.allreduce(cnt) vis_shp = rt.vis.shape ratio = float(total_cnt) / np.prod( (vis_shp[0], vis_shp[1])) # ratio of un-masked vals if ratio < 0.5: # too many masked vals if rt.FRB_cal and ns_arr_num == -1: ns_arr_pinds += [np.array([])] ns_arr_ninds += [np.array([])] if mpiutil.rank0: warnings.warn( 'Too many masked values for auto-correlation of Channel: %d, does not use it' % this_chan) continue tt_mean = mpiutil.gather_array( np.ma.mean(vis, axis=-1).filled(0), root=None) # mean for all freq, for a specific auto bl df = np.diff(tt_mean, axis=-1) pdf = np.where(df > 0, df, 0) pinds = np.where(pdf > pdf.mean() + sigma * pdf.std())[0] #==================================== if len(pinds) == 0: # no raise, might be badchn, continue if rt.FRB_cal and ns_arr_num == -1: ns_arr_pinds += [np.array([])] ns_arr_ninds += [np.array([])] if mpiutil.rank0: warnings.warn( 'No noise on signal is detected for Channel %d, it may be bad channel.' % this_chan) continue #==================================== pinds = pinds + 1 pinds1 = [pinds[0]] for pi in pinds[1:]: if pi - pinds1[-1] > 1: pinds1.append(pi) pinds = np.array(pinds1) pT = Counter( np.diff(pinds)).most_common(1)[0][0] # period of pinds ndf = np.where(df < 0, df, 0) ninds = np.where(ndf < ndf.mean() - sigma * ndf.std())[0] ninds = ninds + 1 ninds = ninds[::-1] #==================================== if len(ninds) == 0: # no raise, might be badchn, continue if rt.FRB_cal and ns_arr_num == -1: ns_arr_pinds += [np.array([])] ns_arr_ninds += [np.array([])] if mpiutil.rank0: warnings.warn( 'No noise off signal is detected for Channel %d, it may be bad channel.' % this_chan) continue #==================================== ninds1 = [ninds[0]] for ni in ninds[1:]: if ni - ninds1[-1] < -1: ninds1.append(ni) ninds = np.array(ninds1[::-1]) nT = Counter( np.diff(ninds)).most_common(1)[0][0] # period of ninds ns_num_add += [min(len(pinds), len(ninds))] if rt.FRB_cal: if ns_arr_num == -1: ns_arr_pinds += [pinds] ns_arr_ninds += [ninds] else: ns_arr_pinds[ns_arr_index] = np.concatenate( [ns_arr_pinds[ns_arr_index], pinds + ns_arr_len]) ns_arr_ninds[ns_arr_index] = np.concatenate( [ns_arr_ninds[ns_arr_index], ninds + ns_arr_len]) #============================================== # continue for non-FRB case if pT != nT: # failed to detect correct period if mpiutil.rank0: warnings.warn( 'Failed to detect correct period for auto-correlation of Channel: %d, positive T %d != negative T %d, does not use it' % (this_chan, pT, nT)) continue else: period = pT ninds = ninds.reshape(-1, 1) dinds = (ninds - pinds).flatten() on_time = Counter(dinds[dinds > 0] % period).most_common(1)[0][0] off_time = Counter(-dinds[dinds < 0] % period).most_common(1)[0][0] if period != on_time + off_time: # incorrect detect if mpiutil.rank0: warnings.warn( 'Incorrect detect for auto-correlation of Channel: %d, period %d != on_time %d + off_time %d, does not use it' % (this_chan, period, on_time, off_time)) continue else: if 'noisesource' in rt.iterkeys(): if rt['noisesource'].shape[ 0] == 1: # only 1 noise source start, stop, cycle = rt['noisesource'][0, :] int_time = rt.attrs['inttime'] true_on_time = np.round((stop - start) / int_time) true_period = np.round(cycle / int_time) if on_time != true_on_time and period != true_period: # inconsistant with the record in the data if mpiutil.rank0: warnings.warn( 'Detected noise source info is inconsistant with the record in the data for auto-correlation of Channel: %d: on_time %d != record_on_time %d, period != record_period %d, does not use it' % (this_chan, on_time, true_on_time, period, true_period)) continue elif rt['noisesource'].shape[ 0] >= 2: # more than 1 noise source if mpiutil.rank0: warnings.warn( 'More than 1 noise source, do not know how to deal with this currently' ) # break if succeed if not rt.FRB_cal: # for FRB case, record all baseline break else: if not rt.FRB_cal: raise DetectNoiseFailure( 'Failed to detect noise source signal') if mpiutil.rank0: print 'Detected noise source: period = %d, on_time = %d, off_time = %d' % ( period, on_time, off_time) on_start = Counter(pinds % period).most_common(1)[0][0] num_period = np.int(np.ceil(len(tt_mean) / np.float(period))) ns_on = np.array([False] * on_start + ([True] * on_time + [False] * off_time) * num_period)[:len(tt_mean)] #============================================== if rt.FRB_cal: ns_arr_len += total_time_span # ns_arr_from_time = np.float128(ns_arr_end_time - ns_arr_start_time)*24.*3600./rt.attrs['inttime'] + 1 ns_arr_from_time = np.around( np.float128(ns_arr_end_time - ns_arr_start_time) / rt.attrs['inttime']) + 1 if ns_arr_len == ns_arr_from_time: pass else: raise IncontinuousData( 'Incontinuous data. Index span calculated from time is %.2f, while the sum of array length is %d! Can not deal with incontinuous data at present!' % (ns_arr_from_time, ns_arr_len)) # print('Detected incontinuous data, use index span %d calculated from time instead of the sum of array length %d!'%(ns_arr_from_time, ns_arr_len)) # ns_arr_len = ns_arr_from_time if ns_arr_num < 0: ns_arr_num = 0 # ns_arr_num += min_inds ns_arr_num += np.around(np.average(ns_num_add)) ns_arr_bl = rt.bl[auto_inds, 0] if mpiutil.rank0 and rt.FRB_cal: np.savez(output_path(ns_arr_file), on_inds=ns_arr_pinds, off_inds=ns_arr_ninds, time_len=ns_arr_len, ns_num=ns_arr_num, auto_chn=ns_arr_bl, start_time=ns_arr_start_time, end_time=ns_arr_end_time) if ns_arr_num < num_noise: raise NoiseNotEnough( 'Number of noise points %d is not enough for calibration(need %d), wait for next file!' % (ns_arr_num, num_noise)) # mpiutil.barrier() #================================================================= if rt.FRB_cal and (not os.path.exists( output_path(ns_prop_file))) and ns_arr_num >= num_noise: if mpiutil.rank0: print( 'Got %d noise points (need %d) to build up noise property file!' % (ns_arr_num, num_noise)) for ns_arr_index, (pinds, ninds) in enumerate( zip(ns_arr_pinds, ns_arr_ninds)): if len(pinds) < num_noise * 2. / 3. or len( ninds) < num_noise * 2. / 3.: print( 'Channel %d does not have enough noise points(%d, need %d) for calibration. Do not use it.' % (ns_arr_bl[ns_arr_index], len(pinds), int(2 * num_noise / 3.))) continue pT = Counter( np.diff(pinds)).most_common(1)[0][0] # period of pinds nT = Counter( np.diff(ninds)).most_common(1)[0][0] # period of ninds #================================================================= if pT != nT: # failed to detect correct period if mpiutil.rank0: warnings.warn( 'Failed to detect correct period for auto-correlation of Channel: %d, positive T %d != negative T %d, does not use it' % (ns_arr_bl[ns_arr_index], pT, nT)) continue else: period = pT ninds = ninds.reshape(-1, 1) dinds = (ninds - pinds).flatten() on_time = Counter(dinds[dinds > 0] % period).most_common(1)[0][0] off_time = Counter(-dinds[dinds < 0] % period).most_common(1)[0][0] if period != on_time + off_time: # incorrect detect if mpiutil.rank0: warnings.warn( 'Incorrect detect for auto-correlation of Channel: %d, period %d != on_time %d + off_time %d, does not use it' % (ns_arr_bl[ns_arr_index], period, on_time, off_time)) continue else: if 'noisesource' in rt.iterkeys(): if rt['noisesource'].shape[ 0] == 1: # only 1 noise source start, stop, cycle = rt['noisesource'][0, :] int_time = rt.attrs['inttime'] true_on_time = np.round((stop - start) / int_time) true_period = np.round(cycle / int_time) if on_time != true_on_time and period != true_period: # inconsistant with the record in the data if mpiutil.rank0: warnings.warn( 'Detected noise source info is inconsistant with the record in the data for auto-correlation of Channel: %d: on_time %d != record_on_time %d, period != record_period %d, does not use it' % (this_chan, on_time, true_on_time, period, true_period)) continue elif rt['noisesource'].shape[ 0] >= 2: # more than 1 noise source if mpiutil.rank0: warnings.warn( 'More than 1 noise source, do not know how to deal with this currently' ) # break if succeed break else: raise DetectNoiseFailure( 'Failed to detect noise source signal') if mpiutil.rank0: print 'Detected noise source: period = %d, on_time = %d, off_time = %d' % ( period, on_time, off_time) on_start = Counter(pinds % period).most_common(1)[0][0] this_time_len = total_time_span first_ind = ns_arr_len - this_time_len skip_inds = first_ind - on_start on_start = period - skip_inds % period num_period = np.int(np.ceil(total_time_span / np.float(period))) ns_on = np.array([False] * on_start + ([True] * on_time + [False] * off_time) * num_period)[:total_time_span] # import matplotlib # matplotlib.use('Agg') # import matplotlib.pyplot as plt # plt.figure() # plt.plot(np.where(ns_on, np.nan, tt_mean)) # # plt.plot(pinds, tt_mean[pinds], 'RI') # # plt.plot(ninds, tt_mean[ninds], 'go') # plt.savefig('df.png') # err ns_on1 = mpiarray.MPIArray.from_numpy_array(ns_on) rt.create_main_time_ordered_dataset('ns_on', ns_on1) rt['ns_on'].attrs['period'] = period rt['ns_on'].attrs['on_time'] = on_time rt['ns_on'].attrs['off_time'] = off_time if (not rt['jul_date'][on_start] is None) and not os.path.exists( output_path(ns_prop_file)): # np.savez(output_path(ns_prop_file), period = period, on_time = on_time, off_time = off_time, reference_time = np.float128(rt['jul_date'][on_start])) np.savez( output_path(ns_prop_file), period=period, on_time=on_time, off_time=off_time, reference_time=np.float128(rt.attrs['sec1970'][0] + on_start * rt.attrs['inttime'])) # mpiutil.barrier() if mpiutil.rank0: print('Save noise property file to %s' % output_path(ns_prop_file)) # set vis_mask corresponding to ns_on on_inds = np.where(rt['ns_on'].local_data[:])[0] rt.local_vis_mask[on_inds] = True if mask_near > 0: on_inds = np.where(ns_on)[0] new_on_inds = on_inds.tolist() for i in xrange(1, mask_near + 1): new_on_inds = new_on_inds + (on_inds - i).tolist() + ( on_inds + i).tolist() new_on_inds = np.unique(new_on_inds) if rt['vis_mask'].distributed: start = rt.vis_mask.local_offset[0] end = start + rt.vis_mask.local_shape[0] else: start = 0 end = rt.vis_mask.shape[0] global_inds = np.arange(start, end).tolist() new_on_inds = np.intersect1d(new_on_inds, global_inds) local_on_inds = [global_inds.index(i) for i in new_on_inds] rt.local_vis_mask[ local_on_inds] = True # set mask using global slicing return super(Detect, self).process(rt)
def wrap(cls, array, axis, comm=None): """Turn a set of numpy arrays into a distributed MPIArray object. This is needed for functions such as `np.fft.fft` which always return an `np.ndarray`. Parameters ---------- array : np.ndarray Array to wrap. axis : integer Axis over which the array is distributed. The lengths are checked to try and ensure this is correct. comm : MPI.Comm, optional The communicator over which the array is distributed. If `None` (default), use `MPI.COMM_WORLD`. Returns ------- dist_array : MPIArray An MPIArray view of the input. """ # from mpi4py import MPI if comm is None: comm = mpiutil.world # Get axis length, both locally, and globally axlen = array.shape[axis] totallen = mpiutil.allreduce(axlen, comm=comm) # Figure out what the distributed layout should be local_num, local_start, local_end = mpiutil.split_local(totallen, comm=comm) # Check the local layout is consistent with what we expect, and send # result to all ranks layout_issue = mpiutil.allreduce(axlen != local_num, op=mpiutil.MAX, comm=comm) if layout_issue: raise Exception( "Cannot wrap, distributed axis local length is incorrect.") # Set shape and offset lshape = array.shape global_shape = list(lshape) global_shape[axis] = totallen loffset = [0] * len(lshape) loffset[axis] = local_start # Setup attributes of class dist_arr = array.view(cls) dist_arr._global_shape = tuple(global_shape) dist_arr._axis = axis dist_arr._local_shape = tuple(lshape) dist_arr._local_offset = tuple(loffset) dist_arr._comm = comm return dist_arr
def process(self, rt): assert isinstance(rt, RawTimestream), '%s only works for RawTimestream object currently' % self.__class__.__name__ if not 'ns_on' in rt.iterkeys(): raise RuntimeError('No noise source info, can not do noise source calibration') local_bl_size, local_bl_start, local_bl_end = mpiutil.split_local(len(rt['blorder'])) rt.redistribute('baseline') num_mean = self.params['num_mean'] phs_only = self.params['phs_only'] save_gain = self.params['save_gain'] tag_output_iter = self.params['tag_output_iter'] gain_file = self.params['gain_file'] bl_incl = self.params['bl_incl'] bl_excl = self.params['bl_excl'] freq_incl = self.params['freq_incl'] freq_excl = self.params['freq_excl'] ns_stable = self.params['ns_stable'] last_transit2stable = self.params['last_transit2stable'] #=================================== normal_gain_file = self.params['normal_gain_file'] rt.gain_file = self.params['gain_file'] absolute_gain_filename = self.params['absolute_gain_filename'] use_center_data = self.params['use_center_data'] if use_center_data and rt['ns_on'].attrs['on_time'] <= 2: warnings.warn('The period %d <= 2, cannot get rid of the beginning and ending points. Use the whole average automatically!') use_center_data = False #=================================== # apply abs gain to data if ps_first if rt.ps_first: if absolute_gain_filename is None: raise NoPsGainFile('Absent parameter absolute_gain_file. In ps_first mode, absolute_gain_filename is required to process!') if not os.path.exists(output_path(absolute_gain_filename)): raise NoPsGainFile('No absolute gain file %s, do the ps calibration first!'%output_path(absolute_gain_filename)) with h5py.File(output_path(absolute_gain_filename, 'r')) as agfilein: if not ns_stable: # try: # if there is transit in this batch of data, build up transit amp and phs # rt.normal_index = np.where(np.abs(rt['jul_date'] - agfilein.attrs['transit_time']) < 2.e-6)[0][0] # to avoid numeric error. 1.e-6 julian date is about 0.1s if rt['sec1970'][0] <= agfilein.attrs['transit_time'] and rt['sec1970'][-1] >= agfilein.attrs['transit_time']: rt.normal_index = np.int64(np.around((agfilein.attrs['transit_time'] - rt['sec1970'][0])/rt.attrs['inttime'])) if mpiutil.rank0: print('Detected transit, time index %d, build up transit normalization gain file!'%rt.normal_index) build_normal_file = True rt.normal_phs = np.zeros((rt['vis'].shape[1], local_bl_size)) rt.normal_phs.fill(np.nan) if not phs_only: rt.normal_amp = np.zeros((rt['vis'].shape[1], local_bl_size)) rt.normal_amp.fill(np.nan) # except IndexError: # no transit in this batch of data, load transit amp and phs from file else: rt.normal_index = None # no transit flag build_normal_file = False if not os.path.exists(output_path(normal_gain_file)): time_info = (aipy.phs.juldate2ephem(agfilein.attrs['transit_jul'] + 8./24.), aipy.phs.juldate2ephem(rt['jul_date'][0] + 8./24.), aipy.phs.juldate2ephem(rt['jul_date'][-1] + 8./24.)) raise TransitGainNotRecorded('The transit %s is not in time range %s to %s and no transit normalization gain was recorded!'%time_info) else: if mpiutil.rank0: print('No transit, use existing transit normalization gain file!') with h5py.File(output_path(normal_gain_file), 'r') as filein: rt.normal_phs = filein['phs'][:, local_bl_start:local_bl_end] if not phs_only: rt.normal_amp = filein['amp'][:, local_bl_start:local_bl_end] else: rt.normal_index = None # no transit flag rt.normal_phs = np.zeros((rt['vis'].shape[1], local_bl_size)) rt.normal_phs.fill(np.nan) if not phs_only: rt.normal_amp = np.zeros((rt['vis'].shape[1], local_bl_size)) rt.normal_amp.fill(np.nan) try: stable_time = rt['pointingtime'][-1,-1] except KeyError: stable_time = rt['transitsource'][-1, 0] if stable_time > 0: # the time when the antenna will be pointing at the target region was recorded, use it stable_time += 300 # plus 5 min to ensure stable else: stable_time = rt['transitsource'][-2, 0] + last_transit2stable stable_time_jul = datetime.utcfromtimestamp(stable_time) stable_time_jul = ephem.julian_date(stable_time_jul) # convert to julian date, convenient to display if not os.path.exists(output_path(normal_gain_file)): if mpiutil.rank0: print('Normalization gain file has not been built yet. Try to build it!') print('Recorded transit Time: %s'%aipy.phs.juldate2ephem(agfilein.attrs['transit_jul'] + 8./24.)) print('Last transit Time: %s'%aipy.phs.juldate2ephem(ephem.julian_date(datetime.utcfromtimestamp(rt['transitsource'][-2,0])) + 8./24.)) print('First time point of this data: %s'%aipy.phs.juldate2ephem(rt['jul_date'][0] + 8./24.)) build_normal_file = True # if rt['jul_date'][0] < stable_time: if rt.attrs['sec1970'][0] < stable_time: raise BeforeStableTime('The beginning time point is %s, but only after %s, will the noise source be stable. Abort the noise calibration!'%(aipy.phs.juldate2ephem(rt['jul_date'][0] + 8./24.), aipy.phs.juldate2ephem(stable_time_jul + 8./24.))) else: if mpiutil.rank0: print('Use existing normalization gain file!') build_normal_file = False with h5py.File(output_path(normal_gain_file), 'r') as filein: rt.normal_phs = filein['phs'][:, local_bl_start:local_bl_end] if not phs_only: rt.normal_amp = filein['amp'][:, local_bl_start:local_bl_end] # apply absolute gain absgain = agfilein['gain'][:] polfeed, _ = bl2pol_feed_inds(rt.local_bl, agfilein['gain'].attrs['feed'][:], agfilein['gain'].attrs['pol'][:]) for ii, (ipol, ifeed, jpol, jfeed) in enumerate(polfeed): # rt.local_vis[:,:,ii] /= (absgain[None,:,ipol,ifeed-1] * absgain[None,:,jpol,jfeed-1].conj()) rt.local_vis[:,:,ii] /= (absgain[None,:,ipol,ifeed] * absgain[None,:,jpol,jfeed].conj()) #=================================== nt = rt.local_vis.shape[0] if num_mean <= 0: raise RuntimeError('Invalid num_mean = %s' % num_mean) ns_on = rt['ns_on'][:] ns_on = np.where(ns_on, 1, 0) diff_ns = np.diff(ns_on) inds = np.where(diff_ns==1)[0] # NOTE: these are inds just 1 before the first ON if not rt.FRB_cal: # for FRB there might be just one noise point, avoid waste if inds[0]-1 < 0: # no off data in the beginning to use inds = inds[1:] if inds[-1]+2 > nt-1: # no on data in the end to use inds = inds[:-1] if save_gain: num_inds = len(inds) shp = (num_inds,)+rt.local_vis.shape[1:] dtype = rt.local_vis.real.dtype # create dataset to record ns_cal_time_inds rt.create_time_ordered_dataset('ns_cal_time_inds', inds) # create dataset to record ns_cal_phase ns_cal_phase = np.empty(shp, dtype=dtype) ns_cal_phase[:] = np.nan ns_cal_phase = mpiarray.MPIArray.wrap(ns_cal_phase, axis=2, comm=rt.comm) rt.create_freq_and_bl_ordered_dataset('ns_cal_phase', ns_cal_phase, axis_order=(None, 1, 2)) rt['ns_cal_phase'].attrs['unit'] = 'radians' if not phs_only: # create dataset to record ns_cal_amp ns_cal_amp = np.empty(shp, dtype=dtype) ns_cal_amp[:] = np.nan ns_cal_amp = mpiarray.MPIArray.wrap(ns_cal_amp, axis=2, comm=rt.comm) rt.create_freq_and_bl_ordered_dataset('ns_cal_amp', ns_cal_amp, axis_order=(None, 1, 2)) if bl_incl == 'all': bls_plt = [ tuple(bl) for bl in rt.bl ] else: bls_plt = [ bl for bl in bl_incl if not bl in bl_excl ] if freq_incl == 'all': freq_plt = range(rt.freq.shape[0]) else: freq_plt = [ fi for fi in freq_incl if not fi in freq_excl ] show_progress = self.params['show_progress'] progress_step = self.params['progress_step'] if rt.ps_first: rt.freq_and_bl_data_operate(self.cal, full_data=True, show_progress=show_progress, progress_step=progress_step, keep_dist_axis=False, num_mean=num_mean, inds=inds, bls_plt=bls_plt, freq_plt=freq_plt, build_normal_file = build_normal_file) else: rt.freq_and_bl_data_operate(self.cal, full_data=True, show_progress=show_progress, progress_step=progress_step, keep_dist_axis=False, num_mean=num_mean, inds=inds, bls_plt=bls_plt, freq_plt=freq_plt) if save_gain: interp_mask_ratio = mpiutil.allreduce(np.sum(rt.interp_mask_count))/1./mpiutil.allreduce(np.size(rt.interp_mask_count)) * 100. if interp_mask_ratio > 50. and rt.normal_index is not None: warnings.warn('%.1f%% of the data was masked due to shortage of noise points for interpolation(need at least 4 to perform cubic spline)! The pointsource calibration may not be done due to too many masked points!'%interp_mask_ratio, NotEnoughPointToInterpolateWarning) if interp_mask_ratio > 80.: rt.interp_all_masked = True # gather bl_order to rank0 bl_order = mpiutil.gather_array(rt['blorder'].local_data, axis=0, root=0, comm=rt.comm) # gather ns_cal_phase / ns_cal_amp to rank 0 ns_cal_phase = mpiutil.gather_array(rt['ns_cal_phase'].local_data, axis=2, root=0, comm=rt.comm) phs_unit = rt['ns_cal_phase'].attrs['unit'] rt.delete_a_dataset('ns_cal_phase', reserve_hint=False) if not phs_only: ns_cal_amp = mpiutil.gather_array(rt['ns_cal_amp'].local_data, axis=2, root=0, comm=rt.comm) rt.delete_a_dataset('ns_cal_amp', reserve_hint=False) if tag_output_iter: gain_file = output_path(gain_file, iteration=self.iteration) else: gain_file = output_path(gain_file) if rt.ps_first: phase_finite_count = mpiutil.allreduce(np.isfinite(rt.normal_phs).sum()) if not phs_only: amp_finite_count = mpiutil.allreduce(np.isfinite(rt.normal_amp).sum()) if mpiutil.rank0: if rt.ps_first and build_normal_file: if phase_finite_count == 0: raise AllMasked('All values are masked when calculating phase!') if not phs_only: if amp_finite_count == 0: raise AllMasked('All values are masked when calculating amplitude!') with h5py.File(output_path(normal_gain_file), 'w') as tapfilein: tapfilein.create_dataset('amp',(rt['vis'].shape[1], rt['vis'].shape[2])) tapfilein.create_dataset('phs',(rt['vis'].shape[1], rt['vis'].shape[2])) with h5py.File(gain_file, 'w') as f: # save time f.create_dataset('time', data=rt['jul_date'][:]) f['time'].attrs['unit'] = 'Julian date' # save freq f.create_dataset('freq', data=rt['freq'][:]) f['freq'].attrs['unit'] = rt['freq'].attrs['unit'] # save bl f.create_dataset('bl_order', data=bl_order) # save ns_cal_time_inds f.create_dataset('ns_cal_time_inds', data=rt['ns_cal_time_inds'][:]) # save ns_cal_phase f.create_dataset('ns_cal_phase', data=ns_cal_phase) f['ns_cal_phase'].attrs['unit'] = phs_unit f['ns_cal_phase'].attrs['dim'] = '(time, freq, bl)' if not phs_only: # save ns_cal_amp f.create_dataset('ns_cal_amp', data=ns_cal_amp) # save channo f.create_dataset('channo', data=rt['channo'][:]) f['channo'].attrs['dim'] = rt['channo'].attrs['dimname'] if rt.exclude_bad: f['channo'].attrs['badchn'] = rt['channo'].attrs['badchn'] if not (absolute_gain_filename is None): if not os.path.exists(output_path(absolute_gain_filename)): raise NoPsGainFile('No absolute gain file %s, do the ps calibration first!'%output_path(absolute_gain_filename)) with h5py.File(output_path(absolute_gain_filename,'r')) as abs_gain: new_gain = uni_gain(abs_gain, f, phs_only = phs_only) f.create_dataset('uni_gain', data = new_gain) f['uni_gain'].attrs['dim'] = '(time, freq, bl)' if rt.ps_first and build_normal_file: if mpiutil.rank0: print('Start write normalization gain into %s'%output_path(normal_gain_file)) for i in range(10): try: for ii in range(mpiutil.size): if ii == mpiutil.rank: with h5py.File(output_path(normal_gain_file), 'r+') as fileout: fileout['phs'][:,local_bl_start:local_bl_end] = rt.normal_phs[:,:] if not phs_only: fileout['amp'][:,local_bl_start:local_bl_end] = rt.normal_amp[:,:] mpiutil.barrier() break except IOError: time.sleep(0.5) continue rt.delete_a_dataset('ns_cal_time_inds', reserve_hint=False) return super(NsCal, self).process(rt)
def generate(self, regen=False): """Calculate the total Fisher matrix and bias and save to a file. Parameters ---------- regen : boolean, optional Force regeneration if products already exist (default `False`). """ if mpiutil.rank0: st = time.time() print("======== Starting PS calculation ========") ffile = self.psdir + "/fisher.hdf5" if os.path.exists(ffile) and not regen: print("Fisher matrix file: %s exists. Skipping..." % ffile) return mpiutil.barrier() # Pre-compute all the angular power spectra for the bands self.genbands() # Calculate Fisher and bias for each m # Pair up each list item with its position. zlist = list(enumerate(range(self.telescope.mmax + 1))) # Partition list based on MPI rank llist = mpiutil.partition_list_mpi(zlist) # Operate on sublist fisher_bias_list = [self.fisher_bias_m(item) for ind, item in llist] # Unpack into separate lists of the Fisher matrix and bias fisher_loc, bias_loc = zip(*fisher_bias_list) # Sum over all local m-modes to get the over all Fisher and bias pe process fisher_loc = np.sum(np.array(fisher_loc), axis=0).real # Be careful of the .real here bias_loc = np.sum(np.array(bias_loc), axis=0).real # Be careful of the .real here self.fisher = mpiutil.allreduce(fisher_loc, op=MPI.SUM) self.bias = mpiutil.allreduce(bias_loc, op=MPI.SUM) # Write out all the PS estimation products if mpiutil.rank0: et = time.time() print("======== Ending PS calculation (time=%f) ========" % (et - st)) # Check to see ensure that Fisher matrix isn't all zeros. if not (self.fisher == 0).all(): # Generate derived quantities (covariance, errors..) cv = la.pinv(self.fisher, rcond=1e-8) err = cv.diagonal()**0.5 cr = cv / np.outer(err, err) else: cv = np.zeros_like(self.fisher) err = cv.diagonal() cr = np.zeros_like(self.fisher) f = h5py.File(self.psdir + "/fisher.hdf5", "w") f.attrs["bandtype"] = np.string_( self.bandtype) # HDF5 string issues f.create_dataset("fisher/", data=self.fisher) f.create_dataset("bias/", data=self.bias) f.create_dataset("covariance/", data=cv) f.create_dataset("errors/", data=err) f.create_dataset("correlation/", data=cr) f.create_dataset("band_power/", data=self.band_power) if self.bandtype == "polar": f.create_dataset("k_start/", data=self.k_start) f.create_dataset("k_end/", data=self.k_end) f.create_dataset("k_center/", data=self.k_center) f.create_dataset("theta_start/", data=self.theta_start) f.create_dataset("theta_end/", data=self.theta_end) f.create_dataset("theta_center/", data=self.theta_center) f.create_dataset("k_bands", data=self.k_bands) f.create_dataset("theta_bands", data=self.theta_bands) elif self.bandtype == "cartesian": f.create_dataset("kpar_start/", data=self.kpar_start) f.create_dataset("kpar_end/", data=self.kpar_end) f.create_dataset("kpar_center/", data=self.kpar_center) f.create_dataset("kperp_start/", data=self.kperp_start) f.create_dataset("kperp_end/", data=self.kperp_end) f.create_dataset("kperp_center/", data=self.kperp_center) f.create_dataset("kpar_bands", data=self.kpar_bands) f.create_dataset("kperp_bands", data=self.kperp_bands) f.close()
def process(self, rt): assert isinstance( rt, RawTimestream ), '%s only works for RawTimestream object currently' % self.__class__.__name__ channel = self.params['channel'] mask_near = max(0, int(self.params['mask_near'])) rt.redistribute(0) # make time the dist axis auto_inds = np.where( rt.bl[:, 0] == rt.bl[:, 1])[0].tolist() # inds for auto-correlations channels = [rt.bl[ai, 0] for ai in auto_inds] # all chosen channels if channel is not None: if channel in channels: bl_ind = auto_inds[channels.index(channel)] else: bl_ind = auto_inds[0] if mpiutil.rank0: print 'Warning: Required channel %d doen not in the data, use channel %d instead' % ( channel, rt.bl[bl_ind, 0]) else: bl_ind = auto_inds[0] # move the chosen channel to the first auto_inds.remove(bl_ind) auto_inds = [bl_ind] + auto_inds for bl_ind in auto_inds: this_chan = rt.bl[bl_ind, 0] # channel of this bl_ind vis = np.ma.array(rt.local_vis[:, :, bl_ind].real, mask=rt.local_vis_mask[:, :, bl_ind]) cnt = vis.count() # number of not masked vals total_cnt = mpiutil.allreduce(cnt) vis_shp = rt.vis.shape ratio = float(total_cnt) / np.prod( (vis_shp[0], vis_shp[1])) # ratio of un-masked vals if ratio < 0.5: # too many masked vals if mpiutil.rank0: warnings.warn( 'Too many masked values for auto-correlation of Channel: %d, does not use it' % this_chan) continue tt_mean = mpiutil.gather_array(np.ma.mean(vis, axis=-1).filled(0), root=None) tt_mean_sort = np.sort(tt_mean) tt_mean_sort = tt_mean_sort[tt_mean_sort > 0] ttms_div = tt_mean_sort[1:] / tt_mean_sort[:-1] ind = np.argmax(ttms_div) if ind > 0 and ttms_div[ind] > 2.0 * max(ttms_div[ind - 1], ttms_div[ind + 1]): sep = np.sqrt(tt_mean_sort[ind] * tt_mean_sort[ind + 1]) break # ttms_diff = np.diff(tt_mean_sort) # ind = np.argmax(ttms_diff) # if ind > 0 and ttms_diff[ind] > 2.0 * max(ttms_diff[ind-1], ttms_diff[ind+1]): # sep = 0.5 * (tt_mean_sort[ind] + tt_mean_sort[ind+1]) # break else: raise RuntimeError( 'Failed to get the threshold to separate ns signal out') ns_on = np.where(tt_mean > sep, True, False) nTs = [] nFs = [] nT = 1 if ns_on[0] else 0 nF = 1 if not ns_on[0] else 0 for i in range(1, len(ns_on)): if ns_on[i]: if ns_on[i] == ns_on[i - 1]: nT += 1 else: nT = 1 nFs.append(nF) else: if ns_on[i] == ns_on[i - 1]: nF += 1 else: nF = 1 nTs.append(nT) on_time = Counter(nTs).most_common(1)[0][0] off_time = Counter(nFs).most_common(1)[0][0] period = on_time + off_time if 'noisesource' in rt.iterkeys(): if rt['noisesource'].shape[0] == 1: # only 1 noise source start, stop, cycle = rt['noisesource'][0, :] int_time = rt.attrs['inttime'] true_on_time = np.round((stop - start) / int_time) true_period = np.round(cycle / int_time) if on_time != true_on_time and period != true_period: # inconsistant with the record in the data if mpiutil.rank0: warnings.warn( 'Detected noise source info is inconsistant with the record in the data for auto-correlation of Channel: %d: on_time %d != record_on_time %d, period != record_period %d, does not use it' % (this_chan, on_time, true_on_time, period, true_period)) elif rt['noisesource'].shape[0] >= 2: # more than 1 noise source if mpiutil.rank0: warnings.warn( 'More than 1 noise source, do not know how to deal with this currently' ) if mpiutil.rank0: print 'Detected noise source: period = %d, on_time = %d, off_time = %d' % ( period, on_time, off_time) ns_on1 = mpiarray.MPIArray.from_numpy_array(ns_on) rt.create_main_time_ordered_dataset('ns_on', ns_on1) rt['ns_on'].attrs['period'] = period rt['ns_on'].attrs['on_time'] = on_time rt['ns_on'].attrs['off_time'] = off_time # set vis_mask corresponding to ns_on on_inds = np.where(rt['ns_on'].local_data[:])[0] rt.local_vis_mask[on_inds] = True if mask_near > 0: on_inds = np.where(ns_on)[0] new_on_inds = on_inds.tolist() for i in xrange(1, mask_near + 1): new_on_inds = new_on_inds + (on_inds - i).tolist() + ( on_inds + i).tolist() new_on_inds = np.unique(new_on_inds) if rt['vis_mask'].distributed: start = rt.vis_mask.local_offset[0] end = start + rt.vis_mask.local_shape[0] else: start = 0 end = rt.vis_mask.shape[0] global_inds = np.arange(start, end).tolist() new_on_inds = np.intersect1d(new_on_inds, global_inds) local_on_inds = [global_inds.index(i) for i in new_on_inds] rt.local_vis_mask[ local_on_inds] = True # set mask using global slicing return super(Detect, self).process(rt)
def read_input(self): """Method for (maybe iteratively) reading data from input data files.""" days = self.params['days'] extra_inttime = self.params['extra_inttime'] drop_days = self.params['drop_days'] mode = self.params['mode'] dist_axis = self.params['dist_axis'] ngrp = len(self.input_grps) if self.next_grp and ngrp > 1: if mpiutil.rank0: print 'Start file group %d of %d...' % (self.grp_cnt, ngrp) self.restart_iteration() # re-start iteration for each group self.next_grp = False self.abs_start = None self.abs_stop = None input_files = self.input_grps[self.grp_cnt] start = self.start[self.grp_cnt] stop = self.stop[self.grp_cnt] if self.int_time is None: # NOTE: here assume all files have the same int_time with h5py.File(self.input_files[0], 'r') as f: self.int_time = f.attrs['inttime'] if self.abs_start is None or self.abs_stop is None: tmp_tod = self._Tod_class(input_files, mode, start, stop, dist_axis) self.abs_start = tmp_tod.main_data_start self.abs_stop = tmp_tod.main_data_stop del tmp_tod iteration = self.iteration if self.iterable else 0 this_start = self.abs_start + np.int( np.around(iteration * days * const.sday / self.int_time)) this_stop = min( self.abs_stop, self.abs_start + np.int( np.around( (iteration + 1) * days * const.sday / self.int_time)) + 2 * extra_inttime) if this_stop >= self.abs_stop: self.next_grp = True self.grp_cnt += 1 if this_start >= this_stop: self.next_grp = True self.grp_cnt += 1 return None this_span = self.int_time * (this_stop - this_start) # in unit second if this_span < drop_days * const.sday: if mpiutil.rank0: print 'Not enough span time, drop it...' return None elif (this_stop - this_start) <= extra_inttime: # use int comparision if mpiutil.rank0: print 'Not enough span time (less than `extra_inttime`), drop it...' return None tod = self._Tod_class(input_files, mode, this_start, this_stop, dist_axis) tod, _ = self.data_select(tod) tod.load_all() # load in all data if self.start_ra is None: # the first iteration if 'time' == tod.main_data_axes[tod.main_data_dist_axis]: # ra_dec is distributed among processes # find the point of ra_dec[extra_inttime, 0] of the global array local_offset = tod['ra_dec'].local_offset[0] local_shape = tod['ra_dec'].local_shape[0] if local_offset <= extra_inttime and extra_inttime < local_offset + local_shape: in_this = 1 start_ra = tod['ra_dec'].local_data[extra_inttime - local_offset, 0] else: in_this = 0 start_ra = None # get the rank max_val, in_rank = mpiutil.allreduce((in_this, tod.rank), op=mpiutil.MAXLOC, comm=tod.comm) # bcast from this rank start_ra = mpiutil.bcast(start_ra, root=in_rank, comm=tod.comm) self.start_ra = start_ra else: self.start_ra = ra_dec[extra_inttime, 0] tod.vis.attrs['start_ra'] = self.start_ra # used for re_order return tod
def process(self, ts): assert isinstance( ts, Timestream ), '%s only works for Timestream object' % self.__class__.__name__ # if mpiutil.rank0: # saveVmat = [] calibrator = self.params['calibrator'] catalog = self.params['catalog'] vis_conj = self.params['vis_conj'] zero_diag = self.params['zero_diag'] span = self.params['span'] reserve_high_gain = self.params['reserve_high_gain'] plot_figs = self.params['plot_figs'] fig_prefix = self.params['fig_name'] tag_output_iter = self.params['tag_output_iter'] save_src_vis = self.params['save_src_vis'] src_vis_file = self.params['src_vis_file'] subtract_src = self.params['subtract_src'] apply_gain = self.params['apply_gain'] save_gain = self.params['save_gain'] save_phs_change = self.params['save_phs_change'] gain_file = self.params['gain_file'] temperature_convert = self.params['temperature_convert'] show_progress = self.params['show_progress'] progress_step = self.params['progress_step'] srcdict = self.params['srcdict'] max_iter = self.params['max_iter'] # if mpiutil.rank0: # print(ts.keys()) # print(ts.attrs.keys()) # # import sys # sys.exit() if save_src_vis or subtract_src or apply_gain or save_gain: pol_type = ts['pol'].attrs['pol_type'] if pol_type != 'linear': raise RuntimeError('Can not do ps_cal for pol_type: %s' % pol_type) ts.redistribute('baseline') feedno = ts['feedno'][:].tolist() pol = [ts.pol_dict[p] for p in ts['pol'][:]] # as string gain_pd = { 'xx': 0, 'yy': 1, 0: 'xx', 1: 'yy' } # for gain related op bls = mpiutil.gather_array(ts.local_bl[:], root=None, comm=ts.comm) # # antpointing = np.radians(ts['antpointing'][-1, :, :]) # radians # transitsource = ts['transitsource'][:] # transit_time = transitsource[-1, 0] # second, sec1970 # int_time = ts.attrs['inttime'] # second # calibrator #see whether the source should be observed #only for single calibrator if ts.is_dish: obsd_sources = ts['transitsource'].attrs['srcname'] obsd_sources = obsd_sources.split(',') # srcdict = {'cas':'CassiopeiaA'} src_index = 0.1 for src_short, src_long in srcdict.items(): # print(src_short,src_long,obsd_sources,calibrator) if (src_short == calibrator) and (src_long in obsd_sources): src_index = obsd_sources.index(src_long) break else: raise Exception( 'the transit source is not included in the observation plan!' ) obs_transit_time = ts['transitsource'][src_index][0] # srclist, cutoff, catalogs = a.scripting.parse_srcs(calibrator, catalog) # cat = a.src.get_catalog(srclist, cutoff, catalogs) # assert(len(cat) == 1), 'Allow only one calibrator' # s = cat.values()[0] # get the calibrator try: s = calibrators.get_src(calibrator) except KeyError: if mpiutil.rank0: print 'Calibrator %s is unavailable, available calibrators are:' for key, d in calibrators.src_data.items(): print '%8s -> %12s' % (key, d[0]) raise RuntimeError('Calibrator %s is unavailable') if mpiutil.rank0: print 'Try to calibrate with %s...' % s.src_name # if mpiutil.rank0: # print 'Calibrating for source %s with' % calibrator, # print 'strength', s._jys, 'Jy', # print 'measured at', s.mfreq, 'GHz', # print 'with index', s.index # get transit time of calibrator # array aa = ts.array aa.set_jultime(ts['jul_date'][0]) # the first obs time point next_transit = aa.next_transit(s) next_transit = aa.next_transit( s) if not ts.is_dish else ephem.date( datetime.utcfromtimestamp(obs_transit_time)) transit_time = a.phs.ephem2juldate(next_transit) # Julian date # get time zone pattern = '[-+]?\d+' tz = re.search(pattern, ts.attrs['timezone']).group() tz = int(tz) local_next_transit = ephem.Date( next_transit + tz * ephem.hour) # plus 8h to get Beijing time # if transit_time > ts['jul_date'][-1]: #======================================================== if ts.is_dish: if (transit_time > max(ts['jul_date'][-1], ts['jul_date'][:].max())) or (transit_time < min( ts['jul_date'][0], ts['jul_date'][:].min())): # raise RuntimeError('dish data does not contain local transit time %s of source %s' % (local_next_transit, calibrator)) raise NoTransit( 'Dish data does not contain local transit time %s of source %s' % (local_next_transit, calibrator)) transit_inds = [ np.searchsorted(ts['jul_date'][:], transit_time) ] peak_span = int(np.around(10. / ts.attrs['inttime'])) peak_start = ts['jul_date'][transit_inds[0] - peak_span] peak_end = ts['jul_date'][transit_inds[0] + peak_span] tspan = (peak_end - peak_start) * 86400. if mpiutil.rank0: print('transit peak: %s' % local_next_transit) print('%d point (should be about 10s) previous: %s' % (peak_span, ephem.date( a.phs.juldate2ephem(peak_start) + tz * ephem.hour))) print( '%d point (should be about 10s) later: %s' % (peak_span, ephem.date( a.phs.juldate2ephem(peak_end) + tz * ephem.hour))) if tspan > 30.: warnings.warn( 'Peak of transit is not continuous in the data! May lead to poor performance!' ) else: if transit_time > max(ts['jul_date'][-1], ts['jul_date'][:].max()): # raise RuntimeError('Cylinder data does not contain local transit time %s of source %s' % (local_next_transit, calibrator)) raise NoTransit( 'Cylinder data does not contain local transit time %s of source %s' % (local_next_transit, calibrator)) # the first transit index transit_inds = [ np.searchsorted(ts['jul_date'][:], transit_time) ] # find all other transit indices aa.set_jultime(ts['jul_date'][0] + 1.0) transit_time = a.phs.ephem2juldate( aa.next_transit(s)) # Julian date cnt = 2 while (transit_time <= ts['jul_date'][-1]): transit_inds.append( np.searchsorted(ts['jul_date'][:], transit_time)) aa.set_jultime(ts['jul_date'][0] + 1.0 * cnt) # Julian date # transit_time = a.phs.ephem2juldate(aa.next_transit(s) if not ts.is_dish else ephem.date(datetime.utcfromtimestamp(obs_transit_time))) transit_time = a.phs.ephem2juldate(aa.next_transit(s)) cnt += 1 #======================================================== if mpiutil.rank0: print 'transit ind of %s: %s, time: %s' % ( s.src_name, transit_inds, local_next_transit) if (not ts.ps_first) and ts.interp_all_masked: raise NotEnoughPointToInterpolateError( 'More than 80% of the data was masked due to shortage of noise points for interpolation(need at least 4 to perform cubic spline)! The pointsource calibration may not be done due to too many masked points!' ) ### now only use the first transit point to do the cal ### may need to improve in the future transit_ind = transit_inds[0] int_time = ts.attrs['inttime'] # second start_ind = transit_ind - np.int(span / int_time) end_ind = transit_ind + np.int( span / int_time) + 1 # plus 1 to make transit_ind is at the center start_ind = max(0, start_ind) end_ind = min(end_ind, ts.vis.shape[0]) if vis_conj: ts.local_vis[:] = ts.local_vis.conj() nt = end_ind - start_ind t_inds = range(start_ind, end_ind) freq = ts.freq[:] nf = len(freq) nlb = len(ts.local_bl[:]) nfeed = len(feedno) tfp_inds = list( itertools.product( t_inds, range(nf), [pol.index('xx'), pol.index('yy')])) # only for xx and yy ns, ss, es = mpiutil.split_all(len(tfp_inds), comm=ts.comm) # gather data to make each process to have its own data which has all bls for ri, (ni, si, ei) in enumerate(zip(ns, ss, es)): lvis = np.zeros((ni, nlb), dtype=ts.vis.dtype) lvis_mask = np.zeros((ni, nlb), dtype=ts.vis_mask.dtype) for ii, (ti, fi, pi) in enumerate(tfp_inds[si:ei]): lvis[ii] = ts.local_vis[ti, fi, pi] lvis_mask[ii] = ts.local_vis_mask[ti, fi, pi] # gather vis from all process for separate bls gvis = mpiutil.gather_array(lvis, axis=1, root=ri, comm=ts.comm) gvis_mask = mpiutil.gather_array(lvis_mask, axis=1, root=ri, comm=ts.comm) if ri == mpiutil.rank: tfp_linds = tfp_inds[si:ei] # inds for this process this_vis = gvis this_vis_mask = gvis_mask del tfp_inds del lvis del lvis_mask tfp_len = len(tfp_linds) cnan = complex(np.nan, np.nan) # complex nan if save_src_vis or subtract_src: # save calibrator src vis lsrc_vis = np.full((tfp_len, nfeed, nfeed), cnan, dtype=ts.vis.dtype) if save_src_vis: # save sky vis lsky_vis = np.full((tfp_len, nfeed, nfeed), cnan, dtype=ts.vis.dtype) # save outlier vis lotl_vis = np.full((tfp_len, nfeed, nfeed), cnan, dtype=ts.vis.dtype) if apply_gain or save_gain: lGain = np.full((tfp_len, nfeed), cnan, dtype=ts.vis.dtype) # find indices mapping between Vmat and vis # bis = range(nbl) bis_conj = [] # indices that shold be conj mis = [ ] # indices in the nfeed x nfeed matrix by flatten it to a vector mis_conj = [ ] # indices (of conj vis) in the nfeed x nfeed matrix by flatten it to a vector for bi, (fdi, fdj) in enumerate(bls): ai, aj = feedno.index(fdi), feedno.index(fdj) mis.append(ai * nfeed + aj) if ai != aj: bis_conj.append(bi) mis_conj.append(aj * nfeed + ai) # construct visibility matrix for a single time, freq, pol Vmat = np.full((nfeed, nfeed), cnan, dtype=ts.vis.dtype) # get flus of the calibrator in the observing frequencies Sc = s.get_jys(freq) if show_progress and mpiutil.rank0: pg = progress.Progress(tfp_len, step=progress_step) for ii, (ti, fi, pi) in enumerate(tfp_linds): if show_progress and mpiutil.rank0: pg.show(ii) # when noise on, just pass if 'ns_on' in ts.iterkeys() and ts['ns_on'][ti]: continue # aa.set_jultime(ts['jul_date'][ti]) # s.compute(aa) # get the topocentric coordinate of the calibrator at the current time # s_top = s.get_crds('top', ncrd=3) # aa.sim_cache(cat.get_crds('eq', ncrd=3)) # for compute bm_response and sim Vmat.flat[mis] = np.ma.array( this_vis[ii], mask=this_vis_mask[ii]).filled(cnan) Vmat.flat[mis_conj] = np.ma.array( this_vis[ii, bis_conj], mask=this_vis_mask[ii, bis_conj]).conj().filled(cnan) # if mpiutil.rank0: # saveVmat += [Vmat.copy()] if save_src_vis: lsky_vis[ii] = Vmat # set invalid val to 0 invalid = ~np.isfinite(Vmat) # a bool array # if too many masks if np.where(invalid)[0].shape[0] > 0.3 * nfeed**2: continue Vmat[invalid] = 0 # # if all are zeros if np.allclose(Vmat, 0.0): continue # fill diagonal of Vmat to 0 if zero_diag: np.fill_diagonal(Vmat, 0) # initialize the outliers med = np.median(Vmat.real) + 1.0J * np.median(Vmat.imag) diff = Vmat - med S0 = np.where( np.abs(diff) > 3.0 * rpca_decomp.MAD(Vmat), diff, 0) # stable PCA decomposition # V0, S = rpca_decomp.decompose(Vmat, rank=1, S=S0, max_iter=100, threshold='hard', tol=1.0e-6, debug=False) V0, S = rpca_decomp.decompose(Vmat, rank=1, S=S0, max_iter=max_iter, threshold='hard', tol=1.0e-6, debug=False) # V0, S = rpca_decomp.decompose(Vmat, rank=1, S=S0, max_iter=1, threshold='hard', tol=1., debug=False) if save_src_vis or subtract_src: lsrc_vis[ii] = V0 if save_src_vis: lotl_vis[ii] = S # plot if plot_figs: ind = ti - start_ind # plot Vmat plt.figure(figsize=(13, 5)) plt.subplot(121) plt.imshow(Vmat.real, aspect='equal', origin='lower', interpolation='nearest') plt.colorbar(shrink=1.0) plt.subplot(122) plt.imshow(Vmat.imag, aspect='equal', origin='lower', interpolation='nearest') plt.colorbar(shrink=1.0) fig_name = '%s_V_%d_%d_%s.png' % (fig_prefix, ind, fi, pol[pi]) if tag_output_iter: fig_name = output_path(fig_name, iteration=self.iteration) else: fig_name = output_path(fig_name) plt.savefig(fig_name) plt.close() # plot V0 plt.figure(figsize=(13, 5)) plt.subplot(121) plt.imshow(V0.real, aspect='equal', origin='lower', interpolation='nearest') plt.colorbar(shrink=1.0) plt.subplot(122) plt.imshow(V0.imag, aspect='equal', origin='lower', interpolation='nearest') plt.colorbar(shrink=1.0) fig_name = '%s_V0_%d_%d_%s.png' % (fig_prefix, ind, fi, pol[pi]) if tag_output_iter: fig_name = output_path(fig_name, iteration=self.iteration) else: fig_name = output_path(fig_name) plt.savefig(fig_name) plt.close() # plot S plt.figure(figsize=(13, 5)) plt.subplot(121) plt.imshow(S.real, aspect='equal', origin='lower', interpolation='nearest') plt.colorbar(shrink=1.0) plt.subplot(122) plt.imshow(S.imag, aspect='equal', origin='lower', interpolation='nearest') plt.colorbar(shrink=1.0) fig_name = '%s_S_%d_%d_%s.png' % (fig_prefix, ind, fi, pol[pi]) if tag_output_iter: fig_name = output_path(fig_name, iteration=self.iteration) else: fig_name = output_path(fig_name) plt.savefig(fig_name) plt.close() # plot N N = Vmat - V0 - S plt.figure(figsize=(13, 5)) plt.subplot(121) plt.imshow(N.real, aspect='equal', origin='lower', interpolation='nearest') plt.colorbar(shrink=1.0) plt.subplot(122) plt.imshow(N.imag, aspect='equal', origin='lower', interpolation='nearest') plt.colorbar(shrink=1.0) fig_name = '%s_N_%d_%d_%s.png' % (fig_prefix, ind, fi, pol[pi]) if tag_output_iter: fig_name = output_path(fig_name, iteration=self.iteration) else: fig_name = output_path(fig_name) plt.savefig(fig_name) plt.close() if apply_gain or save_gain: # if mpiutil.rank0: # np.save('Vmat',saveVmat) e, U = la.eigh(V0 / Sc[fi], eigvals=(nfeed - 1, nfeed - 1)) g = U[:, -1] * e[-1]**0.5 if g[0].real < 0: g *= -1.0 # make all g[0] phase 0, instead of pi lGain[ii] = g # plot Gain liplot = max(start_ind, transit_ind - 1) hiplot = min(end_ind, transit_ind + 1 + 1) if plot_figs and ti >= liplot and ti <= hiplot: # if ti >= liplot and ti <= hiplot: ind = ti - start_ind plt.figure() plt.plot(feedno, g.real, 'b-', label='real') plt.plot(feedno, g.real, 'bo') plt.plot(feedno, g.imag, 'g-', label='imag') plt.plot(feedno, g.imag, 'go') plt.plot(feedno, np.abs(g), 'r-', label='abs') plt.plot(feedno, np.abs(g), 'ro') plt.xlim(feedno[0] - 1, feedno[-1] + 1) yl, yh = plt.ylim() plt.ylim(yl, yh + (yh - yl) / 5) plt.xlabel('Feed number') plt.legend() fig_name = '%s_ants_%d_%d_%s.png' % (fig_prefix, ind, fi, pol[pi]) print('plot %s' % fig_name) if tag_output_iter: fig_name = output_path(fig_name, iteration=self.iteration) else: fig_name = output_path(fig_name) plt.savefig(fig_name) plt.close() # subtract the vis of calibrator from self.vis if subtract_src: nbl = len(bls) lv = np.zeros((lsrc_vis.shape[0], nbl), dtype=lsrc_vis.dtype) for bi, (fd1, fd2) in enumerate(bls): b1, b2 = feedno.index(fd1), feedno.index(fd2) lv[:, bi] = lsrc_vis[:, b1, b2] lv = mpiarray.MPIArray.wrap(lv, axis=0, comm=ts.comm) lv = lv.redistribute(axis=1).local_array.reshape(nt, nf, 2, -1) if 'ns_on' in ts.iterkeys(): lv[ts['ns_on'] [start_ind: end_ind]] = 0 # avoid ns_on signal to become nan ts.local_vis[start_ind:end_ind, :, pol.index('xx')] -= lv[:, :, 0] ts.local_vis[start_ind:end_ind, :, pol.index('yy')] -= lv[:, :, 1] del lv if not save_src_vis: if subtract_src: del lsrc_vis else: if tag_output_iter: src_vis_file = output_path(src_vis_file, iteration=self.iteration) else: src_vis_file = output_path(src_vis_file) # create file and allocate space first by rank0 if mpiutil.rank0: with h5py.File(src_vis_file, 'w') as f: # allocate space shp = (nt, nf, 2, nfeed, nfeed) f.create_dataset('sky_vis', shp, dtype=lsky_vis.dtype) f.create_dataset('src_vis', shp, dtype=lsrc_vis.dtype) f.create_dataset('outlier_vis', shp, dtype=lotl_vis.dtype) f.attrs['calibrator'] = calibrator f.attrs['dim'] = 'time, freq, pol, feed, feed' try: f.attrs['time'] = ts.time[start_ind:end_ind] except RuntimeError: f.create_dataset('time', data=ts.time[start_ind:end_ind]) f.attrs['time'] = '/time' f.attrs['freq'] = freq f.attrs['pol'] = np.array(['xx', 'yy']) f.attrs['feed'] = np.array(feedno) mpiutil.barrier() # write data to file for i in range(10): try: # NOTE: if write simultaneously, will loss data with processes distributed in several nodes for ri in xrange(mpiutil.size): if ri == mpiutil.rank: with h5py.File(src_vis_file, 'r+') as f: for ii, (ti, fi, pi) in enumerate(tfp_linds): ti_ = ti - start_ind pi_ = gain_pd[pol[pi]] f['sky_vis'][ti_, fi, pi_] = lsky_vis[ii] f['src_vis'][ti_, fi, pi_] = lsrc_vis[ii] f['outlier_vis'][ti_, fi, pi_] = lotl_vis[ii] mpiutil.barrier() break except IOError: time.sleep(0.5) continue else: raise RuntimeError('Could not open file: %s...' % src_vis_file) del lsrc_vis del lsky_vis del lotl_vis mpiutil.barrier() if apply_gain or save_gain: # flag outliers in lGain along each feed lG_abs = np.full_like(lGain, np.nan, dtype=lGain.real.dtype) for i in range(lGain.shape[0]): valid_inds = np.where(np.isfinite(lGain[i]))[0] if len(valid_inds) > 3: vabs = np.abs(lGain[i, valid_inds]) vmed = np.median(vabs) vabs_diff = np.abs(vabs - vmed) vmad = np.median(vabs_diff) / 0.6745 if reserve_high_gain: # reserve significantly higher ones, flag only significantly lower ones lG_abs[i, valid_inds] = np.where( vmed - vabs > 3.0 * vmad, np.nan, vabs) else: # flag both significantly higher and lower ones lG_abs[i, valid_inds] = np.where( vabs_diff > 3.0 * vmad, np.nan, vabs) # choose data slice near the transit time li = max(start_ind, transit_ind - 10) - start_ind hi = min(end_ind, transit_ind + 10 + 1) - start_ind # compute s_top for this time range n0 = np.zeros(((hi - li), 3)) for ti, jt in enumerate(ts.time[start_ind:end_ind][li:hi]): aa.set_jultime(jt) s.compute(aa) n0[ti] = s.get_crds('top', ncrd=3) if save_phs_change: n0t = np.zeros((nt, 3)) for ti, jt in enumerate(ts.time[start_ind:end_ind]): aa.set_jultime(jt) s.compute(aa) n0t[ti] = s.get_crds('top', ncrd=3) # get the positions of feeds feedpos = ts['feedpos'][:] # wrap and redistribute Gain and flagged G_abs Gain = mpiarray.MPIArray.wrap(lGain, axis=0, comm=ts.comm) Gain = Gain.redistribute(axis=1).reshape( nt, nf, 2, None).redistribute(axis=0).reshape( None, nf * 2 * nfeed).redistribute(axis=1) G_abs = mpiarray.MPIArray.wrap(lG_abs, axis=0, comm=ts.comm) G_abs = G_abs.redistribute(axis=1).reshape( nt, nf, 2, None).redistribute(axis=0).reshape( None, nf * 2 * nfeed).redistribute(axis=1) fpd_inds = list( itertools.product(range(nf), range(2), range(nfeed))) # only for xx and yy fpd_linds = mpiutil.mpilist(fpd_inds, method='con', comm=ts.comm) del fpd_inds # create data to save the solved gain for each feed lgain = np.full((len(fpd_linds), ), cnan, dtype=Gain.dtype) # gain for each feed if save_phs_change: lphs = np.full((nt, len(fpd_linds)), np.nan, dtype=Gain.real.dtype ) # phase change with time for each feed # check for conj num_conj = 0 for ii, (fi, pi, di) in enumerate(fpd_linds): y = G_abs.local_array[li:hi, ii] inds = np.where(np.isfinite(y))[0] if len(inds) >= max(4, 0.5 * len(y)): # get the approximate magnitude by averaging the central G_abs # solve phase by least square fit ui = (feedpos[di] - feedpos[0]) * ( 1.0e6 * freq[fi] ) / const.c # position of this feed (relative to the first feed) in unit of wavelength exp_factor = np.exp(2.0J * np.pi * np.dot(n0, ui)) ef = exp_factor Gi = Gain.local_array[li:hi, ii] e_phs = np.dot(ef[inds].conj(), Gi[inds] / y[inds]) / len(inds) ea = np.abs(e_phs) e_phs_conj = np.dot(ef[inds], Gi[inds] / y[inds]) / len(inds) eac = np.abs(e_phs_conj) if eac > ea: num_conj += 1 # reduce num_conj from all processes num_conj = mpiutil.allreduce(num_conj, comm=ts.comm) if num_conj > 0.5 * (nf * 2 * nfeed): # 2 for 2 pols if mpiutil.rank0: print '!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!' print '!!! Detect data should be their conjugate... !!!' print '!!! Correct it automatically... !!!' print '!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!' mpiutil.barrier() # correct vis ts.local_vis[:] = ts.local_vis.conj() # correct G Gain.local_array[:] = Gain.local_array.conj() # solve for gain # feedplotG = [] # savenum = 0 for ii, (fi, pi, di) in enumerate(fpd_linds): y = G_abs.local_array[li:hi, ii] inds = np.where(np.isfinite(y))[0] if len(inds) >= max(4, 0.5 * len(y)): # get the approximate magnitude by averaging the central G_abs mag = np.mean(y[inds]) # solve phase by least square fit ui = (feedpos[di] - feedpos[0]) * ( 1.0e6 * freq[fi] ) / const.c # position of this feed (relative to the first feed) in unit of wavelength exp_factor = np.exp(2.0J * np.pi * np.dot(n0, ui)) ef = exp_factor Gi = Gain.local_array[li:hi, ii] # feedplotG += [Gi[10]] # print(feedplotG) e_phs = np.dot(ef[inds].conj(), Gi[inds] / y[inds]) / len(inds) ea = np.abs(e_phs) if np.abs(ea - 1.0) < 0.1: # compute gain for this feed lgain[ii] = mag * e_phs if save_phs_change: lphs[:, ii] = np.angle( np.exp(-2.0J * np.pi * np.dot(n0t, ui)) * Gain.local_array[:, ii]) else: e_phs_conj = np.dot(ef[inds], Gi[inds] / y[inds]) / len(inds) eac = np.abs(e_phs_conj) if eac > ea: if np.abs(eac - 1.0) < 0.01: print 'feedno = %d, fi = %d, pol = %s: may need to be conjugated' % ( feedno[di], fi, gain_pd[pi]) else: print 'feedno = %d, fi = %d, pol = %s: maybe wrong abs(e_phs): %s' % ( feedno[di], fi, gain_pd[pi], ea) # import matplotlib.pyplot as plt # import os.path # feedplotG = np.array(feedplotG) # if mpiutil.rank0: # if os.path.isfile('feedplotG%d.npy'%savenum): # savenum += 1 # np.save('feedplotG%d'%savenum,feedplotG) # plt.plot(feedplotG) # plt.savefig('Gfig%d'%savenum) # gather local gain gain = mpiutil.gather_array(lgain, axis=0, root=None, comm=ts.comm) del lgain gain = gain.reshape(nf, 2, nfeed) if save_phs_change: phs = mpiutil.gather_array(lphs, axis=1, root=0, comm=ts.comm) del lphs if mpiutil.rank0: phs = phs.reshape(nt, nf, 2, nfeed) # apply gain to vis if apply_gain: for fi in range(nf): for pi in [pol.index('xx'), pol.index('yy')]: pi_ = gain_pd[pol[pi]] for bi, (fd1, fd2) in enumerate( ts['blorder'].local_data): g1 = gain[fi, pi_, feedno.index(fd1)] g2 = gain[fi, pi_, feedno.index(fd2)] if np.isfinite(g1) and np.isfinite(g2): ts.local_vis[:, fi, pi, bi] /= (g1 * np.conj(g2)) else: # mask the un-calibrated vis ts.local_vis_mask[:, fi, pi, bi] = True # save gain to file if save_gain: if tag_output_iter: gain_file = output_path(gain_file, iteration=self.iteration) else: gain_file = output_path(gain_file) if mpiutil.rank0: with h5py.File(gain_file, 'w') as f: # allocate space for Gain dset = f.create_dataset('Gain', (nt, nf, 2, nfeed), dtype=Gain.dtype) dset.attrs['calibrator'] = calibrator dset.attrs['dim'] = 'time, freq, pol, feed' try: dset.attrs['time'] = ts.time[start_ind:end_ind] except RuntimeError: f.create_dataset( 'time', data=ts.time[start_ind:end_ind]) dset.attrs['time'] = '/time' dset.attrs['freq'] = freq dset.attrs['pol'] = np.array(['xx', 'yy']) dset.attrs['feed'] = np.array(feedno) # save gain dset = f.create_dataset('gain', data=gain) dset.attrs['calibrator'] = calibrator dset.attrs['dim'] = 'freq, pol, feed' dset.attrs['freq'] = freq dset.attrs['pol'] = np.array(['xx', 'yy']) dset.attrs['feed'] = np.array(feedno) # save phs if save_phs_change: f.create_dataset('phs', data=phs) # save transit index and transit time for the case do the ps cal first f.attrs['transit_index'] = transit_inds[0] # f.attrs['transit_time'] = ts['jul_date'][transit_inds[0]] f.attrs['transit_jul'] = ts['jul_date'][ transit_inds[0]] f.attrs['transit_time'] = ts['sec1970'][ transit_inds[0]] # f.attrs['transit_time'] = ts.attrs['sec1970'] + ts.attrs['inttime']*transit_inds[0] # in sec1970 to improve numeric precision if os.path.exists(output_path( ts.ns_gain_file)) and not ts.ps_first: with h5py.File(output_path(ts.ns_gain_file), 'r+') as ns_file: phs_only = not ('ns_cal_amp' in ns_file.keys()) # exclude_bad = 'badchn' in ns_file['channo'].attrs.keys() new_gain = uni_gain(f, ns_file, phs_only=phs_only) ns_file.create_dataset('uni_gain', data=new_gain) ns_file['uni_gain'].attrs[ 'dim'] = '(time, freq, bl)' mpiutil.barrier() # save Gain for i in range(10): try: # NOTE: if write simultaneously, will loss data with processes distributed in several nodes for ri in xrange(mpiutil.size): if ri == mpiutil.rank: with h5py.File(gain_file, 'r+') as f: for ii, (ti, fi, pi) in enumerate(tfp_linds): ti_ = ti - start_ind pi_ = gain_pd[pol[pi]] f['Gain'][ti_, fi, pi_] = lGain[ii] mpiutil.barrier() break except IOError: time.sleep(0.5) continue else: raise RuntimeError('Could not open file: %s...' % gain_file) mpiutil.barrier() # convert vis from intensity unit to temperature unit in K if temperature_convert: if 'unit' in ts.vis.attrs.keys() and ts.vis.attrs['unit'] == 'K': if mpiutil.rank0: print 'vis is already in unit K, do nothing...' else: factor = 1.0e-26 * (const.c**2 / (2 * const.k_B * (1.0e6 * freq)**2) ) # NOTE: 1Jy = 1.0e-26 W m^-2 Hz^-1 ts.local_vis[:] *= factor[np.newaxis, :, np.newaxis, np.newaxis] ts.vis.attrs['unit'] = 'K' return super(PsCal, self).process(ts)