예제 #1
0
    model = torch.nn.Sequential(*modules)  
    # print([x for x in model.children()])
    # print(model)
    return model
    

if __name__ == '__main__':

    # debugging only

    import numpy as np

    from zoobot.pytorch.training import losses
    from zoobot.shared import label_metadata, schemas
    from zoobot.pytorch.estimators import define_model

    channels = 3

    question_answer_pairs = label_metadata.decals_all_campaigns_ortho_pairs
    dependencies = label_metadata.decals_ortho_dependencies
    schema = schemas.Schema(question_answer_pairs, dependencies)

    loss_func = losses.calculate_multiquestion_loss

    model = define_model.ZoobotModel(schema=schema, loss=loss_func, channels=channels, get_architecture=get_resnet, representation_dim=2048)

    x = torch.from_numpy(np.random.rand(16, channels, 224, 224)).float()
    # print(model(x))
    print(model(x).shape)
예제 #2
0
def main():

    question_answer_pairs = label_metadata.decals_all_campaigns_ortho_pairs
    dependencies = label_metadata.decals_ortho_dependencies
    schema = schemas.Schema(question_answer_pairs, dependencies)
    # logging.info('Schema: {}'.format(schema))

    """Pick which model's predictions to load"""

    shards = ['dr12', 'dr5', 'dr8']

    model_index = 'm4'

    # checkpoint = f'all_campaigns_ortho_v2_{model_index}'
    # checkpoint = f'all_campaigns_ortho_v2_train_only_dr5_{model_index}'
    checkpoint = f'all_campaigns_ortho_v2_train_only_d12_dr5_{model_index}'  # d12 typo, oops

# /home/walml/repos/gz-decals-classifiers/results/test_shard_dr8_checkpoint_all_campaigns_ortho_v2_train_only_d12_dr5_m0.hdf5 have
# /home/walml/repos/gz-decals-classifiers/results/test_shard_dr8_checkpoint_all_campaigns_ortho_v2_train_only_dr12_dr5_m0.hdf5 looking

    predictions_hdf5_locs = [f'/home/walml/repos/gz-decals-classifiers/results/test_shard_{shard}_checkpoint_{checkpoint}.hdf5' for shard in shards]
    print(predictions_hdf5_locs)
    predictions_hdf5_locs = [loc for loc in predictions_hdf5_locs if os.path.exists(loc)]
    assert len(predictions_hdf5_locs) > 0
    logging.info('Num. prediction .hdf5 to load: {}'.format(len(predictions_hdf5_locs)))


    """Specify some details for saving"""
    run_name = f'checkpoint_{checkpoint}'
    save_dir = f'/home/walml/repos/gz-decals-classifiers/results/campaign_comparison/{run_name}'
    if not os.path.isdir(save_dir):
        os.mkdir(save_dir)

    # normalize_cm_matrices = 'true'
    normalize_cm_matrices = None


    """Load volunteer catalogs and match to predictions"""

    catalog_dr12 = pd.read_parquet(f'/home/walml/repos/gz-decals-classifiers/catalogs/dr12_ortho_v2_labelled_catalog.parquet')
    catalog_dr5 = pd.read_parquet(f'/home/walml/repos/gz-decals-classifiers/catalogs/dr5_ortho_v2_labelled_catalog.parquet')
    catalog_dr8 = pd.read_parquet(f'/home/walml/repos/gz-decals-classifiers/catalogs/dr8_ortho_v2_labelled_catalog.parquet')
    catalog = pd.concat([catalog_dr12, catalog_dr5, catalog_dr8], axis=0).reset_index()

    # possibly, catalogs don't include _fraction cols?! 
    # for question in schema.questions:
    #     for answer in question.answers:
    #         catalog[answer.text + '_fraction'] = catalog[answer.text].astype(float) / catalog[question.text + '_total-votes'].astype(float)

    all_labels, all_concentrations = match_predictions_to_catalog(predictions_hdf5_locs, catalog, save_loc=None)
    # print(len(all_labels))
    # print(len(all_concentrations))
    # print(all_labels.head())

    all_labels = multi_catalog_tweaks(all_labels)

    # not actually used currently
    all_fractions = dirichlet_stats.dirichlet_prob_of_answers(all_concentrations, schema, temperature=None)

    # plt.hist(all_labels['smooth-or-featured-dr12_total-votes'], alpha=.5, label='dr12', range=(0, 80))  # 7.5k retired
    # plt.hist(all_labels['smooth-or-featured_total-votes'], alpha=.5, label='dr5', range=(0, 80))  # 2.2k retired
    # plt.hist(all_labels['smooth-or-featured-dr8_total-votes'], alpha=.5, label='dr8', range=(0, 80))  # half have almost no votes, 4.6k retired
    # plt.show()
    # # exit()

    votes_for_retired = 34
    all_labels['is_retired_in_dr12'] = all_labels['smooth-or-featured-dr12_total-votes'] > votes_for_retired
    all_labels['is_retired_in_dr5'] = all_labels['smooth-or-featured-dr5_total-votes'] > votes_for_retired
    all_labels['is_retired_in_dr8'] = all_labels['smooth-or-featured-dr8_total-votes'] > votes_for_retired
    all_labels['is_retired_in_any_dr'] = all_labels['is_retired_in_dr12'] | all_labels['is_retired_in_dr5'] | all_labels['is_retired_in_dr8']

    retired_concentrations = all_concentrations[all_labels['is_retired_in_any_dr']]
    retired_labels = all_labels.query('is_retired_in_any_dr')
    retired_fractions = dirichlet_stats.dirichlet_prob_of_answers(retired_concentrations, schema, temperature=None)

    logging.info('All concentrations: {}'.format(all_concentrations.shape))
    logging.info('Retired concentrations: {}'.format(retired_concentrations.shape))

    # print(all_labels['is_retired_in_dr5'].sum())
    # exit()

    """Now we're ready to calculate some metrics"""

    # create_paper_metric_tables(retired_labels, retired_fractions, schema)

    # # the least interpretable but maybe most ml-meaningful metric
    # # unlike cm and regression, does not only include retired (i.e. high N) galaxies
    # val_loss, loss_by_q_df = get_loss(all_labels, all_concentrations, schema=schema, save_loc=os.path.join(save_dir, 'val_loss_by_q.csv'))
    # print('Mean val loss: {:.3f}'.format(val_loss.mean()))

    confusion_matrices_split_by_confidence(retired_labels, retired_fractions, schema, save_dir, normalize=normalize_cm_matrices, cm_name='cm')

    # print((retired_labels['smooth-or-featured-dr12_total-votes'] > 20).sum())

    get_regression_errors(
        retired=retired_labels,
        predicted_fractions=retired_fractions,
        schema=schema,
        df_save_loc=os.path.join(save_dir, 'regression_errors.csv'),
        fig_save_loc=os.path.join(save_dir, 'regression_errors_bar_plot.pdf')
    )

    """And we can repeat the process, but using the DR5 predictions for the DR8 answer columns"""

    # pick the rows with dr8 galaxies (dr8_)
    dr8_galaxies = all_labels['has_dr8_votes'].values
    dr8_labels = all_labels[dr8_galaxies]
    dr8_concentrations = all_concentrations[dr8_galaxies]
    dr8_fractions = all_fractions[dr8_galaxies]
    # convert the predictions for dr8 answers to use the dr5 answers instead
    dr8_fractions_with_dr5_head = replace_dr8_cols_with_dr5_cols(dr8_fractions, schema)
    dr8_concentrations_with_dr5_head = replace_dr8_cols_with_dr5_cols(dr8_concentrations, schema)

    # calculate loss on all dr8 galaxies, using dr5 head answers
    val_loss, loss_by_q_df = get_loss(dr8_labels, dr8_concentrations_with_dr5_head, schema=schema, save_loc=os.path.join(save_dir, 'val_loss_by_q_dr8_galaxies_with_dr5_head_for_dr8.csv'))
    logging.info('Mean val loss for DR8 galaxies using DR5 head for DR8: {:.3f}'.format(val_loss.mean()))
    
    # select the retired dr8 galaxies by row
    dr8_retired_galaxies = dr8_labels['is_retired_in_dr8'].values
    dr8_retired_labels = dr8_labels[dr8_retired_galaxies]
    dr8_retired_concentrations_with_dr5_head = dr8_concentrations_with_dr5_head[dr8_retired_galaxies]
    dr8_retired_fractions_with_dr5_head = dr8_fractions_with_dr5_head[dr8_retired_galaxies]

    # will give a bunch of warnings as we've only selected dr8 galaxies hence dr12 and dr5 will not have enough votes and give empty confusion matrices
    confusion_matrices_split_by_confidence(
        dr8_retired_labels,
        dr8_retired_fractions_with_dr5_head,
        schema,
        save_dir,
        normalize=normalize_cm_matrices,
        cm_name='cm_dr8_galaxies_with_dr5_head_for_dr8'
    )

    # similarly will throw dr12 and dr5 warnings
    get_regression_errors(
        retired=dr8_retired_labels,
        predicted_fractions=dr8_retired_fractions_with_dr5_head,
        schema=schema,
        df_save_loc=os.path.join(save_dir, 'regression_errors_dr8_galaxies_with_dr5_head_for_dr8.csv'),
        fig_save_loc=os.path.join(save_dir, 'regression_errors_dr8_galaxies_with_dr5_head_for_dr8_bar_plot.pdf')
    )