コード例 #1
0
ファイル: plot.py プロジェクト: chengsoonong/crowdastro
def ra_dec_to_pixels(subject_coords, coords):
    if wcs is None:
        _init_wcs()

    offset, = wcs.all_world2pix([subject_coords], FITS_CONVENTION)
    # The coords are of the middle of the subject.
    coords = wcs.all_world2pix(coords, FITS_CONVENTION)
    coords -= offset
    
    coords[:, 0] /= config['surveys']['atlas']['mosaic_scale_x'] * 424 / 200
    coords[:, 1] /= config['surveys']['atlas']['mosaic_scale_y'] * 424 / 200
    
    coords += [40, 40]
    
    return coords
コード例 #2
0
def create_aperture(prop, filter, pclass, extent=3.5, coord_bat=None, wcs=None):
    
    if prop is None:
        position = wcs.all_world2pix([[coord_bat.ra.value, coord_bat.dec.value]], 0)
        ap = CircularAperture(position, PS_SRC_APS[filter]/PIX_SIZES[filter])
        type = 'fixed'
    elif pclass == 'E':

        position = [prop.xcentroid.value, prop.ycentroid.value]
        sa = prop.semimajor_axis_sigma.value * extent
        sb = prop.semiminor_axis_sigma.value * extent
        theta = prop.orientation.value

        if sa > PS_SRC_APS[filter]/PIX_SIZES[filter]:
            ap = EllipticalAperture(position, sa, sb, theta=theta)
            type = 'extended'
        else:
            ap = CircularAperture(position, PS_SRC_APS[filter]/PIX_SIZES[filter])
            type = 'point'
       
    elif pclass == 'P':
        
        position = [prop.xcentroid.value, prop.ycentroid.value]
        ap = CircularAperture(position, PS_SRC_APS[filter]/PIX_SIZES[filter])
        type = 'point'
    
    return ap, type
コード例 #3
0
ファイル: masking.py プロジェクト: nanten2/n2-tools
def mask_inpolygon(hdu, polygon, axis=('x', 'y')):
    logger.info('(mask_inpolygon) polygon={polygon}, axis={axis}'.format(**locals()))
    logger.info('(mask_inpolygon) start calculation')
    wcs = astropy.wcs.WCS(hdu.header)
    polygon_p = wcs.all_world2pix(polygon, 0)
    path = matplotlib.path.Path(polygon_p)
    ax1 = numpy.arange(hdu.data.shape[1])
    ax2 = numpy.arange(hdu.data.shape[0])
    ax12 = numpy.array([_.ravel() for _ in numpy.meshgrid(ax1, ax2)]).T
    mask = path.contains_points(ax12).reshape(hdu.data.shape).astype(int)
    logger.info('(mask_inpolygon) done')
    
    new_header = hdu.header.copy()
    new_hdu = astropy.io.fits.PrimaryHDU(mask, new_header)
    return new_hdu
コード例 #4
0
ファイル: rgz_data.py プロジェクト: chengsoonong/crowdastro
def get_potential_hosts(subject, cache_name, convert_to_px=True):
    """Finds the potential hosts for a subject.

    subject: RGZ subject dict.
    cache_name: Name of Gator cache.
    convert_to_px: Whether to convert coordinates to pixels. Default True; if
        False then coordinates will be RA/DEC.
    -> dict mapping (x, y) tuples to
        - flux at 3.6μm for aperture #2
        - flux at 4.5μm for aperture #2
        - flux at 5.8μm for aperture #2
        - flux at 8.0μm for aperture #2
        - flux at 24μm for aperture #2
        - stellarity index at 3.6μm
        - uncertainty in RA
        - uncertainty in DEC
    """

    if subject['metadata']['source'].startswith('C'):
        # CDFS
        catalog = 'chandra_cat_f05'
    else:
        # ELAIS-S1
        catalog = 'elaiss1_cat_f05'
    
    query = {
        'catalog': catalog,
        'spatial': 'box',
        'objstr': '{} {}'.format(*subject['coords']),
        'size': '120',
        'outfmt': '3',
    }
    url = 'http://irsa.ipac.caltech.edu/cgi-bin/Gator/nph-query'

    r = requests.get(url, params=query)
    votable = astropy.io.votable.parse_single_table(io.BytesIO(r.content),
                                                    pedantic=False)
    
    ras = votable.array['ra']
    decs = votable.array['dec']

    if convert_to_px:
        # Convert to px.
        fits = get_ir_fits(subject)
        wcs = astropy.wcs.WCS(fits.header)
        xs, ys = wcs.all_world2pix(ras, decs, 0)
    else:
        xs, ys = ras, decs
    
    # Get the astronomical features.
    out = {}  # Maps (x, y) to astronomical features.
    for x, y, row_idx in zip(xs, ys, range(votable.nrows)):
        row = votable.array[row_idx]
        out[x, y] = {
            'name': row['object'],
            'clon': row['clon'],
            'clat': row['clat'],
            'flux_ap2_36': row['flux_ap2_36'],
            'flux_ap2_45': row['flux_ap2_45'],
            'flux_ap2_58': row['flux_ap2_58'],
            'flux_ap2_80': row['flux_ap2_80'],
            'flux_ap2_24': row['flux_ap2_24'],
            'stell_36': row['stell_36'],
            'unc_ra': row['unc_ra'],
            'unc_dec': row['unc_dec'],
        }

    return out
コード例 #5
0
def import_classifications(f_h5, test=False):
    """Imports Radio Galaxy Zoo classifications into crowdastro.

    f_h5: An HDF5 file.
    test: Flag to run on only 10 subjects. Default False.
    """
    # TODO(MatthewJA): This only works for ATLAS/CDFS. Generalise.
    from . import rgz_data as data
    atlas_positions = f_h5['/atlas/cdfs/numeric'][:, :2]
    atlas_ids = f_h5['/atlas/cdfs/string']['zooniverse_id']
    classification_positions = []
    classification_combinations = []
    classification_usernames = []

    with astropy.io.fits.open(
            # RGZ only has cdfs classifications
            config['data_sources']['atlas_cdfs_image'],
            ignore_blank=True) as atlas_image:
        wcs = astropy.wcs.WCS(atlas_image[0].header).dropaxis(3).dropaxis(2)

    for obj_index, atlas_id in enumerate(atlas_ids):
        subject = data.get_subject(atlas_id.decode('ascii'))
        assert subject['zooniverse_id'] == atlas_ids[obj_index].decode('ascii')
        classifications = data.get_subject_classifications(subject)
        offset, = wcs.all_world2pix([subject['coords']], FITS_CONVENTION)
        # The coords are of the middle of the subject.
        offset[0] -= (config['surveys']['atlas']['fits_width'] *
                      config['surveys']['atlas']['mosaic_scale_x'] // 2)
        offset[1] -= (config['surveys']['atlas']['fits_height'] *
                      config['surveys']['atlas']['mosaic_scale_y'] // 2)

        for c_index, classification in enumerate(classifications):
            user_name = classification.get('user_name', '').encode(
                    'ascii', errors='ignore')
            # Usernames actually don't have an upper length limit on RGZ(?!) so
            # I'll cap everything at 50 characters for my own sanity.
            if len(user_name) > 50:
                user_name = user_name[:50]

            classification = parse_classification(classification, subject,
                                                  atlas_positions, wcs, offset)
            full_radio = '|'.join(classification.keys())
            for radio, locations in classification.items():
                if not locations:
                    locations = [(None, None)]

                for click_index, location in enumerate(locations):
                    # Check whether the click index is 0 to maintain the
                    # assumption that we only need the first click.
                    pos_row = (obj_index, location[0], location[1],
                               click_index == 0)
                    com_row = (obj_index, full_radio, radio)
                    # A little redundancy here with the index, but we can assert
                    # that they are the same later to check integrity.
                    classification_positions.append(pos_row)
                    classification_combinations.append(com_row)
                    classification_usernames.append(user_name)

    combinations_dtype = [('index', 'int'),
                          ('full_signature', '<S{}'.format(
                                    MAX_RADIO_SIGNATURE_LENGTH)),
                          ('signature', '<S{}'.format(
                                    MAX_RADIO_SIGNATURE_LENGTH))]
    classification_positions = numpy.array(classification_positions,
                                           dtype=float)
    classification_combinations = numpy.array(classification_combinations,
                                              dtype=combinations_dtype)

    f_h5['/atlas/cdfs/'].create_dataset('classification_positions',
                                        data=classification_positions,
                                        dtype=float)
    f_h5['/atlas/cdfs/'].create_dataset('classification_usernames',
                                        data=classification_usernames,
                                        dtype='<S50')
    f_h5['/atlas/cdfs/'].create_dataset('classification_combinations',
                                        data=classification_combinations,
                                        dtype=combinations_dtype)
コード例 #6
0
def import_atlas(f_h5, test=False, field='cdfs'):
    """Imports the ATLAS dataset into crowdastro, as well as associated SWIRE.

    f_h5: An HDF5 file.
    test: Flag to run on only 10 subjects. Default False.
    """
    from . import rgz_data as data

    # Fetch groups from HDF5.
    cdfs = f_h5['/atlas/{}'.format(field)]

    # First pass, I'll find coords, names, and Zooniverse IDs, as well as how
    # many data points there are.

    coords = []
    names = []
    zooniverse_ids = []

    if (field == 'cdfs'):
        # We need the ATLAS name, but we can only get it by going through the
        # ATLAS catalogue and finding the nearest component.
        # https://github.com/chengsoonong/crowdastro/issues/63
        # Fortunately, @jbanfield has already done this, so we can just load
        # that CSV and match the names.
        # TODO(MatthewJA): This matches the ATLAS component ID, but maybe we
        # should be using the name instead.
        rgz_to_atlas = {}
        with open(config['data_sources']['rgz_to_atlas']) as f:
            reader = csv.DictReader(f)
            for row in reader:
                rgz_to_atlas[row['ID_RGZ']] = row['ID']

        all_subjects = data.get_all_subjects(survey='atlas', field=field)
        if test:
            all_subjects = all_subjects.limit(10)

        for subject in all_subjects:
            ra, dec = subject['coords']
            zooniverse_id = subject['zooniverse_id']

            rgz_source_id = subject['metadata']['source']
            if rgz_source_id not in rgz_to_atlas:
                logging.debug('Skipping %s; no matching ATLAS component.',
                              zooniverse_id)
                continue
            name = rgz_to_atlas[rgz_source_id]

            # Store the results.
            coords.append((ra, dec))
            names.append(name)
            zooniverse_ids.append(zooniverse_id)

    elif (field == 'elais'):
        atlascatalogue = ascii.read(config['data_sources']['atlas_catalogue'])
        ras, decs = atlascatalogue['RA_deg'], atlascatalogue['Dec_deg']
        e_ids = atlascatalogue['ID']
        fields = atlascatalogue['field']

        # Store the results.
        for ra, dec, e_id, field_ in zip(ras, decs, e_ids, fields):
            if (field_ == 'ELAIS-S1'):
                coords.append((ra, dec))
                names.append(e_id)
                zooniverse_ids.append(e_id)

    n_cdfs = len(names)

    # Sort the data by Zooniverse ID.
    coords_to_zooniverse_ids = dict(zip(coords, zooniverse_ids))
    names_to_zooniverse_ids = dict(zip(names, zooniverse_ids))

    coords.sort(key=coords_to_zooniverse_ids.get)
    names.sort(key=names_to_zooniverse_ids.get)
    zooniverse_ids.sort()

    # Begin to store the data. We will have two tables: one for numeric data,
    # and one for strings. We will have to preallocate the numeric table so that
    # we aren't storing huge amounts of image data in memory.

    # Strings.
    dtype = [('zooniverse_id', '<S{}'.format(MAX_ZOONIVERSE_ID_LENGTH)),
             ('name', '<S{}'.format(MAX_NAME_LENGTH))]
    string_data = numpy.array(list(zip(zooniverse_ids, names)), dtype=dtype)
    cdfs.create_dataset('string', data=string_data, dtype=dtype)

    # Numeric.
    image_size = (config['surveys']['atlas']['fits_width'] *
                  config['surveys']['atlas']['fits_height'])
    # RA, DEC, radio, (distance to SWIRE object added later)
    dim = (n_cdfs, 1 + 1 + image_size)
    numeric = cdfs.create_dataset('_numeric', shape=dim, dtype='float32')

    # Load image patches and store numeric data.
    with astropy.io.fits.open(
            config['data_sources']['atlas_{}_image'.format(field)],
            ignore_blank=True) as atlas_image:
        wcs = astropy.wcs.WCS(atlas_image[0].header).dropaxis(3).dropaxis(2)
        pix_coords = wcs.all_world2pix(coords, FITS_CONVENTION)
        assert pix_coords.shape[1] == 2
        logging.debug('Fetching %d ATLAS images.', len(pix_coords))
        for index, (x, y) in enumerate(pix_coords):
            radio = atlas_image[0].data[
                0, 0,  # stokes, freq
                int(y) - config['surveys']['atlas']['fits_height'] // 2:
                int(y) + config['surveys']['atlas']['fits_height'] // 2,
                int(x) - config['surveys']['atlas']['fits_width'] // 2:
                int(x) + config['surveys']['atlas']['fits_width'] // 2]
            numeric[index, 0] = coords[index][0]
            numeric[index, 1] = coords[index][1]
            numeric[index, 2:2 + image_size] = radio.reshape(-1)

    logging.debug('ATLAS imported.')
コード例 #7
0
def import_wise(f_h5, field='cdfs'):
    """Imports the WISE dataset into crowdastro.

    f_h5: An HDF5 file.
    field: 'cdfs' or 'elais'.
    """
    names = []
    rows = []
    logging.debug('Reading WISE catalogue.')
    with open(
            config['data_sources']['wise_{}_catalogue'.format(field)]) as f_tbl:
        # This isn't a valid ASCII table, so Astropy can't handle it. This means
        # we have to parse it manually.
        for _ in range(105):  # Skip the first 105 lines.
            next(f_tbl)

        # Get the column names.
        columns = [c.strip() for c in next(f_tbl).strip().split('|')][1:-1]
        assert len(columns) == 45

        for _ in range(3):  # Skip the next three lines.
            next(f_tbl)

        for row in f_tbl:
            row = row.strip().split()
            assert len(row) == 45
            row = dict(zip(columns, row))
            name = row['designation']
            ra = float(row['ra'])
            dec = float(row['dec'])
            w1mpro = float(remove_nulls(row['w1mpro']))
            w2mpro = float(remove_nulls(row['w2mpro']))
            w3mpro = float(remove_nulls(row['w3mpro']))
            w4mpro = float(remove_nulls(row['w4mpro']))
            # Extra -1 is so we can store nearest distance later.
            rows.append((ra, dec, w1mpro, w2mpro, w3mpro, w4mpro, -1))
            names.append(name)

    logging.debug('Found %d WISE objects.', len(names))

    # Sort by name.
    rows_to_names = dict(zip(rows, names))
    rows.sort(key=rows_to_names.get)
    names.sort()

    names = numpy.array(names, dtype='<S{}'.format(MAX_NAME_LENGTH))
    rows = numpy.array(rows)

    # Filter on distance - only include image data for WISE objects within a
    # given radius of an ATLAS object. Otherwise, there's way too much data to
    # store.
    wise_positions = rows[:, :2]
    atlas_positions = f_h5['/atlas/{}/_numeric'.format(field)][:, :2]
    logging.debug('Computing WISE k-d tree.')
    wise_tree = sklearn.neighbors.KDTree(wise_positions, metric='euclidean')
    indices = numpy.concatenate(
            wise_tree.query_radius(atlas_positions, CANDIDATE_RADIUS))
    indices = numpy.unique(indices)

    logging.debug('Found %d WISE objects near ATLAS objects.', len(indices))

    names = names[indices]
    rows = rows[indices]
    wise_positions = wise_positions[indices]

    # Get distances.
    logging.debug('Finding ATLAS-WISE object distances.')
    distances = scipy.spatial.distance.cdist(atlas_positions, wise_positions,
                                             'euclidean')
    assert distances.shape[0] == atlas_positions.shape[0]
    assert distances.shape[1] == wise_positions.shape[0]
    logging.debug('Done finding distances.')

    # Write numeric data to HDF5.
    rows[:, 6] = distances.min(axis=0)
    atlas_numeric = f_h5['/atlas/{}/_numeric'.format(field)]
    f_h5['/atlas/{}'.format(field)].create_dataset(
        'numeric', dtype='float32',
        shape=(atlas_numeric.shape[0],
               atlas_numeric.shape[1] + len(indices)))
    numeric_f = f_h5['/atlas/{}/numeric'.format(field)]
    numeric_f[:, :atlas_numeric.shape[1]] = atlas_numeric
    numeric_f[:, atlas_numeric.shape[1]:] = distances

    del f_h5['/atlas/{}/_numeric'.format(field)]

    image_size = (PATCH_RADIUS * 2) ** 2
    dim = (rows.shape[0], rows.shape[1] + image_size)
    numeric = f_h5['/wise/{}'.format(field)].create_dataset(
        'numeric', shape=dim, dtype='float32')
    numeric[:, :rows.shape[1]] = rows
    f_h5['/wise/{}'.format(field)].create_dataset('string', data=names)

    # Load and store radio images.
    logging.debug('Importing radio patches.')
    with astropy.io.fits.open(
            config['data_sources']['atlas_{}_image'.format(field)],
            ignore_blank=True) as atlas_image:
        wcs = astropy.wcs.WCS(atlas_image[0].header).dropaxis(3).dropaxis(2)
        pix_coords = wcs.all_world2pix(wise_positions, FITS_CONVENTION)
        assert pix_coords.shape[1] == 2
        assert pix_coords.shape[0] == len(indices)
        logging.debug('Fetching %d ATLAS patches.', len(indices))

        for index, (x, y) in enumerate(pix_coords):
            radio = atlas_image[0].data[
                0, 0,  # stokes, freq
                int(y) - PATCH_RADIUS:
                int(y) + PATCH_RADIUS,
                int(x) - PATCH_RADIUS:
                int(x) + PATCH_RADIUS]
            numeric[index, -image_size:] = radio.reshape(-1)
コード例 #8
0
def main(examples=None, classifier='CNN', labeller='Norris'):
    # Load SWIRE stuff.
    swire_names, swire_coords, swire_features = pipeline.generate_swire_features(overwrite=False)
    swire_labels = pipeline.generate_swire_labels(swire_names, swire_coords, overwrite=False)
    _, (_, swire_test_sets) = pipeline.generate_data_sets(swire_coords, overwrite=False)
    swire_tree = KDTree(swire_coords)
    swire_name_to_index = {n: i for i, n in enumerate(swire_names)}
    # Load ATLAS coords.
    table = astropy.io.ascii.read(pipeline.TABLE_PATH)
    atlas_to_coords = {}
    atlas_to_swire_coords = {}
    for row in table:
        name = row['Component Name (Franzen)']
        if not name:
            continue

        atlas_to_coords[name] = row['Component RA (Franzen)'], row['Component DEC (Franzen)']
        index = swire_name_to_index.get(row['Source SWIRE (Norris)'] or '')
        if index:
            atlas_to_swire_coords[name] = swire_coords[index]

    ir_stretch = astropy.visualization.LogStretch(0.001)
    if examples is None:
        examples = examples_incorrect.get_examples()
        examples = examples[labeller, classifier, 'All']
    for example in examples:
        print('Plotting {}'.format(example))
        predictor_name = '{}_{}'.format(classifier, labeller)
        cid = example[2]
        # Load FITS stuff.
        try:
            radio_fits = astropy.io.fits.open(CDFS_PATH + cid + '_radio.fits')
        except FileNotFoundError:
            if example[1]:  # Has Zooniverse ID
                print('{} not in RGZ'.format(cid))
            continue
        ir_fits = astropy.io.fits.open(CDFS_PATH + cid + '_ir.fits')
        wcs = astropy.wcs.WCS(radio_fits[0].header)
        # Compute info for contour levels. (also from Enno Middelberg)
        median = numpy.median(radio_fits[0].data)
        mad = numpy.median(numpy.abs(radio_fits[0].data - median))
        sigma = mad / mad2sigma
        # Set up the plot.
        fig = plt.figure()
        ax = astropy.visualization.wcsaxes.WCSAxes(
            fig, [0.1, 0.1, 0.8, 0.8], wcs=wcs)
        fig.add_axes(ax)
        ax.set_title('{}'.format(example[0], example[1]))
        # Show the infrared.
        ax.imshow(ir_stretch(ir_fits[0].data), cmap='cubehelix_r',
                  origin='lower')
        # Show the radio.
        ax.contour(radio_fits[0].data, colors='black',
                    levels=[nsig * sigma * sigmult ** i for i in range(15)],
                    linewidths=1, origin='lower', zorder=1)
        # Plot predictions.
        predictions = get_predictions(swire_tree, swire_coords, swire_test_sets, atlas_to_coords[example[0]], predictor_name)
        if not predictions:
            print('No predictions for {}'.format(example[0]))
            continue
        coords = [p[0] for p in predictions]
        probabilities = [p[1] for p in predictions]
        coords = wcs.all_world2pix(coords, 1)
        ax.scatter(coords[:, 0], coords[:, 1], s=numpy.sqrt(numpy.array(probabilities)) * 200, color='white', edgecolor='black', linewidth=1, alpha=0.9, marker='o', zorder=2)
        choice = numpy.argmax(probabilities)
        ax.scatter(coords[choice, 0], coords[choice, 1], s=200 / numpy.sqrt(2), color='blue', marker='x', zorder=2.5)
        try:
            norris_coords, = wcs.all_world2pix([atlas_to_swire_coords[example[0]]], 1)
        except KeyError:
            print('No Norris cross-identification for {}'.format(example[0]))
            continue
        ax.scatter(norris_coords[0], norris_coords[1], marker='+', s=200, zorder=3, color='green')
        lon, lat = ax.coords
        lon.set_major_formatter('hh:mm:ss')
        lon.set_axislabel('Right Ascension')
        lat.set_axislabel('Declination')
        fn = '{}_{}_{}'.format(classifier, labeller, example[0])
        plt.savefig('/Users/alger/repos/crowdastro-projects/ATLAS-CDFS/images/examples/' + fn + '.png',
            bbox_inches='tight', pad_inches=0)
        plt.savefig('/Users/alger/repos/crowdastro-projects/ATLAS-CDFS/images/examples/' + fn + '.pdf',
            bbox_inches='tight', pad_inches=0)
        plt.clf()
コード例 #9
0
ファイル: makequickview.py プロジェクト: rkotulla/sdss
def make_image(img_fn, weight_fn, output_fn, cutout=None, min_max=None, nsigma=[-5,+5], scale='linear'):

    hdulist = fits.open(img_fn)
    data = hdulist[0].data

    if (os.path.isfile(weight_fn)):
        weight_hdulist = fits.open(weight_fn)
        weight_map = weight_hdulist[0].data
    else:
        weight_map = 1

    #
    # Mask out all areas ouside the covered f.o.v.
    #
    data[weight_map <= 0] = numpy.NaN
    valid_data = data[weight_map > 0]

    good_data = weight_map > 0

    ####
    if (not cutout == None):
        print "prepping cutout", cutout
        ra,dec,size,coord = cutout
        #print math.degrees(float(ra)), math.degrees(float(dec))
        #print hdulist[0].header
        wcs = astropy.wcs.WCS(header=hdulist[0].header)
        #print wcs
        ra,dec = math.degrees(float(ra)), math.degrees(float(dec))
        #ra,dec = coord
        x,y = wcs.all_world2pix(ra,dec,0)
        print ra, dec, "-->", x,y
        # position = (49.7, 100.1)
        # size = (40, 50)  # pixels
        # cutout = astropy.nddata.Cutout2D(data, position, size)

        _x = int(x)
        _y = int(y)
        data = data[_y-500:_y+500, _x-500:_x+500]

    if (min_max == None):
        for iteration in range(3):
            qs = numpy.nanpercentile(data[good_data], q=[16, 50, 84])

            _median = qs[1]
            _sigma = (qs[2] - qs[0]) / 2.
            _mingood = _median - 3 * _sigma
            _maxgood = _median + 3 * _sigma
            good_data[(data < _mingood) | (data > _maxgood)] = False

            print iteration, _median, _sigma, _mingood, _maxgood

        #
        # Good cuts are from -5sigma - 5*sigma
        #
        print nsigma
        min_level = _median + nsigma[0] * _sigma
        max_level = _median + nsigma[1] * _sigma
    else:
        min_level, max_level = min_max

    print min_level, max_level
    greyscale = (data - min_level) / (max_level - min_level)

    if (scale == "arcsinh"):
        print "Applying arcsinh contrast adjustment"
        #greyscale = greyscale / 10.
        numpy.arcsinh(greyscale,out=greyscale) / numpy.arcsinh(1.)
        #greyscale = numpy.log10(greyscale)  # / numpy.arcsinh(1.)


    greyscale[greyscale < 0.] = 0.
    greyscale[greyscale >= 1.] = 1.


    print output_fn
    image = Image.fromarray(numpy.uint8(greyscale * 255))
    image = image.transpose(Image.FLIP_TOP_BOTTOM)
    image.save(output_fn, "JPEG")
    del image

    return (min_level, max_level)
コード例 #10
0
def test(inputs_h5, inputs_csv, training_h5, classifier_path,
         astro_transformer_path, image_transformer_path, use_astro=True,
         use_cnn=True):
    classifier = sklearn.externals.joblib.load(classifier_path)
    astro_transformer = sklearn.externals.joblib.load(astro_transformer_path)
    image_transformer = sklearn.externals.joblib.load(image_transformer_path)

    testing_indices = inputs_h5['/atlas/cdfs/testing_indices'].value
    swire_positions = inputs_h5['/swire/cdfs/catalogue'][:, :2]
    atlas_positions = inputs_h5['/atlas/cdfs/positions'].value
    all_astro_inputs = training_h5['astro'].value
    all_cnn_inputs = training_h5['cnn_outputs'].value
    all_labels = training_h5['labels'].value

    atlas_counts = {}  # ATLAS ID to number of objects in that subject.
    for consensus in inputs_h5['/atlas/cdfs/consensus_objects']:
        atlas_id = int(consensus[0])
        atlas_counts[atlas_id] = atlas_counts.get(atlas_id, 0) + 1

    simple_indices = []
    for atlas_id, count in atlas_counts.items():
        if count == 1 and atlas_id in testing_indices:
            simple_indices.append(atlas_id)
    print(simple_indices)

    atlas_positions = inputs_h5['/atlas/cdfs/positions']
    csvdicts = list(csv.DictReader(inputs_csv))
    for index, position in enumerate(atlas_positions[:100]):
        if index not in testing_indices or index not in simple_indices:
            continue

        swire_positions = training_h5['positions']
        swire_tree = sklearn.neighbors.KDTree(swire_positions)
        neighbours, distances = swire_tree.query_radius([position], ARCMIN,
                                                        return_distance=True)
        neighbours = neighbours[0]
        distances = distances[0]
        nearest_positions = swire_positions.value[neighbours]
        astro_inputs = all_astro_inputs[neighbours]
        astro_inputs[:, -1] = distances
        cnn_inputs = all_cnn_inputs[neighbours]
        labels = all_labels[neighbours]

        astro_inputs = astro_transformer.transform(astro_inputs)
        cnn_inputs = image_transformer.transform(cnn_inputs)

        probs = classifier.predict_proba(numpy.hstack([astro_inputs, cnn_inputs]))
    
        for row in csvdicts:
            if int(row['index']) == index and row['survey'] == 'atlas':
                zid = row['zooniverse_id']
                header = row['header']
                break
        subject = get_subject(zid)
        ir = get_ir(subject)
        wcs = astropy.wcs.WCS(astropy.io.fits.Header.fromstring(header))
        points = wcs.all_world2pix(nearest_positions, 1)
        print(zid)
        plot.figure(figsize=(20, 10))

        plot.subplot(1, 2, 1)
        plot.imshow(ir, cmap='gray', norm=co.LogNorm(vmin=ir.min(), vmax=ir.max()))
        contours(subject)
        plot.xlim((0, 200))
        plot.ylim((0, 200))
        plot.axis('off')
        plot.scatter(points[:, 0] - 151, points[:, 1] - 151, zorder=100, c=probs[:, 1], cmap='cool', s=100)

        plot.subplot(1, 2, 2)
        plot.scatter(range(len(probs)), sorted(probs[:, 1]), c=sorted(probs[:, 1]), cmap='cool', linewidth=0, s=100)
        plot.xlim((0, len(probs)))
        plot.ylim((0, 1))
        plot.xlabel('SWIRE object index')
        plot.ylabel('Classifier probability')
        plot.show()