Example #1
0
    def prepare_event(self, source, return_stub=False):

        for event in source:

            self.event_cutflow.count("noCuts")

            if self.event_cutflow.cut("min2Tels trig", len(event.dl0.tels_with_data)):
                if return_stub:
                    yield stub(event)
                else:
                    continue

            # calibrate the event
            self.calib.calibrate(event)

            # telescope loop
            tot_signal = 0
            max_signals = {}
            hillas_dict = {}
            n_tels = {"tot": len(event.dl0.tels_with_data),
                      "LST": 0, "MST": 0, "SST": 0}
            for tel_id in event.dl0.tels_with_data:
                self.image_cutflow.count("noCuts")

                camera = event.inst.subarray.tel[tel_id].camera

                # can this be improved?
                if tel_id not in tel_phi:
                    tel_phi[tel_id] = az_to_phi(event.mc.tel[tel_id].azimuth_raw * u.rad)
                    tel_theta[tel_id] = \
                        alt_to_theta(event.mc.tel[tel_id].altitude_raw*u.rad)

                    # the orientation of the camera (i.e. the pixel positions) needs to
                    # be corrected
                    camera.pix_x, camera.pix_y = \
                        transform_pixel_position(camera.pix_x, camera.pix_y)

                # count the current telescope according to its size
                tel_type = event.inst.subarray.tel[tel_id].optics.tel_type
                n_tels[tel_type] += 1

                # the camera image as a 1D array
                pmt_signal = event.dl1.tel[tel_id].image

                pmt_signal = self.pick_gain_channel(pmt_signal, camera.cam_id)

                # clean the image
                try:
                    with warnings.catch_warnings():
                        warnings.simplefilter("ignore")
                        pmt_signal, new_geom = \
                            self.cleaner.clean(pmt_signal.copy(), camera)

                    if self.image_cutflow.cut("min pixel", pmt_signal) or \
                       self.image_cutflow.cut("min charge", np.sum(pmt_signal)):
                        continue

                except FileNotFoundError as e:
                    print(e)
                    continue
                except EdgeEvent:
                    continue

                # could this go into `hillas_parameters` ...?
                max_signals[tel_id] = np.max(pmt_signal)

                # do the hillas reconstruction of the images
                try:
                    moments = self.hillas_parameters(new_geom, pmt_signal)

                    # import matplotlib.pyplot as plt
                    # from mpl_toolkits.mplot3d import Axes3D
                    # from ctapipe.reco.HillasReconstructor import guess_pix_direction
                    # # NOTE this is correct: +cos(psi) ; +sin(psi)
                    # p2_x = moments.cen_x + moments.length * np.cos(moments.psi)
                    # p2_y = moments.cen_y + moments.length * np.sin(moments.psi)
                    # foclen = \
                    #     event.inst.subarray.tel[tel_id].optics.effective_focal_length
                    #
                    # dir_c, dir_1, dir_2 = guess_pix_direction(
                    #     np.array([0, moments.cen_x / u.m, p2_x / u.m]) * u.m,
                    #     np.array([0, moments.cen_y / u.m, p2_y / u.m]) * u.m,
                    #     az_to_phi(event.mc.tel[tel_id].azimuth_raw * u.rad),
                    #     alt_to_theta(event.mc.tel[tel_id].altitude_raw*u.rad),
                    #     foclen)
                    #
                    # pos = event.inst.subarray.positions[tel_id]
                    # fig = plt.figure()
                    # ax = fig.gca(projection='3d')
                    # points_c = [pos + t * dir_c * u.m for t in np.linspace(0, 5000, 3)]
                    # points_1 = [pos + t * dir_1 * u.m for t in np.linspace(0, 5000, 3)]
                    # points_2 = [pos + t * dir_2 * u.m for t in np.linspace(0, 5000, 3)]
                    # ax.plot(*np.array(points_c).T, color="g", label="tel_dir")
                    # ax.plot(*np.array(points_1).T, color="r", label="centroid")
                    # ax.plot(*np.array(points_2).T, color="b", label="offset")
                    # # ax.set_aspect("equal")
                    # plt.xlim([-1000, 1000])
                    # plt.ylim([-1000, 1000])
                    # ax.set_zlim([0, 2000])
                    # # ax.invert_yaxis()
                    # plt.xlabel("x")
                    # plt.ylabel("y")
                    # plt.legend()
                    # plt.title(tel_id)

                    # # camera display
                    # from ctapipe.visualization import CameraDisplay
                    # fig = plt.figure()
                    # disp4 = CameraDisplay(new_geom, image=pmt_signal, ax=fig.gca())
                    # hw = moments
                    # plt.scatter([0], [0], color="white", marker="P")
                    # plt.scatter([hw.cen_x/u.m], [hw.cen_y/u.m], color="r")
                    # plt.scatter([(hw.cen_x+hw.length*np.cos(hw.psi))/u.m],
                    #             [(hw.cen_y+hw.length*np.sin(hw.psi))/u.m], color="b")
                    # plt.title(tel_id)

                    # plt.show()

                    # if width and/or length are zero (e.g. when there is only only one
                    # pixel or when all  pixel are exactly in one row), the
                    # parametrisation won't be very useful: skip
                    if self.image_cutflow.cut("poor moments", moments):
                        continue

                except hillas.HillasParameterizationError as e:
                    print(e)
                    print("ignoring this camera")
                    continue

                hillas_dict[tel_id] = moments
                tot_signal += moments.size

            n_tels["reco"] = len(hillas_dict)
            if self.event_cutflow.cut("min2Tels reco", n_tels["reco"]):
                if return_stub:
                    yield stub(event)
                else:
                    continue

            try:
                with warnings.catch_warnings():
                    warnings.simplefilter("ignore")
                    # telescope loop done, now do the core fit
                    self.shower_reco.get_great_circles(
                            hillas_dict, event.inst.subarray, tel_phi, tel_theta)
                    pos_fit, err_est_pos = self.shower_reco.fit_core_crosses()
                    dir_fit, err_est_dir = self.shower_reco.fit_origin_crosses()
                    h_max = self.shower_reco.fit_h_max(
                            hillas_dict, event.inst.subarray, tel_phi, tel_theta)
            except Exception as e:
                print(e)
                if return_stub:
                    yield stub(event)
                else:
                    continue

            if self.event_cutflow.cut("position nan", pos_fit) or \
               self.event_cutflow.cut("direction nan", dir_fit):
                if return_stub:
                    yield stub(event)
                else:
                    continue

            yield PreparedEvent(event=event, hillas_dict=hillas_dict, n_tels=n_tels,
                                tot_signal=tot_signal, max_signals=max_signals,
                                pos_fit=pos_fit, dir_fit=dir_fit, h_max=h_max,
                                err_est_pos=err_est_pos, err_est_dir=err_est_dir
                                )
def main():

    # your favourite units here
    energy_unit = u.TeV
    angle_unit = u.deg
    dist_unit = u.m

    parser = make_argparser()
    parser.add_argument('-o', '--outfile', type=str,
                        help="if given, write output file with reconstruction results")
    parser.add_argument('--plot_c', action='store_true',
                        help="plot camera-wise displays")
    group = parser.add_mutually_exclusive_group()
    group.add_argument('--proton', action='store_true',
                       help="do protons instead of gammas")
    group.add_argument('--electron', action='store_true',
                       help="do electrons instead of gammas")

    args = parser.parse_args()

    if args.infile_list:
        filenamelist = []
        for f in args.infile_list:
            filenamelist += glob("{}/{}".format(args.indir, f))
    elif args.proton:
        filenamelist = glob("{}/proton/*gz".format(args.indir))
        channel = "proton"
    elif args.electron:
        filenamelist = glob("{}/electron/*gz".format(args.indir))
        channel = "electron"
    elif args.gamma:
        filenamelist = glob("{}/gamma/*gz".format(args.indir))
        channel = "gamma"
    else:
        raise ValueError("don't know which input to use...")
    filenamelist.sort()

    if not filenamelist:
        print("no files found; check indir: {}".format(args.indir))
        exit(-1)
    else:
        print("found {} files".format(len(filenamelist)))

    tel_phi = {}
    tel_theta = {}

    # keeping track of events and where they were rejected
    Eventcutflow = CutFlow("EventCutFlow")
    Imagecutflow = CutFlow("ImageCutFlow")

    # takes care of image cleaning
    cleaner = ImageCleaner(mode=args.mode, cutflow=Imagecutflow,
                           wavelet_options=args.raw,
                           skip_edge_events=args.skip_edge_events, island_cleaning=True)

    # the class that does the shower reconstruction
    shower_reco = HillasReconstructor()

    shower_max_estimator = ShowerMaxEstimator("paranal")

    preper = EventPreparer(cleaner=cleaner,
                           hillas_parameters=hillas_parameters, shower_reco=shower_reco,
                           event_cutflow=Eventcutflow, image_cutflow=Imagecutflow,
                           # event/image cuts:
                           allowed_cam_ids=[],  # means: all
                           min_ntel=3,
                           min_charge=args.min_charge,
                           min_pixel=3)

    # a signal handler to abort the event loop but still do the post-processing
    signal_handler = SignalHandler()
    signal.signal(signal.SIGINT, signal_handler)

    try:
        # this class defines the reconstruction parameters to keep track of
        class RecoEvent(tb.IsDescription):
            NTels_trigg = tb.Int16Col(dflt=1, pos=0)
            NTels_clean = tb.Int16Col(dflt=1, pos=1)
            EnMC = tb.Float32Col(dflt=1, pos=2)
            xi = tb.Float32Col(dflt=1, pos=3)
            DeltaR = tb.Float32Col(dflt=1, pos=4)
            ErrEstPos = tb.Float32Col(dflt=1, pos=5)
            ErrEstDir = tb.Float32Col(dflt=1, pos=6)
            h_max = tb.Float32Col(dflt=1, pos=7)

        reco_outfile = tb.open_file(
            args.outfile, mode="w",
            # if we don't want to write the event list to disk, need to add more arguments
            **({} if args.store else {"driver": "H5FD_CORE",
                                      "driver_core_backing_store": False}))
        reco_table = reco_outfile.create_table("/", "reco_event", RecoEvent)
        reco_event = reco_table.row
    except:
        reco_event = RecoEvent()
        print("no pytables installed?")

    # ##        #######   #######  ########
    # ##       ##     ## ##     ## ##     ##
    # ##       ##     ## ##     ## ##     ##
    # ##       ##     ## ##     ## ########
    # ##       ##     ## ##     ## ##
    # ##       ##     ## ##     ## ##
    # ########  #######   #######  ##

    cam_id_map = {}

    # define here which telescopes to loop over
    allowed_tels = None
    # allowed_tels = prod3b_tel_ids("L+F+D")
    for i, filename in enumerate(filenamelist[:args.last]):
        print("file: {i} filename = {filename}".format(i=i, filename=filename))

        source = hessio_event_source(filename,
                                     allowed_tels=allowed_tels,
                                     max_events=args.max_events)

        # loop that cleans and parametrises the images and performs the reconstruction
        for (event, hillas_dict, n_tels,
             tot_signal, max_signal, pos_fit, dir_fit, h_max,
             err_est_pos, err_est_dir) in preper.prepare_event(source):

            shower = event.mc

            org_alt = u.Quantity(shower.alt).to(u.deg)
            org_az = u.Quantity(shower.az).to(u.deg)
            if org_az > 180 * u.deg:
                org_az -= 360 * u.deg

            org_the = alt_to_theta(org_alt)
            org_phi = az_to_phi(org_az)
            if org_phi > 180 * u.deg:
                org_phi -= 360 * u.deg
            if org_phi < -180 * u.deg:
                org_phi += 360 * u.deg

            shower_org = linalg.set_phi_theta(org_phi, org_the)
            shower_core = convert_astropy_array([shower.core_x, shower.core_y])

            xi = linalg.angle(dir_fit, shower_org).to(angle_unit)
            diff = linalg.length(pos_fit[:2] - shower_core)

            # print some performance
            print()
            print("xi = {:4.3f}".format(xi))
            print("pos = {:4.3f}".format(diff))
            print("h_max reco: {:4.3f}".format(h_max.to(u.km)))
            print("err_est_dir: {:4.3f}".format(err_est_dir.to(angle_unit)))
            print("err_est_pos: {:4.3f}".format(err_est_pos))

            try:
                # store the reconstruction data in the PyTable
                reco_event["NTels_trigg"] = n_tels["tot"]
                reco_event["NTels_clean"] = len(shower_reco.circles)
                reco_event["EnMC"] = event.mc.energy / energy_unit
                reco_event["xi"] = xi / angle_unit
                reco_event["DeltaR"] = diff / dist_unit
                reco_event["ErrEstPos"] = err_est_pos / dist_unit
                reco_event["ErrEstDir"] = err_est_dir / angle_unit
                reco_event["h_max"] = h_max / dist_unit
                reco_event.append()
                reco_table.flush()

                print()
                print("xi res (68-percentile) = {:4.3f} {}"
                      .format(np.percentile(reco_table.cols.xi, 68), angle_unit))
                print("core res (68-percentile) = {:4.3f} {}"
                      .format(np.percentile(reco_table.cols.DeltaR, 68), dist_unit))
                print("h_max (median) = {:4.3f} {}"
                      .format(np.percentile(reco_table.cols.h_max, 50), dist_unit))

            except NoPyTables:
                pass

            if args.plot_c:
                from mpl_toolkits.mplot3d import Axes3D
                fig = plt.figure()
                ax = fig.gca(projection='3d')
                for c in shower_reco.circles.values():
                    points = [c.pos + t * c.a * u.km for t in np.linspace(0, 15, 3)]
                    ax.plot(*np.array(points).T, linewidth=np.sqrt(c.weight) / 10)
                    ax.scatter(*c.pos[:, None].value, s=np.sqrt(c.weight))
                plt.xlabel("x")
                plt.ylabel("y")
                plt.pause(.1)

                # this plots
                # • the MC shower core
                # • the reconstructed shower core
                # • the used telescopes
                # • and the trace of the Hillas plane on the ground
                plt.figure()
                for tel_id, c in shower_reco.circles.items():
                    plt.scatter(c.pos[0], c.pos[1], s=np.sqrt(c.weight))
                    plt.gca().annotate(tel_id, (c.pos[0].value, c.pos[1].value))
                    plt.plot([c.pos[0].value-500*c.norm[1], c.pos[0].value+500*c.norm[1]],
                             [c.pos[1].value+500*c.norm[0], c.pos[1].value-500*c.norm[0]],
                             linewidth=np.sqrt(c.weight)/10)
                plt.scatter(*pos_fit[:2], c="black", marker="*", label="fitted")
                plt.scatter(*shower_core[:2], c="black", marker="P", label="MC")
                plt.legend()
                plt.xlabel("x")
                plt.ylabel("y")
                plt.xlim(-1400, 1400)
                plt.ylim(-1400, 1400)
                plt.show()

            if signal_handler.stop: break
        if signal_handler.stop: break

    print("\n" + "="*35 + "\n")
    print("xi res (68-percentile) = {:4.3f} {}"
          .format(np.percentile(reco_table.cols.xi, 68), angle_unit))
    print("core res (68-percentile) = {:4.3f} {}"
          .format(np.percentile(reco_table.cols.DeltaR, 68), dist_unit))
    print("h_max (median) = {:4.3f} {}"
          .format(np.percentile(reco_table.cols.h_max, 50), dist_unit))

    # print the cutflows for telescopes and camera images
    print("\n\n")
    Eventcutflow("min2Tels trig")
    print()
    Imagecutflow(sort_column=1)

    # if we don't want to plot anything, we can exit now
    if not args.plot:
        return

    # ########  ##        #######  ########  ######
    # ##     ## ##       ##     ##    ##    ##    ##
    # ##     ## ##       ##     ##    ##    ##
    # ########  ##       ##     ##    ##     ######
    # ##        ##       ##     ##    ##          ##
    # ##        ##       ##     ##    ##    ##    ##
    # ##        ########  #######     ##     ######

    plt.figure()
    plt.hist(reco_table.cols.h_max, bins=np.linspace(000, 15000, 51, True))
    plt.title(channel)
    plt.xlabel("h_max reco")
    plt.pause(.1)

    figure = plt.figure()
    xi_edges = np.linspace(0, 5, 20)
    plt.hist(reco_table.cols.xi, bins=xi_edges, log=True)
    plt.xlabel(r"$\xi$ / deg")
    if args.write:
        save_fig('{}/reco_xi_{}'.format(args.plots_dir, args.mode))
    plt.pause(.1)

    plt.figure()
    plt.hist(reco_table.cols.ErrEstDir[:],
             bins=np.linspace(0, 20, 50))
    plt.title(channel)
    plt.xlabel("beta")
    plt.pause(.1)

    plt.figure()
    plt.hist(np.log10(reco_table.cols.xi[:] / reco_table.cols.ErrEstDir[:]), bins=50)
    plt.title(channel)
    plt.xlabel("log_10(xi / beta)")
    plt.pause(.1)

    # convert the xi-list into a dict with the number of used telescopes as keys
    xi_vs_tel = {}
    for xi, ntel in zip(reco_table.cols.xi, reco_table.cols.NTels_clean):
        if ntel not in xi_vs_tel:
            xi_vs_tel[ntel] = [xi]
        else:
            xi_vs_tel[ntel].append(xi)

    print(args.mode)
    for ntel, xis in sorted(xi_vs_tel.items()):
        print("NTel: {} -- median xi: {}".format(ntel, np.median(xis)))
        # print("histogram:", np.histogram(xis, bins=xi_edges))

    # create a list of energy bin-edges and -centres for violin plots
    Energy_edges = np.linspace(2, 8, 13)
    Energy_centres = (Energy_edges[1:]+Energy_edges[:-1])/2.

    # convert the xi-list in to an energy-binned dict with the bin centre as keys
    xi_vs_energy = {}
    for en, xi in zip(reco_table.cols.EnMC, reco_table.cols.xi):

        # get the bin number this event belongs into
        sbin = np.digitize(np.log10(en), Energy_edges)-1

        # the central value of the bin is the key for the dictionary
        if Energy_centres[sbin] not in xi_vs_energy:
            xi_vs_energy[Energy_centres[sbin]]  = [xi]
        else:
            xi_vs_energy[Energy_centres[sbin]] += [xi]

    # plotting the angular error as violin plots with binning in
    # number of telescopes and shower energy
    figure = plt.figure()
    plt.subplot(211)
    plt.violinplot([np.log10(a) for a in xi_vs_tel.values()],
                   [a for a in xi_vs_tel.keys()],
                   points=60, widths=.75, showextrema=False, showmedians=True)
    plt.xlabel("Number of Telescopes")
    plt.ylabel(r"log($\xi$ / deg)")
    plt.ylim(-3, 2)
    plt.grid()

    plt.subplot(212)
    plt.violinplot([np.log10(a) for a in xi_vs_energy.values()],
                   [a for a in xi_vs_energy.keys()],
                   points=60, widths=(Energy_edges[1] - Energy_edges[0]) / 1.5,
                   showextrema=False, showmedians=True)
    plt.xlabel(r"log(Energy / GeV)")
    plt.ylabel(r"log($\xi$ / deg)")
    plt.ylim(-3, 2)
    plt.grid()

    plt.tight_layout()
    if args.write:
        save_fig('{}/reco_xi_vs_E_NTel_{}'.format(args.plots_dir, args.mode))

    plt.pause(.1)

    # convert the diffs-list into a dict with the number of used telescopes as keys
    diff_vs_tel = {}
    for diff, ntel in zip(reco_table.cols.DeltaR, reco_table.cols.NTels_clean):
        if ntel not in diff_vs_tel:
            diff_vs_tel[ntel] = [diff]
        else:
            diff_vs_tel[ntel].append(diff)

    # convert the diffs-list in to an energy-binned dict with the bin centre as keys
    diff_vs_energy = {}
    for en, diff in zip(reco_table.cols.EnMC, reco_table.cols.DeltaR):

        # get the bin number this event belongs into
        sbin = np.digitize(np.log10(en), Energy_edges) - 1

        # the central value of the bin is the key for the dictionary
        if Energy_centres[sbin] not in diff_vs_energy:
            diff_vs_energy[Energy_centres[sbin]]  = [diff]
        else:
            diff_vs_energy[Energy_centres[sbin]] += [diff]

    # plotting the core position error as violin plots with binning in
    # number of telescopes an shower energy
    plt.figure()
    plt.subplot(211)
    plt.violinplot([np.log10(a) for a in diff_vs_tel.values()],
                   [a for a in diff_vs_tel.keys()],
                   points=60, widths=.75, showextrema=False, showmedians=True)
    plt.xlabel("Number of Telescopes")
    plt.ylabel(r"log($\Delta R$ / m)")
    plt.grid()

    plt.subplot(212)
    plt.violinplot([np.log10(a) for a in diff_vs_energy.values()],
                   [a for a in diff_vs_energy.keys()],
                   points=60, widths=(Energy_edges[1] - Energy_edges[0]) / 1.5,
                   showextrema=False, showmedians=True)
    plt.xlabel(r"log(Energy / GeV)")
    plt.ylabel(r"log($\Delta R$ / m)")
    plt.grid()

    plt.tight_layout()
    if args.write:
        save_fig('{}/reco_dist_vs_E_NTel_{}'.format(args.plots_dir, args.mode))
    plt.show()
def main():

    # your favourite units here
    energy_unit = u.TeV
    angle_unit = u.deg
    dist_unit = u.m

    agree_threshold = .5
    min_tel = 3

    parser = make_argparser()
    parser.add_argument('--classifier',
                        type=str,
                        default=expandvars(
                            "$CTA_SOFT/tino_cta/data/classifier_pickle/"
                            "classifier_{mode}_{cam_id}_{classifier}.pkl"))
    parser.add_argument('--regressor',
                        type=str,
                        default=expandvars(
                            "$CTA_SOFT/tino_cta/data/classifier_pickle/"
                            "regressor_{mode}_{cam_id}_{regressor}.pkl"))
    parser.add_argument('-o',
                        '--outfile',
                        type=str,
                        default="",
                        help="location to write the classified events to.")
    parser.add_argument('--wave_dir',
                        type=str,
                        default=None,
                        help="directory where to find mr_filter. "
                        "if not set look in $PATH")
    parser.add_argument(
        '--wave_temp_dir',
        type=str,
        default='/dev/shm/',
        help="directory where mr_filter to store the temporary fits "
        "files")

    group = parser.add_mutually_exclusive_group()
    group.add_argument('--proton',
                       action='store_true',
                       help="do protons instead of gammas")
    group.add_argument('--electron',
                       action='store_true',
                       help="do electrons instead of gammas")

    args = parser.parse_args()

    if args.infile_list:
        filenamelist = []
        for f in args.infile_list:
            filenamelist += glob("{}/{}".format(args.indir, f))
        filenamelist.sort()
    elif args.proton:
        filenamelist = sorted(glob("{}/proton/*gz".format(args.indir)))
    elif args.electron:
        filenamelist = glob("{}/electron/*gz".format(args.indir))
        channel = "electron"
    else:
        filenamelist = sorted(glob("{}/gamma/*gz".format(args.indir)))

    if not filenamelist:
        print("no files found; check indir: {}".format(args.indir))
        exit(-1)

    # keeping track of events and where they were rejected
    Eventcutflow = CutFlow("EventCutFlow")
    Imagecutflow = CutFlow("ImageCutFlow")

    # takes care of image cleaning
    cleaner = ImageCleaner(mode=args.mode,
                           cutflow=Imagecutflow,
                           wavelet_options=args.raw,
                           tmp_files_directory=args.wave_temp_dir,
                           skip_edge_events=False,
                           island_cleaning=True)

    # the class that does the shower reconstruction
    shower_reco = HillasReconstructor()

    preper = EventPreparer(
        cleaner=cleaner,
        hillas_parameters=hillas_parameters,
        shower_reco=shower_reco,
        event_cutflow=Eventcutflow,
        image_cutflow=Imagecutflow,
        # event/image cuts:
        allowed_cam_ids=[],
        min_ntel=2,
        min_charge=args.min_charge,
        min_pixel=3)

    # wrapper for the scikit-learn classifier
    classifier = EventClassifier.load(args.classifier.format(
        **{
            "mode": args.mode,
            "wave_args": "mixed",
            "classifier": 'RandomForestClassifier',
            "cam_id": "{cam_id}"
        }),
                                      cam_id_list=args.cam_ids)

    # wrapper for the scikit-learn regressor
    regressor = EnergyRegressor.load(args.regressor.format(
        **{
            "mode": args.mode,
            "wave_args": "mixed",
            "regressor": "RandomForestRegressor",
            "cam_id": "{cam_id}"
        }),
                                     cam_id_list=args.cam_ids)

    ClassifierFeatures = namedtuple(
        "ClassifierFeatures",
        ("impact_dist", "sum_signal_evt", "max_signal_cam", "sum_signal_cam",
         "N_LST", "N_MST", "N_SST", "width", "length", "skewness", "kurtosis",
         "h_max", "err_est_pos", "err_est_dir"))

    EnergyFeatures = namedtuple(
        "EnergyFeatures",
        ("impact_dist", "sum_signal_evt", "max_signal_cam", "sum_signal_cam",
         "N_LST", "N_MST", "N_SST", "width", "length", "skewness", "kurtosis",
         "h_max", "err_est_pos", "err_est_dir"))

    # catch ctr-c signal to exit current loop and still display results
    signal_handler = SignalHandler()
    signal.signal(signal.SIGINT, signal_handler)

    # this class defines the reconstruction parameters to keep track of
    class RecoEvent(tb.IsDescription):
        Run_ID = tb.Int16Col(dflt=-1, pos=0)
        Event_ID = tb.Int16Col(dflt=-1, pos=1)
        NTels_trig = tb.Int16Col(dflt=0, pos=0)
        NTels_reco = tb.Int16Col(dflt=0, pos=1)
        NTels_reco_lst = tb.Int16Col(dflt=0, pos=2)
        NTels_reco_mst = tb.Int16Col(dflt=0, pos=3)
        NTels_reco_sst = tb.Int16Col(dflt=0, pos=4)
        MC_Energy = tb.Float32Col(dflt=np.nan, pos=5)
        reco_Energy = tb.Float32Col(dflt=np.nan, pos=6)
        reco_phi = tb.Float32Col(dflt=np.nan, pos=7)
        reco_theta = tb.Float32Col(dflt=np.nan, pos=8)
        off_angle = tb.Float32Col(dflt=np.nan, pos=9)
        xi = tb.Float32Col(dflt=np.nan, pos=10)
        DeltaR = tb.Float32Col(dflt=np.nan, pos=11)
        ErrEstPos = tb.Float32Col(dflt=np.nan, pos=12)
        ErrEstDir = tb.Float32Col(dflt=np.nan, pos=13)
        gammaness = tb.Float32Col(dflt=np.nan, pos=14)
        success = tb.BoolCol(dflt=False, pos=15)

    channel = "gamma" if "gamma" in " ".join(filenamelist) else "proton"
    reco_outfile = tb.open_file(
        mode="w",
        # if no outfile name is given (i.e. don't to write the event list to disk),
        # need specify two "driver" arguments
        **({
            "filename": args.outfile
        } if args.outfile else {
            "filename": "no_outfile.h5",
            "driver": "H5FD_CORE",
            "driver_core_backing_store": False
        }))

    reco_table = reco_outfile.create_table("/", "reco_events", RecoEvent)
    reco_event = reco_table.row

    allowed_tels = None  # all telescopes
    allowed_tels = prod3b_tel_ids("L+N+D")
    for i, filename in enumerate(filenamelist[:args.last]):
        # print(f"file: {i} filename = {filename}")

        source = hessio_event_source(filename,
                                     allowed_tels=allowed_tels,
                                     max_events=args.max_events)

        # loop that cleans and parametrises the images and performs the reconstruction
        for (event, hillas_dict, n_tels, tot_signal, max_signals, pos_fit,
             dir_fit, h_max, err_est_pos,
             err_est_dir) in preper.prepare_event(source, True):

            # now prepare the features for the classifier
            cls_features_evt = {}
            reg_features_evt = {}
            if hillas_dict is not None:
                for tel_id in hillas_dict.keys():
                    Imagecutflow.count("pre-features")

                    tel_pos = np.array(event.inst.tel_pos[tel_id][:2]) * u.m

                    moments = hillas_dict[tel_id]

                    impact_dist = linalg.length(tel_pos - pos_fit)
                    cls_features_tel = ClassifierFeatures(
                        impact_dist=impact_dist / u.m,
                        sum_signal_evt=tot_signal,
                        max_signal_cam=max_signals[tel_id],
                        sum_signal_cam=moments.size,
                        N_LST=n_tels["LST"],
                        N_MST=n_tels["MST"],
                        N_SST=n_tels["SST"],
                        width=moments.width / u.m,
                        length=moments.length / u.m,
                        skewness=moments.skewness,
                        kurtosis=moments.kurtosis,
                        h_max=h_max / u.m,
                        err_est_pos=err_est_pos / u.m,
                        err_est_dir=err_est_dir / u.deg)

                    reg_features_tel = EnergyFeatures(
                        impact_dist=impact_dist / u.m,
                        sum_signal_evt=tot_signal,
                        max_signal_cam=max_signals[tel_id],
                        sum_signal_cam=moments.size,
                        N_LST=n_tels["LST"],
                        N_MST=n_tels["MST"],
                        N_SST=n_tels["SST"],
                        width=moments.width / u.m,
                        length=moments.length / u.m,
                        skewness=moments.skewness,
                        kurtosis=moments.kurtosis,
                        h_max=h_max / u.m,
                        err_est_pos=err_est_pos / u.m,
                        err_est_dir=err_est_dir / u.deg)

                    if np.isnan(cls_features_tel).any() or np.isnan(
                            reg_features_tel).any():
                        continue

                    Imagecutflow.count("features nan")

                    cam_id = event.inst.subarray.tel[tel_id].camera.cam_id

                    try:
                        reg_features_evt[cam_id] += [reg_features_tel]
                        cls_features_evt[cam_id] += [cls_features_tel]
                    except KeyError:
                        reg_features_evt[cam_id] = [reg_features_tel]
                        cls_features_evt[cam_id] = [cls_features_tel]

            if cls_features_evt and reg_features_evt:

                predict_energ = regressor.predict_by_event([reg_features_evt
                                                            ])["mean"][0]
                predict_proba = classifier.predict_proba_by_event(
                    [cls_features_evt])
                gammaness = predict_proba[0, 0]

                try:
                    # the MC direction of origin of the simulated particle
                    shower = event.mc
                    shower_core = np.array(
                        [shower.core_x / u.m, shower.core_y / u.m]) * u.m
                    shower_org = linalg.set_phi_theta(az_to_phi(shower.az),
                                                      alt_to_theta(shower.alt))

                    # and how the reconstructed direction compares to that
                    xi = linalg.angle(dir_fit, shower_org)
                    DeltaR = linalg.length(pos_fit[:2] - shower_core)
                except Exception:
                    # naked exception catch, because I'm not sure where
                    # it would break in non-MC files
                    xi = np.nan
                    DeltaR = np.nan

                phi, theta = linalg.get_phi_theta(dir_fit)
                phi = (phi if phi > 0 else phi + 360 * u.deg)

                # TODO: replace with actual array pointing direction
                array_pointing = linalg.set_phi_theta(0 * u.deg, 20. * u.deg)
                # angular offset between the reconstructed direction and the array
                # pointing
                off_angle = linalg.angle(dir_fit, array_pointing)

                reco_event["NTels_trig"] = len(event.dl0.tels_with_data)
                reco_event["NTels_reco"] = len(hillas_dict)
                reco_event["NTels_reco_lst"] = n_tels["LST"]
                reco_event["NTels_reco_mst"] = n_tels["MST"]
                reco_event["NTels_reco_sst"] = n_tels["SST"]
                reco_event["reco_Energy"] = predict_energ.to(energy_unit).value
                reco_event["reco_phi"] = phi / angle_unit
                reco_event["reco_theta"] = theta / angle_unit
                reco_event["off_angle"] = off_angle / angle_unit
                reco_event["xi"] = xi / angle_unit
                reco_event["DeltaR"] = DeltaR / dist_unit
                reco_event["ErrEstPos"] = err_est_pos / dist_unit
                reco_event["ErrEstDir"] = err_est_dir / angle_unit
                reco_event["gammaness"] = gammaness
                reco_event["success"] = True
            else:
                reco_event["success"] = False

            # save basic event infos
            reco_event["MC_Energy"] = event.mc.energy.to(energy_unit).value
            reco_event["Event_ID"] = event.r1.event_id
            reco_event["Run_ID"] = event.r1.run_id

            reco_table.flush()
            reco_event.append()

            if signal_handler.stop:
                break
        if signal_handler.stop:
            break

    # make sure everything gets written out nicely
    reco_table.flush()

    try:
        print()
        Eventcutflow()
        print()
        Imagecutflow()

        # do some simple event selection
        # and print the corresponding selection efficiency
        N_selected = len([
            x for x in reco_table.where(
                """(NTels_reco > min_tel) & (gammaness > agree_threshold)""")
        ])
        N_total = len(reco_table)
        print("\nfraction selected events:")
        print("{} / {} = {} %".format(N_selected, N_total,
                                      N_selected / N_total * 100))

    except ZeroDivisionError:
        pass

    print("\nlength filenamelist:", len(filenamelist[:args.last]))

    # do some plotting if so desired
    if args.plot:
        gammaness = [x['gammaness'] for x in reco_table]
        NTels_rec = [x['NTels_reco'] for x in reco_table]
        NTel_bins = np.arange(np.min(NTels_rec), np.max(NTels_rec) + 2) - .5

        NTels_rec_lst = [x['NTels_reco_lst'] for x in reco_table]
        NTels_rec_mst = [x['NTels_reco_mst'] for x in reco_table]
        NTels_rec_sst = [x['NTels_reco_sst'] for x in reco_table]

        reco_energy = np.array([x['reco_Energy'] for x in reco_table])
        mc_energy = np.array([x['MC_Energy'] for x in reco_table])

        fig = plt.figure(figsize=(15, 5))
        plt.suptitle(" ** ".join(
            [args.mode, "protons" if args.proton else "gamma"]))
        plt.subplots_adjust(left=0.05, right=0.97, hspace=0.39, wspace=0.2)

        ax = plt.subplot(131)
        histo = np.histogram2d(NTels_rec,
                               gammaness,
                               bins=(NTel_bins, np.linspace(0, 1, 11)))[0].T
        histo_normed = histo / histo.max(axis=0)
        im = ax.imshow(
            histo_normed,
            interpolation='none',
            origin='lower',
            aspect='auto',
            # extent=(*NTel_bins[[0, -1]], 0, 1),
            cmap=plt.cm.inferno)
        ax.set_xlabel("NTels")
        ax.set_ylabel("drifted gammaness")
        plt.title("Total Number of Telescopes")

        # next subplot

        ax = plt.subplot(132)
        histo = np.histogram2d(NTels_rec_sst,
                               gammaness,
                               bins=(NTel_bins, np.linspace(0, 1, 11)))[0].T
        histo_normed = histo / histo.max(axis=0)
        im = ax.imshow(
            histo_normed,
            interpolation='none',
            origin='lower',
            aspect='auto',
            # extent=(*NTel_bins[[0, -1]], 0, 1),
            cmap=plt.cm.inferno)
        ax.set_xlabel("NTels")
        plt.setp(ax.get_yticklabels(), visible=False)
        plt.title("Number of SSTs")

        # next subplot

        ax = plt.subplot(133)
        histo = np.histogram2d(NTels_rec_mst,
                               gammaness,
                               bins=(NTel_bins, np.linspace(0, 1, 11)))[0].T
        histo_normed = histo / histo.max(axis=0)
        im = ax.imshow(
            histo_normed,
            interpolation='none',
            origin='lower',
            aspect='auto',
            # extent=(*NTel_bins[[0, -1]], 0, 1),
            cmap=plt.cm.inferno)
        cb = fig.colorbar(im, ax=ax)
        ax.set_xlabel("NTels")
        plt.setp(ax.get_yticklabels(), visible=False)
        plt.title("Number of MSTs")

        plt.subplots_adjust(wspace=0.05)

        # plot the energy migration matrix
        plt.figure()
        plt.hist2d(np.log10(reco_energy),
                   np.log10(mc_energy),
                   bins=20,
                   cmap=plt.cm.inferno)
        plt.xlabel("E_MC / TeV")
        plt.ylabel("E_rec / TeV")
        plt.colorbar()

        plt.show()
def main():

    # your favourite units here
    energy_unit = u.TeV
    angle_unit = u.deg
    dist_unit = u.m

    parser = make_argparser()
    parser.add_argument(
        '-o',
        '--outfile',
        type=str,
        help="if given, write output file with reconstruction results")
    parser.add_argument('--plot_c',
                        action='store_true',
                        help="plot camera-wise displays")
    group = parser.add_mutually_exclusive_group()
    group.add_argument('--proton',
                       action='store_true',
                       help="do protons instead of gammas")
    group.add_argument('--electron',
                       action='store_true',
                       help="do electrons instead of gammas")

    args = parser.parse_args()

    if args.infile_list:
        filenamelist = []
        for f in args.infile_list:
            filenamelist += glob("{}/{}".format(args.indir, f))
    elif args.proton:
        filenamelist = glob("{}/proton/*gz".format(args.indir))
        channel = "proton"
    elif args.electron:
        filenamelist = glob("{}/electron/*gz".format(args.indir))
        channel = "electron"
    elif args.gamma:
        filenamelist = glob("{}/gamma/*gz".format(args.indir))
        channel = "gamma"
    else:
        raise ValueError("don't know which input to use...")
    filenamelist.sort()

    if not filenamelist:
        print("no files found; check indir: {}".format(args.indir))
        exit(-1)
    else:
        print("found {} files".format(len(filenamelist)))

    tel_phi = {}
    tel_theta = {}

    # keeping track of events and where they were rejected
    Eventcutflow = CutFlow("EventCutFlow")
    Imagecutflow = CutFlow("ImageCutFlow")

    # takes care of image cleaning
    cleaner = ImageCleaner(mode=args.mode,
                           cutflow=Imagecutflow,
                           wavelet_options=args.raw,
                           skip_edge_events=args.skip_edge_events,
                           island_cleaning=True)

    # the class that does the shower reconstruction
    shower_reco = HillasReconstructor()

    shower_max_estimator = ShowerMaxEstimator("paranal")

    preper = EventPreparer(
        cleaner=cleaner,
        hillas_parameters=hillas_parameters,
        shower_reco=shower_reco,
        event_cutflow=Eventcutflow,
        image_cutflow=Imagecutflow,
        # event/image cuts:
        allowed_cam_ids=[],  # means: all
        min_ntel=3,
        min_charge=args.min_charge,
        min_pixel=3)

    # a signal handler to abort the event loop but still do the post-processing
    signal_handler = SignalHandler()
    signal.signal(signal.SIGINT, signal_handler)

    try:
        # this class defines the reconstruction parameters to keep track of
        class RecoEvent(tb.IsDescription):
            NTels_trigg = tb.Int16Col(dflt=1, pos=0)
            NTels_clean = tb.Int16Col(dflt=1, pos=1)
            EnMC = tb.Float32Col(dflt=1, pos=2)
            xi = tb.Float32Col(dflt=1, pos=3)
            DeltaR = tb.Float32Col(dflt=1, pos=4)
            ErrEstPos = tb.Float32Col(dflt=1, pos=5)
            ErrEstDir = tb.Float32Col(dflt=1, pos=6)
            h_max = tb.Float32Col(dflt=1, pos=7)

        reco_outfile = tb.open_file(
            args.outfile,
            mode="w",
            # if we don't want to write the event list to disk, need to add more arguments
            **({} if args.store else {
                "driver": "H5FD_CORE",
                "driver_core_backing_store": False
            }))
        reco_table = reco_outfile.create_table("/", "reco_event", RecoEvent)
        reco_event = reco_table.row
    except:
        reco_event = RecoEvent()
        print("no pytables installed?")

    # ##        #######   #######  ########
    # ##       ##     ## ##     ## ##     ##
    # ##       ##     ## ##     ## ##     ##
    # ##       ##     ## ##     ## ########
    # ##       ##     ## ##     ## ##
    # ##       ##     ## ##     ## ##
    # ########  #######   #######  ##

    cam_id_map = {}

    # define here which telescopes to loop over
    allowed_tels = None
    # allowed_tels = prod3b_tel_ids("L+F+D")
    for i, filename in enumerate(filenamelist[:args.last]):
        print("file: {i} filename = {filename}".format(i=i, filename=filename))

        source = hessio_event_source(filename,
                                     allowed_tels=allowed_tels,
                                     max_events=args.max_events)

        # loop that cleans and parametrises the images and performs the reconstruction
        for (event, hillas_dict, n_tels, tot_signal, max_signal, pos_fit,
             dir_fit, h_max, err_est_pos,
             err_est_dir) in preper.prepare_event(source):

            shower = event.mc

            org_alt = u.Quantity(shower.alt).to(u.deg)
            org_az = u.Quantity(shower.az).to(u.deg)
            if org_az > 180 * u.deg:
                org_az -= 360 * u.deg

            org_the = alt_to_theta(org_alt)
            org_phi = az_to_phi(org_az)
            if org_phi > 180 * u.deg:
                org_phi -= 360 * u.deg
            if org_phi < -180 * u.deg:
                org_phi += 360 * u.deg

            shower_org = linalg.set_phi_theta(org_phi, org_the)
            shower_core = convert_astropy_array([shower.core_x, shower.core_y])

            xi = linalg.angle(dir_fit, shower_org).to(angle_unit)
            diff = linalg.length(pos_fit[:2] - shower_core)

            # print some performance
            print()
            print("xi = {:4.3f}".format(xi))
            print("pos = {:4.3f}".format(diff))
            print("h_max reco: {:4.3f}".format(h_max.to(u.km)))
            print("err_est_dir: {:4.3f}".format(err_est_dir.to(angle_unit)))
            print("err_est_pos: {:4.3f}".format(err_est_pos))

            try:
                # store the reconstruction data in the PyTable
                reco_event["NTels_trigg"] = n_tels["tot"]
                reco_event["NTels_clean"] = len(shower_reco.circles)
                reco_event["EnMC"] = event.mc.energy / energy_unit
                reco_event["xi"] = xi / angle_unit
                reco_event["DeltaR"] = diff / dist_unit
                reco_event["ErrEstPos"] = err_est_pos / dist_unit
                reco_event["ErrEstDir"] = err_est_dir / angle_unit
                reco_event["h_max"] = h_max / dist_unit
                reco_event.append()
                reco_table.flush()

                print()
                print("xi res (68-percentile) = {:4.3f} {}".format(
                    np.percentile(reco_table.cols.xi, 68), angle_unit))
                print("core res (68-percentile) = {:4.3f} {}".format(
                    np.percentile(reco_table.cols.DeltaR, 68), dist_unit))
                print("h_max (median) = {:4.3f} {}".format(
                    np.percentile(reco_table.cols.h_max, 50), dist_unit))

            except NoPyTables:
                pass

            if args.plot_c:
                from mpl_toolkits.mplot3d import Axes3D
                fig = plt.figure()
                ax = fig.gca(projection='3d')
                for c in shower_reco.circles.values():
                    points = [
                        c.pos + t * c.a * u.km for t in np.linspace(0, 15, 3)
                    ]
                    ax.plot(*np.array(points).T,
                            linewidth=np.sqrt(c.weight) / 10)
                    ax.scatter(*c.pos[:, None].value, s=np.sqrt(c.weight))
                plt.xlabel("x")
                plt.ylabel("y")
                plt.pause(.1)

                # this plots
                # • the MC shower core
                # • the reconstructed shower core
                # • the used telescopes
                # • and the trace of the Hillas plane on the ground
                plt.figure()
                for tel_id, c in shower_reco.circles.items():
                    plt.scatter(c.pos[0], c.pos[1], s=np.sqrt(c.weight))
                    plt.gca().annotate(tel_id,
                                       (c.pos[0].value, c.pos[1].value))
                    plt.plot([
                        c.pos[0].value - 500 * c.norm[1],
                        c.pos[0].value + 500 * c.norm[1]
                    ], [
                        c.pos[1].value + 500 * c.norm[0],
                        c.pos[1].value - 500 * c.norm[0]
                    ],
                             linewidth=np.sqrt(c.weight) / 10)
                plt.scatter(*pos_fit[:2],
                            c="black",
                            marker="*",
                            label="fitted")
                plt.scatter(*shower_core[:2],
                            c="black",
                            marker="P",
                            label="MC")
                plt.legend()
                plt.xlabel("x")
                plt.ylabel("y")
                plt.xlim(-1400, 1400)
                plt.ylim(-1400, 1400)
                plt.show()

            if signal_handler.stop: break
        if signal_handler.stop: break

    print("\n" + "=" * 35 + "\n")
    print("xi res (68-percentile) = {:4.3f} {}".format(
        np.percentile(reco_table.cols.xi, 68), angle_unit))
    print("core res (68-percentile) = {:4.3f} {}".format(
        np.percentile(reco_table.cols.DeltaR, 68), dist_unit))
    print("h_max (median) = {:4.3f} {}".format(
        np.percentile(reco_table.cols.h_max, 50), dist_unit))

    # print the cutflows for telescopes and camera images
    print("\n\n")
    Eventcutflow("min2Tels trig")
    print()
    Imagecutflow(sort_column=1)

    # if we don't want to plot anything, we can exit now
    if not args.plot:
        return

    # ########  ##        #######  ########  ######
    # ##     ## ##       ##     ##    ##    ##    ##
    # ##     ## ##       ##     ##    ##    ##
    # ########  ##       ##     ##    ##     ######
    # ##        ##       ##     ##    ##          ##
    # ##        ##       ##     ##    ##    ##    ##
    # ##        ########  #######     ##     ######

    plt.figure()
    plt.hist(reco_table.cols.h_max, bins=np.linspace(000, 15000, 51, True))
    plt.title(channel)
    plt.xlabel("h_max reco")
    plt.pause(.1)

    figure = plt.figure()
    xi_edges = np.linspace(0, 5, 20)
    plt.hist(reco_table.cols.xi, bins=xi_edges, log=True)
    plt.xlabel(r"$\xi$ / deg")
    if args.write:
        save_fig('{}/reco_xi_{}'.format(args.plots_dir, args.mode))
    plt.pause(.1)

    plt.figure()
    plt.hist(reco_table.cols.ErrEstDir[:], bins=np.linspace(0, 20, 50))
    plt.title(channel)
    plt.xlabel("beta")
    plt.pause(.1)

    plt.figure()
    plt.hist(np.log10(reco_table.cols.xi[:] / reco_table.cols.ErrEstDir[:]),
             bins=50)
    plt.title(channel)
    plt.xlabel("log_10(xi / beta)")
    plt.pause(.1)

    # convert the xi-list into a dict with the number of used telescopes as keys
    xi_vs_tel = {}
    for xi, ntel in zip(reco_table.cols.xi, reco_table.cols.NTels_clean):
        if ntel not in xi_vs_tel:
            xi_vs_tel[ntel] = [xi]
        else:
            xi_vs_tel[ntel].append(xi)

    print(args.mode)
    for ntel, xis in sorted(xi_vs_tel.items()):
        print("NTel: {} -- median xi: {}".format(ntel, np.median(xis)))
        # print("histogram:", np.histogram(xis, bins=xi_edges))

    # create a list of energy bin-edges and -centres for violin plots
    Energy_edges = np.linspace(2, 8, 13)
    Energy_centres = (Energy_edges[1:] + Energy_edges[:-1]) / 2.

    # convert the xi-list in to an energy-binned dict with the bin centre as keys
    xi_vs_energy = {}
    for en, xi in zip(reco_table.cols.EnMC, reco_table.cols.xi):

        # get the bin number this event belongs into
        sbin = np.digitize(np.log10(en), Energy_edges) - 1

        # the central value of the bin is the key for the dictionary
        if Energy_centres[sbin] not in xi_vs_energy:
            xi_vs_energy[Energy_centres[sbin]] = [xi]
        else:
            xi_vs_energy[Energy_centres[sbin]] += [xi]

    # plotting the angular error as violin plots with binning in
    # number of telescopes and shower energy
    figure = plt.figure()
    plt.subplot(211)
    plt.violinplot([np.log10(a) for a in xi_vs_tel.values()],
                   [a for a in xi_vs_tel.keys()],
                   points=60,
                   widths=.75,
                   showextrema=False,
                   showmedians=True)
    plt.xlabel("Number of Telescopes")
    plt.ylabel(r"log($\xi$ / deg)")
    plt.ylim(-3, 2)
    plt.grid()

    plt.subplot(212)
    plt.violinplot([np.log10(a) for a in xi_vs_energy.values()],
                   [a for a in xi_vs_energy.keys()],
                   points=60,
                   widths=(Energy_edges[1] - Energy_edges[0]) / 1.5,
                   showextrema=False,
                   showmedians=True)
    plt.xlabel(r"log(Energy / GeV)")
    plt.ylabel(r"log($\xi$ / deg)")
    plt.ylim(-3, 2)
    plt.grid()

    plt.tight_layout()
    if args.write:
        save_fig('{}/reco_xi_vs_E_NTel_{}'.format(args.plots_dir, args.mode))

    plt.pause(.1)

    # convert the diffs-list into a dict with the number of used telescopes as keys
    diff_vs_tel = {}
    for diff, ntel in zip(reco_table.cols.DeltaR, reco_table.cols.NTels_clean):
        if ntel not in diff_vs_tel:
            diff_vs_tel[ntel] = [diff]
        else:
            diff_vs_tel[ntel].append(diff)

    # convert the diffs-list in to an energy-binned dict with the bin centre as keys
    diff_vs_energy = {}
    for en, diff in zip(reco_table.cols.EnMC, reco_table.cols.DeltaR):

        # get the bin number this event belongs into
        sbin = np.digitize(np.log10(en), Energy_edges) - 1

        # the central value of the bin is the key for the dictionary
        if Energy_centres[sbin] not in diff_vs_energy:
            diff_vs_energy[Energy_centres[sbin]] = [diff]
        else:
            diff_vs_energy[Energy_centres[sbin]] += [diff]

    # plotting the core position error as violin plots with binning in
    # number of telescopes an shower energy
    plt.figure()
    plt.subplot(211)
    plt.violinplot([np.log10(a) for a in diff_vs_tel.values()],
                   [a for a in diff_vs_tel.keys()],
                   points=60,
                   widths=.75,
                   showextrema=False,
                   showmedians=True)
    plt.xlabel("Number of Telescopes")
    plt.ylabel(r"log($\Delta R$ / m)")
    plt.grid()

    plt.subplot(212)
    plt.violinplot([np.log10(a) for a in diff_vs_energy.values()],
                   [a for a in diff_vs_energy.keys()],
                   points=60,
                   widths=(Energy_edges[1] - Energy_edges[0]) / 1.5,
                   showextrema=False,
                   showmedians=True)
    plt.xlabel(r"log(Energy / GeV)")
    plt.ylabel(r"log($\Delta R$ / m)")
    plt.grid()

    plt.tight_layout()
    if args.write:
        save_fig('{}/reco_dist_vs_E_NTel_{}'.format(args.plots_dir, args.mode))
    plt.show()
Example #5
0
    def prepare_event(self, source, return_stub=False):

        for event in source:

            self.event_cutflow.count("noCuts")

            if self.event_cutflow.cut("min2Tels trig",
                                      len(event.dl0.tels_with_data)):
                if return_stub:
                    yield stub(event)
                else:
                    continue

            # calibrate the event
            self.calib.calibrate(event)

            # telescope loop
            tot_signal = 0
            max_signals = {}
            hillas_dict = {}
            n_tels = {
                "tot": len(event.dl0.tels_with_data),
                "LST": 0,
                "MST": 0,
                "SST": 0
            }
            for tel_id in event.dl0.tels_with_data:
                self.image_cutflow.count("noCuts")

                camera = event.inst.subarray.tel[tel_id].camera

                # can this be improved?
                if tel_id not in tel_phi:
                    tel_phi[tel_id] = az_to_phi(
                        event.mc.tel[tel_id].azimuth_raw * u.rad)
                    tel_theta[tel_id] = \
                        alt_to_theta(event.mc.tel[tel_id].altitude_raw*u.rad)

                    # the orientation of the camera (i.e. the pixel positions) needs to
                    # be corrected
                    camera.pix_x, camera.pix_y = \
                        transform_pixel_position(camera.pix_x, camera.pix_y)

                # count the current telescope according to its size
                tel_type = event.inst.subarray.tel[tel_id].optics.tel_type
                n_tels[tel_type] += 1

                # the camera image as a 1D array
                pmt_signal = event.dl1.tel[tel_id].image

                pmt_signal = self.pick_gain_channel(pmt_signal, camera.cam_id)

                # clean the image
                try:
                    with warnings.catch_warnings():
                        warnings.simplefilter("ignore")
                        pmt_signal, new_geom = \
                            self.cleaner.clean(pmt_signal.copy(), camera)

                    if self.image_cutflow.cut("min pixel", pmt_signal) or \
                       self.image_cutflow.cut("min charge", np.sum(pmt_signal)):
                        continue

                except FileNotFoundError as e:
                    print(e)
                    continue
                except EdgeEvent:
                    continue

                # could this go into `hillas_parameters` ...?
                max_signals[tel_id] = np.max(pmt_signal)

                # do the hillas reconstruction of the images
                try:
                    moments = self.hillas_parameters(new_geom, pmt_signal)

                    # import matplotlib.pyplot as plt
                    # from mpl_toolkits.mplot3d import Axes3D
                    # from ctapipe.reco.HillasReconstructor import guess_pix_direction
                    # # NOTE this is correct: +cos(psi) ; +sin(psi)
                    # p2_x = moments.cen_x + moments.length * np.cos(moments.psi)
                    # p2_y = moments.cen_y + moments.length * np.sin(moments.psi)
                    # foclen = \
                    #     event.inst.subarray.tel[tel_id].optics.effective_focal_length
                    #
                    # dir_c, dir_1, dir_2 = guess_pix_direction(
                    #     np.array([0, moments.cen_x / u.m, p2_x / u.m]) * u.m,
                    #     np.array([0, moments.cen_y / u.m, p2_y / u.m]) * u.m,
                    #     az_to_phi(event.mc.tel[tel_id].azimuth_raw * u.rad),
                    #     alt_to_theta(event.mc.tel[tel_id].altitude_raw*u.rad),
                    #     foclen)
                    #
                    # pos = event.inst.subarray.positions[tel_id]
                    # fig = plt.figure()
                    # ax = fig.gca(projection='3d')
                    # points_c = [pos + t * dir_c * u.m for t in np.linspace(0, 5000, 3)]
                    # points_1 = [pos + t * dir_1 * u.m for t in np.linspace(0, 5000, 3)]
                    # points_2 = [pos + t * dir_2 * u.m for t in np.linspace(0, 5000, 3)]
                    # ax.plot(*np.array(points_c).T, color="g", label="tel_dir")
                    # ax.plot(*np.array(points_1).T, color="r", label="centroid")
                    # ax.plot(*np.array(points_2).T, color="b", label="offset")
                    # # ax.set_aspect("equal")
                    # plt.xlim([-1000, 1000])
                    # plt.ylim([-1000, 1000])
                    # ax.set_zlim([0, 2000])
                    # # ax.invert_yaxis()
                    # plt.xlabel("x")
                    # plt.ylabel("y")
                    # plt.legend()
                    # plt.title(tel_id)

                    # # camera display
                    # from ctapipe.visualization import CameraDisplay
                    # fig = plt.figure()
                    # disp4 = CameraDisplay(new_geom, image=pmt_signal, ax=fig.gca())
                    # hw = moments
                    # plt.scatter([0], [0], color="white", marker="P")
                    # plt.scatter([hw.cen_x/u.m], [hw.cen_y/u.m], color="r")
                    # plt.scatter([(hw.cen_x+hw.length*np.cos(hw.psi))/u.m],
                    #             [(hw.cen_y+hw.length*np.sin(hw.psi))/u.m], color="b")
                    # plt.title(tel_id)

                    # plt.show()

                    # if width and/or length are zero (e.g. when there is only only one
                    # pixel or when all  pixel are exactly in one row), the
                    # parametrisation won't be very useful: skip
                    if self.image_cutflow.cut("poor moments", moments):
                        continue

                except hillas.HillasParameterizationError as e:
                    print(e)
                    print("ignoring this camera")
                    continue

                hillas_dict[tel_id] = moments
                tot_signal += moments.size

            n_tels["reco"] = len(hillas_dict)
            if self.event_cutflow.cut("min2Tels reco", n_tels["reco"]):
                if return_stub:
                    yield stub(event)
                else:
                    continue

            try:
                with warnings.catch_warnings():
                    warnings.simplefilter("ignore")
                    # telescope loop done, now do the core fit
                    self.shower_reco.get_great_circles(hillas_dict,
                                                       event.inst.subarray,
                                                       tel_phi, tel_theta)
                    pos_fit, err_est_pos = self.shower_reco.fit_core_crosses()
                    dir_fit, err_est_dir = self.shower_reco.fit_origin_crosses(
                    )
                    h_max = self.shower_reco.fit_h_max(hillas_dict,
                                                       event.inst.subarray,
                                                       tel_phi, tel_theta)
            except Exception as e:
                print(e)
                if return_stub:
                    yield stub(event)
                else:
                    continue

            if self.event_cutflow.cut("position nan", pos_fit) or \
               self.event_cutflow.cut("direction nan", dir_fit):
                if return_stub:
                    yield stub(event)
                else:
                    continue

            yield PreparedEvent(event=event,
                                hillas_dict=hillas_dict,
                                n_tels=n_tels,
                                tot_signal=tot_signal,
                                max_signals=max_signals,
                                pos_fit=pos_fit,
                                dir_fit=dir_fit,
                                h_max=h_max,
                                err_est_pos=err_est_pos,
                                err_est_dir=err_est_dir)
def main():

    # your favourite units here
    energy_unit = u.TeV
    angle_unit = u.deg
    dist_unit = u.m

    agree_threshold = .5
    min_tel = 3

    parser = make_argparser()
    parser.add_argument('--classifier', type=str,
                        default=expandvars("$CTA_SOFT/tino_cta/data/classifier_pickle/"
                                           "classifier_{mode}_{cam_id}_{classifier}.pkl"))
    parser.add_argument('--regressor', type=str,
                        default=expandvars("$CTA_SOFT/tino_cta/data/classifier_pickle/"
                                           "regressor_{mode}_{cam_id}_{regressor}.pkl"))
    parser.add_argument('-o', '--outfile', type=str, default="",
                        help="location to write the classified events to.")
    parser.add_argument('--wave_dir', type=str, default=None,
                        help="directory where to find mr_filter. "
                             "if not set look in $PATH")
    parser.add_argument('--wave_temp_dir', type=str, default='/dev/shm/',
                        help="directory where mr_filter to store the temporary fits "
                             "files")

    group = parser.add_mutually_exclusive_group()
    group.add_argument('--proton', action='store_true',
                       help="do protons instead of gammas")
    group.add_argument('--electron', action='store_true',
                       help="do electrons instead of gammas")

    args = parser.parse_args()

    if args.infile_list:
        filenamelist = []
        for f in args.infile_list:
            filenamelist += glob("{}/{}".format(args.indir, f))
        filenamelist.sort()
    elif args.proton:
        filenamelist = sorted(glob("{}/proton/*gz".format(args.indir)))
    elif args.electron:
        filenamelist = glob("{}/electron/*gz".format(args.indir))
        channel = "electron"
    else:
        filenamelist = sorted(glob("{}/gamma/*gz".format(args.indir)))

    if not filenamelist:
        print("no files found; check indir: {}".format(args.indir))
        exit(-1)

    # keeping track of events and where they were rejected
    Eventcutflow = CutFlow("EventCutFlow")
    Imagecutflow = CutFlow("ImageCutFlow")

    # takes care of image cleaning
    cleaner = ImageCleaner(mode=args.mode, cutflow=Imagecutflow,
                           wavelet_options=args.raw,
                           tmp_files_directory=args.wave_temp_dir,
                           skip_edge_events=False, island_cleaning=True)

    # the class that does the shower reconstruction
    shower_reco = HillasReconstructor()

    preper = EventPreparer(
        cleaner=cleaner, hillas_parameters=hillas_parameters,
        shower_reco=shower_reco,
        event_cutflow=Eventcutflow, image_cutflow=Imagecutflow,
        # event/image cuts:
        allowed_cam_ids=[],
        min_ntel=2, min_charge=args.min_charge, min_pixel=3)

    # wrapper for the scikit-learn classifier
    classifier = EventClassifier.load(
        args.classifier.format(**{
            "mode": args.mode,
            "wave_args": "mixed",
            "classifier": 'RandomForestClassifier',
            "cam_id": "{cam_id}"}),
        cam_id_list=args.cam_ids)

    # wrapper for the scikit-learn regressor
    regressor = EnergyRegressor.load(
        args.regressor.format(**{
            "mode": args.mode,
            "wave_args": "mixed",
            "regressor": "RandomForestRegressor",
            "cam_id": "{cam_id}"}),
        cam_id_list=args.cam_ids)

    ClassifierFeatures = namedtuple(
        "ClassifierFeatures", (
            "impact_dist",
            "sum_signal_evt",
            "max_signal_cam",
            "sum_signal_cam",
            "N_LST",
            "N_MST",
            "N_SST",
            "width",
            "length",
            "skewness",
            "kurtosis",
            "h_max",
            "err_est_pos",
            "err_est_dir"))

    EnergyFeatures = namedtuple(
        "EnergyFeatures", (
            "impact_dist",
            "sum_signal_evt",
            "max_signal_cam",
            "sum_signal_cam",
            "N_LST",
            "N_MST",
            "N_SST",
            "width",
            "length",
            "skewness",
            "kurtosis",
            "h_max",
            "err_est_pos",
            "err_est_dir"))

    # catch ctr-c signal to exit current loop and still display results
    signal_handler = SignalHandler()
    signal.signal(signal.SIGINT, signal_handler)

    # this class defines the reconstruction parameters to keep track of
    class RecoEvent(tb.IsDescription):
        Run_ID = tb.Int16Col(dflt=-1, pos=0)
        Event_ID = tb.Int16Col(dflt=-1, pos=1)
        NTels_trig = tb.Int16Col(dflt=0, pos=0)
        NTels_reco = tb.Int16Col(dflt=0, pos=1)
        NTels_reco_lst = tb.Int16Col(dflt=0, pos=2)
        NTels_reco_mst = tb.Int16Col(dflt=0, pos=3)
        NTels_reco_sst = tb.Int16Col(dflt=0, pos=4)
        MC_Energy = tb.Float32Col(dflt=np.nan, pos=5)
        reco_Energy = tb.Float32Col(dflt=np.nan, pos=6)
        reco_phi = tb.Float32Col(dflt=np.nan, pos=7)
        reco_theta = tb.Float32Col(dflt=np.nan, pos=8)
        off_angle = tb.Float32Col(dflt=np.nan, pos=9)
        xi = tb.Float32Col(dflt=np.nan, pos=10)
        DeltaR = tb.Float32Col(dflt=np.nan, pos=11)
        ErrEstPos = tb.Float32Col(dflt=np.nan, pos=12)
        ErrEstDir = tb.Float32Col(dflt=np.nan, pos=13)
        gammaness = tb.Float32Col(dflt=np.nan, pos=14)
        success = tb.BoolCol(dflt=False, pos=15)

    channel = "gamma" if "gamma" in " ".join(filenamelist) else "proton"
    reco_outfile = tb.open_file(
        mode="w",
        # if no outfile name is given (i.e. don't to write the event list to disk),
        # need specify two "driver" arguments
        **({"filename": args.outfile} if args.outfile else
           {"filename": "no_outfile.h5",
            "driver": "H5FD_CORE", "driver_core_backing_store": False}))

    reco_table = reco_outfile.create_table("/", "reco_events", RecoEvent)
    reco_event = reco_table.row

    allowed_tels = None  # all telescopes
    allowed_tels = prod3b_tel_ids("L+N+D")
    for i, filename in enumerate(filenamelist[:args.last]):
        # print(f"file: {i} filename = {filename}")

        source = hessio_event_source(filename,
                                     allowed_tels=allowed_tels,
                                     max_events=args.max_events)

        # loop that cleans and parametrises the images and performs the reconstruction
        for (event, hillas_dict, n_tels,
             tot_signal, max_signals, pos_fit, dir_fit, h_max,
             err_est_pos, err_est_dir) in preper.prepare_event(source, True):

            # now prepare the features for the classifier
            cls_features_evt = {}
            reg_features_evt = {}
            if hillas_dict is not None:
              for tel_id in hillas_dict.keys():
                Imagecutflow.count("pre-features")

                tel_pos = np.array(event.inst.tel_pos[tel_id][:2]) * u.m

                moments = hillas_dict[tel_id]

                impact_dist = linalg.length(tel_pos - pos_fit)
                cls_features_tel = ClassifierFeatures(
                    impact_dist=impact_dist / u.m,
                    sum_signal_evt=tot_signal,
                    max_signal_cam=max_signals[tel_id],
                    sum_signal_cam=moments.size,
                    N_LST=n_tels["LST"],
                    N_MST=n_tels["MST"],
                    N_SST=n_tels["SST"],
                    width=moments.width / u.m,
                    length=moments.length / u.m,
                    skewness=moments.skewness,
                    kurtosis=moments.kurtosis,
                    h_max=h_max / u.m,
                    err_est_pos=err_est_pos / u.m,
                    err_est_dir=err_est_dir / u.deg
                )

                reg_features_tel = EnergyFeatures(
                    impact_dist=impact_dist / u.m,
                    sum_signal_evt=tot_signal,
                    max_signal_cam=max_signals[tel_id],
                    sum_signal_cam=moments.size,
                    N_LST=n_tels["LST"],
                    N_MST=n_tels["MST"],
                    N_SST=n_tels["SST"],
                    width=moments.width / u.m,
                    length=moments.length / u.m,
                    skewness=moments.skewness,
                    kurtosis=moments.kurtosis,
                    h_max=h_max / u.m,
                    err_est_pos=err_est_pos / u.m,
                    err_est_dir=err_est_dir / u.deg
                )

                if np.isnan(cls_features_tel).any() or np.isnan(reg_features_tel).any():
                    continue

                Imagecutflow.count("features nan")

                cam_id = event.inst.subarray.tel[tel_id].camera.cam_id

                try:
                    reg_features_evt[cam_id] += [reg_features_tel]
                    cls_features_evt[cam_id] += [cls_features_tel]
                except KeyError:
                    reg_features_evt[cam_id] = [reg_features_tel]
                    cls_features_evt[cam_id] = [cls_features_tel]

            if cls_features_evt and reg_features_evt:

                predict_energ = regressor.predict_by_event([reg_features_evt])["mean"][0]
                predict_proba = classifier.predict_proba_by_event([cls_features_evt])
                gammaness = predict_proba[0, 0]

                try:
                    # the MC direction of origin of the simulated particle
                    shower = event.mc
                    shower_core = np.array([shower.core_x / u.m,
                                            shower.core_y / u.m]) * u.m
                    shower_org = linalg.set_phi_theta(az_to_phi(shower.az),
                                                      alt_to_theta(shower.alt))

                    # and how the reconstructed direction compares to that
                    xi = linalg.angle(dir_fit, shower_org)
                    DeltaR = linalg.length(pos_fit[:2] - shower_core)
                except Exception:
                    # naked exception catch, because I'm not sure where
                    # it would break in non-MC files
                    xi = np.nan
                    DeltaR = np.nan

                phi, theta = linalg.get_phi_theta(dir_fit)
                phi = (phi if phi > 0 else phi + 360 * u.deg)

                # TODO: replace with actual array pointing direction
                array_pointing = linalg.set_phi_theta(0 * u.deg, 20. * u.deg)
                # angular offset between the reconstructed direction and the array
                # pointing
                off_angle = linalg.angle(dir_fit, array_pointing)

                reco_event["NTels_trig"] = len(event.dl0.tels_with_data)
                reco_event["NTels_reco"] = len(hillas_dict)
                reco_event["NTels_reco_lst"] = n_tels["LST"]
                reco_event["NTels_reco_mst"] = n_tels["MST"]
                reco_event["NTels_reco_sst"] = n_tels["SST"]
                reco_event["reco_Energy"] = predict_energ.to(energy_unit).value
                reco_event["reco_phi"] = phi / angle_unit
                reco_event["reco_theta"] = theta / angle_unit
                reco_event["off_angle"] = off_angle / angle_unit
                reco_event["xi"] = xi / angle_unit
                reco_event["DeltaR"] = DeltaR / dist_unit
                reco_event["ErrEstPos"] = err_est_pos / dist_unit
                reco_event["ErrEstDir"] = err_est_dir / angle_unit
                reco_event["gammaness"] = gammaness
                reco_event["success"] = True
            else:
                reco_event["success"] = False

            # save basic event infos
            reco_event["MC_Energy"] = event.mc.energy.to(energy_unit).value
            reco_event["Event_ID"] = event.r1.event_id
            reco_event["Run_ID"] = event.r1.run_id

            reco_table.flush()
            reco_event.append()

            if signal_handler.stop:
                break
        if signal_handler.stop:
            break

    # make sure everything gets written out nicely
    reco_table.flush()

    try:
        print()
        Eventcutflow()
        print()
        Imagecutflow()

        # do some simple event selection
        # and print the corresponding selection efficiency
        N_selected = len([x for x in reco_table.where(
            """(NTels_reco > min_tel) & (gammaness > agree_threshold)""")])
        N_total = len(reco_table)
        print("\nfraction selected events:")
        print("{} / {} = {} %".format(N_selected, N_total, N_selected / N_total * 100))

    except ZeroDivisionError:
        pass

    print("\nlength filenamelist:", len(filenamelist[:args.last]))

    # do some plotting if so desired
    if args.plot:
        gammaness = [x['gammaness'] for x in reco_table]
        NTels_rec = [x['NTels_reco'] for x in reco_table]
        NTel_bins = np.arange(np.min(NTels_rec), np.max(NTels_rec) + 2) - .5

        NTels_rec_lst = [x['NTels_reco_lst'] for x in reco_table]
        NTels_rec_mst = [x['NTels_reco_mst'] for x in reco_table]
        NTels_rec_sst = [x['NTels_reco_sst'] for x in reco_table]

        reco_energy = np.array([x['reco_Energy'] for x in reco_table])
        mc_energy = np.array([x['MC_Energy'] for x in reco_table])

        fig = plt.figure(figsize=(15, 5))
        plt.suptitle(" ** ".join([args.mode, "protons" if args.proton else "gamma"]))
        plt.subplots_adjust(left=0.05, right=0.97, hspace=0.39, wspace=0.2)

        ax = plt.subplot(131)
        histo = np.histogram2d(NTels_rec, gammaness,
                               bins=(NTel_bins, np.linspace(0, 1, 11)))[0].T
        histo_normed = histo / histo.max(axis=0)
        im = ax.imshow(histo_normed, interpolation='none', origin='lower',
                       aspect='auto',
                       # extent=(*NTel_bins[[0, -1]], 0, 1),
                       cmap=plt.cm.inferno)
        ax.set_xlabel("NTels")
        ax.set_ylabel("drifted gammaness")
        plt.title("Total Number of Telescopes")

        # next subplot

        ax = plt.subplot(132)
        histo = np.histogram2d(NTels_rec_sst, gammaness,
                               bins=(NTel_bins, np.linspace(0, 1, 11)))[0].T
        histo_normed = histo / histo.max(axis=0)
        im = ax.imshow(histo_normed, interpolation='none', origin='lower',
                       aspect='auto',
                       # extent=(*NTel_bins[[0, -1]], 0, 1),
                       cmap=plt.cm.inferno)
        ax.set_xlabel("NTels")
        plt.setp(ax.get_yticklabels(), visible=False)
        plt.title("Number of SSTs")

        # next subplot

        ax = plt.subplot(133)
        histo = np.histogram2d(NTels_rec_mst, gammaness,
                               bins=(NTel_bins, np.linspace(0, 1, 11)))[0].T
        histo_normed = histo / histo.max(axis=0)
        im = ax.imshow(histo_normed, interpolation='none', origin='lower',
                       aspect='auto',
                       # extent=(*NTel_bins[[0, -1]], 0, 1),
                       cmap=plt.cm.inferno)
        cb = fig.colorbar(im, ax=ax)
        ax.set_xlabel("NTels")
        plt.setp(ax.get_yticklabels(), visible=False)
        plt.title("Number of MSTs")

        plt.subplots_adjust(wspace=0.05)

        # plot the energy migration matrix
        plt.figure()
        plt.hist2d(np.log10(reco_energy), np.log10(mc_energy), bins=20,
                   cmap=plt.cm.inferno)
        plt.xlabel("E_MC / TeV")
        plt.ylabel("E_rec / TeV")
        plt.colorbar()

        plt.show()