def rapid_stage(locus_data): locus_properties = str(locus_data.get_properties()) objid = locus_properties['alert_id'] ra = locus_properties['ra'] dec = locus_properties['dec'] redshift = 0. mwebv = 0.2 alert_id, mjd, ras, decs, passband, mag, magerr, zeropoint, = locus_data.get_time_series( 'ra', 'dec', 'ztf_fid', 'ztf_magpsf', 'ztf_sigmapsf', 'ztf_magzpsci') flux = 10.**(-0.4 * (mag - zeropoint)) fluxerr = np.abs(flux * magerr * (np.log(10.) / 2.5)) photflag = [0] * int(len(mjd) / 2 - 3) + [ 6144 ] + [4096] * int(len(mjd) / 2 + 2) passband = np.where((passband == 1) | (passband == '1.0'), 'g', passband) passband = np.where((passband == 2) | (passband == '2.0'), 'r', passband) deleteindexes = np.where((passband == 3) | (passband == '3.0')) mjd, passband, flux, fluxerr, zeropoint, photflag = delete_indexes( deleteindexes, mjd, passband, flux, fluxerr, zeropoint, photflag) light_curve_list = [(mjd, flux, fluxerr, passband, zeropoint, photflag, ra, dec, objid, redshift, mwebv)] classification = Classify(light_curve_list) predictions = classification.get_predictions() print(predictions) print( locus_data.get_time_series('ra', 'dec', 'ztf_fid', 'ztf_magpsf', 'ztf_sigmapsf', 'ztf_magzpsci', 'ztf_diffmaglim'))
def main(graph=None, model=None): """ Example code to run astrorapid. Ignore the graph and model parameter inputs unless you wish to do your own multithreading. (Note: astrorapid already performs its own parallelisation based on a keras and tensorflow backend). """ mjd = [57433.4816, 57436.4815, 57439.4817, 57451.4604, 57454.4397, 57459.3963, 57462.418, 57465.4385, 57468.3768, 57473.3606, 57487.3364, 57490.3341, 57493.3154, 57496.3352, 57505.3144, 57513.2542, 57532.2717, 57536.2531, 57543.2545, 57546.2703, 57551.2115, 57555.2669, 57558.2769, 57561.1899, 57573.2133, 57433.5019, 57436.4609, 57439.4587, 57444.4357, 57459.4189, 57468.3142, 57476.355, 57479.3568, 57487.3586, 57490.3562, 57493.3352, 57496.2949, 57505.3557, 57509.2932, 57513.2934, 57518.2735, 57521.2739, 57536.2321, 57539.2115, 57543.2301, 57551.1701, 57555.2107, 57558.191, 57573.1923, 57576.1749, 57586.1854] flux = [2.0357230e+00, -2.0382695e+00, 1.0084588e+02, 5.5482742e+01, 1.4867026e+01, -6.5136810e+01, 1.6740545e+01, -5.7269131e+01, 1.0649184e+02, 1.5505235e+02, 3.2445984e+02, 2.8735449e+02, 2.0898877e+02, 2.8958893e+02, 1.9793906e+02, -1.3370536e+01, -3.9001358e+01, 7.4040916e+01, -1.7343750e+00, 2.7844931e+01, 6.0861992e+01, 4.2057487e+01, 7.1565346e+01, -2.6085690e-01, -6.8435440e+01, 17.573107, 41.445435, -110.72664, 111.328964, -63.48336, 352.44907, 199.59058, 429.83075, 338.5255, 409.94604, 389.71262, 195.63905, 267.13318, 123.92461, 200.3431, 106.994514, 142.96387, 56.491238, 55.17521, 97.556946, -29.263103, 142.57687, -20.85057, -0.67210346, 63.353024, -40.02601] fluxerr = [42.784702, 43.83665, 99.98704, 45.26248, 43.040398, 44.00679, 41.856007, 49.354336, 105.86439, 114.0044, 45.697918, 44.15781, 60.574158, 93.08788, 66.04482, 44.26264, 91.525085, 42.768955, 43.228336, 44.178196, 62.15593, 109.270035, 174.49638, 72.6023, 48.021034, 44.86118, 48.659588, 100.97703, 148.94061, 44.98218, 139.11194, 71.4585, 47.766987, 45.77923, 45.610615, 60.50458, 105.11658, 71.41217, 43.945534, 45.154167, 43.84058, 52.93122, 44.722775, 44.250145, 43.95989, 68.101326, 127.122025, 124.1893, 49.952255, 54.50728, 114.91599] passband = ['g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r'] zeropoint = [27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5] photflag = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4096, 4096, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4096, 6144, 4096, 4096, 4096, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] objid = 'MSIP_01_NONIa-0001_10400862' ra = 3.75464531293933 dec = 0.205076187109334 redshift = 0.233557 mwebv = 0.0228761 light_curve_list = [(mjd, flux, fluxerr, passband, zeropoint, photflag, ra, dec, objid, redshift, mwebv)] classification = Classify(light_curve_list, known_redshift=True, graph=graph, model=model) predictions, time_steps = classification.get_predictions(return_predictions_at_obstime=False) print(predictions) # classification.plot_light_curves_and_classifications(step=False) # classification.plot_classification_animation() # classification.plot_classification_animation_step() predictions, time_steps = classification.get_predictions(return_predictions_at_obstime=True) import matplotlib.pyplot as plt for i, class_name in enumerate(classification.class_names): plt.plot(time_steps[0], predictions[0][:, i], label=class_name) plt.legend() plt.show()
def rapid_stage(locus_data): locus_properties = locus_data.get_properties() objid = locus_properties['alert_id'] ra = locus_properties['ra'] dec = locus_properties['dec'] redshift = 0. # TODO: Get correct redshift mwebv = 0.2 # TODO: Get correct extinction alert_id, mjd, ras, decs, passband, mag, magerr, zeropoint, = locus_data.get_time_series( 'ra', 'dec', 'ztf_fid', 'ztf_magpsf', 'ztf_sigmapsf', 'ztf_magzpsci') # Ignore mag Nonetype values alertid, mjd, ras, decs, passband, mag, magerr, zeropoint = delete_indexes( np.where(mag == None), alert_id, mjd, ras, decs, passband, mag, magerr, zeropoint) print( locus_data.get_time_series('ra', 'dec', 'ztf_fid', 'ztf_magpsf', 'ztf_sigmapsf', 'ztf_magzpsci')) if len(mjd) < 3: print("less than 3 points") return zpt_median = np.median(zeropoint[zeropoint != None]) zeropoint[zeropoint == None] = zpt_median zeropoint = np.asarray(zeropoint, dtype=np.float64) mag = np.asarray(mag, dtype=np.float64) flux = 10.**(-0.4 * (mag - zeropoint)) fluxerr = np.abs(flux * magerr * (np.log(10.) / 2.5)) passband = np.where((passband == 1) | (passband == '1.0'), 'g', passband) passband = np.where((passband == 2) | (passband == '2.0'), 'r', passband) # Set photflag detections when S/N > 5 photflag = np.zeros(len(flux)) photflag[flux / fluxerr > 5] = 4096 photflag[np.where(mjd == min(mjd[photflag == 4096]))] = 6144 deleteindexes = np.where((passband == 3) | (passband == '3.0') | (np.isnan(mag))) mjd, passband, flux, fluxerr, zeropoint, photflag = delete_indexes( deleteindexes, mjd, passband, flux, fluxerr, zeropoint, photflag) light_curve_list = [(mjd, flux, fluxerr, passband, zeropoint, photflag, ra, dec, objid, redshift, mwebv)] from keras import backend as K K.clear_session() classification = Classify(known_redshift=True) predictions = classification.get_predictions(light_curve_list) print(predictions) for i, name in enumerate(classification.class_names): locus_data.set_property('rapid_class_probability_{}'.format(name), predictions[0][-1][i])
def classify_lasair_light_curves(object_names=('ZTF18acsovsw', ), plot=True): light_curve_list = [] for object_name in object_names: try: mjd, passband, mag, magerr, photflag, zeropoint, ra, dec, objid, redshift, mwebv = read_lasair_json( object_name) except Exception as e: print(e) continue flux = 10.**(-0.4 * (mag - zeropoint)) fluxerr = np.abs(flux * magerr * (np.log(10.) / 2.5)) passband = np.where((passband == 1) | (passband == '1'), 'g', passband) passband = np.where((passband == 2) | (passband == '2'), 'r', passband) # Set photflag detections when S/N > 5 photflag2 = np.zeros(len(flux)) photflag2[flux / fluxerr > 5] = 4096 photflag2[np.where(mjd == min(mjd[photflag2 == 4096]))] = 6144 mjd_first_detection = min(mjd[photflag == 4096]) photflag[np.where(mjd == mjd_first_detection)] = 6144 deleteindexes = np.where(((passband == 3) | (passband == '3')) | (mjd > mjd_first_detection) & (photflag == 0)) print("Deleting indexes {} at mjd {} and passband {}".format( deleteindexes, mjd[deleteindexes], passband[deleteindexes])) mjd, passband, flux, fluxerr, zeropoint, photflag = delete_indexes( deleteindexes, mjd, passband, flux, fluxerr, zeropoint, photflag) light_curve_list += [(mjd, flux, fluxerr, passband, zeropoint, photflag, ra, dec, objid, redshift, mwebv)] classification = Classify(light_curve_list, known_redshift=True, bcut=False, zcut=None) predictions, time_steps = classification.get_predictions( return_predictions_at_obstime=False) print(predictions) if plot: classification.plot_light_curves_and_classifications( step=True, use_interp_flux=False) # classification.plot_light_curves_and_classifications(step=False, use_interp_flux=True) # classification.plot_light_curves_and_classifications(step=False, use_interp_flux=False) # classification.plot_classification_animation_step() # classification.plot_classification_animation() return classification.orig_lc, classification.timesX, classification.y_predict
def __init__(self, model: str = None): """ Creates a new instance of RAPID Parameters ---------- model: str, optional The filepath of the model to be loaded, if any. """ if model is not None: self.classifier = Classify(known_redshift=True, model_filepath=model) else: self.classifier = Classify(known_redshift=True)
def classify(self, lc): mjd, flux, fluxerr, passband, ra, dec, objid, redshift, mwebv = lc[0] photflag = [4096] * len(flux) photflag[np.argmax(flux)] = 6144 photflag = np.array(photflag) light_curve_info1 = (mjd, flux, fluxerr, passband, photflag, ra, dec, objid, redshift, mwebv) light_curve_list = [ light_curve_info1, ] classification = Classify(known_redshift=True, model_filepath='keras_model.hdf5') predictions = classification.get_predictions(light_curve_list) y_predict = predictions[0][0] time_steps = predictions[1][0] classified_lightcurve = { 'time': datetime.datetime.utcnow(), 'obj_id': objid, 'mjd': mjd, 'flux': flux, 'fluxerr': fluxerr, 'passband': passband, 'ra': ra, 'dec': dec, 'redshift': redshift, 'mwebv': mwebv, 'photflag': photflag.tolist(), 'predicted_y': y_predict.tolist(), 'timesteps': time_steps.tolist(), } print(classified_lightcurve['timesteps']) client = MongoClient('ampelml-mongo') db = client.ampel_ml classified_lightcurves = db.classified_lightcurves class_lc_id = classified_lightcurves.insert_one( classified_lightcurve).inserted_id last_class_lc = classified_lightcurves.find().sort([('time', -1) ]).limit(1)[0] return dumps(last_class_lc)
def main(): mjd = [ 57433.4816, 57436.4815, 57439.4817, 57451.4604, 57454.4397, 57459.3963, 57462.418, 57465.4385, 57468.3768, 57473.3606, 57487.3364, 57490.3341, 57493.3154, 57496.3352, 57505.3144, 57513.2542, 57532.2717, 57536.2531, 57543.2545, 57546.2703, 57551.2115, 57555.2669, 57558.2769, 57561.1899, 57573.2133, 57433.5019, 57436.4609, 57439.4587, 57444.4357, 57459.4189, 57468.3142, 57476.355, 57479.3568, 57487.3586, 57490.3562, 57493.3352, 57496.2949, 57505.3557, 57509.2932, 57513.2934, 57518.2735, 57521.2739, 57536.2321, 57539.2115, 57543.2301, 57551.1701, 57555.2107, 57558.191, 57573.1923, 57576.1749, 57586.1854 ] flux = [ 2.0357230e+00, -2.0382695e+00, 1.0084588e+02, 5.5482742e+01, 1.4867026e+01, -6.5136810e+01, 1.6740545e+01, -5.7269131e+01, 1.0649184e+02, 1.5505235e+02, 3.2445984e+02, 2.8735449e+02, 2.0898877e+02, 2.8958893e+02, 1.9793906e+02, -1.3370536e+01, -3.9001358e+01, 7.4040916e+01, -1.7343750e+00, 2.7844931e+01, 6.0861992e+01, 4.2057487e+01, 7.1565346e+01, -2.6085690e-01, -6.8435440e+01, 17.573107, 41.445435, -110.72664, 111.328964, -63.48336, 352.44907, 199.59058, 429.83075, 338.5255, 409.94604, 389.71262, 195.63905, 267.13318, 123.92461, 200.3431, 106.994514, 142.96387, 56.491238, 55.17521, 97.556946, -29.263103, 142.57687, -20.85057, -0.67210346, 63.353024, -40.02601 ] fluxerr = [ 42.784702, 43.83665, 99.98704, 45.26248, 43.040398, 44.00679, 41.856007, 49.354336, 105.86439, 114.0044, 45.697918, 44.15781, 60.574158, 93.08788, 66.04482, 44.26264, 91.525085, 42.768955, 43.228336, 44.178196, 62.15593, 109.270035, 174.49638, 72.6023, 48.021034, 44.86118, 48.659588, 100.97703, 148.94061, 44.98218, 139.11194, 71.4585, 47.766987, 45.77923, 45.610615, 60.50458, 105.11658, 71.41217, 43.945534, 45.154167, 43.84058, 52.93122, 44.722775, 44.250145, 43.95989, 68.101326, 127.122025, 124.1893, 49.952255, 54.50728, 114.91599 ] passband = [ 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r' ] zeropoint = [ 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5, 27.5 ] photflag = [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4096, 4096, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4096, 6144, 4096, 4096, 4096, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] objid = 'MSIP_01_NONIa-0001_10400862' ra = 3.75464531293933 dec = 0.205076187109334 redshift = 0.233557 mwebv = 0.0228761 light_curve_list = [(mjd, flux, fluxerr, passband, zeropoint, photflag, ra, dec, objid, redshift, mwebv)] classification = Classify(light_curve_list) predictions = classification.get_predictions() print(predictions) classification.plot_light_curves_and_classifications() classification.plot_classification_animation()
def do(self,debug=False): # run this under Admin user = User.objects.get(username='******') # get New, Watch, FollowupRequested, Following transients_to_classify = \ Transient.objects.filter(Q(status__name = 'New') | Q(status__name = 'Watch') | Q(status__name = 'FollowupRequested') | Q(status__name = 'Following')) light_curve_list_z,light_curve_list_noz,transient_list_z,transient_list_noz = [],[],[],[] for t in transients_to_classify: #.filter(Q(name = '2019np') | Q(name = '2019gf')): ra, dec, objid, redshift = t.ra, t.dec, t.name,t.z_or_hostz() if not t.mw_ebv is None: mwebv = t.mw_ebv else: mwebv = 0.0 photdata = get_all_phot_for_transient(user, t.id) if not photdata: continue gobs = photdata.filter(band__name = 'g-ZTF') robs = photdata.filter(band__name = 'r-ZTF') if not len(gobs) and not len(robs): continue mjd, passband, flux, fluxerr, mag, magerr, zeropoint, photflag = \ np.array([]),np.array([]),[],[],[],[],[],[] if redshift: transient_list_z += [t] else: transient_list_noz += [t] first_detection_set = False for obs,filt in zip([gobs.order_by('obs_date'),robs.order_by('obs_date')],['g','r']): for p in obs: if p.data_quality: continue if len(np.where((filt == passband) & (np.abs(mjd - date_to_mjd(p.obs_date)) < 0.001))[0]): continue mag += [p.mag] if p.mag_err: magerr += [p.mag_err] fluxerr_obs = 0.4*np.log(10)*p.mag_err else: magerr += [0.001] fluxerr_obs = 0.4*np.log(10)*0.001 flux_obs = 10**(-0.4*(p.mag-27.5)) mjd = np.append(mjd,[date_to_mjd(p.obs_date)]) flux += [flux_obs] fluxerr += [fluxerr_obs] zeropoint += [27.5] passband = np.append(passband,[filt]) if flux_obs/fluxerr_obs > 5 and not first_detection_set: photflag += [6144] first_detection_set = True elif flux_obs/fluxerr_obs > 5: photflag += [4096] else: photflag += [0] #except: import pdb; pdb.set_trace() try: if redshift: light_curve_info = (mjd, flux, fluxerr, passband, zeropoint, photflag, ra, dec, objid, redshift, mwebv) light_curve_list_z += [light_curve_info,] else: light_curve_info = (mjd, flux, fluxerr, passband, zeropoint, photflag, ra, dec, objid, None, mwebv) light_curve_list_noz += [light_curve_info,] except: import pdb; pdb.set_trace() if len(light_curve_list_noz): classification_noz = Classify(light_curve_list_noz, known_redshift=False, bcut=False, zcut=None) predictions_noz = classification_noz.get_predictions() if len(light_curve_list_z): classification_z = Classify(light_curve_list_z, known_redshift=True, bcut=False, zcut=None) predictions_z = classification_z.get_predictions() if debug: import matplotlib matplotlib.use('MacOSX') import matplotlib.pyplot as plt plt.ion() classification_z.plot_light_curves_and_classifications() for tl in [transient_list_z,transient_list_noz]: for t,i in zip(tl,range(len(tl))): best_predictions = predictions_z[0][i][-1,:] adjusted_best_predictions = np.zeros(10) idx,outclassnames,PIa = 0,[],0 for j in range(len(classification_z.class_names)): if classification_z.class_names[j] == 'Pre-explosion': continue elif classification_z.class_names[j].startswith('SNIa'): PIa += best_predictions[j] else: outclassnames += [classification_z.class_names[j]] adjusted_best_predictions[idx] = best_predictions[j] idx += 1 outclassnames += ['SN Ia'] outclassnames = np.array(outclassnames) adjusted_best_predictions[9] = PIa print(t.name,outclassnames[adjusted_best_predictions == np.max(adjusted_best_predictions)][0]) transient_class = outclassnames[adjusted_best_predictions == np.max(adjusted_best_predictions)][0] photo_class = TransientClass.objects.filter(name = classdict[transient_class]) if len(photo_class): t.photo_class = photo_class[0] t.save() else: print('class %s not in DB'%classdict[transient_class]) raise RuntimeError('class %s not in DB'%classdict[transient_class])
def classify_lasair_light_curves( object_names=('ZTF18acsovsw', ), plot=True, figdir='.'): light_curve_list = [] peakfluxes_g, peakfluxes_r = [], [] mjds, passbands, mags, magerrs, zeropoints, photflags = [], [], [], [], [], [] obj_names = [] ras, decs, objids, redshifts, mwebvs = [], [], [], [], [] peakmags_g, peakmags_r = [], [] for object_name in object_names: try: mjd, passband, mag, magerr, photflag, zeropoint, ra, dec, objid, redshift, mwebv = read_lasair_json( object_name) sortidx = np.argsort(mjd) mjds.append(mjd[sortidx]) passbands.append(passband[sortidx]) mags.append(mag[sortidx]) magerrs.append(magerr[sortidx]) zeropoints.append(zeropoint[sortidx]) photflags.append(photflag[sortidx]) obj_names.append(object_name) ras.append(ra) decs.append(dec) objids.append(objid) redshifts.append(redshift) mwebvs.append(mwebv) peakmags_g.append(min(mag[passband == 1])) peakmags_r.append(min(mag[passband == 2])) except Exception as e: print(e) continue flux = 10.**(-0.4 * (mag - zeropoint)) fluxerr = np.abs(flux * magerr * (np.log(10.) / 2.5)) passband = np.where((passband == 1) | (passband == '1'), 'g', passband) passband = np.where((passband == 2) | (passband == '2'), 'r', passband) # Set photflag detections when S/N > 5 photflag2 = np.zeros(len(flux)) photflag2[flux / fluxerr > 5] = 4096 photflag2[np.where(mjd == min(mjd[photflag2 == 4096]))] = 6144 mjd_first_detection = min(mjd[photflag == 4096]) photflag[np.where(mjd == mjd_first_detection)] = 6144 deleteindexes = np.where(((passband == 3) | (passband == '3')) | (mjd > mjd_first_detection) & (photflag == 0)) if deleteindexes[0].size > 0: print("Deleting indexes {} at mjd {} and passband {}".format( deleteindexes, mjd[deleteindexes], passband[deleteindexes])) mjd, passband, flux, fluxerr, zeropoint, photflag = delete_indexes( deleteindexes, mjd, passband, flux, fluxerr, zeropoint, photflag) light_curve_list += [(mjd, flux, fluxerr, passband, photflag, ra, dec, objid, redshift, mwebv)] try: dummy = max(flux[passband == 'g']) dummy = max(flux[passband == 'r']) except Exception as e: print(e) continue peakfluxes_g.append(max(flux[passband == 'g'])) peakfluxes_r.append(max(flux[passband == 'r'])) # import sys # import pickle # with open('save_real_ZTF_unprocessed_data_snia_osc_12nov2019.npz', 'wb') as f: # pickle.dump([mjds, passbands, mags, magerrs, photflags, zeropoints, ras, decs, objids, redshifts, mwebvs], f) # # np.savez('save_real_ZTF_unprocessed_data_snia_osc_12nov2019.npz', mjds=mjds, passbands=passbands, mags=mags, magerrs=magerrs, photflags=photflags, zeropoints=zeropoints, ras=ras, decs=decs, objids=objids, redshifts=redshifts, mwebvs=mwebvs)# , peakflux_g=peakfluxes_g, peakflux_r=peakfluxes_r) # print("finished") # # # sys.exit(0) # # with open('save_real_ZTF_unprocessed_data_snia_osc_12nov2019.npz', 'rb') as f: # # a = pickle.load(f) classification = Classify(known_redshift=True, bcut=False, zcut=None) predictions, time_steps = classification.get_predictions( light_curve_list, return_predictions_at_obstime=False) print(predictions) if plot: # try: classification.plot_light_curves_and_classifications( step=True, use_interp_flux=False, figdir=figdir, plot_matrix_input=True) # except Exception as e: # print(e) # classification.plot_light_curves_and_classifications(step=False, use_interp_flux=True) # classification.plot_light_curves_and_classifications(step=False, use_interp_flux=False) # classification.plot_classification_animation_step() # classification.plot_classification_animation() return classification.orig_lc, classification.timesX, classification.y_predict
1.9793906e+02, -1.3370536e+01, -3.9001358e+01, 7.4040916e+01, -1.7343750e+00, 2.7844931e+01, 6.0861992e+01, 4.2057487e+01, 7.1565346e+01, -2.6085690e-01, -6.8435440e+01, 17.573107, 41.445435, -110.72664, 111.328964, -63.48336, 352.44907, 199.59058, 429.83075, 338.5255, 409.94604, 389.71262, 195.63905, 267.13318, 123.92461, 200.3431, 106.994514, 142.96387, 56.491238, 55.17521, 97.556946, -29.263103, 142.57687, -20.85057, -0.67210346, 63.353024, -40.02601] fluxerr = [42.784702, 43.83665, 99.98704, 45.26248, 43.040398, 44.00679, 41.856007, 49.354336, 105.86439, 114.0044, 45.697918, 44.15781, 60.574158, 93.08788, 66.04482, 44.26264, 91.525085, 42.768955, 43.228336, 44.178196, 62.15593, 109.270035, 174.49638, 72.6023, 48.021034, 44.86118, 48.659588, 100.97703, 148.94061, 44.98218, 139.11194, 71.4585, 47.766987, 45.77923, 45.610615, 60.50458, 105.11658, 71.41217, 43.945534, 45.154167, 43.84058, 52.93122, 44.722775, 44.250145, 43.95989, 68.101326, 127.122025, 124.1893, 49.952255, 54.50728, 114.91599] passband = ['g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r'] photflag = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4096, 4096, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4096, 6144, 4096, 4096, 4096, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] objid = 'transient_1' ra = 3.75464531293933 dec = 0.205076187109334 redshift = 0.233557 mwebv = 0.0228761 light_curve_list = [(mjd, flux, fluxerr, passband, photflag, ra, dec, objid, redshift, mwebv)] classification = Classify() predictions, time_steps = classification.get_predictions(light_curve_list) print(predictions) classification.plot_light_curves_and_classifications() classification.plot_classification_animation()
class RAPIDModel: """Wrapper for the model outlined by the RAPID paper (Muthukrishna et al 2019)""" def __init__(self, model: str = None): """ Creates a new instance of RAPID Parameters ---------- model: str, optional The filepath of the model to be loaded, if any. """ if model is not None: self.classifier = Classify(known_redshift=True, model_filepath=model) else: self.classifier = Classify(known_redshift=True) def set_metadata(self, metadata: pd.DataFrame): """Sets the loaded metadata Parameters ---------- metadata: pd.DataFrame : The new metadata to be loaded """ assert isinstance(metadata, pd.DataFrame) self._metadata = metadata def set_curves(self, curves: pd.DataFrame): """Sets the loaded transient data Parameters ---------- curves: pd.DataFrame : The new transient data to be loaded """ assert isinstance(curves, pd.DataFrame) self._curves = curves def set_data(self, curves: pd.DataFrame, metadata: pd.DataFrame): """Set both the transient data and the metadata Parameters ---------- curves: pd.DataFrame : The new transient data to be loaded metadata: pd.DataFrame : The new metadata to be loaded """ self.set_curves(curves) self.set_metadata(metadata) def _get_custom_data(self, class_num, data_dir, save_dir, passbands, known_redshift, nprocesses, redo): """Function for traning purposes. Notes ----- See astrorapid for API usage. """ light_list, target_list = p2r.convert(self._curves, self._metadata) # now we need to preprocess return read_multiple_light_curves(light_list) def train(self, curves: pd.DataFrame, metadata: pd.DataFrame, *, class_map: dict = p2r.class_map, band_map: dict = p2r.band_map, save_path: str = "models/", file_name: str = None, load_model: bool = False): """Train the model on the loaded data. Parameters ---------- curves : pd.DataFrame The curve data to train on. metadata : pd.DataFrame The metadata for each curve. class_map : dict, optional The class mapping from PLAsTiCC to the model. band_map : dict, optional The bands and their mapping from PLAsTiCC to models. save_path : str, optional Where the model should be saved. Default "models/" file_name : str, optional Override the filename of the model. Defaults to RAPIDModel_yyyy_mm_dd_hh_mm_ss.hdf5 load_model : bool, optional Whether the model should be loaded into the class after training. Default False. """ # we need to create a new model and replace the classifier with it. # use currying to return a function. def _get_data(class_num, data_dir, save_dir, passbands, known_redshift, nprocesses, redo, calculate_t0): # This is a function RAPID needs to call to get the data. # get the class num tuple class_map = { key: value for (key, value) in p2r.class_map.items() if value is class_num } band_map = { key: value for (key, value) in p2r.band_map.items() if key in passbands } light_list, target_list = p2r.convert(curves, metadata, classes=class_map, bands=band_map) # now we need to preprocess return read_multiple_light_curves(light_list) # create_custom_classifier( get_data_func=_get_data, data_dir='data/', class_nums=tuple(class_map.values()), passbands=tuple(band_map.keys()), save_dir=save_path, ) pass def test(self, curves: pd.DataFrame, metadata: pd.DataFrame, return_probabilities: bool = False) -> (list, list): """Tests the model on the currently loaded data. Parameters ---------- curves : pd.DataFrame The transient data for each object. metadata : pd.DataFrame The metadata for each object. return_probabilities : bool, optional If the predictions should be a probability of arrays as opposed to the most likely class Returns ------- target_list : list of int A list of true classes, ordered. predictions_list : list of int or list of list of int A list of predictions, ordered. Depending on the value of `return_probabilities`, can either be an int """ logger.info("testing model") light_list, target_list = p2r.convert(curves, metadata) predictions, steps = self.classifier.get_predictions(light_list) assert len(target_list) == len(predictions) target_list = np.add(target_list, 1) logger.info("done testing model") if return_probabilities: return target_list, predictions predictions_list = [] for index, pred in enumerate(predictions): pred_class = np.argmax(pred[-1]) # if pred_class is forbidden (0 or 9) don't bother # if pred_class < 0 or pred_class > 8: # continue predictions_list.append(pred_class) return target_list, predictions_list