Beispiel #1
0
def get_predictions(swire_tree,
                    swire_coords,
                    swire_names,
                    swire_test_sets,
                    atlas_coords,
                    predictor_name,
                    radius=1 / 60):
    import pdb
    predictions_ = pipeline.unserialise_predictions(
        pipeline.WORKING_DIR + predictor_name + '_predictions', [0, 1, 2, 3],
        ['RGZ & Norris'])
    for predictions in predictions_:
        nearby = swire_tree.query_ball_point(atlas_coords,
                                             radius)  # all-SWIRE indices
        nearby_bool = numpy.zeros((swire_test_sets.shape[0], ), dtype=bool)
        nearby_bool[nearby] = True
        set_ = swire_test_sets[:, pipeline.SET_NAMES['RGZ'],
                               predictions.quadrant]  # all-SWIRE indices, mask
        if not nearby_bool[set_].any():
            # Wrong quadrant.
            continue
        # pdb.set_trace()
        nearby_predictions = predictions.probabilities[
            nearby_bool[set_]]  # quadrant + dataset indices
        nearby_coords = swire_coords[nearby_bool & set_]
        nearby_names = swire_names[nearby_bool & set_]
        try:
            assert len(nearby_coords) == len(nearby_predictions)
        except AssertionError:
            pdb.set_trace()
            raise
        return list(zip(nearby_names, nearby_predictions))
def get_predictions(swire_tree, swire_coords, swire_test_sets, atlas_coords, predictor_name, radius=1 / 60):
    import pdb
    predictions_ = pipeline.unserialise_predictions(pipeline.WORKING_DIR + predictor_name + '_predictions', [0, 1, 2, 3], ['RGZ & Norris'])
    for predictions in predictions_:
        nearby = swire_tree.query_ball_point(atlas_coords, radius)  # all-SWIRE indices
        nearby_bool = numpy.zeros((swire_test_sets.shape[0],), dtype=bool)
        nearby_bool[nearby] = True
        set_ = swire_test_sets[:, pipeline.SET_NAMES['RGZ'], predictions.quadrant]  # all-SWIRE indices, mask
        if not nearby_bool[set_].any():
            # Wrong quadrant.
            continue
        # pdb.set_trace()
        nearby_predictions = predictions.probabilities[nearby_bool[set_]]  # quadrant + dataset indices
        nearby_coords = swire_coords[nearby_bool & set_]
        try:
            assert len(nearby_coords) == len(nearby_predictions)
        except AssertionError:
            pdb.set_trace()
            raise
        return list(zip(nearby_coords, nearby_predictions))
Beispiel #3
0
def print_table(field='cdfs'):
    titlemap = {
        'RGZ & Norris & compact': 'Compact',
        'RGZ & Norris & resolved': 'Resolved',
        'RGZ & Norris': 'All',
        'RGZ & compact': 'Compact',
        'RGZ & resolved': 'Resolved',
        'RGZ': 'All',
    }

    lr_predictions = itertools.chain(
        pipeline.unserialise_predictions(
            pipeline.WORKING_DIR + 'LogisticRegression_norris_{}_predictions'.format(field)),
        pipeline.unserialise_predictions(
            pipeline.WORKING_DIR + 'LogisticRegression_rgz_{}_predictions'.format(field)))
    rf_predictions = itertools.chain(
        pipeline.unserialise_predictions(
            pipeline.WORKING_DIR + 'RandomForestClassifier_norris_{}_predictions'.format(field)),
        pipeline.unserialise_predictions(
            pipeline.WORKING_DIR + 'RandomForestClassifier_rgz_{}_predictions'.format(field)))
    cnn_predictions = itertools.chain(
        pipeline.unserialise_predictions(
            pipeline.WORKING_DIR + 'CNN_norris_{}_predictions'.format(field)),
        pipeline.unserialise_predictions(
            pipeline.WORKING_DIR + 'CNN_rgz_{}_predictions'.format(field)))

    swire_names, swire_coords, _ = pipeline.generate_swire_features(overwrite=False, field=field)
    swire_labels = pipeline.generate_swire_labels(swire_names, swire_coords, overwrite=False, field=field)
    _, (_, swire_test_sets) = pipeline.generate_data_sets(swire_coords, swire_labels, overwrite=False, field=field)

    swire_names = numpy.array(swire_names)
    swire_coords = numpy.array(swire_coords)

    predictions_map = collections.defaultdict(dict) # SWIRE -> predictor -> probability
    swire_coords_map = {}
    swire_expert_map = {}
    swire_rgz_map = {}
    known_predictors = set()

    for classifier, predictions_ in [['LR', lr_predictions], ['CNN', cnn_predictions], ['RF', rf_predictions]]:
        for predictions in predictions_:
            dataset_name = predictions.dataset_name
            labeller = predictions.labeller
            if labeller == 'rgz' and 'Norris' in dataset_name:
                labeller = 'RGZ N'
                continue
            labeller = labeller.title() if labeller == 'norris' else labeller.upper()
            predictor_name = '{}({} / {})'.format(classifier, labeller, titlemap[dataset_name])
            if field == 'cdfs':
                swire_names_ = swire_names[swire_test_sets[:, pipeline.SET_NAMES['RGZ'], predictions.quadrant]]
                swire_coords_ = swire_coords[swire_test_sets[:, pipeline.SET_NAMES['RGZ'], predictions.quadrant]]
                swire_labels_ = swire_labels[swire_test_sets[:, pipeline.SET_NAMES['RGZ'], predictions.quadrant]]
            else:
                swire_names_ = swire_names[swire_test_sets[:, 0, 0]]
                swire_coords_ = swire_coords[swire_test_sets[:, 0, 0]]
                swire_labels_ = swire_labels[swire_test_sets[:, 0, 0]]
            assert predictions.probabilities.shape[0] == len(swire_names_), \
                'expected {}, got {}'.format(predictions.probabilities.shape[0], len(swire_names_))
            for name, coords, prediction, label in zip(swire_names_, swire_coords_, predictions.probabilities, swire_labels_):
                predictions_map[name][predictor_name] = prediction
                swire_coords_map[name] = coords
                swire_expert_map[name] = label[0]
                swire_rgz_map[name] = label[1]
            known_predictors.add(predictor_name)

    known_predictors = sorted(known_predictors)

    swires = sorted(predictions_map)
    ras = []
    decs = []
    is_expert_host = []
    is_rgz_host = []
    predictor_columns = collections.defaultdict(list)
    for swire in swires:
        for predictor in known_predictors:
            predictor_columns[predictor].append(predictions_map[swire].get(predictor, ''))
        ras.append(swire_coords_map[swire][0])
        decs.append(swire_coords_map[swire][1])
        is_expert_host.append(['no', 'yes'][swire_expert_map[swire]])
        is_rgz_host.append(['no', 'yes'][swire_rgz_map[swire]])

    table = astropy.table.Table(
        data=[swires, ras, decs, is_expert_host, is_rgz_host] + [predictor_columns[p] for p in known_predictors],
        names=['SWIRE', 'RA', 'Dec', 'Expert host', 'RGZ host'] + known_predictors)
    table.write('/Users/alger/data/Crowdastro/predicted_swire_table_{}_21_03_18.csv'.format(field), format='csv')
    for p in known_predictors:
        table[p].format = '{:.4f}'
    table.write('/Users/alger/data/Crowdastro/predicted_swire_table_{}_21_03_18.tex'.format(field), format='latex')
def plot_predictions(cut=0.95,
                     labeller='norris',
                     dataset_name=None,
                     classifier=None):
    """Plot colour-colour diagram for predicted host galaxies.

    labeller in {'norris', 'rgz'}
    dataset_name in {'RGZ & Norris', ...}
    """

    with h5py.File(CROWDASTRO_PATH, 'r') as f:
        swire_numeric_cdfs = f['/swire/cdfs/numeric'][:, 2:2 + 4]

    f_36 = swire_numeric_cdfs[:, 0]
    f_45 = swire_numeric_cdfs[:, 1]
    f_58 = swire_numeric_cdfs[:, 2]
    f_80 = swire_numeric_cdfs[:, 3]
    detection_58 = (f_58 != -99)
    detection_80 = (f_80 != -99)

    p = pipeline.unserialise_predictions(
        pipeline.WORKING_DIR +
        '{}_{}_cdfs_predictions'.format(classifier, labeller))
    predictions = {}
    for i in p:
        predictions[i.dataset_name, i.quadrant] = i

    swire_names, swire_coords, _ = pipeline.generate_swire_features(
        overwrite=False, field='cdfs')
    swire_labels = pipeline.generate_swire_labels(swire_names,
                                                  swire_coords,
                                                  overwrite=False,
                                                  field='cdfs')
    _, (_, swire_test_sets) = pipeline.generate_data_sets(swire_coords,
                                                          swire_labels,
                                                          overwrite=False,
                                                          field='cdfs')

    xs = []
    ys = []
    colours = []
    for q in range(4):
        swire_set = swire_test_sets[:, pipeline.SET_NAMES['RGZ'], q]
        if labeller == 'norris' and not dataset_name:
            # predictions_set = predictions['RGZ & Norris', q].probabilities > cut
            f_36_ = f_36[swire_set & swire_labels[:, 0]]  #[predictions_set]
            f_45_ = f_45[swire_set & swire_labels[:, 0]]  #[predictions_set]
            f_58_ = f_58[swire_set & swire_labels[:, 0]]  #[predictions_set]
            f_80_ = f_80[swire_set & swire_labels[:, 0]]  #[predictions_set]
        elif labeller == 'rgz' and not dataset_name:
            f_36_ = f_36[swire_set & swire_labels[:, 1]]
            f_45_ = f_45[swire_set & swire_labels[:, 1]]
            f_58_ = f_58[swire_set & swire_labels[:, 1]]
            f_80_ = f_80[swire_set & swire_labels[:, 1]]
        if labeller == 'norris' and dataset_name:
            predictions_set = predictions[dataset_name, q].probabilities > cut
            f_36_ = f_36[swire_set][predictions_set]
            f_45_ = f_45[swire_set][predictions_set]
            f_58_ = f_58[swire_set][predictions_set]
            f_80_ = f_80[swire_set][predictions_set]
            probabilities = predictions[dataset_name,
                                        q].probabilities[predictions_set]
        detection_58_ = (f_58_ != -99)
        detection_80_ = (f_80_ != -99)
        detection_all_ = detection_58_ & detection_80_

        ratio_58_36 = numpy.log10(f_58_[detection_all_] /
                                  f_36_[detection_all_])
        ratio_80_45 = numpy.log10(f_80_[detection_all_] /
                                  f_45_[detection_all_])
        probabilities = probabilities[detection_all_]
        xs.extend(ratio_58_36)
        ys.extend(ratio_80_45)
        colours.extend(probabilities)

    assert len(xs) == len(ys)
    assert len(xs) == len(colours)

    plot_basic()
    if dataset_name:
        plt.scatter(xs,
                    ys,
                    s=20,
                    marker='^',
                    linewidth=0,
                    alpha=0.5,
                    c=numpy.array(colours),
                    cmap='winter')
    else:
        plt.scatter(xs, ys, s=25, c='r', marker='^', linewidth=0)
    plt.xlim((-0.75, 1.0))
    plt.ylim((-0.75, 1.0))
    plt.xlabel('$\\log_{10}(S_{5.8}/S_{3.6})$')
    plt.ylabel('$\\log_{10}(S_{8.0}/S_{4.5})$')
    plt.subplots_adjust(left=0.2, bottom=0.15, right=0.95, top=0.95)
    plt.colorbar()
    plt.show()
def plot_grid(field='cdfs'):
    # Load predictions.
    lr_predictions = itertools.chain(
        pipeline.unserialise_predictions(
            pipeline.WORKING_DIR + 'LogisticRegression_norris_{}_predictions'.format(field)),
        pipeline.unserialise_predictions(
            pipeline.WORKING_DIR + 'LogisticRegression_rgz_{}_predictions'.format(field)))
    rf_predictions = itertools.chain(
        pipeline.unserialise_predictions(
            pipeline.WORKING_DIR + 'RandomForestClassifier_norris_{}_predictions'.format(field)),
        pipeline.unserialise_predictions(
            pipeline.WORKING_DIR + 'RandomForestClassifier_rgz_{}_predictions'.format(field)))
    cnn_predictions = itertools.chain(
        pipeline.unserialise_predictions(
            pipeline.WORKING_DIR + 'CNN_norris_{}_predictions'.format(field)),
        pipeline.unserialise_predictions(
            pipeline.WORKING_DIR + 'CNN_rgz_{}_predictions'.format(field)))

    # Convert to the format we need. e.g. {'RGZ' -> [acc, acc, acc, acc]}
    lr_norris_accuracies = {sstr: [0] * 4 for sstr in pipeline.SET_NAMES}
    lr_rgz_accuracies = {sstr: [0] * 4 for sstr in pipeline.SET_NAMES}
    rf_norris_accuracies = {sstr: [0] * 4 for sstr in pipeline.SET_NAMES}
    rf_rgz_accuracies = {sstr: [0] * 4 for sstr in pipeline.SET_NAMES}
    cnn_norris_accuracies = {sstr: [0] * 4 for sstr in pipeline.SET_NAMES}
    cnn_rgz_accuracies = {sstr: [0] * 4 for sstr in pipeline.SET_NAMES}
    for predictions in lr_predictions:
        dataset_name = predictions.dataset_name
        if predictions.labeller == 'norris':
            lr_norris_accuracies[dataset_name][predictions.quadrant] = predictions.balanced_accuracy
        else:
            lr_rgz_accuracies[dataset_name][predictions.quadrant] = predictions.balanced_accuracy
    for predictions in rf_predictions:
        dataset_name = predictions.dataset_name
        if predictions.labeller == 'norris':
            rf_norris_accuracies[dataset_name][predictions.quadrant] = predictions.balanced_accuracy
        else:
            rf_rgz_accuracies[dataset_name][predictions.quadrant] = predictions.balanced_accuracy
    for predictions in cnn_predictions:
        dataset_name = predictions.dataset_name
        if predictions.labeller == 'norris':
            cnn_norris_accuracies[dataset_name][predictions.quadrant] = predictions.balanced_accuracy
        else:
            cnn_rgz_accuracies[dataset_name][predictions.quadrant] = predictions.balanced_accuracy

    if field == 'cdfs':
        # Load RGZ cross-identifications and compute a balanced accuracy with them.
        swire_names, swire_coords, _ = pipeline.generate_swire_features(overwrite=False, field=field)
        swire_labels = pipeline.generate_swire_labels(swire_names, swire_coords, overwrite=False, field=field)
        (_, atlas_test_sets), (_, swire_test_sets) = pipeline.generate_data_sets(swire_coords, swire_labels, overwrite=False, field=field)
        label_rgz_accuracies = {sstr: [0] * 4 for sstr in pipeline.SET_NAMES}
        label_norris_accuracies = {sstr: [1] * 4 for sstr in pipeline.SET_NAMES}  # By definition.
        for dataset_name in pipeline.SET_NAMES:
            for quadrant in range(4):
                test_set = swire_test_sets[:, pipeline.SET_NAMES[dataset_name], quadrant]
                predictions = swire_labels[test_set, 1]
                trues = swire_labels[test_set, 0]
                ba = balanced_accuracy(trues, predictions)
                label_rgz_accuracies[dataset_name][quadrant] = ba

    colours = ['grey', 'magenta', 'blue', 'orange']
    markers = ['o', '^', 'x', 's']
    handles = {}
    plt.figure(figsize=(5, 5))

    accuracy_map = defaultdict(lambda: defaultdict(dict))  # For table output.
    output_sets = [
        ('LR', [lr_norris_accuracies, lr_rgz_accuracies]),
        ('CNN', [cnn_norris_accuracies, cnn_rgz_accuracies]),
        ('RF', [rf_norris_accuracies, rf_rgz_accuracies]),
    ]
    if field == 'cdfs':
        output_sets.append(('Labels', [label_norris_accuracies, label_rgz_accuracies]))
    for j, (classifier_name, classifier_set) in enumerate(output_sets):
        for i, set_name in enumerate(norris_labelled_sets):
            if 'compact' not in set_name:  # Skip compact.
                ax = plt.subplot(2, 1, {'RGZ & Norris & resolved': 1, 'RGZ & Norris': 2}[set_name])
                ax.set_ylim((80, 100))
                ax.set_xlim((-0.5, 1.5))
                ax.set_xticks([0, 1])#, 2])
                ax.set_xticklabels(['Norris',
                                    # 'RGZ N',
                                    'RGZ',
                                   ], rotation='horizontal')
                if i == 2:
                    plt.xlabel('Labels')
                plt.ylabel('{}\nBalanced accuracy\n(per cent)'.format(titlemap[set_name]))

                ax.title.set_fontsize(16)
                ax.xaxis.label.set_fontsize(12)
                ax.yaxis.label.set_fontsize(9)
                for tick in ax.get_xticklabels() + ax.get_yticklabels():
                    tick.set_fontsize(10)

                ax.grid(which='major', axis='y', color='#EEEEEE')
            for k in range(4):
                if 'compact' in set_name:
                    continue
                if j != 3:  # !Labels
                    ax.scatter([0 + (j - 1) / 5], classifier_set[0][set_name][k] * 100,
                                color=colours[j], marker=markers[j], linewidth=1, edgecolor='k')
                rgz_offset = ((j - 1.5) / 6) if field == 'cdfs' else (j - 1) / 5
                handles[j] = ax.scatter([1 + rgz_offset],
                           classifier_set[1][fullmap[set_name]][k] * 100,
                           color=colours[j], marker=markers[j], linewidth=1, edgecolor='k')
                # ax.scatter([1 + (j - 1) / 5], classifier_set[1][set_name][k] * 100,
                #            color=colours[j], marker=markers[j], linewidth=1, edgecolor='k')
            # Compute for table.
            for labeller in ['Norris', 'RGZ N', 'RGZ']:
                if labeller == 'Norris':
                    mean = numpy.mean(classifier_set[0][set_name]) * 100
                    stdev = numpy.std(classifier_set[0][set_name]) * 100
                elif labeller == 'RGZ N':
                    continue
                    # mean = numpy.mean(classifier_set[1][set_name]) * 100
                    # stdev = numpy.std(classifier_set[1][set_name]) * 100
                elif labeller == 'RGZ':
                    mean = numpy.mean(classifier_set[1][fullmap[set_name]]) * 100
                    stdev = numpy.std(classifier_set[1][fullmap[set_name]]) * 100
                accuracy_map[labeller][classifier_name][titlemap[set_name]] = '${:.02f} \\pm {:.02f}$'.format(mean, stdev)

    # Assemble table.
    col_labeller = []
    col_classifier = []
    col_compact = []
    col_resolved = []
    col_all = []
    for labeller in ['Norris', 'RGZ N', 'RGZ']:
        if labeller == 'RGZ N':
            continue

        for classifier in ['CNN', 'LR', 'RF'] + ['Labels'] if field == 'cdfs' else []:
            col_labeller.append(labeller)
            col_classifier.append(classifier)
            col_compact.append(accuracy_map[labeller][classifier]['Compact'])
            col_resolved.append(accuracy_map[labeller][classifier]['Resolved'])
            col_all.append(accuracy_map[labeller][classifier]['All'])
    out_table = astropy.table.Table([col_labeller, col_classifier, col_compact, col_resolved, col_all],
                                    names=['Labeller', 'Classifier', "Mean `Compact' accuracy\\\\(per cent)",
                                           "Mean `Resolved' accuracy\\\\(per cent)",
                                           "Mean `All' accuracy\\\\(per cent)"])
    out_table.write('../{}_accuracy_table.tex'.format(field), format='latex')

    plt.figlegend([handles[j] for j in sorted(handles)], ['LR', 'CNN', 'RF'] + (['Labels'] if field == 'cdfs' else []), 'lower center', ncol=4, fontsize=10)
    plt.subplots_adjust(bottom=0.2, hspace=0.25)
    plt.savefig('../images/{}_ba_grid.pdf'.format(field),
                bbox_inches='tight', pad_inches=0)
    plt.savefig('../images/{}_ba_grid.png'.format(field),
                bbox_inches='tight', pad_inches=0)
Beispiel #6
0
#!/usr/bin/env python3
"""Find good/bad examples of cross-identification.

Input files:
- ???

Output files:
- images/cdfs_ba_grid.pdf
- images/cdfs_ba_grid.png

Matthew Alger <*****@*****.**>
Research School of Astronomy and Astrophysics
The Australian National University
2017
"""
import matplotlib.pyplot as plt

import configure_plotting
import pipeline

configure_plotting.configure()

# Load predictions.
lr_predictions = pipeline.unserialise_predictions(
    pipeline.WORKING_DIR + 'LogisticRegression_predictions')
rf_predictions = pipeline.unserialise_predictions(
    pipeline.WORKING_DIR + 'RandomForestClassifier_predictions')
def plot_predictions(cut=0.95, labeller='norris', dataset_name=None, classifier=None):
    """Plot colour-colour diagram for predicted host galaxies.

    labeller in {'norris', 'rgz'}
    dataset_name in {'RGZ & Norris', ...}
    """

    with h5py.File(CROWDASTRO_PATH, 'r') as f:
        swire_numeric_cdfs = f['/swire/cdfs/numeric'][:, 2:2 + 4]

    f_36 = swire_numeric_cdfs[:, 0]
    f_45 = swire_numeric_cdfs[:, 1]
    f_58 = swire_numeric_cdfs[:, 2]
    f_80 = swire_numeric_cdfs[:, 3]
    detection_58 = (f_58 != -99)
    detection_80 = (f_80 != -99)

    p = pipeline.unserialise_predictions(
            pipeline.WORKING_DIR + '{}_{}_cdfs_predictions'.format(classifier, labeller))
    predictions = {}
    for i in p:
        predictions[i.dataset_name, i.quadrant] = i


    swire_names, swire_coords, _ = pipeline.generate_swire_features(overwrite=False, field='cdfs')
    swire_labels = pipeline.generate_swire_labels(swire_names, swire_coords, overwrite=False, field='cdfs')
    _, (_, swire_test_sets) = pipeline.generate_data_sets(swire_coords, swire_labels, overwrite=False, field='cdfs')

    xs = []
    ys = []
    colours = []
    for q in range(4):
        swire_set = swire_test_sets[:, pipeline.SET_NAMES['RGZ'], q]
        if labeller == 'norris' and not dataset_name:
            # predictions_set = predictions['RGZ & Norris', q].probabilities > cut
            f_36_ = f_36[swire_set & swire_labels[:, 0]]#[predictions_set]
            f_45_ = f_45[swire_set & swire_labels[:, 0]]#[predictions_set]
            f_58_ = f_58[swire_set & swire_labels[:, 0]]#[predictions_set]
            f_80_ = f_80[swire_set & swire_labels[:, 0]]#[predictions_set]
        elif labeller == 'rgz' and not dataset_name:
            f_36_ = f_36[swire_set & swire_labels[:, 1]]
            f_45_ = f_45[swire_set & swire_labels[:, 1]]
            f_58_ = f_58[swire_set & swire_labels[:, 1]]
            f_80_ = f_80[swire_set & swire_labels[:, 1]]
        if labeller == 'norris' and dataset_name:
            predictions_set = predictions[dataset_name, q].probabilities > cut
            f_36_ = f_36[swire_set][predictions_set]
            f_45_ = f_45[swire_set][predictions_set]
            f_58_ = f_58[swire_set][predictions_set]
            f_80_ = f_80[swire_set][predictions_set]
            probabilities = predictions[dataset_name, q].probabilities[predictions_set]
        detection_58_ = (f_58_ != -99)
        detection_80_ = (f_80_ != -99)
        detection_all_ = detection_58_ & detection_80_

        ratio_58_36 = numpy.log10(f_58_[detection_all_] / f_36_[detection_all_])
        ratio_80_45 = numpy.log10(f_80_[detection_all_] / f_45_[detection_all_])
        probabilities = probabilities[detection_all_]
        xs.extend(ratio_58_36)
        ys.extend(ratio_80_45)
        colours.extend(probabilities)

    assert len(xs) == len(ys)
    assert len(xs) == len(colours)

    plot_basic()
    if dataset_name:
        plt.scatter(xs, ys, s=20, marker='^', linewidth=0, alpha=0.5, c=numpy.array(colours), cmap='winter')
    else:
        plt.scatter(xs, ys, s=25, c='r', marker='^', linewidth=0)
    plt.xlim((-0.75, 1.0))
    plt.ylim((-0.75, 1.0))
    plt.xlabel('$\\log_{10}(S_{5.8}/S_{3.6})$')
    plt.ylabel('$\\log_{10}(S_{8.0}/S_{4.5})$')
    plt.subplots_adjust(left=0.2, bottom=0.15, right=0.95, top=0.95)
    plt.colorbar()
    plt.show()
Beispiel #8
0
def plot_grid(field='cdfs'):
    # Load predictions.
    lr_predictions = itertools.chain(
        pipeline.unserialise_predictions(
            pipeline.WORKING_DIR +
            'LogisticRegression_norris_{}_predictions'.format(field)),
        pipeline.unserialise_predictions(
            pipeline.WORKING_DIR +
            'LogisticRegression_rgz_{}_predictions'.format(field)))
    rf_predictions = itertools.chain(
        pipeline.unserialise_predictions(
            pipeline.WORKING_DIR +
            'RandomForestClassifier_norris_{}_predictions'.format(field)),
        pipeline.unserialise_predictions(
            pipeline.WORKING_DIR +
            'RandomForestClassifier_rgz_{}_predictions'.format(field)))
    cnn_predictions = itertools.chain(
        pipeline.unserialise_predictions(
            pipeline.WORKING_DIR + 'CNN_norris_{}_predictions'.format(field)),
        pipeline.unserialise_predictions(
            pipeline.WORKING_DIR + 'CNN_rgz_{}_predictions'.format(field)))

    # Convert to the format we need. e.g. {'RGZ' -> [acc, acc, acc, acc]}
    lr_norris_accuracies = {sstr: [0] * 4 for sstr in pipeline.SET_NAMES}
    lr_rgz_accuracies = {sstr: [0] * 4 for sstr in pipeline.SET_NAMES}
    rf_norris_accuracies = {sstr: [0] * 4 for sstr in pipeline.SET_NAMES}
    rf_rgz_accuracies = {sstr: [0] * 4 for sstr in pipeline.SET_NAMES}
    cnn_norris_accuracies = {sstr: [0] * 4 for sstr in pipeline.SET_NAMES}
    cnn_rgz_accuracies = {sstr: [0] * 4 for sstr in pipeline.SET_NAMES}
    for predictions in lr_predictions:
        dataset_name = predictions.dataset_name
        if predictions.labeller == 'norris':
            lr_norris_accuracies[dataset_name][
                predictions.quadrant] = predictions.balanced_accuracy
        else:
            lr_rgz_accuracies[dataset_name][
                predictions.quadrant] = predictions.balanced_accuracy
    for predictions in rf_predictions:
        dataset_name = predictions.dataset_name
        if predictions.labeller == 'norris':
            rf_norris_accuracies[dataset_name][
                predictions.quadrant] = predictions.balanced_accuracy
        else:
            rf_rgz_accuracies[dataset_name][
                predictions.quadrant] = predictions.balanced_accuracy
    for predictions in cnn_predictions:
        dataset_name = predictions.dataset_name
        if predictions.labeller == 'norris':
            cnn_norris_accuracies[dataset_name][
                predictions.quadrant] = predictions.balanced_accuracy
        else:
            cnn_rgz_accuracies[dataset_name][
                predictions.quadrant] = predictions.balanced_accuracy

    if field == 'cdfs':
        # Load RGZ cross-identifications and compute a balanced accuracy with them.
        swire_names, swire_coords, _ = pipeline.generate_swire_features(
            overwrite=False, field=field)
        swire_labels = pipeline.generate_swire_labels(swire_names,
                                                      swire_coords,
                                                      overwrite=False,
                                                      field=field)
        (_, atlas_test_sets), (_,
                               swire_test_sets) = pipeline.generate_data_sets(
                                   swire_coords,
                                   swire_labels,
                                   overwrite=False,
                                   field=field)
        label_rgz_accuracies = {sstr: [0] * 4 for sstr in pipeline.SET_NAMES}
        label_norris_accuracies = {
            sstr: [1] * 4
            for sstr in pipeline.SET_NAMES
        }  # By definition.
        for dataset_name in pipeline.SET_NAMES:
            for quadrant in range(4):
                test_set = swire_test_sets[:, pipeline.SET_NAMES[dataset_name],
                                           quadrant]
                predictions = swire_labels[test_set, 1]
                trues = swire_labels[test_set, 0]
                ba = balanced_accuracy(trues, predictions)
                label_rgz_accuracies[dataset_name][quadrant] = ba

    colours = ['grey', 'magenta', 'blue', 'orange']
    markers = ['o', '^', 'x', 's']
    handles = {}
    plt.figure(figsize=(5, 5))

    accuracy_map = defaultdict(lambda: defaultdict(dict))  # For table output.
    output_sets = [
        ('LR', [lr_norris_accuracies, lr_rgz_accuracies]),
        ('CNN', [cnn_norris_accuracies, cnn_rgz_accuracies]),
        ('RF', [rf_norris_accuracies, rf_rgz_accuracies]),
    ]
    if field == 'cdfs':
        output_sets.append(
            ('Labels', [label_norris_accuracies, label_rgz_accuracies]))
    for j, (classifier_name, classifier_set) in enumerate(output_sets):
        for i, set_name in enumerate(norris_labelled_sets):
            if 'compact' not in set_name:  # Skip compact.
                ax = plt.subplot(2, 1, {
                    'RGZ & Norris & resolved': 1,
                    'RGZ & Norris': 2
                }[set_name])
                ax.set_ylim((80, 100))
                ax.set_xlim((-0.5, 1.5))
                ax.set_xticks([0, 1])  #, 2])
                ax.set_xticklabels(
                    [
                        'Norris',
                        # 'RGZ N',
                        'RGZ',
                    ],
                    rotation='horizontal')
                if i == 2:
                    plt.xlabel('Labels')
                plt.ylabel('{}\nBalanced accuracy\n(per cent)'.format(
                    titlemap[set_name]))

                ax.title.set_fontsize(16)
                ax.xaxis.label.set_fontsize(12)
                ax.yaxis.label.set_fontsize(9)
                for tick in ax.get_xticklabels() + ax.get_yticklabels():
                    tick.set_fontsize(10)

                ax.grid(which='major', axis='y', color='#EEEEEE')
            for k in range(4):
                if 'compact' in set_name:
                    continue
                if j != 3:  # !Labels
                    ax.scatter([0 + (j - 1) / 5],
                               classifier_set[0][set_name][k] * 100,
                               color=colours[j],
                               marker=markers[j],
                               linewidth=1,
                               edgecolor='k')
                rgz_offset = ((j - 1.5) /
                              6) if field == 'cdfs' else (j - 1) / 5
                handles[j] = ax.scatter(
                    [1 + rgz_offset],
                    classifier_set[1][fullmap[set_name]][k] * 100,
                    color=colours[j],
                    marker=markers[j],
                    linewidth=1,
                    edgecolor='k')
                # ax.scatter([1 + (j - 1) / 5], classifier_set[1][set_name][k] * 100,
                #            color=colours[j], marker=markers[j], linewidth=1, edgecolor='k')
            # Compute for table.
            for labeller in ['Norris', 'RGZ N', 'RGZ']:
                if labeller == 'Norris':
                    mean = numpy.mean(classifier_set[0][set_name]) * 100
                    stdev = numpy.std(classifier_set[0][set_name]) * 100
                elif labeller == 'RGZ N':
                    continue
                    # mean = numpy.mean(classifier_set[1][set_name]) * 100
                    # stdev = numpy.std(classifier_set[1][set_name]) * 100
                elif labeller == 'RGZ':
                    mean = numpy.mean(
                        classifier_set[1][fullmap[set_name]]) * 100
                    stdev = numpy.std(
                        classifier_set[1][fullmap[set_name]]) * 100
                accuracy_map[labeller][classifier_name][
                    titlemap[set_name]] = '${:.02f} \\pm {:.02f}$'.format(
                        mean, stdev)

    # Assemble table.
    col_labeller = []
    col_classifier = []
    col_compact = []
    col_resolved = []
    col_all = []
    for labeller in ['Norris', 'RGZ N', 'RGZ']:
        if labeller == 'RGZ N':
            continue

        for classifier in ['CNN', 'LR', 'RF'] + ['Labels'
                                                 ] if field == 'cdfs' else []:
            col_labeller.append(labeller)
            col_classifier.append(classifier)
            col_compact.append(accuracy_map[labeller][classifier]['Compact'])
            col_resolved.append(accuracy_map[labeller][classifier]['Resolved'])
            col_all.append(accuracy_map[labeller][classifier]['All'])
    out_table = astropy.table.Table(
        [col_labeller, col_classifier, col_compact, col_resolved, col_all],
        names=[
            'Labeller', 'Classifier', "Mean `Compact' accuracy\\\\(per cent)",
            "Mean `Resolved' accuracy\\\\(per cent)",
            "Mean `All' accuracy\\\\(per cent)"
        ])
    out_table.write('../{}_accuracy_table.tex'.format(field), format='latex')

    plt.figlegend([handles[j] for j in sorted(handles)], ['LR', 'CNN', 'RF'] +
                  (['Labels'] if field == 'cdfs' else []),
                  'lower center',
                  ncol=4,
                  fontsize=10)
    plt.subplots_adjust(bottom=0.2, hspace=0.25)
    plt.savefig('../images/{}_ba_grid.pdf'.format(field),
                bbox_inches='tight',
                pad_inches=0)
    plt.savefig('../images/{}_ba_grid.png'.format(field),
                bbox_inches='tight',
                pad_inches=0)