Example #1
0
def rapid_stage(locus_data):
    locus_properties = str(locus_data.get_properties())
    objid = locus_properties['alert_id']
    ra = locus_properties['ra']
    dec = locus_properties['dec']
    redshift = 0.
    mwebv = 0.2

    alert_id, mjd, ras, decs, passband, mag, magerr, zeropoint, = locus_data.get_time_series(
        'ra', 'dec', 'ztf_fid', 'ztf_magpsf', 'ztf_sigmapsf', 'ztf_magzpsci')

    flux = 10.**(-0.4 * (mag - zeropoint))
    fluxerr = np.abs(flux * magerr * (np.log(10.) / 2.5))
    photflag = [0] * int(len(mjd) / 2 - 3) + [
        6144
    ] + [4096] * int(len(mjd) / 2 + 2)
    passband = np.where((passband == 1) | (passband == '1.0'), 'g', passband)
    passband = np.where((passband == 2) | (passband == '2.0'), 'r', passband)

    deleteindexes = np.where((passband == 3) | (passband == '3.0'))
    mjd, passband, flux, fluxerr, zeropoint, photflag = delete_indexes(
        deleteindexes, mjd, passband, flux, fluxerr, zeropoint, photflag)

    light_curve_list = [(mjd, flux, fluxerr, passband, zeropoint, photflag, ra,
                         dec, objid, redshift, mwebv)]

    classification = Classify(light_curve_list)
    predictions = classification.get_predictions()
    print(predictions)

    print(
        locus_data.get_time_series('ra', 'dec', 'ztf_fid', 'ztf_magpsf',
                                   'ztf_sigmapsf', 'ztf_magzpsci',
                                   'ztf_diffmaglim'))
Example #2
0
def main(graph=None, model=None):
    """
    Example code to run astrorapid.

    Ignore the graph and model parameter inputs unless you wish to do your own multithreading.
    (Note: astrorapid already performs its own parallelisation based on a keras and tensorflow backend).
    """

    mjd = [57433.4816, 57436.4815, 57439.4817, 57451.4604, 57454.4397, 57459.3963, 57462.418, 57465.4385, 57468.3768,
           57473.3606, 57487.3364, 57490.3341, 57493.3154, 57496.3352, 57505.3144, 57513.2542, 57532.2717, 57536.2531,
           57543.2545, 57546.2703, 57551.2115, 57555.2669, 57558.2769, 57561.1899, 57573.2133, 57433.5019, 57436.4609,
           57439.4587, 57444.4357, 57459.4189, 57468.3142, 57476.355, 57479.3568, 57487.3586, 57490.3562, 57493.3352,
           57496.2949, 57505.3557, 57509.2932, 57513.2934, 57518.2735, 57521.2739, 57536.2321, 57539.2115, 57543.2301,
           57551.1701, 57555.2107, 57558.191, 57573.1923, 57576.1749, 57586.1854]
    flux = [2.0357230e+00, -2.0382695e+00, 1.0084588e+02, 5.5482742e+01, 1.4867026e+01, -6.5136810e+01, 1.6740545e+01,
            -5.7269131e+01, 1.0649184e+02, 1.5505235e+02, 3.2445984e+02, 2.8735449e+02, 2.0898877e+02, 2.8958893e+02,
            1.9793906e+02, -1.3370536e+01, -3.9001358e+01, 7.4040916e+01, -1.7343750e+00, 2.7844931e+01, 6.0861992e+01,
            4.2057487e+01, 7.1565346e+01, -2.6085690e-01, -6.8435440e+01, 17.573107, 41.445435, -110.72664, 111.328964,
            -63.48336, 352.44907, 199.59058, 429.83075, 338.5255, 409.94604, 389.71262, 195.63905, 267.13318, 123.92461,
            200.3431, 106.994514, 142.96387, 56.491238, 55.17521, 97.556946, -29.263103, 142.57687, -20.85057,
            -0.67210346, 63.353024, -40.02601]
    fluxerr = [42.784702, 43.83665, 99.98704, 45.26248, 43.040398, 44.00679, 41.856007, 49.354336, 105.86439, 114.0044,
               45.697918, 44.15781, 60.574158, 93.08788, 66.04482, 44.26264, 91.525085, 42.768955, 43.228336, 44.178196,
               62.15593, 109.270035, 174.49638, 72.6023, 48.021034, 44.86118, 48.659588, 100.97703, 148.94061, 44.98218,
               139.11194, 71.4585, 47.766987, 45.77923, 45.610615, 60.50458, 105.11658, 71.41217, 43.945534, 45.154167,
               43.84058, 52.93122, 44.722775, 44.250145, 43.95989, 68.101326, 127.122025, 124.1893, 49.952255, 54.50728,
               114.91599]
    passband = ['g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g',
                'g', 'g', 'g', 'g', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r',
                'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r']
    zeropoint = [27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5,
                 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5,
                 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5]
    photflag = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4096, 4096, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4096,
                6144, 4096, 4096, 4096, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
    objid = 'MSIP_01_NONIa-0001_10400862'
    ra = 3.75464531293933
    dec = 0.205076187109334
    redshift = 0.233557
    mwebv = 0.0228761

    light_curve_list = [(mjd, flux, fluxerr, passband, zeropoint, photflag, ra, dec, objid, redshift, mwebv)]

    classification = Classify(light_curve_list, known_redshift=True, graph=graph, model=model)
    predictions, time_steps = classification.get_predictions(return_predictions_at_obstime=False)
    print(predictions)

    # classification.plot_light_curves_and_classifications(step=False)
    # classification.plot_classification_animation()
    # classification.plot_classification_animation_step()

    predictions, time_steps = classification.get_predictions(return_predictions_at_obstime=True)

    import matplotlib.pyplot as plt
    for i, class_name in enumerate(classification.class_names):
        plt.plot(time_steps[0], predictions[0][:, i], label=class_name)
    plt.legend()
    plt.show()
def rapid_stage(locus_data):
    locus_properties = locus_data.get_properties()
    objid = locus_properties['alert_id']
    ra = locus_properties['ra']
    dec = locus_properties['dec']
    redshift = 0.  # TODO: Get correct redshift
    mwebv = 0.2  # TODO: Get correct extinction

    alert_id, mjd, ras, decs, passband, mag, magerr, zeropoint, = locus_data.get_time_series(
        'ra', 'dec', 'ztf_fid', 'ztf_magpsf', 'ztf_sigmapsf', 'ztf_magzpsci')

    # Ignore mag Nonetype values
    alertid, mjd, ras, decs, passband, mag, magerr, zeropoint = delete_indexes(
        np.where(mag == None), alert_id, mjd, ras, decs, passband, mag, magerr,
        zeropoint)

    print(
        locus_data.get_time_series('ra', 'dec', 'ztf_fid', 'ztf_magpsf',
                                   'ztf_sigmapsf', 'ztf_magzpsci'))
    if len(mjd) < 3:
        print("less than 3 points")
        return
    zpt_median = np.median(zeropoint[zeropoint != None])
    zeropoint[zeropoint == None] = zpt_median
    zeropoint = np.asarray(zeropoint, dtype=np.float64)
    mag = np.asarray(mag, dtype=np.float64)

    flux = 10.**(-0.4 * (mag - zeropoint))
    fluxerr = np.abs(flux * magerr * (np.log(10.) / 2.5))

    passband = np.where((passband == 1) | (passband == '1.0'), 'g', passband)
    passband = np.where((passband == 2) | (passband == '2.0'), 'r', passband)

    # Set photflag detections when S/N > 5
    photflag = np.zeros(len(flux))
    photflag[flux / fluxerr > 5] = 4096
    photflag[np.where(mjd == min(mjd[photflag == 4096]))] = 6144

    deleteindexes = np.where((passband == 3) | (passband == '3.0')
                             | (np.isnan(mag)))
    mjd, passband, flux, fluxerr, zeropoint, photflag = delete_indexes(
        deleteindexes, mjd, passband, flux, fluxerr, zeropoint, photflag)

    light_curve_list = [(mjd, flux, fluxerr, passband, zeropoint, photflag, ra,
                         dec, objid, redshift, mwebv)]

    from keras import backend as K
    K.clear_session()

    classification = Classify(known_redshift=True)
    predictions = classification.get_predictions(light_curve_list)
    print(predictions)

    for i, name in enumerate(classification.class_names):
        locus_data.set_property('rapid_class_probability_{}'.format(name),
                                predictions[0][-1][i])
Example #4
0
def classify_lasair_light_curves(object_names=('ZTF18acsovsw', ), plot=True):
    light_curve_list = []
    for object_name in object_names:
        try:
            mjd, passband, mag, magerr, photflag, zeropoint, ra, dec, objid, redshift, mwebv = read_lasair_json(
                object_name)
        except Exception as e:
            print(e)
            continue

        flux = 10.**(-0.4 * (mag - zeropoint))
        fluxerr = np.abs(flux * magerr * (np.log(10.) / 2.5))

        passband = np.where((passband == 1) | (passband == '1'), 'g', passband)
        passband = np.where((passband == 2) | (passband == '2'), 'r', passband)

        # Set photflag detections when S/N > 5
        photflag2 = np.zeros(len(flux))
        photflag2[flux / fluxerr > 5] = 4096
        photflag2[np.where(mjd == min(mjd[photflag2 == 4096]))] = 6144

        mjd_first_detection = min(mjd[photflag == 4096])
        photflag[np.where(mjd == mjd_first_detection)] = 6144

        deleteindexes = np.where(((passband == 3) | (passband == '3'))
                                 | (mjd > mjd_first_detection)
                                 & (photflag == 0))
        print("Deleting indexes {} at mjd {} and passband {}".format(
            deleteindexes, mjd[deleteindexes], passband[deleteindexes]))
        mjd, passband, flux, fluxerr, zeropoint, photflag = delete_indexes(
            deleteindexes, mjd, passband, flux, fluxerr, zeropoint, photflag)

        light_curve_list += [(mjd, flux, fluxerr, passband, zeropoint,
                              photflag, ra, dec, objid, redshift, mwebv)]

    classification = Classify(light_curve_list,
                              known_redshift=True,
                              bcut=False,
                              zcut=None)
    predictions, time_steps = classification.get_predictions(
        return_predictions_at_obstime=False)
    print(predictions)

    if plot:
        classification.plot_light_curves_and_classifications(
            step=True, use_interp_flux=False)
        # classification.plot_light_curves_and_classifications(step=False, use_interp_flux=True)
        # classification.plot_light_curves_and_classifications(step=False, use_interp_flux=False)
        # classification.plot_classification_animation_step()
        # classification.plot_classification_animation()

    return classification.orig_lc, classification.timesX, classification.y_predict
 def __init__(self, model: str = None):
     """
     Creates a new instance of RAPID
     Parameters
     ----------
     model: str, optional
         The filepath of the model to be loaded, if any.
     """
     if model is not None:
         self.classifier = Classify(known_redshift=True,
                                    model_filepath=model)
     else:
         self.classifier = Classify(known_redshift=True)
Example #6
0
    def classify(self, lc):

        mjd, flux, fluxerr, passband, ra, dec, objid, redshift, mwebv = lc[0]

        photflag = [4096] * len(flux)
        photflag[np.argmax(flux)] = 6144
        photflag = np.array(photflag)

        light_curve_info1 = (mjd, flux, fluxerr, passband, photflag, ra, dec,
                             objid, redshift, mwebv)
        light_curve_list = [
            light_curve_info1,
        ]

        classification = Classify(known_redshift=True,
                                  model_filepath='keras_model.hdf5')
        predictions = classification.get_predictions(light_curve_list)

        y_predict = predictions[0][0]
        time_steps = predictions[1][0]

        classified_lightcurve = {
            'time': datetime.datetime.utcnow(),
            'obj_id': objid,
            'mjd': mjd,
            'flux': flux,
            'fluxerr': fluxerr,
            'passband': passband,
            'ra': ra,
            'dec': dec,
            'redshift': redshift,
            'mwebv': mwebv,
            'photflag': photflag.tolist(),
            'predicted_y': y_predict.tolist(),
            'timesteps': time_steps.tolist(),
        }

        print(classified_lightcurve['timesteps'])

        client = MongoClient('ampelml-mongo')
        db = client.ampel_ml
        classified_lightcurves = db.classified_lightcurves

        class_lc_id = classified_lightcurves.insert_one(
            classified_lightcurve).inserted_id

        last_class_lc = classified_lightcurves.find().sort([('time', -1)
                                                            ]).limit(1)[0]

        return dumps(last_class_lc)
Example #7
0
def main():
    mjd = [
        57433.4816, 57436.4815, 57439.4817, 57451.4604, 57454.4397, 57459.3963,
        57462.418, 57465.4385, 57468.3768, 57473.3606, 57487.3364, 57490.3341,
        57493.3154, 57496.3352, 57505.3144, 57513.2542, 57532.2717, 57536.2531,
        57543.2545, 57546.2703, 57551.2115, 57555.2669, 57558.2769, 57561.1899,
        57573.2133, 57433.5019, 57436.4609, 57439.4587, 57444.4357, 57459.4189,
        57468.3142, 57476.355, 57479.3568, 57487.3586, 57490.3562, 57493.3352,
        57496.2949, 57505.3557, 57509.2932, 57513.2934, 57518.2735, 57521.2739,
        57536.2321, 57539.2115, 57543.2301, 57551.1701, 57555.2107, 57558.191,
        57573.1923, 57576.1749, 57586.1854
    ]
    flux = [
        2.0357230e+00, -2.0382695e+00, 1.0084588e+02, 5.5482742e+01,
        1.4867026e+01, -6.5136810e+01, 1.6740545e+01, -5.7269131e+01,
        1.0649184e+02, 1.5505235e+02, 3.2445984e+02, 2.8735449e+02,
        2.0898877e+02, 2.8958893e+02, 1.9793906e+02, -1.3370536e+01,
        -3.9001358e+01, 7.4040916e+01, -1.7343750e+00, 2.7844931e+01,
        6.0861992e+01, 4.2057487e+01, 7.1565346e+01, -2.6085690e-01,
        -6.8435440e+01, 17.573107, 41.445435, -110.72664, 111.328964,
        -63.48336, 352.44907, 199.59058, 429.83075, 338.5255, 409.94604,
        389.71262, 195.63905, 267.13318, 123.92461, 200.3431, 106.994514,
        142.96387, 56.491238, 55.17521, 97.556946, -29.263103, 142.57687,
        -20.85057, -0.67210346, 63.353024, -40.02601
    ]
    fluxerr = [
        42.784702, 43.83665, 99.98704, 45.26248, 43.040398, 44.00679,
        41.856007, 49.354336, 105.86439, 114.0044, 45.697918, 44.15781,
        60.574158, 93.08788, 66.04482, 44.26264, 91.525085, 42.768955,
        43.228336, 44.178196, 62.15593, 109.270035, 174.49638, 72.6023,
        48.021034, 44.86118, 48.659588, 100.97703, 148.94061, 44.98218,
        139.11194, 71.4585, 47.766987, 45.77923, 45.610615, 60.50458,
        105.11658, 71.41217, 43.945534, 45.154167, 43.84058, 52.93122,
        44.722775, 44.250145, 43.95989, 68.101326, 127.122025, 124.1893,
        49.952255, 54.50728, 114.91599
    ]
    passband = [
        'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g',
        'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'r', 'r', 'r',
        'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r',
        'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r'
    ]
    zeropoint = [
        27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5,
        27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5,
        27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5,
        27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5,
        27.5, 27.5, 27.5
    ]
    photflag = [
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4096, 4096, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 4096, 6144, 4096, 4096, 4096, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0
    ]
    objid = 'MSIP_01_NONIa-0001_10400862'
    ra = 3.75464531293933
    dec = 0.205076187109334
    redshift = 0.233557
    mwebv = 0.0228761

    light_curve_list = [(mjd, flux, fluxerr, passband, zeropoint, photflag, ra,
                         dec, objid, redshift, mwebv)]

    classification = Classify(light_curve_list)
    predictions = classification.get_predictions()
    print(predictions)

    classification.plot_light_curves_and_classifications()
    classification.plot_classification_animation()
Example #8
0
	def do(self,debug=False):
		
		# run this under Admin
		user = User.objects.get(username='******')

		# get New, Watch, FollowupRequested, Following
		transients_to_classify = \
			Transient.objects.filter(Q(status__name = 'New') |
									 Q(status__name = 'Watch') |
									 Q(status__name = 'FollowupRequested') |
									 Q(status__name = 'Following'))

		light_curve_list_z,light_curve_list_noz,transient_list_z,transient_list_noz = [],[],[],[]
		for t in transients_to_classify: #.filter(Q(name = '2019np') | Q(name = '2019gf')):
			ra, dec, objid, redshift = t.ra, t.dec, t.name,t.z_or_hostz()

			if not t.mw_ebv is None:
				mwebv = t.mw_ebv
			else:
				mwebv = 0.0


			photdata = get_all_phot_for_transient(user, t.id)
			if not photdata: continue
			gobs = photdata.filter(band__name = 'g-ZTF')
			robs = photdata.filter(band__name = 'r-ZTF')
			if not len(gobs) and not len(robs): continue
			mjd, passband, flux, fluxerr, mag, magerr, zeropoint, photflag = \
				np.array([]),np.array([]),[],[],[],[],[],[]
			
			if redshift: transient_list_z += [t]
			else: transient_list_noz += [t]
				
			first_detection_set = False
			for obs,filt in zip([gobs.order_by('obs_date'),robs.order_by('obs_date')],['g','r']):
				for p in obs:
					if p.data_quality: continue
					if len(np.where((filt == passband) & (np.abs(mjd - date_to_mjd(p.obs_date)) < 0.001))[0]): continue

					mag += [p.mag]
					if p.mag_err:
						magerr += [p.mag_err]
						fluxerr_obs = 0.4*np.log(10)*p.mag_err
					else:
						magerr += [0.001]
						fluxerr_obs = 0.4*np.log(10)*0.001

					flux_obs = 10**(-0.4*(p.mag-27.5))
					mjd = np.append(mjd,[date_to_mjd(p.obs_date)])
					flux += [flux_obs]
					fluxerr += [fluxerr_obs]
					zeropoint += [27.5]
					passband = np.append(passband,[filt])
					
					if flux_obs/fluxerr_obs > 5 and not first_detection_set:
						photflag += [6144]
						first_detection_set = True
					elif flux_obs/fluxerr_obs > 5:
						photflag += [4096]
					else: photflag += [0]
					#except: import pdb; pdb.set_trace()

			try:
				if redshift:
					light_curve_info = (mjd, flux, fluxerr, passband,
										zeropoint, photflag, ra, dec, objid, redshift, mwebv)
					light_curve_list_z += [light_curve_info,]
				else:
					light_curve_info = (mjd, flux, fluxerr, passband, zeropoint, photflag, ra, dec, objid, None, mwebv)			
					light_curve_list_noz += [light_curve_info,]
			except:
				import pdb; pdb.set_trace()

		if len(light_curve_list_noz):
			classification_noz = Classify(light_curve_list_noz, known_redshift=False, bcut=False, zcut=None)
			predictions_noz = classification_noz.get_predictions()
		if len(light_curve_list_z):
			classification_z = Classify(light_curve_list_z, known_redshift=True, bcut=False, zcut=None)
			predictions_z = classification_z.get_predictions()
			
		if debug:
			import matplotlib
			matplotlib.use('MacOSX')
			import matplotlib.pyplot as plt
			plt.ion()
			classification_z.plot_light_curves_and_classifications()

		for tl in [transient_list_z,transient_list_noz]:
			for t,i in zip(tl,range(len(tl))):
				best_predictions = predictions_z[0][i][-1,:]

				adjusted_best_predictions = np.zeros(10)
				idx,outclassnames,PIa = 0,[],0
				for j in range(len(classification_z.class_names)):
					if classification_z.class_names[j] == 'Pre-explosion': continue
					elif classification_z.class_names[j].startswith('SNIa'): PIa += best_predictions[j]
					else:
						outclassnames += [classification_z.class_names[j]]
						adjusted_best_predictions[idx] = best_predictions[j]
						idx += 1
				outclassnames += ['SN Ia']
				outclassnames = np.array(outclassnames)
				adjusted_best_predictions[9] = PIa

				print(t.name,outclassnames[adjusted_best_predictions == np.max(adjusted_best_predictions)][0])
				transient_class = outclassnames[adjusted_best_predictions == np.max(adjusted_best_predictions)][0]
				photo_class = TransientClass.objects.filter(name = classdict[transient_class])

				if len(photo_class):
					t.photo_class = photo_class[0]
					t.save()
				else:
					print('class %s not in DB'%classdict[transient_class])
					raise RuntimeError('class %s not in DB'%classdict[transient_class])
Example #9
0
def classify_lasair_light_curves(
        object_names=('ZTF18acsovsw', ), plot=True, figdir='.'):
    light_curve_list = []
    peakfluxes_g, peakfluxes_r = [], []
    mjds, passbands, mags, magerrs, zeropoints, photflags = [], [], [], [], [], []
    obj_names = []
    ras, decs, objids, redshifts, mwebvs = [], [], [], [], []
    peakmags_g, peakmags_r = [], []
    for object_name in object_names:
        try:
            mjd, passband, mag, magerr, photflag, zeropoint, ra, dec, objid, redshift, mwebv = read_lasair_json(
                object_name)
            sortidx = np.argsort(mjd)
            mjds.append(mjd[sortidx])
            passbands.append(passband[sortidx])
            mags.append(mag[sortidx])
            magerrs.append(magerr[sortidx])
            zeropoints.append(zeropoint[sortidx])
            photflags.append(photflag[sortidx])
            obj_names.append(object_name)
            ras.append(ra)
            decs.append(dec)
            objids.append(objid)
            redshifts.append(redshift)
            mwebvs.append(mwebv)
            peakmags_g.append(min(mag[passband == 1]))
            peakmags_r.append(min(mag[passband == 2]))

        except Exception as e:
            print(e)
            continue

        flux = 10.**(-0.4 * (mag - zeropoint))
        fluxerr = np.abs(flux * magerr * (np.log(10.) / 2.5))

        passband = np.where((passband == 1) | (passband == '1'), 'g', passband)
        passband = np.where((passband == 2) | (passband == '2'), 'r', passband)

        # Set photflag detections when S/N > 5
        photflag2 = np.zeros(len(flux))
        photflag2[flux / fluxerr > 5] = 4096
        photflag2[np.where(mjd == min(mjd[photflag2 == 4096]))] = 6144

        mjd_first_detection = min(mjd[photflag == 4096])
        photflag[np.where(mjd == mjd_first_detection)] = 6144

        deleteindexes = np.where(((passband == 3) | (passband == '3'))
                                 | (mjd > mjd_first_detection)
                                 & (photflag == 0))
        if deleteindexes[0].size > 0:
            print("Deleting indexes {} at mjd {} and passband {}".format(
                deleteindexes, mjd[deleteindexes], passband[deleteindexes]))
        mjd, passband, flux, fluxerr, zeropoint, photflag = delete_indexes(
            deleteindexes, mjd, passband, flux, fluxerr, zeropoint, photflag)

        light_curve_list += [(mjd, flux, fluxerr, passband, photflag, ra, dec,
                              objid, redshift, mwebv)]

        try:
            dummy = max(flux[passband == 'g'])
            dummy = max(flux[passband == 'r'])
        except Exception as e:
            print(e)
            continue

        peakfluxes_g.append(max(flux[passband == 'g']))
        peakfluxes_r.append(max(flux[passband == 'r']))

    # import sys
    # import pickle
    # with open('save_real_ZTF_unprocessed_data_snia_osc_12nov2019.npz', 'wb') as f:
    #     pickle.dump([mjds, passbands, mags, magerrs, photflags, zeropoints, ras, decs, objids, redshifts, mwebvs], f)
    # # np.savez('save_real_ZTF_unprocessed_data_snia_osc_12nov2019.npz', mjds=mjds, passbands=passbands, mags=mags, magerrs=magerrs, photflags=photflags, zeropoints=zeropoints, ras=ras, decs=decs, objids=objids, redshifts=redshifts, mwebvs=mwebvs)# , peakflux_g=peakfluxes_g, peakflux_r=peakfluxes_r)
    # print("finished")
    # # # sys.exit(0)
    # # with open('save_real_ZTF_unprocessed_data_snia_osc_12nov2019.npz', 'rb') as f:
    # #     a = pickle.load(f)

    classification = Classify(known_redshift=True, bcut=False, zcut=None)
    predictions, time_steps = classification.get_predictions(
        light_curve_list, return_predictions_at_obstime=False)
    print(predictions)

    if plot:
        # try:
        classification.plot_light_curves_and_classifications(
            step=True,
            use_interp_flux=False,
            figdir=figdir,
            plot_matrix_input=True)
        # except Exception as e:
        #     print(e)
        # classification.plot_light_curves_and_classifications(step=False, use_interp_flux=True)
        # classification.plot_light_curves_and_classifications(step=False, use_interp_flux=False)
        # classification.plot_classification_animation_step()
        # classification.plot_classification_animation()

    return classification.orig_lc, classification.timesX, classification.y_predict
        1.9793906e+02, -1.3370536e+01, -3.9001358e+01, 7.4040916e+01, -1.7343750e+00, 2.7844931e+01, 6.0861992e+01,
        4.2057487e+01, 7.1565346e+01, -2.6085690e-01, -6.8435440e+01, 17.573107, 41.445435, -110.72664, 111.328964,
        -63.48336, 352.44907, 199.59058, 429.83075, 338.5255, 409.94604, 389.71262, 195.63905, 267.13318, 123.92461,
        200.3431, 106.994514, 142.96387, 56.491238, 55.17521, 97.556946, -29.263103, 142.57687, -20.85057, -0.67210346,
        63.353024, -40.02601]
fluxerr = [42.784702, 43.83665, 99.98704, 45.26248, 43.040398, 44.00679, 41.856007, 49.354336, 105.86439, 114.0044,
           45.697918, 44.15781, 60.574158, 93.08788, 66.04482, 44.26264, 91.525085, 42.768955, 43.228336, 44.178196,
           62.15593, 109.270035, 174.49638, 72.6023, 48.021034, 44.86118, 48.659588, 100.97703, 148.94061, 44.98218,
           139.11194, 71.4585, 47.766987, 45.77923, 45.610615, 60.50458, 105.11658, 71.41217, 43.945534, 45.154167,
           43.84058, 52.93122, 44.722775, 44.250145, 43.95989, 68.101326, 127.122025, 124.1893, 49.952255, 54.50728,
           114.91599]
passband = ['g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g',
            'g', 'g', 'g', 'g', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r',
            'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r']
photflag = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4096, 4096, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4096,
            6144, 4096, 4096, 4096, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
objid = 'transient_1'
ra = 3.75464531293933
dec = 0.205076187109334
redshift = 0.233557
mwebv = 0.0228761

light_curve_list = [(mjd, flux, fluxerr, passband, photflag, ra, dec, objid, redshift, mwebv)]

classification = Classify()
predictions, time_steps = classification.get_predictions(light_curve_list)
print(predictions)

classification.plot_light_curves_and_classifications()
classification.plot_classification_animation()
class RAPIDModel:
    """Wrapper for the model outlined by the RAPID paper (Muthukrishna et al 2019)"""
    def __init__(self, model: str = None):
        """
        Creates a new instance of RAPID
        Parameters
        ----------
        model: str, optional
            The filepath of the model to be loaded, if any.
        """
        if model is not None:
            self.classifier = Classify(known_redshift=True,
                                       model_filepath=model)
        else:
            self.classifier = Classify(known_redshift=True)

    def set_metadata(self, metadata: pd.DataFrame):
        """Sets the loaded metadata

        Parameters
        ----------
        metadata: pd.DataFrame :
            The new metadata to be loaded
        
        """
        assert isinstance(metadata, pd.DataFrame)
        self._metadata = metadata

    def set_curves(self, curves: pd.DataFrame):
        """Sets the loaded transient data

        Parameters
        ----------
        curves: pd.DataFrame :
            The new transient data to be loaded
        """
        assert isinstance(curves, pd.DataFrame)
        self._curves = curves

    def set_data(self, curves: pd.DataFrame, metadata: pd.DataFrame):
        """Set both the transient data and the metadata

        Parameters
        ----------
        curves: pd.DataFrame :
            The new transient data to be loaded
        metadata: pd.DataFrame :
            The new metadata to be loaded
        """
        self.set_curves(curves)
        self.set_metadata(metadata)

    def _get_custom_data(self, class_num, data_dir, save_dir, passbands,
                         known_redshift, nprocesses, redo):
        """Function for traning purposes.
        Notes
        -----
        See astrorapid for API usage.
        """
        light_list, target_list = p2r.convert(self._curves, self._metadata)
        # now we need to preprocess
        return read_multiple_light_curves(light_list)

    def train(self,
              curves: pd.DataFrame,
              metadata: pd.DataFrame,
              *,
              class_map: dict = p2r.class_map,
              band_map: dict = p2r.band_map,
              save_path: str = "models/",
              file_name: str = None,
              load_model: bool = False):
        """Train the model on the loaded data.

        Parameters
        ----------
        curves : pd.DataFrame
            The curve data to train on.
        metadata : pd.DataFrame
            The metadata for each curve.
        class_map : dict, optional
            The class mapping from PLAsTiCC to the model.
        band_map : dict, optional
            The bands and their mapping from PLAsTiCC to models.
        save_path : str, optional
            Where the model should be saved. Default "models/"
        file_name : str, optional
            Override the filename of the model. Defaults to RAPIDModel_yyyy_mm_dd_hh_mm_ss.hdf5
        load_model : bool, optional
            Whether the model should be loaded into the class after training. Default False.
        """

        # we need to create a new model and replace the classifier with it.
        # use currying to return a function.
        def _get_data(class_num, data_dir, save_dir, passbands, known_redshift,
                      nprocesses, redo, calculate_t0):
            # This is a function RAPID needs to call to get the data.
            # get the class num tuple
            class_map = {
                key: value
                for (key, value) in p2r.class_map.items() if value is class_num
            }
            band_map = {
                key: value
                for (key, value) in p2r.band_map.items() if key in passbands
            }
            light_list, target_list = p2r.convert(curves,
                                                  metadata,
                                                  classes=class_map,
                                                  bands=band_map)
            # now we need to preprocess
            return read_multiple_light_curves(light_list)

        #
        create_custom_classifier(
            get_data_func=_get_data,
            data_dir='data/',
            class_nums=tuple(class_map.values()),
            passbands=tuple(band_map.keys()),
            save_dir=save_path,
        )
        pass

    def test(self,
             curves: pd.DataFrame,
             metadata: pd.DataFrame,
             return_probabilities: bool = False) -> (list, list):
        """Tests the model on the currently loaded data.

        Parameters
        ----------
        curves : pd.DataFrame
            The transient data for each object.
        metadata : pd.DataFrame
            The metadata for each object.
        return_probabilities : bool, optional
            If the predictions should be a probability of arrays as opposed to the most likely class

        Returns
        -------
        target_list : list of int
            A list of true classes, ordered.
        predictions_list : list of int or list of list of int
            A list of predictions, ordered. Depending on the value of `return_probabilities`,
            can either be an int
        
        """
        logger.info("testing model")
        light_list, target_list = p2r.convert(curves, metadata)
        predictions, steps = self.classifier.get_predictions(light_list)
        assert len(target_list) == len(predictions)
        target_list = np.add(target_list, 1)
        logger.info("done testing model")
        if return_probabilities:
            return target_list, predictions

        predictions_list = []
        for index, pred in enumerate(predictions):
            pred_class = np.argmax(pred[-1])
            # if pred_class is forbidden (0 or 9) don't bother
            # if pred_class < 0 or pred_class > 8:
            #     continue
            predictions_list.append(pred_class)
        return target_list, predictions_list