示例#1
0
    def _classify(self, obs_config, input_file):
        """
        Run the ML classifier

        :param dict obs_config: Observation config
        :param str input_file: HDF5 file to process
        :return: prefix of output figure path
        """

        output_prefix = "{output_dir}/triggers/ranked_CB{beam:02d}".format(
            **obs_config)

        # Synthesized beams
        if self.process_sb:
            sb_option = '--synthesized_beams'
        else:
            sb_option = ''

        # Galactic DM
        dmgal = util.get_ymw16(obs_config['parset'], obs_config['beam'],
                               self.logger)

        # Add optional classifier models (1D time and DM-time)
        model_option = ''
        if self.model_dmtime:
            model_option += ' --fn_model_dm {model_dir}/{model_dmtime}'.format(
                **obs_config)
        if self.model_1dtime:
            model_option += ' --fn_model_time {model_dir}/{model_1dtime}'.format(
                **obs_config)

        cmd = "export CUDA_VISIBLE_DEVICES={ml_gpus}; /home/{user}/python36/bin/python3 {classifier} " \
              " {model_option} {sb_option} " \
              " --pthresh {pthresh_freqtime} --save_ranked --plot_ranked --fnout={output_prefix} {input_file} " \
              " --pthresh_dm {pthresh_dmtime} --DMgal {dmgal} " \
              " {model_dir}/20190416freq_time.hdf5".format(user=os.getlogin(),
                                                           output_prefix=output_prefix, sb_option=sb_option,
                                                           model_option=model_option, dmgal=dmgal,
                                                           input_file=input_file, **obs_config)
        self.logger.info("Running {}".format(cmd))
        os.system(cmd)
        # ToDo: count number of output figures
        return output_prefix
示例#2
0
    def _get_overview(self, obs_config):
        """
        Generate observation overview file

        :param dict obs_config: Observation config
        """
        # For now replicate old command
        parset = obs_config['parset']
        info_file = os.path.join(obs_config['result_dir'], 'info.yaml')
        info = {}
        info['utc_start'] = parset['task.startTime']
        info['tobs'] = parset['task.duration']
        info['source'] = parset['task.source.name']
        # get YMW16 DM for central beam
        info['ymw16'] = "{:.2f}".format(util.get_ymw16(parset, logger=self.logger))
        info['telescopes'] = parset['task.telescopes'].replace('[', '').replace(']', '')
        info['taskid'] = parset['task.taskID']
        with open(info_file, 'w') as f:
            yaml.dump(info, f, default_flow_style=False)
示例#3
0
    def _generate_info_file(self):
        """
        Generate observation info files

        :return: info (dict), coordinates of each CB (dict)
        """
        # generate observation summary file
        fname = os.path.join(self.central_result_dir, 'info.yaml')
        # start with the observation parset
        parset = self.obs_config['parset']
        info = parset.copy()
        # format telescope list
        info['task.telescopes'] = info['task.telescopes'].replace('[',
                                                                  '').replace(
                                                                      ']', '')
        # Add central frequency
        info['freq'] = self.obs_config['freq']
        # Add YMW16 DM limit for CB00
        info['ymw16'] = util.get_ymw16(self.obs_config['parset'], 0,
                                       self.logger)
        # Add exact start time (startpacket)
        info['startpacket'] = self.obs_config['startpacket']
        # Add classifier probability thresholds
        with open(self.config_file, 'r') as f:
            classifier_config = yaml.load(
                f, Loader=yaml.SafeLoader)['processor']['classifier']
        info['classifier_threshold_freqtime'] = classifier_config[
            'thresh_freqtime']
        info['classifier_threshold_dmtime'] = classifier_config[
            'thresh_dmtime']
        # add path to website
        # get FQDN in way that actually adds the domain
        # simply socket.getfqdn does not actually do that on ARTS
        fqdn = socket.getaddrinfo(socket.gethostname(), None, 0,
                                  socket.SOCK_DGRAM, 0,
                                  socket.AI_CANONNAME)[0][3]
        info['web_link'] = 'http://{fqdn}/~{user}/darc/{webdir}/' \
                           '{date}/{datetimesource}'.format(fqdn=fqdn, user=os.getlogin(),
                                                            webdir=self.webdir, **self.obs_config)
        # save the file
        with open(fname, 'w') as f:
            yaml.dump(info, f, default_flow_style=False)

        # generate file with coordinates
        coordinates = {}
        for beam in self.obs_config['beams']:
            try:
                key = "task.beamSet.0.compoundBeam.{}.phaseCenter".format(beam)
                c1, c2 = ast.literal_eval(parset[key].replace('deg', ''))
                if parset['task.directionReferenceFrame'] == 'HADEC':
                    # get convert HADEC to J2000 RADEC at midpoint of observation
                    midpoint = Time(parset['task.startTime']) + .5 * float(
                        parset['task.duration']) * u.s
                    pointing = util.hadec_to_radec(c1 * u.deg, c2 * u.deg,
                                                   midpoint)
                else:
                    pointing = SkyCoord(c1, c2, unit=(u.deg, u.deg))
            except Exception as e:
                self.logger.error(
                    "Failed to get pointing for CB{:02d}: {}".format(beam, e))
                coordinates[beam] = ['-1', '-1', '-1', '-1']
            else:
                # get pretty strings
                ra = pointing.ra.to_string(unit=u.hourangle,
                                           sep=':',
                                           pad=True,
                                           precision=1)
                dec = pointing.dec.to_string(unit=u.deg,
                                             sep=':',
                                             pad=True,
                                             precision=1)
                gl, gb = pointing.galactic.to_string(precision=8).split(' ')
                coordinates[beam] = [ra, dec, gl, gb]

        # save to result dir
        with open(os.path.join(self.central_result_dir, 'coordinates.txt'),
                  'w') as f:
            f.write("#CB RA Dec Gl Gb\n")
            for beam, coord in coordinates.items():
                f.write("{:02d} {} {} {} {}\n".format(beam, *coord))

        return info, coordinates
示例#4
0
    def _visualize(self):
        """
        Run the visualization of candidates
        """
        ncand = len(self.files)
        self.logger.debug(f"Visualizing {ncand} candidates")

        # get max galactic DM
        dmgal = util.get_ymw16(self.obs_config['parset'],
                               self.obs_config['beam'], self.logger)
        # DMgal is zero if something failed, in that case set the value to infinity so no plots are marked, instead of
        # all
        if dmgal == 0:
            dmgal = np.inf

        # get plot order
        order = self._get_plot_order()
        # get the number of plot pages
        nplot_per_page = self.config.nplot_per_side**2
        npage = int(np.ceil(len(order) / nplot_per_page))
        # order files, then split per page
        try:
            files = self.files[order]
        except IndexError:
            self.logger.error("Failed to get plot order")
            return

        num_full_page, nplot_last_incomplete_page = divmod(
            len(files), nplot_per_page)
        files_split = []
        for page in range(num_full_page):
            files_split.append(files[page * nplot_per_page:(page + 1) *
                                     nplot_per_page])
        if nplot_last_incomplete_page != 0:
            files_split.append(files[-nplot_last_incomplete_page:])

        for page in range(npage):
            for plot_type in self.config.plot_types:
                # create figure
                fig, axes = plt.subplots(nrows=self.config.nplot_per_side,
                                         ncols=self.config.nplot_per_side,
                                         figsize=(self.config.figsize,
                                                  self.config.figsize))
                axes = axes.flatten()
                # loop over the files
                for i, fname in enumerate(files_split[page]):
                    # load the data and parameters
                    data, params = self._load_data(fname, plot_type)
                    try:
                        ntime = data.shape[1]
                    except IndexError:
                        ntime = len(data)
                    times = np.arange(-ntime / 2,
                                      ntime / 2) * params['tsamp'] * 1e3

                    ax = axes[i]
                    xlabel = 'Time (ms)'
                    if plot_type == 'freq_time':
                        nfreq = data.shape[0]
                        ylabel = 'Frequency (MHz)'
                        title = 'p:{prob_freqtime:.2f} DM:{dm:.2f} t:{toa:.2f}\n' \
                                'S/N:{snr:.2f} width:{downsamp} SB:{sb}'.format(**params)
                        freqs = np.linspace(
                            0,
                            BANDWIDTH.to(u.MHz).value,
                            nfreq) + self.obs_config['min_freq']
                        X, Y = np.meshgrid(times, freqs)
                        ax.pcolormesh(X,
                                      Y,
                                      data,
                                      cmap=self.config.cmap_freqtime,
                                      shading='nearest')
                        # Add DM 0 curve
                        delays = util.dm_to_delay(
                            params['dm'] * u.pc / u.cm**3, freqs[0] * u.MHz,
                            freqs * u.MHz).to(u.ms).value
                        ax.plot(times[0] + delays, freqs, c='r', alpha=.5)
                    elif plot_type == 'dm_time':
                        ylabel = r'DM (pc cm$^{-3}$)'
                        title = 'p:{prob_dmtime:.2f} DM:{dm:.2f} t:{toa:.2f}\n' \
                                'S/N:{snr:.2f} width:{downsamp} SB:{sb}'.format(**params)
                        X, Y = np.meshgrid(times, params['dms'])
                        ax.pcolormesh(X,
                                      Y,
                                      data,
                                      cmap=self.config.cmap_dmtime,
                                      shading='nearest')
                        # add line if DM 0 is in plot range
                        if min(params['dms']) <= 0 <= max(params['dms']):
                            ax.axhline(0, c='r', alpha=.5)
                    elif plot_type == '1d_time':
                        ylabel = 'Power (norm.)'
                        title = 'DM:{dm:.2f} t:{toa:.2f}\n' \
                                'S/N:{snr:.2f} width:{downsamp} SB:{sb}'.format(**params)
                        ax.plot(times, data, c=self.config.colour_1dtime)
                    else:
                        raise ProcessorException(
                            f"Unknown plot type: {plot_type}, should not be able to get here!"
                        )

                    # add plot title
                    ax.set_title(title)
                    # ylabel only the first column
                    if ax.is_first_col():
                        ax.set_ylabel(ylabel)
                    # xlabel only the last row. This is a bit tricky: on the last page, this is not necessarily
                    # the last possible row
                    if (page != npage - 1) and ax.is_last_row():
                        ax.set_xlabel(xlabel)
                    else:
                        # a plot is the bottom one in a column if the number of remaining plots is less than a full row
                        nplot_remaining = len(files_split[page]) - i - 1
                        if nplot_remaining < self.config.nplot_per_side:
                            ax.set_xlabel(xlabel)
                    ax.set_xlim(times[0], times[-1])
                    # add red border if DM > DMgal
                    if params['dm'] > dmgal:
                        plt.setp(ax.spines.values(),
                                 color='red',
                                 linewidth=2,
                                 alpha=0.85)

                    # on the last page, disable the remaining plots if there are any
                    if page == npage - 1:
                        remainder = nplot_per_page - nplot_last_incomplete_page
                        if remainder > 0:
                            for ax in axes[-remainder:]:
                                ax.axis('off')

                fig.set_tight_layout(True)
                # ensure the number of digits used for the page index is always the same, and large enough
                # then sorting works as expected
                page_str = str(page).zfill(len(str(npage)))
                fig_fname = os.path.join(self.output_dir,
                                         f'ranked_{plot_type}_{page_str}.pdf')
                fig.savefig(fig_fname)
        # merge the plots
        output_file = f"{self.output_dir}/CB{self.obs_config['beam']:02d}.pdf"
        merger = PdfFileMerger()
        for plot_type in self.config.plot_types:
            fnames = glob.glob(f'{self.output_dir}/*{plot_type}*.pdf')
            fnames.sort()
            for fname in fnames:
                merger.append(fname)
        merger.write(output_file)
        # copy the file to the central output directory
        self.logger.info(
            f"Saving plots to {self.result_dir}/{os.path.basename(output_file)}"
        )
        copy(output_file, self.result_dir)
示例#5
0
    def _process_triggers(self):
        """
        Read thresholds (DM, width, S/N) for clustering

        Continuously read AMBER triggers from queue and start processing for known and/or new sources
        """

        # set observation parameters
        utc_start = Time(self.obs_config['startpacket'] / TIME_UNIT, format='unix')
        datetimesource = self.obs_config['datetimesource']
        dt = TSAMP.to(u.second).value
        chan_width = (BANDWIDTH / float(NCHAN)).to(u.MHz).value
        cent_freq = (self.obs_config['min_freq'] * u.MHz + 0.5 * BANDWIDTH).to(u.GHz).value
        sys_params = {'dt': dt, 'delta_nu_MHz': chan_width, 'nu_GHz': cent_freq}
        pointing = self._get_pointing()
        dmgal = util.get_ymw16(self.obs_config['parset'], self.obs_config['beam'], self.logger)

        # get known source dm and type
        dm_src, src_type, src_name = self._get_source()
        if src_type is not None:
            thresh_src = {'dm_src': dm_src,
                          'src_type': src_type,
                          'src_name': src_name,
                          'dm_min': max(dm_src - self.dm_range, self.dm_min_global),
                          'dm_max': dm_src + self.dm_range,
                          'width_max': np.inf,
                          'snr_min': self.snr_min_global,
                          'pointing': pointing,
                          'dmgal': dmgal
                          }
            self.logger.info("Setting {src_name} trigger DM range to {dm_min} - {dm_max}, "
                             "max downsamp={width_max}, min S/N={snr_min}".format(**thresh_src))

        # set min and max DM for new sources with unknown DM
        thresh_new = {'src_type': None,
                      'src_name': None,
                      'dm_min': max(dmgal * self.thresh_iquv['dm_frac_min'], self.dm_min_global),
                      'dm_max': np.inf,
                      'width_max': self.thresh_iquv['width_max'],
                      'snr_min': self.thresh_iquv['snr_min'],
                      'pointing': pointing,
                      'dmgal': dmgal
                      }
        # if known source, check whether or not LOFAR triggering should be enabled for new sources
        if src_type is not None and src_name in self.lofar_trigger_sources:
            thresh_new['skip_lofar'] = not self.thresh_lofar['trigger_on_new_sources']
        else:
            thresh_new['skip_lofar'] = False

        self.logger.info("Setting new source trigger DM range to {dm_min} - {dm_max}, "
                         "max downsamp={width_max}, min S/N={snr_min}, skip LOFAR "
                         "triggering={skip_lofar}".format(**thresh_new))

        # main loop
        while self.observation_running:
            if self.amber_triggers:
                # Copy the triggers so class-wide list can receive new triggers without those getting lost
                with self.lock:
                    triggers = self.amber_triggers
                    self.amber_triggers = []
                # check for header (always, because it is received once for every amber instance)
                if not self.hdr_mapping:
                    for trigger in triggers:
                        if trigger.startswith('#'):
                            # read header, remove comment symbol
                            header = trigger.split()[1:]
                            self.logger.info("Received header: {}".format(header))
                            # Check if all required params are present and create mapping to col index
                            keys = ['beam_id', 'integration_step', 'time', 'DM', 'SNR']
                            for key in keys:
                                try:
                                    self.hdr_mapping[key] = header.index(key)
                                except ValueError:
                                    self.logger.error("Key missing from clusters header: {}".format(key))
                                    self.hdr_mapping = {}
                                    return

                # header should be present now
                if not self.hdr_mapping:
                    self.logger.error("First clusters received but header not found")
                    continue

                # remove headers from triggers (i.e. any trigger starting with #)
                triggers = [trigger for trigger in triggers if not trigger.startswith('#')]

                # triggers is empty if only header was received
                if not triggers:
                    self.logger.info("Only header received - Canceling processing")
                    continue

                # split strings and convert to numpy array
                try:
                    triggers = np.array(list(map(lambda val: val.split(), triggers)), dtype=float)
                except Exception as e:
                    self.logger.error("Failed to process triggers: {}".format(e))
                    continue

                # pick columns to feed to clustering algorithm
                triggers_for_clustering = triggers[:, (self.hdr_mapping['DM'], self.hdr_mapping['SNR'],
                                                       self.hdr_mapping['time'], self.hdr_mapping['integration_step'],
                                                       self.hdr_mapping['beam_id'])]

                # known source and new source triggering, in thread so clustering itself does not
                # delay next run
                # known source triggering
                if src_type is not None:
                    self.threads['trigger_known_source'] = threading.Thread(target=self._check_triggers,
                                                                            args=(triggers_for_clustering, sys_params,
                                                                                  utc_start, datetimesource),
                                                                            kwargs=thresh_src)
                    self.threads['trigger_known_source'].start()
                # new source triggering
                self.threads['trigger_new_source'] = threading.Thread(target=self._check_triggers,
                                                                      args=(triggers_for_clustering, sys_params,
                                                                            utc_start, datetimesource),
                                                                      kwargs=thresh_new)
                self.threads['trigger_new_source'].start()

            sleep(self.interval)
        self.logger.info("Observation finished")