コード例 #1
0
def setup_dom_tables(
        dom_tables_kind,
        dom_tables_fname_proto,
        gcd,
        angsens_model,
        norm_version,
        use_sd_indices=const.ALL_STRS_DOMS,
        step_length=1.0,
        num_phi_samples=None,
        ckv_sigma_deg=None,
        template_library=None,
        compute_t_indep_exp=True,
        use_directionality=True,
        no_noise=False,
        force_no_mmap=False,
    ):
    """Instantiate and load single-DOM tables

    """
    print('Instantiating and loading single-DOM tables')
    t0 = time.time()

    # TODO: set mmap based on memory?
    if force_no_mmap:
        mmap = False
    else:
        mmap = 'uncompr' in dom_tables_kind

    if dom_tables_kind in ['raw_templ_compr', 'ckv_templ_compr']:
        template_library = np.load(expand(template_library))
    else:
        template_library = None

    gcd = extract_gcd(gcd)

    if no_noise:
        gcd['noise'] = np.zeros_like(gcd['noise'])

    # Instantiate single-DOM tables class
    dom_tables = Retro5DTables(
        table_kind=dom_tables_kind,
        geom=gcd['geo'],
        rde=gcd['rde'],
        noise_rate_hz=gcd['noise'],
        angsens_model=angsens_model,
        compute_t_indep_exp=compute_t_indep_exp,
        use_directionality=use_directionality,
        norm_version=norm_version,
        num_phi_samples=num_phi_samples,
        ckv_sigma_deg=ckv_sigma_deg,
        template_library=template_library,
        use_sd_indices=use_sd_indices
    )

    if '{subdet' in dom_tables_fname_proto:
        doms = const.ALL_DOMS
        for subdet in ['ic', 'dc']:
            if subdet == 'ic':
                strings = const.IC_STRS
            else:
                strings = const.DC_STRS

            for dom in doms:
                fpath = dom_tables_fname_proto.format(
                    subdet=subdet, dom=dom, depth_idx=dom-1
                )
                shared_table_sd_indices = []
                for string in strings:
                    sd_idx = const.get_sd_idx(string=string, dom=dom)
                    if sd_idx not in use_sd_indices:
                        continue
                    shared_table_sd_indices.append(sd_idx)

                if not shared_table_sd_indices:
                    continue

                dom_tables.load_table(
                    fpath=fpath,
                    sd_indices=shared_table_sd_indices,
                    step_length=step_length,
                    mmap=mmap
                )
    elif '{string}' in dom_tables_fname_proto:
        raise NotImplementedError('dom_tables_fname_proto with {string} not'
                                  ' implemented')
    elif '{string_idx}' in dom_tables_fname_proto:
        raise NotImplementedError('dom_tables_fname_proto with {string_idx}'
                                  ' not implemented')
    else:
        stacked_tables_fpath = expand(join(
            dom_tables_fname_proto,
            'stacked_{}.npy'.format(dom_tables.table_name)
        ))
        stacked_tables_meta_fpath = expand(join(
            dom_tables_fname_proto,
            'stacked_{}_meta.pkl'.format(dom_tables.table_name)
        ))
        stacked_t_indep_tables_fpath = expand(join(
            dom_tables_fname_proto,
            'stacked_{}.npy'.format(dom_tables.t_indep_table_name)
        ))
        dom_tables.load_stacked_tables(
            stacked_tables_meta_fpath=stacked_tables_meta_fpath,
            stacked_tables_fpath=stacked_tables_fpath,
            stacked_t_indep_tables_fpath=stacked_t_indep_tables_fpath,
            mmap_t_indep=mmap
        )

    print('  -> {:.3f} s\n'.format(time.time() - t0))

    return dom_tables
コード例 #2
0
def get_hits(event, path, angsens_model=None):
    """From an event, take either pulses or photons (optionally applying
    weights to the latter for angular sensitivity) and create the three
    structured numpy arrays necessary for Retro to process the information as
    "hits".

    Parameters
    ----------
    event

    path

    angsens_model : basestring, numpy.polynomial.Polynomial, or None
        If specified and photons are extracted, weights for the photons will be
        applied according to the angular sensitivity model specified.
        Otherwise, each photon will carry a weight of one.

    Returns
    -------
    hits : shape (n_hits,) array of dtype HIT_T
    hits_indexer : shape (n_hit_doms,) array of dtype SD_INDEXER_T
    hits_summary : shape (1,) array of dtype HITS_SUMMARY_T

    """
    photons = path[0] == 'photons'

    series = get_path(event, path)

    if photons:
        time_window_start = 0
        time_window_stop = 0
        if angsens_model is not None:
            if isinstance(angsens_model, basestring):
                angsens_poly, _ = load_angsens_model(angsens_model)
            elif isinstance(angsens_model, np.polynomial.Polynomial):
                angsens_poly = angsens_model
            else:
                raise TypeError('`angsens_model` is {} but must be either'
                                ' string or np.polynomial.Polynomial'
                                .format(type(angsens_model)))

    else:
        trigger_hierarchy = event['triggers']['I3TriggerHierarchy']
        time_window_start = np.inf
        time_window_stop = -np.inf
        for trigger in trigger_hierarchy:
            source = trigger['source']

            # Do not expand the in-ice window based on GLOBAL triggers (of
            # any TypeID)
            if source == SourceID.GLOBAL:
                continue

            tr_type = trigger['type']
            config_id = trigger['config_id']
            tr_time = trigger['time']

            # TODO: rework to _only_ use ConfigID
            trigger_handled = False
            if tr_type == TypeID.SIMPLE_MULTIPLICITY:
                if source == SourceID.IN_ICE:
                    if config_id == ConfigID.SMT8_IN_ICE:
                        trigger_handled = True
                        left_dt = -4e3
                        right_dt = 5e3 + 6e3
                    elif config_id == ConfigID.SMT3_DeepCore:
                        trigger_handled = True
                        left_dt = -4e3
                        right_dt = 2.5e3 + 6e3
            elif tr_type == TypeID.VOLUME:
                if source == SourceID.IN_ICE:
                    trigger_handled = True
                    left_dt = -4e3
                    right_dt = 1e3 + 6e3
            elif tr_type == TypeID.STRING:
                if source == SourceID.IN_ICE:
                    trigger_handled = True
                    left_dt = -4e3
                    right_dt = 1.5e3 + 6e3

            if not trigger_handled:
                raise NotImplementedError(
                    'Trigger TypeID {}, SourceID {}, config_id {} not'
                    ' implemented'
                    .format(TypeID(tr_type).name, # pylint: disable=no-member
                            SourceID(source).name, # pylint: disable=no-member
                            config_id)
                )

            time_window_start = min(time_window_start, tr_time + left_dt)
            time_window_stop = max(time_window_stop, tr_time + right_dt)

    hits = []
    hits_indexer = []
    offset = 0

    for (string, dom, pmt), p in series:
        sd_idx = const.get_sd_idx(string=string, dom=dom)
        num = len(p)
        sd_hits = np.empty(shape=num, dtype=HIT_T)
        sd_hits['time'] = p['time']
        if not photons:
            sd_hits['charge'] = p['charge']
        elif angsens_model:
            sd_hits['charge'] = angsens_poly(p['coszen'])
        else:
            sd_hits['charge'] = 1

        hits.append(sd_hits)
        hits_indexer.append((sd_idx, offset, num))
        offset += num

    hits = np.concatenate(hits) #, dtype=HIT_T)
    if hits.dtype != HIT_T:
        raise TypeError('got dtype {}'.format(hits.dtype))

    hits_indexer = np.array(hits_indexer, dtype=SD_INDEXER_T)

    hit_times = hits['time']
    hit_charges = hits['charge']
    total_charge = np.sum(hit_charges)

    earliest_hit_time = hit_times.min()
    latest_hit_time = hit_times.max()
    average_hit_time = np.sum(hit_times * hit_charges) / total_charge

    total_num_hits = len(hits)
    total_num_doms_hit = len(hits_indexer)

    hits_summary = np.array(
        (
            earliest_hit_time,
            latest_hit_time,
            average_hit_time,
            total_charge,
            total_num_hits,
            total_num_doms_hit,
            time_window_start,
            time_window_stop
        ),
        dtype=HITS_SUMMARY_T
    )

    return hits, hits_indexer, hits_summary
コード例 #3
0
def get_hits(event,
             path,
             hit_charge_quant,
             min_hit_charge,
             angsens_model=None):
    """From an event, take either pulses or photons (optionally applying
    weights to the latter for angular sensitivity) and create the three
    structured numpy arrays necessary for Retro to process the information as
    "hits".

    Parameters
    ----------
    event : mapping

    path : string

    hit_charge_quant : scalar >= 0
        quantize charge in steps of this size; 0 disables quantization

    min_hit_charge : scalar >= 0
        filter out pulses with charge less than this value; 0 disables minimum
        charge filtering

    angsens_model : str, numpy.polynomial.Polynomial, or None
        If specified and photons are extracted, weights for the photons will be
        applied according to the angular sensitivity model specified.
        Otherwise, each photon will carry a weight of one.

    Returns
    -------
    hits : shape (n_hits,) array of dtype HIT_T
    hits_indexer : shape (n_hit_doms,) array of dtype SD_INDEXER_T
    hits_summary : shape (1,) array of dtype HITS_SUMMARY_T

    """
    photons = path[0] == 'photons'

    series = get_path(event, path)

    if photons:
        time_window_start = 0.
        time_window_stop = 0.
        if angsens_model is not None:
            if isinstance(angsens_model, string_types):
                angsens_poly, _ = load_angsens_model(angsens_model)
            elif isinstance(angsens_model, np.polynomial.Polynomial):
                angsens_poly = angsens_model
            else:
                raise TypeError('`angsens_model` is {} but must be either'
                                ' string or np.polynomial.Polynomial'.format(
                                    type(angsens_model)))

    else:
        trigger_hierarchy = event['triggers']['I3TriggerHierarchy']
        time_window_start = np.inf
        time_window_stop = -np.inf
        for trigger in trigger_hierarchy:
            if 'key' in trigger.dtype.names:  # New (more correct) TRIGGER_T struct
                trigger_key = trigger['key']
                source = trigger_key['source']
                tr_type = trigger_key['type']
                config_id = trigger_key['config_id']
            else:  # old TRIGGER_T had triggerkey fields at same level as trigger
                source = trigger['source']
                tr_type = trigger['type']
                config_id = trigger['config_id']

            # Do not expand the in-ice window based on GLOBAL triggers (of
            # any TriggerTypeID)
            if source == TriggerSourceID.GLOBAL:
                continue

            tr_time = trigger['time']

            # TODO: rework to _only_ use TriggerConfigID?
            # Below values can be extracted by running
            # $I3_SRC/trigger-sim/resources/scripts/print_trigger_configuration.py -g GCDFILE
            trigger_handled = False
            if tr_type == TriggerTypeID.SIMPLE_MULTIPLICITY:
                if source == TriggerSourceID.IN_ICE:
                    if config_id == TriggerConfigID.SMT8_IN_ICE:
                        trigger_handled = True
                        left_dt = -4e3
                        right_dt = 5e3 + 6e3
                    elif config_id == TriggerConfigID.SMT3_DeepCore:
                        trigger_handled = True
                        left_dt = -4e3
                        right_dt = 2.5e3 + 6e3
            elif tr_type == TriggerTypeID.VOLUME:
                if source == TriggerSourceID.IN_ICE:
                    trigger_handled = True
                    left_dt = -4e3
                    right_dt = 1e3 + 6e3
            elif tr_type == TriggerTypeID.STRING:
                if source == TriggerSourceID.IN_ICE:
                    trigger_handled = True
                    left_dt = -4e3
                    right_dt = 1.5e3 + 6e3

            if not trigger_handled:
                raise NotImplementedError(
                    'Trigger TypeID {}, SourceID {}, config_id {} not'
                    ' implemented'.format(
                        TriggerTypeID(tr_type).name,  # pylint: disable=no-member
                        TriggerSourceID(source).name,  # pylint: disable=no-member
                        config_id))

            time_window_start = min(time_window_start, tr_time + left_dt)
            time_window_stop = max(time_window_stop, tr_time + right_dt)

    hits = []
    hits_indexer = []
    offset = 0

    for (string, dom, pmt), hits_ in series:
        # -- Filter the pulses -- #
        if hit_charge_quant > 0:
            hits_["charge"] = QUANTIZE_VEC(hits_["charge"], hit_charge_quant)
        if min_hit_charge > 0:
            hits_ = hits_[hits_["charge"] >= min_hit_charge]

        num = len(hits_)
        if num == 0:
            continue

        sd_idx = const.get_sd_idx(string=string, om=dom, pmt=pmt)
        sd_hits = np.empty(shape=num, dtype=HIT_T)
        sd_hits['time'] = hits_['time']
        if not photons:
            sd_hits['charge'] = hits_['charge']
        elif angsens_model:
            sd_hits['charge'] = angsens_poly(hits_['coszen'])
        else:
            sd_hits['charge'] = 1

        hits.append(sd_hits)
        hits_indexer.append((sd_idx, offset, num))
        offset += num

    if len(hits) == 0:
        hits = np.empty(shape=0, dtype=HIT_T)
        hits_indexer = np.empty(shape=0, dtype=SD_INDEXER_T)
        hits_summary = np.empty(shape=0, dtype=HITS_SUMMARY_T)
        return hits, hits_indexer, hits_summary

    hits = np.concatenate(hits)  #, dtype=HIT_T)
    if hits.dtype != HIT_T:
        raise TypeError('got dtype {}'.format(hits.dtype))

    hits_indexer = np.array(hits_indexer, dtype=SD_INDEXER_T)

    hit_times = hits['time']
    hit_charges = hits['charge']
    total_charge = np.sum(hit_charges)

    earliest_hit_time = hit_times.min()
    latest_hit_time = hit_times.max()
    average_hit_time = np.sum(hit_times * hit_charges) / total_charge

    num_hits = len(hits)
    num_doms_hit = len(hits_indexer)

    hits_summary = np.array(
        (
            earliest_hit_time,
            latest_hit_time,
            average_hit_time,
            total_charge,
            num_hits,
            num_doms_hit,
            time_window_start,
            time_window_stop,
        ),
        dtype=HITS_SUMMARY_T,
    )

    return hits, hits_indexer, hits_summary
コード例 #4
0
ファイル: plot_retro_dom_pdfs.py プロジェクト: icecube/retro
def plot_run_info(files,
                  labels,
                  outdir,
                  fwd_hists=None,
                  data_or_sim_label=None,
                  paired=False,
                  gradient=False,
                  plot=True):
    """Plot `files` using `labels` (one for each file).

    Parameters
    ----------
    files : string or iterable thereof
    labels : string or iterable thereof
    outdir : string
    fwd_hists : string, optional
    data_or_sim_label : string, optional

    """
    if isinstance(files, string_types):
        files = [files]
    if isinstance(labels, string_types):
        labels = [labels]

    outdir = expand(outdir)

    if fwd_hists is not None:
        fwd_hists = load_pickle(fwd_hists)
        if 'binning' in fwd_hists:
            t_min = fwd_hists['binning']['t_min']
            t_max = fwd_hists['binning']['t_max']
            t_window = t_max - t_min
            num_bins = fwd_hists['binning']['num_bins']
            spacing = fwd_hists['binning']['spacing']
            assert spacing == 'linear', spacing
            fwd_hists_binning = np.linspace(t_min, t_max, num_bins + 1)
        elif 'bin_edges' in fwd_hists:
            fwd_hists_binning = fwd_hists['bin_edges']
            t_window = np.max(fwd_hists_binning) - np.min(fwd_hists_binning)
        else:
            raise ValueError(
                'Need "binning" or "bin_edges" in fwd_hists; keys are {}'.
                format(fwd_hists.keys()))
        hist_bin_widths = np.diff(fwd_hists_binning)
        if 'results' in fwd_hists:
            fwd_hists = fwd_hists['results']
        else:
            raise ValueError('Could not find key "results" in fwd hists!')
    else:
        raise NotImplementedError('Need fwd hists for now.')

    if not isdir(outdir):
        makedirs(outdir)

    run_infos = []
    all_string_dom_pairs = set()
    mc_true_params = None
    for filepath in files:
        filepath = expand(filepath)
        if isdir(filepath):
            filepath = join(filepath, 'run_info.pkl')
        run_info = load_pickle(filepath)
        run_infos.append(run_info)
        pairs = []
        for sd_idx in run_info['sd_indices']:
            pairs.append(get_string_om_pair(sd_idx))
        all_string_dom_pairs.update(pairs)
        if data_or_sim_label is None:
            data_or_sim_label = (
                'Simulation: ' +
                run_info['sim_to_test'].replace('_', ' ').capitalize())

        if mc_true_params is None:
            if 'sim' in run_info:
                mc_true_params = run_info['sim']['mc_true_params']
            else:
                print('mc_true_params not in run_info', filepath)

    params_label = None
    if mc_true_params is not None:
        params_label = []
        for plab, pval in mc_true_params.items():
            units = ''

            if plab == 't':
                pval = format(int(pval), 'd')
                #plab = r'{}'.format(plab)
                units = r'\, \rm{ ns}'

            elif plab in 'x y z'.split():
                pval = format(pval, '0.1f')
                #plab = r'${}$'.format(plab)
                units = r'\, \rm{ m}'

            elif plab in 'track_energy cascade_energy'.split():
                pval = format(int(pval), 'd')
                plab = r'E_{\rm %s}' % plab.split('_')[0]
                units = r'\, \rm{ GeV}'

            elif plab in 'track_azimuth track_zenith cascade_azimuth cascade_zenith'.split(
            ):
                pval = format(pval / np.pi, '.2f')
                if 'azimuth' in plab:
                    ltr = r'\phi'
                elif 'zenith' in plab:
                    ltr = r'\theta'
                plab = ltr + r'_{\rm %s}' % plab.split('_')[0]
                units = r'\, \pi'

            params_label.append('{}={}{}'.format(plab, pval, units))
        params_label = '$' + r',\;'.join(params_label) + '$'

    if plot:
        fig, ax = plt.subplots(1, 1, figsize=(10, 8), dpi=72)

    t_indep_tots = []
    tots_incl_noise = []
    tots_excl_noise = []
    kss = []
    ref_tots_incl_noise = []
    ref_tots_excl_noise = []
    ref_areas_incl_noise = []
    for string, dom in reversed(sorted(all_string_dom_pairs)):
        if plot:
            ax.clear()
        all_zeros = True
        xmin = np.inf
        xmax = -np.inf
        ref_y = None
        if fwd_hists:
            if (string, dom) in fwd_hists:
                # Hit rate per nanosecond in each bin (includes noise hit rate)
                ref_y = fwd_hists[(string, dom)] / hist_bin_widths

                # Duplicate first element for plotting via `plt.step`
                ref_y = np.array([ref_y[0]] + ref_y.tolist())

                # Figure out "meaningful" range
                nonzero_mask = ref_y != 0  #~np.isclose(ref_y, 0)
                if np.any(nonzero_mask):
                    all_zeros = False
                    #ref_y_all_zeros = False
                    min_mask = (ref_y - ref_y.min()) >= 0.01 * (ref_y.max() -
                                                                ref_y.min())
                    xmin = min(xmin, fwd_hists_binning[min_mask].min())
                    xmax = max(xmax, fwd_hists_binning[min_mask].max())
            else:
                ref_y = np.zeros_like(fwd_hists_binning)

            ref_y_areas = ref_y[1:] * hist_bin_widths
            ref_y_area = np.sum(ref_y_areas)

            ref_tots_incl_noise.append(ref_y_area)

            # Following only works if our time window is large enough s.t. exp
            # hits from event is zero somewhere, and then it'll only be noise
            # contributing at that time...
            ref_tots_excl_noise.append(np.sum(ref_y_areas - ref_y_areas.min()))
            ref_areas_incl_noise.append(ref_y_area)

            if plot:
                ax.step(
                    fwd_hists_binning,
                    ref_y,
                    lw=1,
                    label=(r'Fwd: $\Sigma \lambda_q \Delta t$={}'.format(
                        num_fmt(ref_y_area))),
                    clip_on=True,
                    #color='C0'
                )

        colors = ['C%d' % i for i in range(1, 10)]
        linestyles = ['-', '--']
        linewidths = [5, 3, 2, 2, 2, 2, 2]

        for plt_i, (label, run_info) in enumerate(zip(labels, run_infos)):
            sample_hit_times = run_info['hit_times']
            if len(tots_incl_noise) <= plt_i:
                tots_incl_noise.append([])
                tots_excl_noise.append([])
                t_indep_tots.append([])
                kss.append([])

            results = run_info['results']
            if (string, dom) in pairs:
                rslt = results[get_sd_idx(string, dom)]
                if 'exp_p_at_hit_times' in rslt:
                    y = rslt['exp_p_at_hit_times']
                    y_ti = rslt['exp_p_at_all_times']
                    t_indep_tots[plt_i].append(y_ti)
                else:
                    y = rslt['pexp_at_hit_times']

                nonzero_mask = y != y[0]  #~np.isclose(y, 0)
                if np.any(nonzero_mask):
                    all_zeros = False
                    min_mask = y >= 0.01 * max(y)
                    xmin = min(xmin, sample_hit_times[min_mask].min())
                    xmax = max(xmax, sample_hit_times[min_mask].max())
            else:
                y = np.zeros_like(sample_hit_times)

            #y_area = np.sum(

            masked_y = np.ma.masked_invalid(y * hist_bin_widths)
            tot_excl_noise = np.sum(masked_y - masked_y.min())
            tot_incl_noise = masked_y.sum()
            if tot_excl_noise != 0:
                tots_excl_noise[plt_i].append(tot_excl_noise)
                tots_incl_noise[plt_i].append(tot_incl_noise)
            else:
                tots_excl_noise[plt_i].append(0)
                tots_incl_noise[plt_i].append(0)
            kss[plt_i].append(ks_test(y, ref_y[1:]))

            #kl_div = None
            custom_label = r'{:3s}: $\Sigma \lambda_q \Delta t$={}, ti={}'.format(
                label, num_fmt(tots_incl_noise[plt_i][-1]), num_fmt(y_ti))
            #if ref_y is not None: # and not ref_y_all_zeros:
            #    abs_mean_diff = np.abs(np.mean(y - ref_y[1:]))
            #    #rel_abs_mean_diff = abs_mean_diff / np.sum(ref_y[1:])

            #    mask = ref_y[1:] > 0
            #    kl_ref_vals = ref_y[1:][mask]
            #    kl_ref_vals /= np.sum(kl_ref_vals)

            #    y_prob_vals = y[mask]
            #    y_prob_vals /= np.sum(y_prob_vals)

            #    with np.errstate(divide='ignore'):
            #        kl_div = -np.sum(kl_ref_vals * np.log(y_prob_vals / kl_ref_vals))
            #    custom_label = format(rel_abs_mean_diff, '9.6f') + '  ' + label

            if paired:
                c_idx, ls_idx = divmod(plt_i, 2)
                color = colors[c_idx]
                linestyle = linestyles[ls_idx]
            else:
                color = None
                linestyle = None

            if plot:
                ax.plot(sample_hit_times,
                        y,
                        label=custom_label,
                        color=color,
                        linestyle=linestyle,
                        linewidth=linewidths[plt_i],
                        clip_on=True)

        if all_zeros:
            continue

        if xmin == xmax:
            xmin = np.min(fwd_hists_binning)
            xmax = np.max(fwd_hists_binning)

        if plot:
            ax.set_xlim(xmin, xmax)
            ax.set_ylim(0, ax.get_ylim()[1])

            for pos in 'bottom left top right'.split():
                ax.spines[pos].set_visible(False)

            ax.xaxis.set_ticks_position('none')
            ax.yaxis.set_ticks_position('none')

            ax.xaxis.tick_bottom()
            ax.yaxis.tick_left()

            #if kl_div is not None:
            #title = ' '*6 + 'Abs diff'.ljust(8) + '  ' + 'Simulation'
            #else:
            title = 'Code'

            leg = ax.legend(
                #title=title,
                #loc='best',
                loc='upper right',
                #frameon=False,
                framealpha=0.7,
                prop=dict(family='monospace', size=12))
            plt.setp(leg.get_title(), family='monospace', fontsize=12)
            #if kl_div is not None:
            #leg._legend_box.align = "left"
            leg.get_frame().set_linewidth(0)
            ax.set_xlabel('Time from event vertex (ns)', fontsize=14)

            if data_or_sim_label is not None:
                plt.text(0.5,
                         1.1,
                         data_or_sim_label,
                         ha='center',
                         va='bottom',
                         transform=ax.transAxes,
                         fontsize=16)
            if params_label is not None:
                plt.text(0.5,
                         1.05,
                         params_label,
                         ha='center',
                         va='bottom',
                         transform=ax.transAxes,
                         fontsize=12)

            ax.text(0.5,
                    1.0,
                    'String {}, DOM {}'.format(string, dom),
                    ha='center',
                    va='bottom',
                    transform=ax.transAxes,
                    fontsize=14)

            fbasename = 'string_{}_dom_{}'.format(string, dom)
            fig.savefig(join(outdir, fbasename + '.png'))
            sys.stdout.write('({}, {}) '.format(string, dom))
            sys.stdout.flush()
    sys.stdout.write('\n\n')
    sys.stdout.flush()

    ref_tots_incl_noise = np.array(ref_tots_incl_noise)
    ref_tots_excl_noise = np.array(ref_tots_excl_noise)
    ref_areas_incl_noise = np.array(ref_areas_incl_noise)

    ref_tot_incl_noise = np.sum(ref_tots_incl_noise)
    ref_tot_excl_noise = np.sum(ref_tots_excl_noise)
    ref_area_incl_noise = np.sum(ref_areas_incl_noise)

    print('{:9s}  {:9s}  {:16s}  {:16s}  {:16s}  {}'.format(
        'wtd KS'.rjust(9), 'avg KS'.rjust(9), 'Ratio incl noise'.rjust(16),
        'Ratio excl noise'.rjust(16), 't-indep ratio'.rjust(16), 'Label'))
    for label, ks, tot_incl_noise, tot_excl_noise, ti_tot in zip(
            labels, kss, tots_incl_noise, tots_excl_noise, t_indep_tots):
        ks = np.array(ks)
        mask = ~np.isnan(ks)
        ks_avg = np.mean(ks[mask])
        ks_wtd_avg = (np.sum(ks[mask] * ref_tots_excl_noise[mask]) /
                      np.sum(ref_tots_excl_noise[mask]))
        print('{:9s}  {:9s}  {:16s}  {:16s}  {:16s}  {}'.format(
            format(ks_wtd_avg, '.7f').rjust(9),
            format(ks_avg, '.7f').rjust(9),
            format(np.sum(tot_excl_noise) / ref_tot_excl_noise,
                   '.12f').rjust(16),
            format(np.sum(tot_incl_noise) / ref_tot_incl_noise,
                   '.12f').rjust(16),
            format(np.sum(ti_tot) / ref_area_incl_noise, '.12f').rjust(16),
            label))
コード例 #5
0
def setup_dom_tables(
    dom_tables_kind,
    dom_tables_fname_proto,
    gcd,
    norm_version='binvol2.5',
    use_sd_indices=const.ALL_STRS_DOMS,
    num_phi_samples=None,
    ckv_sigma_deg=None,
    template_library=None,
    compute_t_indep_exp=True,
    no_noise=False,
    force_no_mmap=False,
):
    """Instantiate and load single-DOM tables.

    Parameters
    ----------
    dom_tables_kind : str
    dom_tables_fname_proto : str
    gcd : str
    norm_version : str, optional
    use_sd_indices : sequence, optional
    num_phi_samples : int, optional
    ckv_sigma_deg : float, optional
    template_library : str, optional
    compute_t_indep_exp : bool, optional
    no_noise : bool, optional
    force_no_mmap : bool, optional

    Returns
    -------
    dom_tables : Retro5DTables

    """
    print('Instantiating and loading DOM tables')
    t0 = time.time()

    dom_tables_fname_proto = expand(dom_tables_fname_proto)

    # TODO: set mmap based on memory?
    if force_no_mmap:
        mmap = False
    else:
        mmap = 'uncompr' in dom_tables_kind

    if dom_tables_kind in ['raw_templ_compr', 'ckv_templ_compr']:
        template_library = np.load(expand(template_library))
    else:
        template_library = None

    gcd = extract_gcd(gcd)

    if no_noise:
        gcd['noise'] = np.zeros_like(gcd['noise'])

    # Instantiate single-DOM tables class
    dom_tables = Retro5DTables(
        table_kind=dom_tables_kind,
        geom=gcd['geo'],
        rde=gcd['rde'],
        noise_rate_hz=gcd['noise'],
        compute_t_indep_exp=compute_t_indep_exp,
        norm_version=norm_version,
        num_phi_samples=num_phi_samples,
        ckv_sigma_deg=ckv_sigma_deg,
        template_library=template_library,
        use_sd_indices=use_sd_indices,
    )

    if '{subdet' in dom_tables_fname_proto:
        doms = const.ALL_DOMS
        for subdet in ['ic', 'dc']:
            if subdet == 'ic':
                strings = const.IC_STRS
            else:
                strings = const.DC_STRS

            for dom in doms:
                fpath = dom_tables_fname_proto.format(subdet=subdet,
                                                      dom=dom,
                                                      depth_idx=dom - 1)

                shared_table_sd_indices = []
                for string in strings:
                    sd_idx = const.get_sd_idx(string=string, om=dom, pmt=0)
                    if sd_idx not in use_sd_indices:
                        continue
                    shared_table_sd_indices.append(sd_idx)

                if not shared_table_sd_indices:
                    continue

                dom_tables.load_table(
                    fpath=fpath,
                    sd_indices=shared_table_sd_indices,
                    mmap=mmap,
                )

    elif '{string}' in dom_tables_fname_proto:
        raise NotImplementedError('dom_tables_fname_proto with {string} not'
                                  ' implemented')

    elif '{string_idx}' in dom_tables_fname_proto:
        raise NotImplementedError('dom_tables_fname_proto with {string_idx}'
                                  ' not implemented')

    elif '{cluster_idx}' in dom_tables_fname_proto:
        cluster_idx = -1
        while True:
            cluster_idx += 1
            dpath = dom_tables_fname_proto.format(cluster_idx=cluster_idx)
            if not isdir(dpath):
                print(
                    'failed to find "{}" (this may inidicate that all existing '
                    "tables are loaded)\n".format(dpath))
                break

            # TODO: make the omkeys field generic to all tables & place
            # loading & intersection ops within the `load_table` method.
            omkeys = np.load(join(dpath, 'omkeys.npy'))
            sd_indices = set(const.omkeys_to_sd_indices(omkeys))
            shared_table_sd_indices = sd_indices.intersection(use_sd_indices)

            dom_tables.load_table(
                fpath=dpath,
                sd_indices=shared_table_sd_indices,
                mmap=mmap,
            )

    else:
        stacked_tables_fpath = expand(
            join(dom_tables_fname_proto,
                 'stacked_{}.npy'.format(dom_tables.table_name)))
        stacked_tables_meta_fpath = expand(
            join(dom_tables_fname_proto,
                 'stacked_{}_meta.pkl'.format(dom_tables.table_name)))
        stacked_t_indep_tables_fpath = expand(
            join(dom_tables_fname_proto,
                 'stacked_{}.npy'.format(dom_tables.t_indep_table_name)))
        dom_tables.load_stacked_tables(
            stacked_tables_meta_fpath=stacked_tables_meta_fpath,
            stacked_tables_fpath=stacked_tables_fpath,
            stacked_t_indep_tables_fpath=stacked_t_indep_tables_fpath,
            mmap_t_indep=mmap,
        )

    for table in dom_tables.tables:
        assert np.all(np.isfinite(table['weight'])), 'table not finite!'
        assert np.all(table['weight'] >= 0), 'table is negative!'
        assert np.min(table['index']) >= 0, 'table has negative index'
        if dom_tables.template_library is not None:
            assert np.max(table['index']) < dom_tables.template_library.shape[0], \
                    'table too large index'
    if dom_tables.template_library is not None:
        assert np.all(np.isfinite(
            dom_tables.template_library)), 'templates not finite!'
        assert np.all(dom_tables.template_library >= 0
                      ), 'templates have negative values!'

    print('  -> {:.3f} s\n'.format(time.time() - t0))

    return dom_tables