コード例 #1
0
def get_lightcurve_alerts(username, password, list_names):
    """Query the light curve for a list of candidates"""

    k = Kowalski(username=username, password=password, verbose=False)
    q = {
        "query_type": "find",
        "query": {
            "catalog": "ZTF_alerts",
            "filter": {
                'objectId': {
                    '$in': list(list_names)
                }
            },
            "projection": {
                "objectId": 1,
                "candidate.jd": 1,
                "candidate.ra": 1,
                "candidate.dec": 1,
                "candidate.magpsf": 1,
                "candidate.isdiffpos": 1,
                "candidate.fid": 1,
                "candidate.sigmapsf": 1,
                "candidate.programid": 1,
                "candidate.magzpsci": 1,
                "candidate.magzpsciunc": 1,
                "candidate.sgscore1": 1,
                "candidate.sgscore2": 1,
                "candidate.sgscore3": 1,
                "candidate.distpsnr1": 1,
                "candidate.distpsnr2": 1,
                "candidate.distpsnr3": 1,
                "candidate.field": 1,
                "candidate.rcid": 1,
                "candidate.pid": 1
            }
        },
        "kwargs": {
            "hint": "objectId_1"
        }
    }

    r = k.query(query=q)
    try:
        if r['data'] == []:
            print("No candidates to be checked?")
            return None
    except KeyError:
        #Try the query one more time
        r = k.query(query=q)
        try:
            if r['data'] == []:
                print("No candidates to be checked?")
                return None
        except KeyError:
            return None
    return r['data']
コード例 #2
0
def check_history(list_sources, radius=1.):
    '''Query ZTF and ZUDS alerts and select
    only those sources without a negative detection.
    '''

    # Get the coordinates of the candidates
    sources = query_kowalski_coords(username,
                                    password,
                                    list_sources,
                                    catalog='ZUDS_alerts')
    coords_arr = list((c['ra'], c['dec']) for c in sources)

    k = Kowalski(username=username, password=password, verbose=False)
    q = {
        "query_type": "cone_search",
        "object_coordinates": {
            "radec": f"{coords_arr}",
            "cone_search_radius": f"{radius}",
            "cone_search_unit": "arcsec"
        },
        "catalogs": {
            "ZTF_alerts": {
                "filter": {
                    'candidate.isdiffpos': {
                        '$in': ['f', 0]
                    }
                },
                "projection": {
                    "objectId": 1,
                    "candidate.ra": 1,
                    "candidate.dec": 1
                }
            },
            "ZUDS_alerts": {
                "filter": {
                    'candidate.isdiffpos': {
                        '$in': ['f', 0]
                    }
                },
                "projection": {
                    "objectId": 1,
                    "candidate.ra": 1,
                    "candidate.dec": 1
                }
            }
        }
    }

    r = k.query(query=q)

    with_neg_sub = []
    for s, i in zip(sources, r['result_data']['ZTF_alerts'].keys()):
        if len(r['result_data']['ZTF_alerts'][i]) != 0 or len(
                r['result_data']['ZUDS_alerts'][i]) != 0:
            with_neg_sub.append(s['name'])

    selected_list = list(s['name'] for s in sources
                         if not (s['name'] in with_neg_sub))

    return selected_list
コード例 #3
0
def get_jdstarthist_kowalski(source_names, username, password):
    '''Query kowalski, look for one alert and return a list of candidate.jdstarthist '''
    k = Kowalski(username=username, password=password, verbose=False)
    q = {
        "query_type": "find",
        "query": {
            "catalog": "ZTF_alerts",
            "filter": {
                "objectId": {
                    "$in": list(source_names)
                }
            },
            "projection": {
                "_id": 0,
                "objectId": 1,
                "candidate.jdstarthist": 1
            }
        }
    }
    r = k.query(query=q)
    jdstarthist_tuples = set(
        list((s['objectId'], s['candidate']['jdstarthist'])
             for s in r['result_data']['query_result']))
    #re-output both lists to avoid mixed sorting
    names_out = list(x[0] for x in jdstarthist_tuples)
    jdstarthist_list = list(x[1] for x in jdstarthist_tuples)

    return names_out, jdstarthist_list
コード例 #4
0
def get_lightcurve_alerts_aux(username, password, list_names):
    """Query the light curve for a list of candidates"""

    k = Kowalski(username=username, password=password, verbose=False)
    q = {
        "query_type": "find",
        "query": {
            "catalog": "ZTF_alerts_aux",
            "filter": {
                '_id': {
                    '$in': list(list_names)
                }
            },
            "projection": {}
        },
        "kwargs": {
            "hint": "_id_"
        }
    }
    r = k.query(query=q)
    if r['result_data']['query_result'] == []:
        print("No candidates to be checked?")
        return None
    out = []
    for l in r['result_data']['query_result']:
        with_det = list({
            'objectId': l['_id'],
            'candidate': s
        } for s in l['prv_candidates'] if 'magpsf' in s.keys())
        out = out + with_det

    return out
コード例 #5
0
def query_kowalski_coords(username, password, names):
    '''Query kowalski to get the coordinates of given ZTF sources.'''

    names = list(names)
    k = Kowalski(username=username, password=password, verbose=False)

    q = {"query_type": "find",
         "query": {
                   "catalog": "ZTF_alerts",
                   "filter": {"objectId": {"$in": names}},
                   "projection": {"_id": 0,
                                  "candid": 1,
                                  "objectId": 1,
                                  "candidate.ra": 1,
                                  "candidate.dec": 1
                                  },
                   }
         }
    results_all = k.query(query=q)
    results = results_all.get('data')
    sources = []
    for n in names:
        source = {}
        source["name"] = n
        source["ra"] = list(r["candidate"]["ra"] for r in results if
                            r["objectId"] == n)[0]
        source["dec"] = list(r["candidate"]["dec"] for r in results if
                             r["objectId"] == n)[0]
        source["candid"] = list(r["candid"] for r in results if
                                r["objectId"] == n)[0]
        sources.append(source)

    return sources
コード例 #6
0
def match_kowalski_clu(username, password, list_in, catalog='ZUDS_alerts_aux'):
    '''Query kowalski and apply the CLU filter'''

    k = Kowalski(username=username, password=password, verbose=False)
    q = {
        "query_type": "find",
        "query": {
            "catalog": catalog,
            "filter": {
                "_id": {
                    '$in': list(list_in)
                }
            },
            "projection": {
                "objectId": 1,
                "cross_matches.CLU_20190625": 1
            }
        }
    }

    r = k.query(query=q)

    query_results = r['result_data']['query_result']
    list_out = list(t['_id'] for t in query_results
                    if len(t['cross_matches']['CLU_20190625']) != 0)

    return list_out
コード例 #7
0
def check_lightcurve_alerts(username, password, list_names, min_days,
                            max_days):
    """Re-query light curve info for a list of candidates\
    and check that their full/updated duration is consistent\
    with the time limits provided"""

    k = Kowalski(username=username, password=password, verbose=False)
    q = {
        "query_type": "find",
        "query": {
            "catalog": "ZTF_alerts",
            "filter": {
                'objectId': {
                    '$in': list(list_names)
                }
            },
            "projection": {
                "objectId": 1,
                "candidate.jd": 1,
                "candidate.ndethist": 1,
                "candidate.jdstarthist": 1,
                "candidate.jdendhist": 1,
                "candidate.jdendhist": 1,
                "candidate.magpsf": 1,
                "candidate.sigmapsf": 1,
                "candidate.programid": 1,
            }
        },
        "kwargs": {
            "hint": "objectId_1"
        }
    }

    r = k.query(query=q)
    if r['data'] == []:
        print("No candidates to be checked?")
        return None

    old = []
    objectid_list = []
    for info in r['data']:
        if info['objectId'] in old:
            continue
        if (info['candidate']['jdendhist'] -
                info['candidate']['jdstarthist']) < min_days:
            continue
        if (info['candidate']['jdendhist'] -
                info['candidate']['jdstarthist']) > max_days:
            old.append(info['objectId'])
        objectid_list.append(info['objectId'])
    clean_set = set(objectid_list)
    #Remove those objects considered old
    for n in set(old):
        try:
            clean_set.remove(n)
        except:
            pass

    return clean_set
コード例 #8
0
def query_kowalski_clu(username, password, clu):
    '''Query kowalski to get a table of CLU galaxies.'''

    k = Kowalski(username=username, password=password, verbose=False)
    q = {"query_type": "general_search", 
        "query": "db['CLU_20180513'].find({},{'distmpc': 1})" 
        }
    r = k.query(query=q)

    return r
コード例 #9
0
def xmatch(_radecs, batch_size: int = 100, verbose: int = 0):
    k = Kowalski(username='******', password='******', verbose=False)

    num_obj = len(_radecs)

    if verbose:
        print(f'Total entries: {num_obj}')

    num_batches = int(np.ceil(num_obj / batch_size))

    times = []

    # ids = set()

    for nb in range(num_batches):
        # print(_radecs[nb * batch_size: (nb + 1) * batch_size])
        q = {
            "query_type": "cone_search",
            "object_coordinates": {
                "radec": f"{_radecs[nb * batch_size: (nb + 1) * batch_size]}",
                "cone_search_radius": "1",
                "cone_search_unit": "arcsec"
            },
            "catalogs": {
                "Gaia_DR2": {
                    "filter": {},
                    "projection": {
                        "_id": 1
                    }
                }
            }
        }

        tic = time()
        r = k.query(query=q)
        toc = time()
        times.append(toc - tic)
        if verbose:
            print(
                f'Fetching batch {nb + 1}/{num_batches} with {batch_size} sources/LCs took: {toc - tic:.3f} seconds'
            )

        # Data are here:
        data = r['result_data']
        # TODO: your magic here
        if verbose == 2:
            print(data)
        # for sc, sources in data['Gaia_DR2'].items():
        #     ids = ids.union([s['_id'] for s in sources])
        # print(len(ids))

    if verbose:
        print(f'min: {np.min(times)}')
        print(f'median: {np.median(times)}')
        print(f'max: {np.max(times)}')
コード例 #10
0
def query_db(coords, r=3.):
    """ given a set of coordinates, get the matchIDs from the database
        Note that everything that is in the aperture is returned.
    
    Parameters
    ----------

    coords : 2d-array
        2d-array with Ra,dec in degrees
    r : float
        matching radius
    coords : str
        catalog name

    Returns
    -------
    matchIDs : 1d-array
        an array with matchIDs (int64)
    

    """

    k = Kowalski(username='******',
                 password='******',
                 verbose=False)

    # cone search radius must be in radians:
    cone_search_radius = r * np.pi / 180.0 / 3600.

    # construct query: RA and Dec's must be in degrees; RA must c (-180, 180]
    query = {'$or': []}
    for ra, dec in coords:
        # convert
        #_ra, _dec = radec_str2geojson(*obj_crd)
        obj_query = {
            'coordinates.radec_geojson': {
                '$geoWithin': {
                    '$centerSphere': [[ra - 180., dec], cone_search_radius]
                }
            }
        }
        query['$or'].append(obj_query)

    q = {
        "query_type": "general_search",
        "query": "db['ZTF_20181220'].find(%s)" % (query)
    }

    # execute query
    output = k.query(query=q)

    # get matchids
    matchIDs = [str(l['_id']) for l in output['result_data']['query_result']]

    return (matchIDs)
コード例 #11
0
def query_kowalski_coords(username, password, names):
    '''Query kowalski to get the coordinates of given ZTF sources. '''

    k = Kowalski(username=username, password=password, verbose=False)
    q = {
        "query_type": "general_search",
        "query": "db['CLU_20180513'].find({},{'distmpc': 1})"
    }
    r = k.query(query=q)

    q = {
        "query_type": "find",
        "query": {
            "catalog": "ZTF_alerts",
            "filter": {
                "objectId": {
                    "$in": names
                }
            },
            "projection": {
                "_id": 0,
                "candid": 1,
                "objectId": 1,
                "candidate.ra": 1,
                "candidate.dec": 1
            },
        }
    }
    results_all = k.query(query=q)
    results = results_all['result_data']['query_result']
    sources = []
    for n in names:
        source = {}
        source["name"] = n
        source["ra"] = list(r["candidate"]["ra"] for r in results
                            if r["objectId"] == n)[0]
        source["dec"] = list(r["candidate"]["dec"] for r in results
                             if r["objectId"] == n)[0]
        sources.append(source)

    return sources
コード例 #12
0
def get_cutouts(name, username, password):
    """Query kowalski to get the candidate stamps"""
    from penquins import Kowalski

    k = Kowalski(username=username, password=password, verbose=False)

    if type(name) == str:
        list_names = [name]
    elif type(name) == list:
        list_names = name
    else:
        print(f"{name} must be a list or a string")
        return None

    q = {
        "query_type": "find",
        "query": {
            "catalog": "ZTF_alerts",
            "filter": {
                'objectId': {
                    '$in': list(list_names)
                }
            },
            "projection": {
                "objectId": 1,
                "candidate.jd": 1,
                "candidate.ra": 1,
                "candidate.dec": 1,
                "candidate.magpsf": 1,
                "candidate.fid": 1,
                "candidate.sigmapsf": 1,
                "candidate.programid": 1,
                "candidate.field": 1,
                "candidate.rcid": 1,
                "cutoutScience": 1,
                "cutoutTemplate": 1,
                "cutoutDifference": 1,
            }
        },
        "kwargs": {
            "hint": "objectId_1"
        }
    }

    r = k.query(query=q)

    if r['data'] == []:
        print("No candidates to be checked?")
        return None
    else:
        alerts = r['data']

    return alerts
コード例 #13
0
def get_index_info(catalog):
    """List which indexes are available on Kowalski to query a catalog
       more quickly"""
    q = {
        "query_type": "info",
        "query": {
            "command": "index_info",
            "catalog": catalog
        }
    }
    k = Kowalski(username=username, password=password, verbose=False)
    r = k.query(query=q)
    indexes = r['result_data']['query_result']
    for ii, (kk, vv) in enumerate(indexes.items()):
        print(f'index #{ii+1}: "{kk}"\n{vv["key"]}\n')
コード例 #14
0
def get_ztf(filename, name, username, password):

    k = Kowalski(username=username, password=password, verbose=True)

    q = {"query_type": "general_search",
     "query": "db['ZTF_alerts'].find({'objectId': {'$eq': '"+name+"'}})"
     }
    r = k.query(query=q,timeout=10)
    if len(r['result_data']['query_result']) >0:
        candidate = r['result_data']['query_result'][0]
        prevcandidates= r['result_data']['query_result'][0]['prv_candidates']

        jd = [candidate['candidate']['jd']]
        mag = [candidate['candidate']['magpsf']]
        magerr = [candidate['candidate']['sigmapsf']]
        filt = [candidate['candidate']['fid']]

        for candidate in prevcandidates:
            jd.append(candidate['jd'])
            if not candidate['magpsf'] == None:
                mag.append(candidate['magpsf'])
            else:
                mag.append(candidate['diffmaglim'])
            if not candidate['sigmapsf'] == None:
                magerr.append(candidate['sigmapsf'])
            else:
                magerr.append(np.inf)

            filt.append(candidate['fid'])
        filtname = []
        for f in filt:
            if f == 1:
                filtname.append('g')
            elif f == 2:
                filtname.append('r')
            elif f == 3:
                filtname.append('i')
    idx = np.argsort(jd)

    fid = open(filename,'w')
    for ii in idx:
        t = Time(jd[ii], format='jd').isot
        fid.write('%s %s %.5f %.5f\n'%(t,filtname[ii],mag[ii],magerr[ii]))
    fid.close()
コード例 #15
0
def get_lightcurve_alerts_aux(username, password, list_names):
    """Query the light curve for a list of candidates"""

    k = Kowalski(username=username, password=password, verbose=False)
    q = {
        "query_type": "find",
        "query": {
            "catalog": "ZTF_alerts_aux",
            "filter": {
                '_id': {
                    '$in': list(list_names)
                }
            },
            "projection": {}
        },
        "kwargs": {
            "hint": "_id_"
        }
    }
    '''
    q = {"query_type": "info",
         "query": {
             "command": "index_info",
             "catalog": "ZTF_alerts"
         }
         }
    # ZTF19acdpipl
    '''
    r = k.query(query=q)
    #indexes = r['result_data']['query_result']
    #for ii, (kk, vv) in enumerate(indexes.items()):
    #    print(f'index #{ii+1}: "{kk}"\n{vv["key"]}\n')
    if r['result_data']['query_result'] == []:
        print("No candidates to be checked?")
        return None
    out = []
    for l in r['result_data']['query_result']:
        with_det = list({
            'objectId': l['_id'],
            'candidate': s
        } for s in l['prv_candidates'] if 'magpsf' in s.keys())
        out = out + with_det

    return out
コード例 #16
0
    ra_header_.append(hdu[0].header['CRVAL1'])
    dec_header_.append(hdu[0].header['CRVAL2'])
    #filt_.append(hdu[0].header['filter'])
    filt_.append(opts.filter)
    date = Time(hdu[0].header['DATE-OBS'], format='isot', scale='utc')
    date_.append(date)
    date_jd_.append(date.jd)
    name_.append(opts.transient)

    if opts.doKowalski:
        q = {
            "query_type": "general_search",
            "query":
            "db['ZTF_alerts'].find({'objectId': {'$eq': '" + ID + "'}})"
        }
        r = ko.query(query=q, timeout=30)

        if len(r['result_data']['query_result']) > 0:
            # getting metadata
            candidate = r['result_data']['query_result'][0]
            ra, dec = candidate['candidate']['ra'], candidate['candidate'][
                'dec']
            ra_all_tran.append(ra)
            dec_all_tran.append(dec)
#           print("query worked")
#           print('solve-field '+files[i][:-5]+'_red.fits --ra '+str(ra)+' --dec '+str(dec)+' --dir /home/roboao/Tomas/output --scale-units arcsecperpix --scale-low 0.255 --scale-high 0.26 --radius 0.04 --overwrite')
        else:
            ra_all_tran.append(np.nan)
            dec_all_tran.append(np.nan)
    else:
        idx = np.where(opts.transient == transients["name"])[0]
コード例 #17
0
def query_kowalski_frb(args, t):
    """Query kowalski with cone searches centered at given locations"""

    # Prepare a dictionary for each source
    dict_sources = {}
    for s in t:
        if args.frb_names is not None and not (s['frb_name']
                                               in args.frb_names):
            continue
        try:
            coords = SkyCoord(ra=s["rop_raj"],
                              dec=s["rop_decj"],
                              unit=(u.hourangle, u.deg),
                              frame='icrs')
        except ValueError:
            pdb.set_trace()
        id_ra = f"{str(coords.ra.deg).replace('.','_')}"
        id_dec = f"{str(coords.dec.deg).replace('.','_')}"
        id_coords = f"({id_ra}, {id_dec})"
        date = Time(s['utc'].replace('/', '-'), format='iso')
        dict_sources[s['frb_name']] = {
            'ra': coords.ra.deg,
            'dec': coords.dec.deg,
            'id_coords': id_coords,
            'jd': date.jd,
            'candidates': []
        }
    # Check that there is at least one source
    if len(dict_sources.keys()) == 0:
        print("No FRBs correspond to the given input.")
        if args.frb_names is not None:
            print(
                f"No FRB among {args.frb_names} are present in {args.cat_name}"
            )
        return None

    # coords_arr.append((coords.ra.deg,coords.dec.deg))
    coords_arr = list((dict_sources[k]['ra'], dict_sources[k]['dec'])
                      for k in dict_sources.keys())

    k = Kowalski(username=username, password=password, verbose=False)

    q = {
        "query_type": "cone_search",
        "object_coordinates": {
            "radec": f"{coords_arr}",
            "cone_search_radius": args.search_radius,
            "cone_search_unit": "arcmin"
        },
        "catalogs": {
            "ZTF_alerts": {
                "filter": {
                    "candidate.drb": {
                        '$gt': 0.5
                    },
                    "candidate.ndethist": {
                        '$gte': args.ndethist
                    },
                    "classifications.braai": {
                        '$gt': 0.5
                    },
                    "candidate.ssdistnr": {
                        '$gt': 10
                    },
                    "candidate.magpsf": {
                        '$gt': 10
                    }
                },
                "projection": {
                    "objectId": 1,
                    "candidate.rcid": 1,
                    "candidate.drb": 1,
                    "candidate.ra": 1,
                    "candidate.dec": 1,
                    "candidate.jd": 1,
                    "candidate.magpsf": 1,
                    "candidate.sigmapsf": 1,
                    "candidate.fid": 1,
                    "candidate.sgscore1": 1,
                    "candidate.distpsnr1": 1,
                    "candidate.sgscore2": 1,
                    "candidate.distpsnr2": 1,
                    "candidate.sgscore3": 1,
                    "candidate.distpsnr3": 1,
                    "candidate.ssdistnr": 1,
                    "candidate.isdiffpos": 1
                }
            }
        },
        "kwargs": {
            "hint": "gw01"
        }
    }

    r = k.query(query=q)

    for idcoords in r['result_data']['ZTF_alerts'].keys():
        #Identify 'candid' for all relevant candidates
        objectId_list = []
        with_neg_sub = []
        stellar_list = []

        # No sources
        if len(r['result_data']['ZTF_alerts'][idcoords]) == 0:
            key = list(k for k in dict_sources.keys()
                       if dict_sources[k]['id_coords'] == idcoords)[0]
            dict_sources[key]['candidates'] = []
            print(f"No candidates for {key}")
            continue

        for i in np.arange(len(r['result_data']['ZTF_alerts'][idcoords])):
            info = r['result_data']['ZTF_alerts'][idcoords][i]
            if info['objectId'] in stellar_list or (info['objectId']
                                                    in with_neg_sub):
                continue
            if info['candidate']['isdiffpos'] in ['f', 0]:
                with_neg_sub.append(info['objectId'])
            try:
                if (np.abs(info['candidate']['distpsnr1']) < 2.
                        and info['candidate']['sgscore1'] >= 0.5):
                    stellar_list.append(info['objectId'])
            except:
                pass
            try:
                if (np.abs(info['candidate']['distpsnr1']) < 15.
                        and info['candidate']['srmag1'] < 15.
                        and info['candidate']['srmag1'] > 0.
                        and info['candidate']['sgscore1'] >= 0.5):
                    continue
            except:
                pass
            try:
                if (np.abs(info['candidate']['distpsnr2']) < 15.
                        and info['candidate']['srmag2'] < 15.
                        and info['candidate']['srmag2'] > 0.
                        and info['candidate']['sgscore2'] >= 0.5):
                    continue
            except:
                pass
            try:
                if (np.abs(info['candidate']['distpsnr3']) < 15.
                        and info['candidate']['srmag3'] < 15.
                        and info['candidate']['srmag3'] > 0.
                        and info['candidate']['sgscore3'] >= 0.5):
                    continue
            except:
                pass
            objectId_list.append(info['objectId'])
        set_objectId = set(objectId_list)

        # Remove objects with negative subtraction
        if args.reject_neg:
            for n in set(with_neg_sub):
                try:
                    set_objectId.remove(n)
                except:
                    pass

        # Remove stellar objects
        for n in set(stellar_list):
            try:
                set_objectId.remove(n)
            except:
                pass

        # Add the list of ZTF candidates to the FRB list
        key = list(k for k in dict_sources.keys()
                   if dict_sources[k]['id_coords'] == idcoords)[0]
        dict_sources[key]['candidates'] = list(set(set_objectId))
        tot_sources = len(r['result_data']['ZTF_alerts'][idcoords])
        print(
            f"{len(set_objectId)}/{tot_sources} candidates selected for {key}")

    return dict_sources
コード例 #18
0
ファイル: fetch_lc.py プロジェクト: renlliang3/kowalski
    # num_batches = 100

    times = []

    for nb in range(num_batches):
        qu = {
            "query_type":
            "general_search",
            "query":
            "db['ZTF_sources_20190412'].find({}, " +
            "{'_id': 1, 'data.programid': 1, 'data.hjd': 1, " +
            f"'data.mag': 1, 'data.magerr': 1}}).skip({nb*batch_size}).limit({batch_size})"
        }

        # print(qu)
        tic = time()
        r = k.query(query=qu)
        toc = time()
        times.append(toc - tic)
        print(
            f'Fetching batch {nb+1}/{num_batches} with {batch_size} sources/LCs took: {toc-tic:.3f} seconds'
        )

        # Light curves are here:
        # print(r['result_data']['query_result'])
        # Must filter out data.programid == 1 data

    print(f'min: {np.min(times)}')
    print(f'median: {np.median(times)}')
    print(f'max: {np.max(times)}')
コード例 #19
0
def query_kowalski(username, password, list_fields, min_days, max_days,
                   ndethist_min, jd, jd_gap=50.):
    '''Query kowalski and apply the selection criteria'''

    k = Kowalski(username=username, password=password, verbose=False)

    # Correct the minimum number of detections
    ndethist_min_corrected = int(ndethist_min - 1)

    jd_start = jd
    jd_end = jd + jd_gap

    #Initialize a set for the results
    set_objectId_all = set([])
    for field in list_fields:
        set_objectId_field = set([])
        q = {"query_type": "find",
             "query": {
                       "catalog": "ZTF_alerts",      
                       "filter": {
                                  'candidate.jd': {'$gt': jd_start, '$lt': jd_end},
                                  'candidate.field': int(field),
                                  'candidate.drb': {'$gt': 0.9},
                                  'classifications.braai': {'$gt': 0.8},
                                  'candidate.ndethist': {'$gt': ndethist_min_corrected},
                                  'candidate.magpsf': {'$gt': 12}
                                  #'candidate.isdiffpos': 't'
                                   },
                       "projection": {
                                      "objectId": 1,
                                      "candidate.rcid": 1,
                                      "candidate.ra": 1,
                                      "candidate.dec": 1,
                                      "candidate.jd": 1,
                                      "candidate.ndethist": 1,
                                      "candidate.jdstarthist": 1,
                                      "candidate.jdendhist": 1,
                                      "candidate.jdendhist": 1,
                                      "candidate.magpsf": 1,
                                      "candidate.sigmapsf": 1,
                                      "candidate.fid": 1,
                                      "candidate.programid": 1,
                                      "candidate.isdiffpos": 1,
                                      "candidate.ndethist": 1,
                                      "candidate.ssdistnr": 1,
                                      "candidate.rb": 1,
                                      "candidate.drb": 1,
                                      "candidate.distpsnr1": 1,   
                                      "candidate.sgscore1": 1,
                                      "candidate.srmag1": 1,
                                      "candidate.distpsnr2": 1,   
                                      "candidate.sgscore2": 1,
                                      "candidate.srmag2": 1,
                                      "candidate.distpsnr3": 1,   
                                      "candidate.sgscore3": 1,
                                      "candidate.srmag3": 1
                                       }
                       },
            "kwargs": {"hint": "jd_field_rb_drb_braai_ndethhist_magpsf_isdiffpos"}
             }

        #Perform the query
        r = k.query(query=q)
        print(f"Search completed for field {field}, \
{Time(jd, format='jd').iso} + {jd_gap:.1f} days.")


        #Identify 'candid' for all relevant candidates
        objectId_list = []
        with_neg_sub = []
        old = []
        out_of_time_window = []
        stellar_list = []

        try:
            if r['result_data']['query_result'] == []:
                print("No candidates")
                continue
        except KeyError:
            print(f"ERROR! jd={jd}, field={field}" ) 
            #pdb.set_trace()
            continue

        for info in r['result_data']['query_result']:    
            #if info['objectId'] == 'ZTF19abyfbii':
 	    #    pdb.set_trace()
            if info['objectId'] in old:
                continue
            if info['objectId'] in stellar_list:
                continue
            if np.abs(info['candidate']['ssdistnr']) < 10:
                continue
            try:
                if info['candidate']['drb'] < 0.5:
                    continue
            except KeyError:
                pass
            if info['candidate']['isdiffpos'] in ['f',0]:
                with_neg_sub.append(info['objectId'])
            if (info['candidate']['jdendhist'] - info['candidate']['jdstarthist']) < min_days:
                continue
            if (info['candidate']['jdendhist'] - info['candidate']['jdstarthist']) > max_days:
                old.append(info['objectId'])
            try:
                if (np.abs(info['candidate']['distpsnr1']) < 2. and info['candidate']['sgscore1'] >= 0.5):
                    stellar_list.append(info['objectId'])
            except:
                pass
            try:
                if (np.abs(info['candidate']['distpsnr1']) < 15. and
                           info['candidate']['srmag1'] < 15. and
                           info['candidate']['srmag1'] > 0. and
                           info['candidate']['sgscore1'] >= 0.5):
                    continue
            except:
                pass
            try:
                if (np.abs(info['candidate']['distpsnr2']) < 15. and
                           info['candidate']['srmag2'] < 15. and
                           info['candidate']['srmag2'] > 0. and
                           info['candidate']['sgscore2'] >= 0.5):
                    continue
            except:
                pass
            try:
                if (np.abs(info['candidate']['distpsnr3']) < 15. and
                           info['candidate']['srmag3'] < 15. and
                           info['candidate']['srmag3'] > 0. and
                           info['candidate']['sgscore3'] >= 0.5):
                    continue
            except:
                pass

            objectId_list.append(info['objectId'])

        set_objectId = set(objectId_list)

        #Remove those objects with negative subtraction
        for n in set(with_neg_sub):
            try:
                set_objectId.remove(n)
            except:
                do = 'do nothing'

        #Remove stellar objects
        for n in set(stellar_list):
            try:
                set_objectId.remove(n)
            except:
                do = 'do nothing'

        #Remove those objects considered old
        for n in set(old):
            try:
                set_objectId.remove(n)
            except:
                do = 'do nothing'

        #Remove those objects whole alerts go bejond jd_trigger+max_days
        for n in set(out_of_time_window):
            try:
                set_objectId.remove(n)
            except:
                do = 'do nothing'
        #print(set_objectId)
        set_objectId_all = set_objectId_all | set_objectId
        #print("Cumulative:", set_objectId_all)

        print("Field", field, len(set_objectId_all))

    return set_objectId_all
コード例 #20
0
ファイル: run_tails.py プロジェクト: dmitryduev/tails
def run(
    cleanup: str = "none",
    checkpoint: str = "../models/tails-20210107/tails",
    config: str = "../config.yaml",
    date: Optional[str] = None,
    nthreads: int = N_CPU,
    output_base_path: str = "./",
    score_threshold: float = 0.6,
    twilight: bool = False,
    single_image: Optional[str] = None,
):
    """🚀 Run Tails on ZTF data
    :param cleanup: Delete raw data: ref|sci|all|none
    :param checkpoint: Pre-trained model weights
    :param config: Path to yaml file with configs and secrets
    :param date: UTC date string YYYYMMDD
    :param nthreads: Number of threads for image re-projecting
    :param output_base_path: Base path for output
    :param score_threshold: score threshold for declaring a candidate plausible (0 <= score_threshold <= 1)
    :param single_image: Run on single ccd-quad image, feed id in format ztf_20200810193681_000635_zr_c09_o_q2
    :param twilight: Run on the Twilight survey data only

    :return:
    """
    p_base = pathlib.Path(output_base_path)

    config = load_config(config)

    # build model and load weights
    model = Tails()
    model.load_weights(checkpoint).expect_partial()

    if not (0 <= score_threshold <= 1):
        raise ValueError("score_threshold must be (0 <= score_threshold <=1)")

    if not (1 <= nthreads <= N_CPU):
        raise ValueError(f"nthreads must be (1 <= nthreads <={N_CPU})")

    cleanup = cleanup.lower()
    if cleanup not in ("all", "none", "ref", "sci"):
        raise ValueError("cleanup value not in ('all', 'none', 'ref', 'sci')")

    if single_image:
        datestr = single_image[4:12]
        date = datetime.datetime.strptime(datestr, "%Y%m%d")
        print(date)

        p_date = p_base / "runs" / datestr
        if not p_date.exists():
            p_date.mkdir(parents=True, exist_ok=True)

        names = [single_image]

    else:
        if date:
            datestr = date
        else:
            datestr = datetime.datetime.utcnow().strftime("%Y%m%d")

        date = datetime.datetime.strptime(datestr, "%Y%m%d")
        print(date)

        p_date = p_base / "runs" / datestr
        if not p_date.exists():
            p_date.mkdir(parents=True, exist_ok=True)

        # setup
        kowalski = Kowalski(
            username=config["kowalski"]["username"],
            password=config["kowalski"]["password"],
        )

        q = {
            "query_type": "find",
            "query": {
                "catalog": "ZTF_ops",
                "filter": {
                    "jd_start": {"$gt": Time(date).jd, "$lt": Time(date).jd + 1}
                },
                "projection": {"_id": 0, "fileroot": 1},
            },
        }

        if twilight:
            q["query"]["filter"]["qcomment"] = {"$regex": "Twilight"}

        r = kowalski.query(q).get("data", dict())
        fileroots = sorted([e["fileroot"] for e in r])

        names = [
            f"{fileroot}_c{ccd:02d}_o_q{quad:1d}"
            for fileroot in fileroots
            for ccd in range(1, 17)
            for quad in range(1, 5)
        ]

    # fetch data first
    nsp = [(name, config, p_base) for name in names]
    with mp.Pool(processes=N_CPU) as pool:
        list(tqdm(pool.imap(fetch_data, nsp), total=len(nsp)))

    for name in tqdm(names):
        process_ccd_quad(
            name=name,
            p_date=p_date,
            checkpoint=checkpoint,
            model=model,
            config=config,
            nthreads=nthreads,
            score_threshold=score_threshold,
            cleanup=cleanup,
        )
コード例 #21
0
    # fields = k.query(q_get_fields)['result_data']['query_result']

    fields = [245, 246, 744]
    rcids = list(range(0, 64))
    # print(rcids)

    # field, rcid = 245, 7
    for field in fields:
        for rcid in rcids:
            q_count = {
                "query_type":
                "general_search",
                "query":
                f"db['ZTF_alerts'].count_documents({{'candidate.field': {field}, 'candidate.rcid': {rcid}}})"
            }
            num_alerts = k.query(q_count)['result_data']['query_result']
            print(f'field: {field}, rcid: {rcid}, num_alerts: {num_alerts}')

            q = {
                "query_type":
                "general_search",
                "query":
                f"db['ZTF_alerts'].find({{'candidate.field': {field}, 'candidate.rcid': {rcid}}}, "
                "{'candidate.jd': 1, 'candidate.magpsf': 1, 'candidate.sigmapsf': 1, '_id': 0})"
            }
            # print(q)
            data = k.query(q)['result_data']['query_result']
            data = [d['candidate'] for d in data]
            df = pd.DataFrame(data)
            print(df)
コード例 #22
0
class Scope:
    def __init__(self):
        # check configuration
        with status("Checking configuration"):
            check_configs(config_wildcards=["config.*yaml"])

            self.config = load_config(
                pathlib.Path(__file__).parent.absolute() / "config.yaml")

            # use token specified as env var (if exists)
            kowalski_token_env = os.environ.get("KOWALSKI_TOKEN")
            if kowalski_token_env is not None:
                self.config["kowalski"]["token"] = kowalski_token_env

        # try setting up K connection if token is available
        if self.config["kowalski"]["token"] is not None:
            with status("Setting up Kowalski connection"):
                self.kowalski = Kowalski(
                    token=self.config["kowalski"]["token"],
                    protocol=self.config["kowalski"]["protocol"],
                    host=self.config["kowalski"]["host"],
                    port=self.config["kowalski"]["port"],
                )
        else:
            self.kowalski = None
            # raise ConnectionError("Could not connect to Kowalski.")
            print("Kowalski not available")

    def _get_features(
        self,
        positions: Sequence[Sequence[float]],
        catalog: str = "ZTF_source_features_20210401",
        max_distance: Union[float, int] = 5.0,
        distance_units: str = "arcsec",
    ) -> pd.DataFrame:
        """Get nearest source in feature set for a set of given positions

        :param positions: R.A./Decl. [deg]
        :param catalog: feature catalog to query
        :param max_distance:
        :param distance_units: arcsec | arcmin | deg | rad
        :return:
        """
        if self.kowalski is None:
            raise ConnectionError("Kowalski connection not established.")
        if catalog is None:
            catalog = self.config["kowalski"]["collections"]["features"]
        query = {
            "query_type": "near",
            "query": {
                "max_distance": max_distance,
                "distance_units": distance_units,
                "radec": positions,
                "catalogs": {
                    catalog: {
                        "filter": {},
                        "projection": {
                            "period": 1,
                            "ra": 1,
                            "dec": 1,
                        },
                    }
                },
            },
        }
        response = self.kowalski.query(query=query)
        features_nearest = [
            v[0] for k, v in response.get("data").get(catalog).items()
            if len(v) > 0
        ]
        df = pd.DataFrame.from_records(features_nearest)

        return df

    def _get_nearest_gaia(
        self,
        positions: Sequence[Sequence[float]],
        catalog: str = None,
        max_distance: Union[float, int] = 5.0,
        distance_units: str = "arcsec",
    ) -> pd.DataFrame:
        """Get nearest Gaia source for a set of given positions

        :param positions: R.A./Decl. [deg]
        :param catalog: Gaia catalog to query
        :param max_distance:
        :param distance_units: arcsec | arcmin | deg | rad
        :return:
        """
        if self.kowalski is None:
            raise ConnectionError("Kowalski connection not established.")
        if catalog is None:
            catalog = self.config["kowalski"]["collections"]["gaia"]
        query = {
            "query_type": "near",
            "query": {
                "max_distance": max_distance,
                "distance_units": distance_units,
                "radec": positions,
                "catalogs": {
                    catalog: {
                        "filter": {},
                        "projection": {
                            "parallax": 1,
                            "parallax_error": 1,
                            "pmra": 1,
                            "pmra_error": 1,
                            "pmdec": 1,
                            "pmdec_error": 1,
                            "phot_g_mean_mag": 1,
                            "phot_bp_mean_mag": 1,
                            "phot_rp_mean_mag": 1,
                            "ra": 1,
                            "dec": 1,
                        },
                    }
                },
            },
            "kwargs": {
                "limit": 1
            },
        }
        response = self.kowalski.query(query=query)
        gaia_nearest = [
            v[0] for k, v in response.get("data").get(catalog).items()
            if len(v) > 0
        ]
        df = pd.DataFrame.from_records(gaia_nearest)

        df["M"] = df["phot_g_mean_mag"] + 5 * np.log10(
            df["parallax"] * 0.001) + 5
        df["Ml"] = (df["phot_g_mean_mag"] + 5 * np.log10(
            (df["parallax"] + df["parallax_error"]) * 0.001) + 5)
        df["BP-RP"] = df["phot_bp_mean_mag"] - df["phot_rp_mean_mag"]

        return df

    def _get_light_curve_data(
        self,
        ra: float,
        dec: float,
        catalog: str = "ZTF_sources_20201201",
        cone_search_radius: Union[float, int] = 2,
        cone_search_unit: str = "arcsec",
        filter_flagged_data: bool = True,
    ) -> pd.DataFrame:
        """Get light curve data from Kowalski

        :param ra: R.A. in deg
        :param dec: Decl. in deg
        :param catalog: collection name on Kowalski
        :param cone_search_radius:
        :param cone_search_unit: arcsec | arcmin | deg | rad
        :param filter_flagged_data: remove flagged/bad data?
        :return: flattened light curve data as pd.DataFrame
        """
        if self.kowalski is None:
            raise ConnectionError("Kowalski connection not established.")
        query = {
            "query_type": "cone_search",
            "query": {
                "object_coordinates": {
                    "cone_search_radius": cone_search_radius,
                    "cone_search_unit": cone_search_unit,
                    "radec": {
                        "target": [ra, dec]
                    },
                },
                "catalogs": {
                    catalog: {
                        "filter": {},
                        "projection": {
                            "_id": 1,
                            "filter": 1,
                            "field": 1,
                            "data.hjd": 1,
                            "data.fid": 1,
                            "data.mag": 1,
                            "data.magerr": 1,
                            "data.ra": 1,
                            "data.dec": 1,
                            "data.programid": 1,
                            "data.catflags": 1,
                        },
                    }
                },
            },
        }
        response = self.kowalski.query(query=query)
        light_curves_raw = response.get("data").get(catalog).get("target")

        light_curves = []
        for light_curve in light_curves_raw:
            df = pd.DataFrame.from_records(light_curve["data"])
            # broadcast to all data points:
            df["_id"] = light_curve["_id"]
            df["filter"] = light_curve["filter"]
            df["field"] = light_curve["field"]
            light_curves.append(df)

        df = pd.concat(light_curves, ignore_index=True)

        if filter_flagged_data:
            mask_flagged_data = df["catflags"] != 0
            df = df.loc[~mask_flagged_data]

        return df

    @staticmethod
    def develop():
        """Install developer tools"""
        subprocess.run(["pre-commit", "install"])

    @classmethod
    def lint(cls):
        """Lint sources"""
        try:
            import pre_commit  # noqa: F401
        except ImportError:
            cls.develop()

        try:
            subprocess.run(["pre-commit", "run", "--all-files"], check=True)
        except subprocess.CalledProcessError:
            sys.exit(1)

    def doc(self):
        """Build docs"""

        from scope.utils import (
            make_tdtax_taxonomy,
            plot_gaia_density,
            plot_gaia_hr,
            plot_light_curve_data,
            plot_periods,
        )

        # generate taxonomy.html
        with status("Generating taxonomy visualization"):
            path_static = pathlib.Path(
                __file__).parent.absolute() / "doc" / "_static"
            if not path_static.exists():
                path_static.mkdir(parents=True, exist_ok=True)
            tdtax.write_viz(
                make_tdtax_taxonomy(self.config["taxonomy"]),
                outname=path_static / "taxonomy.html",
            )

        # generate images for the Field Guide
        if (self.kowalski is None) or (not self.kowalski.ping()):
            print("Kowalski connection not established, cannot generate docs.")
            return

        period_limits = {
            "cepheid": [1.0, 100.0],
            "delta_scuti": [0.03, 0.3],
            "beta_lyr": [0.3, 25],
            "rr_lyr": [0.2, 1.0],
            "w_uma": [0.2, 0.8],
        }
        period_loglimits = {
            "cepheid": True,
            "delta_scuti": False,
            "beta_lyr": True,
            "rr_lyr": False,
            "w_uma": False,
        }

        # example periods
        with status("Generating example period histograms"):
            path_doc_data = pathlib.Path(
                __file__).parent.absolute() / "doc" / "data"

            # stored as ra/decs in csv format under /data/golden
            golden_sets = pathlib.Path(
                __file__).parent.absolute() / "data" / "golden"
            for golden_set in golden_sets.glob("*.csv"):
                golden_set_name = golden_set.stem
                positions = pd.read_csv(golden_set).to_numpy().tolist()
                features = self._get_features(positions=positions)

                if len(features) == 0:
                    print(f"No features for {golden_set_name}")
                    continue

                limits = period_limits.get(golden_set_name)
                loglimits = period_loglimits.get(golden_set_name)

                plot_periods(
                    features=features,
                    limits=limits,
                    loglimits=loglimits,
                    save=path_doc_data / f"period__{golden_set_name}",
                )

        # example skymaps for all Golden sets
        with status("Generating skymaps diagrams for Golden sets"):
            path_doc_data = pathlib.Path(
                __file__).parent.absolute() / "doc" / "data"

            path_gaia_density = (pathlib.Path(__file__).parent.absolute() /
                                 "data" / "Gaia_hp8_densitymap.fits")
            # stored as ra/decs in csv format under /data/golden
            golden_sets = pathlib.Path(
                __file__).parent.absolute() / "data" / "golden"
            for golden_set in golden_sets.glob("*.csv"):
                golden_set_name = golden_set.stem
                positions = pd.read_csv(golden_set).to_numpy().tolist()

                plot_gaia_density(
                    positions=positions,
                    path_gaia_density=path_gaia_density,
                    save=path_doc_data / f"radec__{golden_set_name}",
                )

        # example light curves
        with status("Generating example light curves"):
            path_doc_data = pathlib.Path(
                __file__).parent.absolute() / "doc" / "data"

            for sample_object_name, sample_object in self.config["docs"][
                    "field_guide"].items():
                sample_light_curves = self._get_light_curve_data(
                    ra=sample_object["coordinates"][0],
                    dec=sample_object["coordinates"][1],
                    catalog=self.config["kowalski"]["collections"]["sources"],
                )
                plot_light_curve_data(
                    light_curve_data=sample_light_curves,
                    period=sample_object.get("period"),
                    title=sample_object.get("title"),
                    save=path_doc_data / sample_object_name,
                )

        # example HR diagrams for all Golden sets
        with status("Generating HR diagrams for Golden sets"):
            path_gaia_hr_histogram = (
                pathlib.Path(__file__).parent.absolute() / "doc" / "data" /
                "gaia_hr_histogram.dat")
            # stored as ra/decs in csv format under /data/golden
            golden_sets = pathlib.Path(
                __file__).parent.absolute() / "data" / "golden"
            for golden_set in golden_sets.glob("*.csv"):
                golden_set_name = golden_set.stem
                positions = pd.read_csv(golden_set).to_numpy().tolist()
                gaia_sources = self._get_nearest_gaia(positions=positions)

                plot_gaia_hr(
                    gaia_data=gaia_sources,
                    path_gaia_hr_histogram=path_gaia_hr_histogram,
                    save=path_doc_data / f"hr__{golden_set_name}",
                )

        # build docs
        subprocess.run(["make", "html"], cwd="doc", check=True)

    @staticmethod
    def fetch_models(gcs_path: str = "gs://ztf-scope/models"):
        """
        Fetch SCoPe models from GCP

        :return:
        """
        path_models = pathlib.Path(__file__).parent / "models"
        if not path_models.exists():
            path_models.mkdir(parents=True, exist_ok=True)

        command = [
            "gsutil",
            "-m",
            "cp",
            "-n",
            "-r",
            os.path.join(gcs_path, "*.csv"),
            str(path_models),
        ]
        p = subprocess.run(command, check=True)
        if p.returncode != 0:
            raise RuntimeError("Failed to fetch SCoPe models")

    @staticmethod
    def fetch_datasets(gcs_path: str = "gs://ztf-scope/datasets"):
        """
        Fetch SCoPe datasets from GCP

        :return:
        """
        path_datasets = pathlib.Path(__file__).parent / "data" / "training"
        if not path_datasets.exists():
            path_datasets.mkdir(parents=True, exist_ok=True)

        command = [
            "gsutil",
            "-m",
            "cp",
            "-n",
            "-r",
            os.path.join(gcs_path, "*.csv"),
            str(path_datasets),
        ]
        p = subprocess.run(command, check=True)
        if p.returncode != 0:
            raise RuntimeError("Failed to fetch SCoPe datasets")

    def train(
        self,
        tag: str,
        path_dataset: Union[str, pathlib.Path],
        gpu: Optional[int] = None,
        verbose: bool = False,
        **kwargs,
    ):
        """Train classifier

        :param tag: classifier designation, refers to "class" in config.taxonomy
        :param path_dataset: local path to csv file with the dataset
        :param gpu: GPU id to use, zero-based. check tf.config.list_physical_devices('GPU') for available devices
        :param verbose:
        :param kwargs: refer to utils.DNN.setup and utils.Dataset.make
        :return:
        """

        import tensorflow as tf

        if gpu is not None:
            # specified a GPU to run on?
            gpus = tf.config.list_physical_devices("GPU")
            tf.config.experimental.set_visible_devices(gpus[gpu], "GPU")
        else:
            # otherwise run on CPU
            tf.config.experimental.set_visible_devices([], "GPU")

        import wandb
        from wandb.keras import WandbCallback

        from scope.nn import DNN
        from scope.utils import Dataset

        train_config = self.config["training"]["classes"][tag]

        features = self.config["features"][train_config["features"]]

        ds = Dataset(
            tag=tag,
            path_dataset=path_dataset,
            features=features,
            verbose=verbose,
            **kwargs,
        )

        label = train_config["label"]

        # values from kwargs override those defined in config. if latter is absent, use reasonable default
        threshold = kwargs.get("threshold", train_config.get("threshold", 0.5))
        balance = kwargs.get("balance", train_config.get("balance", None))
        weight_per_class = kwargs.get(
            "weight_per_class", train_config.get("weight_per_class", False))
        scale_features = kwargs.get("scale_features", "min_max")

        test_size = kwargs.get("test_size", train_config.get("test_size", 0.1))
        val_size = kwargs.get("val_size", train_config.get("val_size", 0.1))
        random_state = kwargs.get("random_state",
                                  train_config.get("random_state", 42))
        feature_stats = self.config.get("feature_stats", None)

        batch_size = kwargs.get("batch_size",
                                train_config.get("batch_size", 64))
        shuffle_buffer_size = kwargs.get(
            "shuffle_buffer_size", train_config.get("shuffle_buffer_size",
                                                    512))
        epochs = kwargs.get("epochs", train_config.get("epochs", 100))

        datasets, indexes, steps_per_epoch, class_weight = ds.make(
            target_label=label,
            threshold=threshold,
            balance=balance,
            weight_per_class=weight_per_class,
            scale_features=scale_features,
            test_size=test_size,
            val_size=val_size,
            random_state=random_state,
            feature_stats=feature_stats,
            batch_size=batch_size,
            shuffle_buffer_size=shuffle_buffer_size,
            epochs=epochs,
        )

        # set up and train model
        dense_branch = kwargs.get("dense_branch", True)
        conv_branch = kwargs.get("conv_branch", True)
        loss = kwargs.get("loss", "binary_crossentropy")
        optimizer = kwargs.get("optimizer", "adam")
        lr = float(kwargs.get("lr", 3e-4))
        momentum = float(kwargs.get("momentum", 0.9))
        monitor = kwargs.get("monitor", "val_loss")
        patience = int(kwargs.get("patience", 20))
        callbacks = kwargs.get("callbacks",
                               ("reduce_lr_on_plateau", "early_stopping"))
        run_eagerly = kwargs.get("run_eagerly", False)
        pre_trained_model = kwargs.get("pre_trained_model")
        save = kwargs.get("save", False)

        # parse boolean args
        dense_branch = forgiving_true(dense_branch)
        conv_branch = forgiving_true(conv_branch)
        run_eagerly = forgiving_true(run_eagerly)
        save = forgiving_true(save)

        classifier = DNN(name=tag)

        classifier.setup(
            dense_branch=dense_branch,
            features_input_shape=(len(features), ),
            conv_branch=conv_branch,
            dmdt_input_shape=(26, 26, 1),
            loss=loss,
            optimizer=optimizer,
            learning_rate=lr,
            momentum=momentum,
            monitor=monitor,
            patience=patience,
            callbacks=callbacks,
            run_eagerly=run_eagerly,
        )

        if verbose:
            print(classifier.model.summary())

        if pre_trained_model is not None:
            classifier.load(pre_trained_model)

        time_tag = datetime.datetime.utcnow().strftime("%Y%m%d_%H%M%S")

        if not kwargs.get("test", False):
            wandb.login(key=self.config["wandb"]["token"])
            wandb.init(
                project=self.config["wandb"]["project"],
                tags=[tag],
                name=f"{tag}-{time_tag}",
                config={
                    "tag": tag,
                    "label": label,
                    "dataset": pathlib.Path(path_dataset).name,
                    "scale_features": scale_features,
                    "learning_rate": lr,
                    "epochs": epochs,
                    "patience": patience,
                    "random_state": random_state,
                    "batch_size": batch_size,
                    "architecture": "scope-net",
                    "dense_branch": dense_branch,
                    "conv_branch": conv_branch,
                },
            )
            classifier.meta["callbacks"].append(WandbCallback())

        classifier.train(
            datasets["train"],
            datasets["val"],
            steps_per_epoch["train"],
            steps_per_epoch["val"],
            epochs=epochs,
            class_weight=class_weight,
            verbose=verbose,
        )

        if verbose:
            print("Evaluating on test set:")
        stats = classifier.evaluate(datasets["test"], verbose=verbose)
        if verbose:
            print(stats)

        param_names = (
            "loss",
            "tp",
            "fp",
            "tn",
            "fn",
            "accuracy",
            "precision",
            "recall",
            "auc",
        )
        if not kwargs.get("test", False):
            # log model performance on the test set
            for param, value in zip(param_names, stats):
                wandb.run.summary[f"test_{param}"] = value
            p, r = wandb.run.summary["test_precision"], wandb.run.summary[
                "test_recall"]
            wandb.run.summary["test_f1"] = 2 * p * r / (p + r)

        if datasets["dropped_samples"] is not None:
            # log model performance on the dropped samples
            if verbose:
                print("Evaluating on samples dropped from the training set:")
            stats = classifier.evaluate(datasets["dropped_samples"],
                                        verbose=verbose)
            if verbose:
                print(stats)

            if not kwargs.get("test", False):
                for param, value in zip(param_names, stats):
                    wandb.run.summary[f"dropped_samples_{param}"] = value
                p, r = (
                    wandb.run.summary["dropped_samples_precision"],
                    wandb.run.summary["dropped_samples_recall"],
                )
                wandb.run.summary["dropped_samples_f1"] = 2 * p * r / (p + r)

        if save:
            output_path = str(
                pathlib.Path(__file__).parent.absolute() / "models" / tag)
            if verbose:
                print(f"Saving model to {output_path}")
            classifier.save(
                output_path=output_path,
                output_format="tf",
                tag=time_tag,
            )

            return time_tag

    def test(self):
        """Test different workflows

        :return:
        """
        import uuid
        import shutil

        # create a mock dataset and check that the training pipeline works
        dataset = f"{uuid.uuid4().hex}.csv"
        path_mock = pathlib.Path(
            __file__).parent.absolute() / "data" / "training"

        try:
            if not path_mock.exists():
                path_mock.mkdir(parents=True, exist_ok=True)

            feature_names = self.config["features"]["ontological"]
            class_names = [
                self.config["training"]["classes"][class_name]["label"]
                for class_name in self.config["training"]["classes"]
            ]

            entries = []
            for i in range(1000):
                entry = {
                    **{
                        feature_name: np.random.normal(0, 0.1)
                        for feature_name in feature_names
                    },
                    **{
                        class_name: np.random.choice([0, 1])
                        for class_name in class_names
                    },
                    **{
                        "non-variable": np.random.choice([0, 1])
                    },
                    **{
                        "dmdt": np.abs(np.random.random((26, 26))).tolist()
                    },
                }
                entries.append(entry)

            df_mock = pd.DataFrame.from_records(entries)
            df_mock.to_csv(path_mock / dataset, index=False)

            tag = "vnv"
            time_tag = self.train(
                tag=tag,
                path_dataset=path_mock / dataset,
                batch_size=32,
                epochs=3,
                verbose=True,
                save=True,
                test=True,
            )
            path_model = (pathlib.Path(__file__).parent.absolute() / "models" /
                          tag / time_tag)
            shutil.rmtree(path_model)
        finally:
            # clean up after thyself
            (path_mock / dataset).unlink()
コード例 #23
0
def query_kowalski(username,
                   password,
                   ra_center,
                   dec_center,
                   radius,
                   jd_trigger,
                   min_days,
                   max_days,
                   slices,
                   ndethist_min,
                   within_days,
                   after_trigger=True,
                   verbose=True):
    '''Query kowalski and apply the selection criteria'''

    k = Kowalski(username=username, password=password, verbose=False)
    # Initialize a set for the results
    set_objectId_all = set([])
    slices = slices + 1

    for slice_lim, i in zip(
            np.linspace(0, len(ra_center), slices)[:-1],
            np.arange(len(np.linspace(0, len(ra_center), slices)[:-1]))):
        try:
            ra_center_slice = ra_center[int(slice_lim):int(
                np.linspace(0, len(ra_center), slices)[:-1][i + 1])]
            dec_center_slice = dec_center[int(slice_lim):int(
                np.linspace(0, len(dec_center), slices)[:-1][i + 1])]
        except IndexError:
            ra_center_slice = ra_center[int(slice_lim):]
            dec_center_slice = dec_center[int(slice_lim):]
        coords_arr = []
        for ra, dec in zip(ra_center_slice, dec_center_slice):
            try:
                # Remove points too far south for ZTF.
                # Say, keep only Dec>-40 deg to be conservative
                if dec < -40.:
                    continue
                coords = SkyCoord(ra=float(ra) * u.deg, dec=float(dec) * u.deg)
                coords_arr.append((coords.ra.deg, coords.dec.deg))
            except ValueError:
                print("Problems with the galaxy coordinates?")
                continue

        # Correct the minimum number of detections
        ndethist_min_corrected = int(ndethist_min - 1)

        # Correct the jd_trigger if the user specifies to query
        # also before the trigger
        if after_trigger is False:
            jd_trigger = 0
        try:
            print(
                f"slice: {int(slice_lim)}:{int(np.linspace(0,len(ra_center),slices)[:-1][i+1])}"
            )
        except:
            print(f"slice: {int(slice_lim)}:{int(len(ra_center))}")
        q = {
            "query_type": "cone_search",
            "query": {
                "object_coordinates": {
                    "radec": f"{coords_arr}",
                    "cone_search_radius": f"{radius}",
                    "cone_search_unit": "arcmin"
                },
                "catalogs": {
                    "ZTF_alerts": {
                        "filter": {
                            "candidate.jd": {
                                '$gt': jd_trigger
                            },
                            "candidate.drb": {
                                '$gt': 0.8
                            },
                            "candidate.ndethist": {
                                '$gt': ndethist_min_corrected
                            },
                            "candidate.jdstarthist": {
                                '$gt': jd_trigger,
                                '$lt': jd_trigger + within_days
                            }
                        },
                        "projection": {
                            "objectId": 1,
                            "candidate.rcid": 1,
                            "candidate.ra": 1,
                            "candidate.dec": 1,
                            "candidate.jd": 1,
                            "candidate.ndethist": 1,
                            "candidate.jdstarthist": 1,
                            "candidate.jdendhist": 1,
                            "candidate.jdendhist": 1,
                            "candidate.magpsf": 1,
                            "candidate.sigmapsf": 1,
                            "candidate.fid": 1,
                            "candidate.programid": 1,
                            "candidate.isdiffpos": 1,
                            "candidate.ndethist": 1,
                            "candidate.ssdistnr": 1,
                            "candidate.rb": 1,
                            "candidate.drb": 1,
                            "candidate.distpsnr1": 1,
                            "candidate.sgscore1": 1,
                            "candidate.srmag1": 1,
                            "candidate.distpsnr2": 1,
                            "candidate.sgscore2": 1,
                            "candidate.srmag2": 1,
                            "candidate.distpsnr3": 1,
                            "candidate.sgscore3": 1,
                            "candidate.srmag3": 1
                        }
                    }
                },
                "kwargs": {
                    "hint": "gw01"
                }
            }
        }

        # Perform the query
        r = k.query(query=q)
        print('Search completed for this slice.')

        objectId_list = []
        with_neg_sub = []
        old = []
        out_of_time_window = []
        stellar_list = []

        # Try to query kowalski up to 5 times
        i = 1
        no_candidates = False
        while i <= 5:
            try:
                if r['data'] == []:
                    no_candidates = True
                keys_list = list(r['data']['ZTF_alerts'].keys())
                break
            except (AttributeError, KeyError, TypeError):
                print(f"failed attempt {i}")
                i += 1
        if i > 5:
            print(f"SKIPPING jd={jd}, field={field} after 5 attempts")
            continue
        if no_candidates is True:
            if verbose is True:
                print(f"No candidates")
            continue
        for key in keys_list:
            all_info = r['data']['ZTF_alerts'][key]

            for info in all_info:
                if info['objectId'] in old:
                    continue
                if info['objectId'] in stellar_list:
                    continue
                if np.abs(info['candidate']['ssdistnr']) < 10:
                    continue
                if info['candidate']['isdiffpos'] in ['f', 0]:
                    with_neg_sub.append(info['objectId'])
                if (info['candidate']['jdendhist'] -
                        info['candidate']['jdstarthist']) < min_days:
                    continue
                if (info['candidate']['jdendhist'] -
                        info['candidate']['jdstarthist']) > max_days:
                    old.append(info['objectId'])
                if (info['candidate']['jdstarthist'] -
                        jd_trigger) > within_days:
                    old.append(info['objectId'])
                # REMOVE!  Only for O3a paper
                #if (info['candidate']['jdendhist'] -
                #info['candidate']['jdstarthist']) >= 72./24. and info['candidate']['ndethist'] <= 2.:
                #    out_of_time_window.append(info['objectId'])
                if after_trigger is True:
                    if (info['candidate']['jdendhist'] -
                            jd_trigger) > max_days:
                        out_of_time_window.append(info['objectId'])
                else:
                    if (info['candidate']['jdendhist'] -
                            info['candidate']['jdstarthist']) > max_days:
                        out_of_time_window.append(info['objectId'])
                try:
                    if (np.abs(info['candidate']['distpsnr1']) < 1.5
                            and info['candidate']['sgscore1'] > 0.50):
                        stellar_list.append(info['objectId'])
                except (KeyError, ValueError):
                    pass
                try:
                    if (np.abs(info['candidate']['distpsnr1']) < 15.
                            and info['candidate']['srmag1'] < 15.
                            and info['candidate']['srmag1'] > 0.
                            and info['candidate']['sgscore1'] >= 0.5):
                        continue
                except (KeyError, ValueError):
                    pass
                try:
                    if (np.abs(info['candidate']['distpsnr2']) < 15.
                            and info['candidate']['srmag2'] < 15.
                            and info['candidate']['srmag2'] > 0.
                            and info['candidate']['sgscore2'] >= 0.5):
                        continue
                except (KeyError, ValueError):
                    pass
                try:
                    if (np.abs(info['candidate']['distpsnr3']) < 15.
                            and info['candidate']['srmag3'] < 15.
                            and info['candidate']['srmag3'] > 0.
                            and info['candidate']['sgscore3'] >= 0.5):
                        continue
                except (KeyError, ValueError):
                    pass

                objectId_list.append(info['objectId'])

        set_objectId = set(objectId_list)

        # Remove those objects with negative subtraction
        for n in set(with_neg_sub):
            try:
                set_objectId.remove(n)
            except (ValueError, KeyError):
                pass

        # Remove stellar objects
        for n in set(stellar_list):
            try:
                set_objectId.remove(n)
            except (ValueError, KeyError):
                pass

        # Remove those objects considered old
        for n in set(old):
            try:
                set_objectId.remove(n)
            except (ValueError, KeyError):
                pass

        # Remove those objects whole alerts go bejond jd_trigger+max_days
        for n in set(out_of_time_window):
            try:
                set_objectId.remove(n)
            except (ValueError, KeyError):
                pass
        print(set_objectId)

        set_objectId_all = set_objectId_all | set_objectId
        print("Cumulative:", set_objectId_all)

    return set_objectId_all
コード例 #24
0
def fetch_lc_radecs(_radecs):
    k = Kowalski(username='******', password='******', verbose=False)

    num_obj = len(_radecs)

    print(f'Total entries: {num_obj}')

    batch_size = 100
    num_batches = int(np.ceil(num_obj / batch_size))

    times = []

    ids = set()

    for nb in range(num_batches):
        # print(_radecs[nb * batch_size: (nb + 1) * batch_size])
        q = {
            "query_type": "cone_search",
            "object_coordinates": {
                "radec": f"{_radecs[nb * batch_size: (nb + 1) * batch_size]}",
                "cone_search_radius": "2",
                "cone_search_unit": "arcsec"
            },
            "catalogs": {
                "ZTF_sources_20190412": {
                    "filter": {},
                    "projection": {
                        "_id": 1,
                        "filter": 1,
                        "data.expid": 1,
                        "data.ra": 1,
                        "data.dec": 1,
                        "data.programid": 1,
                        "data.hjd": 1,
                        "data.mag": 1,
                        "data.magerr": 1
                    }
                }
            }
        }

        tic = time()
        r = k.query(query=q)
        toc = time()
        times.append(toc - tic)
        print(
            f'Fetching batch {nb + 1}/{num_batches} with {batch_size} sources/LCs took: {toc - tic:.3f} seconds'
        )

        # Light curves are here:
        data = r['result_data']
        # TODO: your magic here
        # print(data)
        for sc, sources in data['ZTF_sources_20190412'].items():
            ids = ids.union([s['_id'] for s in sources])
        print(len(ids))
        # FIXME: Must filter out data.programid == 1 data

    print(f'min: {np.min(times)}')
    print(f'median: {np.median(times)}')
    print(f'max: {np.max(times)}')
コード例 #25
0
def query_kowalski(username,
                   password,
                   list_fields,
                   min_days,
                   max_days,
                   ndethist_min,
                   jd_gap=50.):
    '''Query kowalski and apply the selection criteria'''

    k = Kowalski(username=username, password=password, verbose=False)
    #Initialize a set for the results
    set_objectId_all = set([])
    for field in list_fields:
        #Correct the minimum number of detections
        ndethist_min_corrected = int(ndethist_min - 1)

        jd = 2458650.0
        jd_start = jd
        jd_end = jd + jd_gap
        q = {
            "query_type": "find",
            "query": {
                "catalog": "ZTF_alerts",
                "filter": {
                    'candidate.jd': {
                        '$gt': jd_start,
                        '$lt': jd_end
                    },
                    'candidate.field': field,
                    'candidate.rb': {
                        '$gt': 0.5
                    },
                    'candidate.drb': {
                        '$gt': 0.5
                    },
                    'candidate.ndethist': {
                        '$gt': ndethist_min_corrected
                    },
                    'candidate.magpsf': {
                        '$gt': 16
                    }
                    #'candidate.isdiffpos': 't'
                    #'objectId'
                },
                "projection": {
                    "objectId": 1,
                    "candidate.rcid": 1,
                    "candidate.ra": 1,
                    "candidate.dec": 1,
                    "candidate.jd": 1,
                    "candidate.ndethist": 1,
                    "candidate.jdstarthist": 1,
                    "candidate.jdendhist": 1,
                    "candidate.jdendhist": 1,
                    "candidate.magpsf": 1,
                    "candidate.sigmapsf": 1,
                    "candidate.fid": 1,
                    "candidate.programid": 1,
                    "candidate.isdiffpos": 1,
                    "candidate.ndethist": 1,
                    "candidate.ssdistnr": 1,
                    "candidate.rb": 1,
                    "candidate.drb": 1,
                    "candidate.distpsnr1": 1,
                    "candidate.sgscore1": 1,
                    "candidate.srmag1": 1,
                    "candidate.distpsnr2": 1,
                    "candidate.sgscore2": 1,
                    "candidate.srmag2": 1,
                    "candidate.distpsnr3": 1,
                    "candidate.sgscore3": 1,
                    "candidate.srmag3": 1
                }
            },
            "kwargs": {
                "hint": "jd_field_rb_drb_braai_ndethhist_magpsf_isdiffpos"
            }
        }

        #Perform the query
        r = k.query(query=q)
        print(f"Search completed for field {field}.")

        #        #Dump the results in a json file
        #        with open(f'results_clu25Mpc_1week_{i+1}.json', 'w') as j:
        #            json.dump(r, j)

        #Identify 'candid' for all relevant candidates
        objectId_list = []
        with_neg_sub = []
        old = []
        out_of_time_window = []
        stellar_list = []

        if r['result_data']['query_result'] == []:
            print("No candidates")
            continue

        for info in r['result_data']['query_result']:
            #if info['objectId'] == 'ZTF19abyfbii':
            #    pdb.set_trace()
            if info['objectId'] in old:
                continue
            if info['objectId'] in stellar_list:
                continue
            if np.abs(info['candidate']['ssdistnr']) < 10:
                continue
            if info['candidate']['isdiffpos'] in ['f', 0]:
                with_neg_sub.append(info['objectId'])
            if (info['candidate']['jdendhist'] -
                    info['candidate']['jdstarthist']) < min_days:
                continue
            if (info['candidate']['jdendhist'] -
                    info['candidate']['jdstarthist']) > max_days:
                old.append(info['objectId'])
            try:
                if (np.abs(info['candidate']['distpsnr1']) < 2.
                        and info['candidate']['sgscore1'] >= 0.76):
                    stellar_list.append(info['objectId'])
            except:
                do = 'do nothing.'
            try:
                if (np.abs(info['candidate']['distpsnr1']) < 15.
                        and info['candidate']['srmag1'] < 15.
                        and info['candidate']['sgscore1'] >= 0.5):
                    continue
            except:
                do = 'do nothing.'
            try:
                if (np.abs(info['candidate']['distpsnr2']) < 15.
                        and info['candidate']['srmag2'] < 15.
                        and info['candidate']['sgscore2'] >= 0.5):
                    continue
            except:
                do = 'do nothing.'
            try:
                if (np.abs(info['candidate']['distpsnr3']) < 15.
                        and info['candidate']['srmag3'] < 15.
                        and info['candidate']['sgscore3'] >= 0.5):
                    continue
            except:
                do = 'do nothing.'

            objectId_list.append(info['objectId'])

        set_objectId = set(objectId_list)

        #Remove those objects with negative subtraction
        for n in set(with_neg_sub):
            try:
                set_objectId.remove(n)
            except:
                do = 'do nothing'

        #Remove stellar objects
        for n in set(stellar_list):
            try:
                set_objectId.remove(n)
            except:
                do = 'do nothing'

        #Remove those objects considered old
        for n in set(old):
            try:
                set_objectId.remove(n)
            except:
                do = 'do nothing'

        #Remove those objects whole alerts go bejond jd_trigger+max_days
        for n in set(out_of_time_window):
            try:
                set_objectId.remove(n)
            except:
                do = 'do nothing'
        print(set_objectId)

        set_objectId_all = set_objectId_all | set_objectId
        print("Cumulative:", set_objectId_all)
        '''
        print('----stats-----')
        print('Number of sources with negative sub: ', len(set(with_neg_sub)))
        print('Number of sources with only pos subtraction: ', len(set_objectId))
        print(f"Number of sources older than {max_days} days: {len(set(old))}, specifically {set(old)}")
        '''

    return set_objectId_all
コード例 #26
0
def get_ztf(filename,
            name,
            username,
            password,
            filetype="default",
            z=0.0,
            zerr=0.0001,
            SN_Type="Ia"):

    k = Kowalski(username=username, password=password, verbose=True)

    q = {
        "query_type": "general_search",
        "query": "db['ZTF_alerts'].find({'objectId': {'$eq': '" + name + "'}})"
    }
    r = k.query(query=q, timeout=10)
    if len(r['result_data']['query_result']) > 0:
        candidate = r['result_data']['query_result'][0]
        prevcandidates = r['result_data']['query_result'][0]['prv_candidates']

        print(candidate, prevcandidates)

        jd = [candidate['candidate']['jd']]
        mag = [candidate['candidate']['magpsf']]
        magerr = [candidate['candidate']['sigmapsf']]
        filt = [candidate['candidate']['fid']]

        for candidate in prevcandidates:
            jd.append(candidate['jd'])
            if not candidate['magpsf'] == None:
                mag.append(candidate['magpsf'])
            else:
                mag.append(candidate['diffmaglim'])
            if not candidate['sigmapsf'] == None:
                magerr.append(candidate['sigmapsf'])
            else:
                magerr.append(np.inf)

            filt.append(candidate['fid'])
        filtname = []
        for f in filt:
            if f == 1:
                filtname.append('g')
            elif f == 2:
                filtname.append('r')
            elif f == 3:
                filtname.append('i')
    idx = np.argsort(jd)

    if filetype == "lc":
        mjds, fluxs, fluxerrs, passband = [], [], [], []
        for ii in idx:
            t = Time(jd[ii], format='jd').mjd
            flux = 10**((mag[ii] + 48.60) / (-2.5))
            fluxerr = magerr[ii] * flux
            mjds.append(t)
            fluxs.append(flux)
            fluxerrs.append(fluxerr)
            passband.append(filtname[ii])
        return mjds, fluxs, fluxerrs, passband

    fid = open(filename, 'w')
    if filetype == "default":
        for ii in idx:
            t = Time(jd[ii], format='jd').isot
            fid.write('%s %s %.5f %.5f\n' %
                      (t, filtname[ii], mag[ii], magerr[ii]))
    elif filetype == "snmachine":
        fid.write('HOST_GALAXY_PHOTO-Z:   %.4f  +- %.4f\n' % (z, zerr))
        fid.write('SIM_COMMENT:  SN Type = %s\n' % SN_Type)
        for ii in idx:
            t = Time(jd[ii], format='jd').mjd
            flux = 10**((mag[ii] + 48.60) / (-2.5))
            fluxerr = magerr[ii] * flux
            fid.write('OBS: %.3f %s NULL %.3e %.3e %.2f %.5f %.5f\n' %
                      (t, filtname[ii], flux, fluxerr, flux / fluxerr, mag[ii],
                       magerr[ii]))
    fid.close()
コード例 #27
0
def agn_b_scores(name, username, password, colors=False):
    k = Kowalski(username=username, password=password, verbose=False)
    q = {
        "query_type": "find",
        "query": {
            "catalog": 'ZTF_alerts',
            "filter": {
                "objectId": name
            },
            "projection": {
                "_id": 0,
                "cutoutScience": 0,
                "cutoutTemplate": 0,
                "cutoutDifference": 0
            },
        }
    }
    r = k.query(query=q)
    alerts = r['data']
    ra, dec = alerts[0]['candidate']['ra'], alerts[0]['candidate']['dec']

    cc = SkyCoord(ra, dec, unit=(u.deg, u.deg))
    table = Irsa.query_region(coordinates=cc,
                              catalog="allwise_p3as_psd",
                              spatial="Cone",
                              radius=2 * u.arcsec)

    # AGN WISE
    if len(table['w1mpro']) == 0:
        agn = False
        temp_points = 6
    else:
        w1, w1_err, w2, w2_err, w3, w3_err = table['w1mpro'], table[
            'w1sigmpro'], table['w2mpro'], table['w2sigmpro'], table[
                'w3mpro'], table['w3sigmpro']
        if w1 - w2 > 0.8 + 0.1 and w2_err < 0.5 and w1_err < 0.5:
            agn = True
            temp_points = -2
        elif w2 - w3 > 2.5 + 0.1 and w2_err < 0.5 and w3_err < 0.5:
            agn = True
            temp_points = -2
        elif w1 - w2 > 0.8 and w2_err < 0.5 and w1_err < 0.5:
            agn = True
            temp_points = 0
        elif w2 - w3 > 2.5 and w2_err < 0.5 and w3_err < 0.5:
            agn = True
            temp_points = 0
        elif w1 - w2 > 0.8 - 0.2 and w2_err < 0.5 and w1_err < 0.5:
            agn = False
            temp_points = 2
        elif w2 - w3 > 2.5 - 0.3 and w2_err < 0.5 and w3_err < 0.5:
            agn = False
            temp_points = 2
        elif w1 - w2 > 0.8 - 0.5 and w2_err < 0.5 and w1_err < 0.5:
            agn = False
            temp_points = 4
        elif w2 - w3 > 2.5 - 0.5 and w2_err < 0.5 and w3_err < 0.5:
            agn = False
            temp_points = 4

        else:
            agn = False
            temp_points = 6
    # low b
    if np.abs(cc.galactic.b.value) < 15:
        b_temp_points = -10
    else:
        b_temp_points = 0

    if colors:
        return temp_points, agn, [w1 - w2, w2 - w3]
    else:
        return temp_points, b_temp_points
コード例 #28
0
ファイル: fetch.py プロジェクト: dmitryduev/ztf-wd
class WhiteDwarf(object):
    def __init__(self, config_file: str):
        try:
            ''' load config data '''
            self.config = self.get_config(_config_file=config_file)
            ''' set up logging at init '''
            self.logger, self.logger_utc_date = self.set_up_logging(
                _name='archive', _mode='a')

            # make dirs if necessary:
            for _pp in ('app', 'alerts', 'tmp', 'logs'):
                _path = self.config['path']['path_{:s}'.format(_pp)]
                if not os.path.exists(_path):
                    os.makedirs(_path)
                    self.logger.debug('Created {:s}'.format(_path))
            ''' init connection to Kowalski '''
            self.kowalski = Kowalski(username=secrets['kowalski']['user'],
                                     password=secrets['kowalski']['password'])
            # host='localhost', port=8082, protocol='http'
            ''' init db if necessary '''
            self.init_db()
            ''' connect to db: '''
            self.db = None
            # will exit if this fails
            self.connect_to_db()

        except Exception as e:
            print(e)
            traceback.print_exc()
            sys.exit()

    @staticmethod
    def get_config(_config_file):
        """
            Load config JSON file
        """
        ''' script absolute location '''
        abs_path = os.path.dirname(inspect.getfile(inspect.currentframe()))

        if _config_file[0] not in ('/', '~'):
            if os.path.isfile(os.path.join(abs_path, _config_file)):
                config_path = os.path.join(abs_path, _config_file)
            else:
                raise IOError('Failed to find config file')
        else:
            if os.path.isfile(_config_file):
                config_path = _config_file
            else:
                raise IOError('Failed to find config file')

        with open(config_path) as cjson:
            config_data = json.load(cjson)
            # config must not be empty:
            if len(config_data) > 0:
                return config_data
            else:
                raise Exception('Failed to load config file')

    def set_up_logging(self, _name='ztf_wd', _mode='w'):
        """ Set up logging

            :param _name:
            :param _level: DEBUG, INFO, etc.
            :param _mode: overwrite log-file or append: w or a
            :return: logger instance
            """
        # 'debug', 'info', 'warning', 'error', or 'critical'
        if self.config['misc']['logging_level'] == 'debug':
            _level = logging.DEBUG
        elif self.config['misc']['logging_level'] == 'info':
            _level = logging.INFO
        elif self.config['misc']['logging_level'] == 'warning':
            _level = logging.WARNING
        elif self.config['misc']['logging_level'] == 'error':
            _level = logging.ERROR
        elif self.config['misc']['logging_level'] == 'critical':
            _level = logging.CRITICAL
        else:
            raise ValueError(
                'Config file error: logging level must be ' +
                '\'debug\', \'info\', \'warning\', \'error\', or \'critical\'')

        # get path to logs from config:
        _path = self.config['path']['path_logs']

        if not os.path.exists(_path):
            os.makedirs(_path)
        utc_now = datetime.datetime.utcnow()

        # http://www.blog.pythonlibrary.org/2012/08/02/python-101-an-intro-to-logging/
        _logger = logging.getLogger(_name)

        _logger.setLevel(_level)
        # create the logging file handler
        fh = logging.FileHandler(os.path.join(
            _path, '{:s}.{:s}.log'.format(_name, utc_now.strftime('%Y%m%d'))),
                                 mode=_mode)
        logging.Formatter.converter = time.gmtime

        formatter = logging.Formatter(
            '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
        # formatter = logging.Formatter('%(asctime)s %(message)s')
        fh.setFormatter(formatter)

        # add handler to logger object
        _logger.addHandler(fh)

        return _logger, utc_now.strftime('%Y%m%d')

    def shut_down_logger(self):
        """
            Prevent writing to multiple log-files after 'manual rollover'
        :return:
        """
        handlers = self.logger.handlers[:]
        for handler in handlers:
            handler.close()
            self.logger.removeHandler(handler)

    def check_logging(self):
        """
            Check if a new log file needs to be started and start it if necessary
        """
        if datetime.datetime.utcnow().strftime(
                '%Y%m%d') != self.logger_utc_date:
            # reset
            self.shut_down_logger()
            self.logger, self.logger_utc_date = self.set_up_logging(
                _name='ztf_wd', _mode='a')

    def init_db(self):
        """
            Initialize db if new Mongo instance
        :return:
        """
        _client = pymongo.MongoClient(
            username=self.config['database']['admin'],
            password=self.config['database']['admin_pwd'],
            host=self.config['database']['host'],
            port=self.config['database']['port'])
        # _id: db_name.user_name
        user_ids = [
            _u['_id'] for _u in _client.admin.system.users.find({}, {'_id': 1})
        ]

        db_name = self.config['database']['db']
        username = self.config['database']['user']

        # print(f'{db_name}.{username}')
        # print(user_ids)

        if f'{db_name}.{username}' not in user_ids:
            _client[db_name].command('createUser',
                                     self.config['database']['user'],
                                     pwd=self.config['database']['pwd'],
                                     roles=['readWrite'])
            self.logger.info('Successfully initialized db')

    def connect_to_db(self):
        """
            Connect to MongoDB-powered database
        :return:
        """
        _config = self.config
        try:
            if self.logger is not None:
                self.logger.debug(
                    'Connecting to the database at {:s}:{:d}'.format(
                        _config['database']['host'],
                        _config['database']['port']))
            _client = pymongo.MongoClient(host=_config['database']['host'],
                                          port=_config['database']['port'])
            # grab main database:
            _db = _client[_config['database']['db']]

        except Exception as _e:
            if self.logger is not None:
                self.logger.error(_e)
                self.logger.error(
                    'Failed to connect to the database at {:s}:{:d}'.format(
                        _config['database']['host'],
                        _config['database']['port']))
            # raise error
            raise ConnectionRefusedError
        try:
            # authenticate
            _db.authenticate(_config['database']['user'],
                             _config['database']['pwd'])
            if self.logger is not None:
                self.logger.debug(
                    'Successfully authenticated with the database at {:s}:{:d}'
                    .format(_config['database']['host'],
                            _config['database']['port']))
        except Exception as _e:
            if self.logger is not None:
                self.logger.error(_e)
                self.logger.error(
                    'Authentication failed for the database at {:s}:{:d}'.
                    format(_config['database']['host'],
                           _config['database']['port']))
            raise ConnectionRefusedError

        if self.logger is not None:
            self.logger.debug(
                'Successfully connected to database at {:s}:{:d}'.format(
                    _config['database']['host'], _config['database']['port']))

        # (re)define self.db
        self.db = dict()
        self.db['client'] = _client
        self.db['db'] = _db

    # @timeout(seconds_before_timeout=120)
    def disconnect_from_db(self):
        """
            Disconnect from MongoDB database.
        :return:
        """
        self.logger.debug('Disconnecting from the database.')
        if self.db is not None:
            try:
                self.db['client'].close()
                self.logger.debug(
                    'Successfully disconnected from the database.')
            except Exception as e:
                self.logger.error('Failed to disconnect from the database.')
                self.logger.error(e)
            finally:
                # reset
                self.db = None
        else:
            self.logger.debug('No connection found.')

    # @timeout(seconds_before_timeout=120)
    def check_db_connection(self):
        """
            Check if DB connection is alive/established.
        :return: True if connection is OK
        """
        self.logger.debug('Checking database connection.')
        if self.db is None:
            try:
                self.connect_to_db()
            except Exception as e:
                self.logger.error('Lost database connection.')
                self.logger.error(e)
                return False
        else:
            try:
                # force connection on a request as the connect=True parameter of MongoClient seems
                # to be useless here
                self.db['client'].server_info()
            except pymongo.errors.ServerSelectionTimeoutError as e:
                self.logger.error('Lost database connection.')
                self.logger.error(e)
                return False

        return True

    def insert_db_entry(self, _collection=None, _db_entry=None):
        """
            Insert a document _doc to collection _collection in DB.
            It is monitored for timeout in case DB connection hangs for some reason
        :param _collection:
        :param _db_entry:
        :return:
        """
        assert _collection is not None, 'Must specify collection'
        assert _db_entry is not None, 'Must specify document'
        try:
            self.db['db'][_collection].insert_one(_db_entry)
        except Exception as _e:
            self.logger.info('Error inserting {:s} into {:s}'.format(
                str(_db_entry['_id']), _collection))
            traceback.print_exc()
            self.logger.error(_e)

    def insert_multiple_db_entries(self, _collection=None, _db_entries=None):
        """
            Insert a document _doc to collection _collection in DB.
            It is monitored for timeout in case DB connection hangs for some reason
        :param _db:
        :param _collection:
        :param _db_entries:
        :return:
        """
        assert _collection is not None, 'Must specify collection'
        assert _db_entries is not None, 'Must specify documents'
        try:
            # ordered=False ensures that every insert operation will be attempted
            # so that if, e.g., a document already exists, it will be simply skipped
            self.db['db'][_collection].insert_many(_db_entries, ordered=False)
        except pymongo.errors.BulkWriteError as bwe:
            self.logger.info(bwe.details)
        except Exception as _e:
            traceback.print_exc()
            self.logger.error(_e)

    def replace_db_entry(self, _collection=None, _filter=None, _db_entry=None):
        """
            Insert a document _doc to collection _collection in DB.
            It is monitored for timeout in case DB connection hangs for some reason
        :param _collection:
        :param _filter:
        :param _db_entry:
        :return:
        """
        assert _collection is not None, 'Must specify collection'
        assert _db_entry is not None, 'Must specify document'
        try:
            self.db['db'][_collection].replace_one(_filter,
                                                   _db_entry,
                                                   upsert=True)
        except Exception as _e:
            self.logger.info('Error replacing {:s} in {:s}'.format(
                str(_db_entry['_id']), _collection))
            traceback.print_exc()
            self.logger.error(_e)

    def cross_match(self,
                    _jd_start,
                    _jd_end,
                    _stars: dict,
                    _fov_size_ref_arcsec=2,
                    retries=3) -> dict:

        for ir in range(retries):
            try:
                self.logger.debug(f'Querying Kowalski, attempt {ir+1}')
                # query Kowalski:
                # if False:
                q = {
                    "query_type": "cone_search",
                    "object_coordinates": {
                        "radec": str(_stars),
                        "cone_search_radius": str(_fov_size_ref_arcsec),
                        "cone_search_unit": "arcsec"
                    },
                    "catalogs": {
                        "ZTF_alerts": {
                            "filter":
                            f'{{"candidate.jd": {{"$gt": {_jd_start}, "$lt": {_jd_end}}}}}',
                            "projection": "{}"
                        }
                    },
                    "kwargs": {
                        "save": False
                    }
                }
                # {"candidate.jd": {"$gt": _jd, "$lt": _jd + 1}}
                # {"_id": 1, "objectId": 1,
                #                                                             "candid": 1,
                #                                                             "candidate.jd": 1,
                #                                                             "candidate.programid": 1,
                #                                                             "candidate.rb": 1,
                #                                                             "candidate.magpsf": 1,
                #                                                             "candidate.sigmapsf": 1}
                # ,
                #                               "Gaia_DR2_WD": {"filter": '{}',
                #                                               "projection": '{"_id": 1, "coordinates": 0}'}
                # print(q)
                r = self.kowalski.query(query=q, timeout=300)
                # print(r)

                matches = r['result_data']['ZTF_alerts']

                # only return non-empty matches:
                non_empty_matches = {
                    m: v
                    for m, v in matches.items()
                    if ((v is not None) and (len(v) > 0))
                }

                return non_empty_matches

            except Exception as _e:
                self.logger.error(_e)
                continue

        return {}

    def get_doc_by_id(self, _coll: str, _ids: list, retries=3) -> dict:

        for ir in range(retries):
            try:
                self.logger.debug(f'Querying Kowalski, attempt {ir+1}')
                q = {
                    "query_type": "general_search",
                    "query":
                    f"db['{_coll}'].find({{'_id': {{'$in': {_ids}}}}})",
                    "kwargs": {
                        "save": False
                    }
                }
                # print(q)
                r = self.kowalski.query(query=q, timeout=300)
                # print(r)
                result = r['result_data']['query_result']

                # convert to dict id -> result
                matches = {obj['_id']: obj for obj in result}

                return matches

            except Exception as _e:
                self.logger.error(_e)
                continue

        return {}

    def dump_lightcurve(self, alert, time_label='days_ago'):
        path_out = os.path.join(self.config['path']['path_alerts'],
                                alert['_id'])

        if not os.path.exists(path_out):
            os.makedirs(path_out)

        dflc = make_dataframe(alert)

        filter_color = {1: 'green', 2: 'red', 3: 'pink'}
        if time_label == 'days_ago':
            now = Time.now().jd
            t = dflc.jd - now
            xlabel = f'Days Before {str(datetime.datetime.utcnow())} UTC'
        elif time_label == 'jd':
            t = dflc.jd
            xlabel = 'Date (JD)'
        elif time_label == 'datetime':
            t = Time(dflc.jd, format='jd').datetime
            xlabel = 'Date (UTC)'

        plt.close('all')
        fig = plt.figure()
        ax = fig.add_subplot(111)
        for fid, color in filter_color.items():
            ref_flux = None
            # plot detections in this filter:
            w = (dflc.fid == fid) & ~dflc.magpsf.isnull() & (dflc.distnr <= 5)
            if np.sum(w):
                # we want to plot (reference_flux + sign*difference_flux) -> mag
                sign = 2 * (dflc.loc[w, 'isdiffpos'].values == 't') - 1
                ref_mag = np.float64(dflc.loc[w].iloc[0]['magnr'])
                ref_flux = np.float64(10**(0.4 * (27 - ref_mag)))
                ref_sigflux = np.float64(dflc.iloc[0]['sigmagnr'] / 1.0857 *
                                         ref_flux)

                difference_flux = np.float64(
                    10**(0.4 * (27 - dflc.loc[w, 'magpsf'].values)))
                difference_sigflux = np.float64(
                    dflc.loc[w, 'sigmapsf'].values / 1.0857 * difference_flux)

                if not isinstance(difference_flux, np.ndarray):
                    difference_flux = np.array([difference_flux])
                if not isinstance(difference_sigflux, np.ndarray):
                    difference_sigflux = np.array([difference_sigflux])

                dc_flux = ref_flux + sign * difference_flux
                dc_sigflux = np.sqrt(difference_sigflux**2 + ref_sigflux**2)

                if not isinstance(dc_flux, np.ndarray):
                    dc_flux = np.array([dc_flux])
                if not isinstance(dc_sigflux, np.ndarray):
                    dc_sigflux = np.array([dc_sigflux])

                # mask bad values:
                w_good = dc_flux > 0
                # print(dc_flux)
                # print(dc_sigflux)
                # print(w_good)

                dc_mag = 27 - 2.5 * np.log10(dc_flux[w_good])
                dc_sigmag = dc_sigflux[w_good] / dc_flux[w_good] * 1.0857

                ax.errorbar(t[w][w_good],
                            dc_mag,
                            dc_sigmag,
                            fmt='.',
                            color=color)

            wnodet = (dflc.fid
                      == fid) & dflc.magpsf.isnull() & (dflc.diffmaglim > 0)
            if np.sum(wnodet) and (ref_flux is not None):
                # if we have a non-detection that means that there's no flux +/- 5 sigma from
                # the ref flux (unless it's a bad subtraction)
                difference_fluxlim = np.float64(
                    10**(0.4 * (27 - dflc.loc[wnodet, 'diffmaglim'].values)))
                dc_flux_ulim = ref_flux + difference_fluxlim
                dc_flux_llim = ref_flux - difference_fluxlim

                if not isinstance(dc_flux_ulim, np.ndarray):
                    dc_flux_ulim = np.array([dc_flux_ulim])
                if not isinstance(dc_flux_llim, np.ndarray):
                    dc_flux_llim = np.array([dc_flux_llim])

                # mask bad values:
                w_u_good = dc_flux_ulim > 0
                w_l_good = dc_flux_llim > 0

                dc_mag_ulim = 27 - 2.5 * np.log10(dc_flux_ulim[w_u_good])
                dc_mag_llim = 27 - 2.5 * np.log10(dc_flux_llim[w_l_good])
                ax.scatter(t[wnodet][w_u_good],
                           dc_mag_ulim,
                           marker='v',
                           color=color,
                           alpha=0.25)
                ax.scatter(t[wnodet][w_l_good],
                           dc_mag_llim,
                           marker='^',
                           color=color,
                           alpha=0.25)

        plt.gca().invert_yaxis()
        ax.set_xlabel(xlabel)
        ax.set_ylabel('Magnitude')

        plt.savefig(os.path.join(path_out, 'lightcurve.jpg'),
                    bbox_inches="tight",
                    pad_inches=0,
                    dpi=200)

    def dump_cutout(self, alert, save_fits=False):
        path_out = os.path.join(self.config['path']['path_alerts'],
                                alert['_id'])

        if not os.path.exists(path_out):
            os.makedirs(path_out)

        for tag in ('science', 'template', 'difference'):

            data = alert[f'cutout{tag.capitalize()}']['stampData']

            tmp = io.BytesIO()
            tmp.write(data)
            tmp.seek(0)

            # new format? try to decompress loss-less fits:
            try:
                decompressed_file = gzip.GzipFile(fileobj=tmp, mode='rb')

                with fits.open(decompressed_file) as dff:
                    if save_fits:
                        dff.writeto(os.path.join(path_out, f'{tag}.fits'),
                                    overwrite=True)
                    # print(dff[0].data)

                    img = dff[0].data

                    plt.close('all')
                    fig = plt.figure()
                    fig.set_size_inches(4, 4, forward=False)
                    ax = plt.Axes(fig, [0., 0., 1., 1.])
                    ax.set_axis_off()
                    fig.add_axes(ax)

                    # remove nans:
                    img = np.array(img)
                    img = np.nan_to_num(img)

                    if tag != 'difference':
                        # img += np.min(img)
                        img[img <= 0] = np.median(img)
                        plt.imshow(img,
                                   cmap='gray',
                                   norm=LogNorm(),
                                   origin='lower')
                    else:
                        plt.imshow(img, cmap='gray', origin='lower')
                    plt.savefig(os.path.join(path_out, f'{tag}.jpg'), dpi=50)

            # failed? try old jpg format
            except Exception as _e:
                traceback.print_exc()
                self.logger.error(str(_e))
                try:
                    tmp.seek(0)
                    Image.open(tmp).save(os.path.join(path_out, f'{tag}.jpg'))
                except Exception as _e:
                    traceback.print_exc()
                    self.logger.error(str(_e))
                    self.logger.error(
                        f'Failed to save stamp: {alert["_id"]} {tag}')

    def get_ps1_image(self, alert):
        """

        :param alert:
        :return:
        """
        # TODO: get PanSTARRS image
        pass

    def run(self, _all=False):
        # compute current UTC. the script is run everyday at 19:00 UTC (~noon in LA)
        utc_date = datetime.datetime.utcnow()
        utc_date = datetime.datetime(utc_date.year, utc_date.month,
                                     utc_date.day)

        # convert to jd
        jd_date = Time(utc_date).jd
        self.logger.info('Starting cycle: {} {}'.format(
            str(utc_date), str(jd_date)))

        if not _all:
            # grab last night only
            jd_start = jd_date
            jd_end = jd_date + 1
        else:
            # grab everything:
            utc_date_survey_start = datetime.datetime(2017, 9, 1)
            jd_date_survey_start = Time(utc_date_survey_start).jd
            jd_start = jd_date_survey_start
            jd_end = jd_date + 1

        # with open('/Users/dmitryduev/_caltech/python/ztf-wd/code/wds.20180811.json') as wdjson:
        with open(self.config['path']['path_wd_db']) as wdjson:
            wds = json.load(wdjson)['query_result']

        total_detected = 0

        matches_to_ingest = []

        # for batch_size run a cross match with ZTF_alerts for current UTC
        for ic, chunk in enumerate(chunks(wds, 1000)):
            self.logger.info(f'Chunk #{ic}')
            # print(chunk[0]['_id'])

            # {name: (ra, dec)}
            stars = {c['_id']: (c['ra'], c['dec']) for c in chunk}
            # print(stars)

            # run cone search on the batch
            matches = self.cross_match(_jd_start=jd_start,
                                       _jd_end=jd_end,
                                       _stars=stars,
                                       _fov_size_ref_arcsec=2,
                                       retries=3)

            self.logger.debug(list(matches.keys()))

            total_detected += len(matches)
            self.logger.info(
                f'total # of white dwarfs detected so far: {total_detected}')

            if len(matches) > 0:
                # get full WD info for matched objects:
                wds = self.get_doc_by_id(_coll='Gaia_DR2_WD',
                                         _ids=list(map(int, matches.keys())),
                                         retries=3)

                # append to corresponding matches
                self.logger.debug(list(matches.keys()))
                for match in matches.keys():
                    for alert in matches[match]:
                        alert['xmatch'] = dict()
                        alert['xmatch']['nearest_within_5_arcsec'] = {
                            'Gaia_DR2_WD': wds[int(match)]
                        }

                        self.logger.debug('{} {}'.format(
                            alert['_id'], alert['xmatch']
                            ['nearest_within_5_arcsec']['Gaia_DR2_WD']['_id']))

                        matches_to_ingest.append(alert)

                        # generate previews for the endpoint
                        self.dump_cutout(alert, save_fits=False)
                        self.dump_lightcurve(alert)

            # raise Exception('HALT!!')

        # collection_obs
        # ingest every matched object into own db. It's not that many, so just dump everything
        if len(matches_to_ingest) > 0:
            self.insert_multiple_db_entries(
                _collection=self.config['database']['collection_obs'],
                _db_entries=matches_to_ingest)

        self.logger.info(f'total # of white dwarfs detected: {total_detected}')

        self.logger.info('Creating indices')
        self.db['db'][self.config['database']['collection_obs']].create_index(
            [('coordinates.radec_geojson', '2dsphere')], background=True)
        self.db['db'][self.config['database']['collection_obs']].create_index(
            [('objectId', pymongo.ASCENDING)], background=True)
        self.db['db'][self.config['database']['collection_obs']].create_index(
            [('candid', pymongo.ASCENDING)], background=True)
        self.db['db'][self.config['database']['collection_obs']].create_index(
            [('candidate.programid', pymongo.ASCENDING)], background=True)
        self.db['db'][self.config['database']['collection_obs']].create_index(
            [('candidate.jd', pymongo.ASCENDING)], background=True)

        self.logger.info('All done')

    def shutdown(self):
        self.kowalski.close()
コード例 #29
0
ファイル: watcher.py プロジェクト: dmitryduev/tails
def sentinel(
    utc_start: Optional[str] = None,
    utc_stop: Optional[str] = None,
    twilight: Optional[bool] = False,
    test: Optional[bool] = False,
    verbose: Optional[bool] = False,
):
    """
    ZTF Sentinel service

    - Monitors the ZTF_ops collection on Kowalski for new ZTF data (Twilight only by default).
    - Uses dask.distributed to process individual ZTF image frames (ccd-quads).
      Each worker is initialized with a TailsWorker instance that maintains a Fritz connection and preloads Tails.
      The candidate comet detections, if any, are posted to Fritz together with auto-annotations
      (cross-matches from the MPC and SkyBot) and auxiliary data.

    :param utc_start: UTC start date/time in arrow-parsable format. If not set, defaults to (now - 1h)
    :param utc_stop: UTC stop date/time in arrow-parsable format. If not set, defaults to (now + 1h).
                     If set, program runs once
    :param twilight: process only the data of the ZTF Twilight survey
    :param test: run in test mode
    :param verbose: verbose?
    :return:
    """
    if verbose:
        log("Setting up MongoDB connection")

    init_db(config=config, verbose=verbose)

    mongo = Mongo(
        host=config["sentinel"]["database"]["host"],
        port=config["sentinel"]["database"]["port"],
        username=config["sentinel"]["database"]["username"],
        password=config["sentinel"]["database"]["password"],
        db=config["sentinel"]["database"]["db"],
        verbose=verbose,
    )
    if verbose:
        log("Set up MongoDB connection")

    collection = config["sentinel"]["database"]["collection"]

    # remove dangling entries in the db at startup
    mongo.db[collection].delete_many({"status": "processing"})

    # Configure dask client
    if verbose:
        log("Initializing dask.distributed client")
    dask_client = dask.distributed.Client(
        address=
        f"{config['sentinel']['dask']['host']}:{config['sentinel']['dask']['scheduler_port']}"
    )

    # init each worker with Worker instance
    if verbose:
        log("Initializing dask.distributed workers")
    worker_initializer = WorkerInitializer()
    dask_client.register_worker_plugin(worker_initializer, name="worker-init")

    if test:
        frame = "ztf_20191014495961_000570_zr_c05_o_q3"
        with timer(f"Submitting frame {frame} for processing", verbose):
            mongo.db[collection].update_one({"_id": frame},
                                            {"$set": {
                                                "status": "processing"
                                            }},
                                            upsert=True)
            future = dask_client.submit(process_frame, frame, pure=True)
            dask.distributed.fire_and_forget(future)
            future.release()
            del future
        return True

    if verbose:
        log("Setting up Kowalski connection")
    kowalski = Kowalski(
        token=config["kowalski"]["token"],
        protocol=config["kowalski"]["protocol"],
        host=config["kowalski"]["host"],
        port=config["kowalski"]["port"],
        verbose=verbose,
    )
    if verbose:
        log(f"Kowalski connection OK: {kowalski.ping()}")

    while True:
        try:
            # monitor the past 24 hours as sometimes there are data processing/posting delays at IPAC
            start = (arrow.get(utc_start) if utc_start is not None else
                     arrow.utcnow().shift(hours=-24))
            stop = (arrow.get(utc_stop)
                    if utc_stop is not None else arrow.utcnow().shift(hours=1))

            if (stop - start).total_seconds() < 0:
                raise ValueError("utc_stop must be greater than utc_start")

            if verbose:
                log(f"Looking for ZTF exposures between {start} and {stop}")

            kowalski_query = {
                "query_type": "find",
                "query": {
                    "catalog": "ZTF_ops",
                    "filter": {
                        "jd_start": {
                            "$gt": Time(start.datetime).jd,
                            "$lt": Time(stop.datetime).jd,
                        }
                    },
                    "projection": {
                        "_id": 0,
                        "fileroot": 1
                    },
                },
            }

            if twilight:
                kowalski_query["query"]["filter"]["qcomment"] = {
                    "$regex": "Twilight"
                }

            response = kowalski.query(query=kowalski_query).get("data", dict())
            file_roots = sorted([entry["fileroot"] for entry in response])

            frame_names = [
                f"{file_root}_c{ccd:02d}_o_q{quad:1d}"
                for file_root in file_roots for ccd in range(1, 17)
                for quad in range(1, 5)
            ]

            if verbose:
                log(f"Found {len(frame_names)} ccd-quad frames")
                log(frame_names)

            processed_frames = [
                frame["_id"] for frame in mongo.db[collection].find(
                    {
                        "_id": {
                            "$in": frame_names
                        },
                        "status": {
                            "$in": ["processing", "success"]
                        },
                    },
                    {"_id": 1},
                )
            ]
            if verbose:
                log(processed_frames)

            unprocessed_frames = set(frame_names) - set(processed_frames)

            for frame in unprocessed_frames:
                with timer(f"Submitting frame {frame} for processing",
                           verbose):
                    mongo.db[collection].update_one(
                        {"_id": frame}, {"$set": {
                            "status": "processing"
                        }},
                        upsert=True)
                    future = dask_client.submit(process_frame,
                                                frame,
                                                pure=True)
                    dask.distributed.fire_and_forget(future)
                    future.release()
                    del future

        except Exception as e:
            log(e)

        # run once if utc_stop is set
        if utc_stop is not None:
            break
        else:
            log("Heartbeat")
            time.sleep(60)
コード例 #30
0
def query_kowalski_alerts(username,
                          password,
                          date_start,
                          date_end,
                          catalog='ZUDS_alerts',
                          min_days=None,
                          starthist=None):
    '''Query alerts with kowalski and apply the selection criteria'''

    k = Kowalski(username=username, password=password, verbose=False)

    # Initialize a set for the results
    set_objectId_all = set([])
    print(date_start.jd, date_end.jd, starthist.jd)
    q = {
        "query_type": "find",
        "query": {
            "catalog": catalog,
            "filter": {
                "candidate.jd": {
                    '$gt': date_start.jd,
                    '$lt': date_end.jd
                },
                "candidate.drb": {
                    '$gt': 0.6
                },
                "classifications.braai": {
                    '$gt': 0.6
                },
                "candidate.jdstarthist_single": {
                    '$gt': starthist.jd
                },
                "candidate.fwhm": {
                    '$gt': 0.5,
                    '$lt': 8
                },
            },
            "projection": {
                "objectId": 1,
                "candid": 1,
                "candidate.rcid": 1,
                "candidate.ra": 1,
                "candidate.dec": 1,
                "candidate.jd": 1,
                "candidate.ndethist": 1,
                "candidate.jdstarthist_single": 1,
                "candidate.jdstarthist_stack": 1,
                "candidate.jdendhist_single": 1,
                "candidate.jdendhist_stack": 1,
                "candidate.magpsf": 1,
                "candidate.sigmapsf": 1,
                "candidate.fid": 1,
                "candidate.programid": 1,
                "candidate.isdiffpos": 1,
                "candidate.ndethist": 1,
                "candidate.ssdistnr": 1,
                "candidate.rb": 1,
                "candidate.drb": 1,
                "candidate.distpsnr1": 1,
                "candidate.sgscore1": 1,
                "candidate.srmag1": 1,
                "candidate.distpsnr2": 1,
                "candidate.sgscore2": 1,
                "candidate.srmag2": 1,
                "candidate.distpsnr3": 1,
                "candidate.sgscore3": 1,
                "candidate.srmag3": 1,
                "candidate.fwhm": 1,
                "candidate.lstype1": 1,
                "candidate.lszspec1": 1,
                "candidate.lsz1": 1,
                "candidate.lszphotl681": 1,
                "candidate.alert_type": 1
            }
        },
        "kwargs": {}  # {"limit": 3}
    }

    # Perform the query
    r = k.query(query=q)
    result = r['result_data']['query_result']

    set_names = set(
        list(c['objectId'] for c in r['result_data']['query_result']))
    print(f"There are {len(set_names)} sources found")

    # Match with CLU
    list_clu = match_kowalski_clu(username, password, set_names)
    print(f"{len(list_clu)} sources were found matched with CLU galaxies")

    # Stricted selection
    reject = []
    done = []
    done_names = []

    for info in r['result_data']['query_result']:
        if info['objectId'] == 'ZUDS20esmwf':
            import pdb
            pdb.set_trace()
        if not (info['objectId'] in list_clu) or info[
                'objectId'] in done_names or info['objectId'] in reject:
            continue
        # If single alerts (not stack), check that there is enough
        # time separation between first and last observation
        try:
            if info['candidate']['alert_type'] == 'single':
                if info['candidate']['jdendhist_single'] - info['candidate'][
                        'jdstarthist_single'] > min_days:
                    pass
                else:
                    reject.append(info['objectId'])
                    continue
            else:
                pass
        except (KeyError, ValueError, TypeError):
            pass

        try:
            if (np.abs(info['candidate']['distpsnr1']) < 1.
                    and info['candidate']['sgscore1'] >= 0.60):
                reject.append(info['objectId'])
                continue
        except (KeyError, ValueError, TypeError):
            pass
        #try:
        #    if (np.abs(info['candidate']['distpsnr1']) < 1. and
        #    info['candidate']['lstype1'] == 'PSF'):
        #        reject.append(info['objectId'])
        #        continue
        #except (KeyError, ValueError, TypeError):
        #    pass
        try:
            if (np.abs(info['candidate']['distpsnr1']) < 1.
                    and info['candidate']['lsz1'] < 999.
                    and info['candidate']['lszspec1'] > 0.1):
                reject.append(info['objectId'])
                continue
        except (KeyError, ValueError, TypeError):
            pass
        try:
            if (np.abs(info['candidate']['distpsnr1']) < 1.
                    and info['candidate']['lsz1'] < 21.
                    and info['candidate']['lszphotl681'] > 0.1):
                reject.append(info['objectId'])
                continue
        except (KeyError, ValueError, TypeError):
            pass
        try:
            if (np.abs(info['candidate']['distpsnr1']) < 15.
                    and info['candidate']['srmag1'] < 15.
                    and info['candidate']['srmag1'] > 0.
                    and info['candidate']['sgscore1'] >= 0.49):
                reject.append(info['objectId'])
                continue
        except (KeyError, ValueError, TypeError):
            pass
        try:
            if (np.abs(info['candidate']['distpsnr2']) < 15.
                    and info['candidate']['srmag2'] < 15.
                    and info['candidate']['srmag2'] > 0.
                    and info['candidate']['sgscore2'] >= 0.49):
                reject.append(info['objectId'])
                continue
        except (KeyError, ValueError, TypeError):
            pass
        try:
            if (np.abs(info['candidate']['distpsnr3']) < 15.
                    and info['candidate']['srmag3'] < 15.
                    and info['candidate']['srmag3'] > 0.
                    and info['candidate']['sgscore3'] >= 0.49):
                reject.append(info['objectId'])
                continue
        except (KeyError, ValueError, TypeError):
            pass
        done_names.append(info['objectId'])
        done.append((info['objectId'], info['candid']))

    # Check that no source was kept among the rejected ones
    checked = list(d for d in done if not (d[0] in reject))
    checked_names = list(c[0] for c in checked)

    print(f"{len(done)} sources survived stellarity and bright sources cuts")

    # Check history for negative subtractions in ZUDS and ZTF alerts
    list_selected = check_history(checked_names)
    print(
        f"{len(list_selected)} sources have no historical negative detections")

    sources = list({
        'name': c[0],
        'candid': c[1]
    } for c in checked if c[0] in list_selected)
    return sources