def evaluate_ms_ssim(conf):
    # create DataLoader
    loader = create_databunch(conf["data_path"], conf["fourier"],
                              conf["source_list"], conf["batch_size"])
    model_path = conf["model_path"]
    out_path = Path(model_path).parent / "evaluation"
    out_path.mkdir(parents=True, exist_ok=True)

    img_size = loader.dataset[0][0][0].shape[-1]
    model = load_pretrained_model(conf["arch_name"], conf["model_path"],
                                  img_size)
    if conf["model_path_2"] != "none":
        model_2 = load_pretrained_model(conf["arch_name_2"],
                                        conf["model_path_2"], img_size)

    vals = []

    if img_size < 160:
        click.echo("\nThis is only a placeholder!\
                Images too small for meaningful ms ssim calculations.\n")

    # iterate trough DataLoader
    for i, (img_test, img_true) in enumerate(tqdm(loader)):

        pred = eval_model(img_test, model)
        if conf["model_path_2"] != "none":
            pred_2 = eval_model(img_test, model_2)
            pred = torch.cat((pred, pred_2), dim=1)

        ifft_truth = get_ifft(img_true, amp_phase=conf["amp_phase"])
        ifft_pred = get_ifft(pred, amp_phase=conf["amp_phase"])

        if img_size < 160:
            ifft_truth = pad_unsqueeze(torch.tensor(ifft_truth))
            ifft_pred = pad_unsqueeze(torch.tensor(ifft_pred))

        vals.extend([
            ms_ssim(pred.unsqueeze(0),
                    truth.unsqueeze(0),
                    data_range=truth.max())
            for pred, truth in zip(ifft_pred, ifft_truth)
        ])

    click.echo("\nCreating ms-ssim histogram.\n")
    vals = torch.tensor(vals)
    histogram_ms_ssim(
        vals,
        out_path,
        plot_format=conf["format"],
    )

    click.echo(f"\nThe mean ms-ssim value is {vals.mean()}.\n")
def evaluate_point(conf):
    # create DataLoader
    loader = create_databunch(conf["data_path"], conf["fourier"],
                              conf["source_list"], conf["batch_size"])
    model_path = conf["model_path"]
    out_path = Path(model_path).parent / "evaluation"
    out_path.mkdir(parents=True, exist_ok=True)

    img_size = loader.dataset[0][0][0].shape[-1]
    model = load_pretrained_model(conf["arch_name"], conf["model_path"],
                                  img_size)
    if conf["model_path_2"] != "none":
        model_2 = load_pretrained_model(conf["arch_name_2"],
                                        conf["model_path_2"], img_size)

    vals = []
    lengths = []

    for i, (img_test, img_true, source_list) in enumerate(tqdm(loader)):

        pred = eval_model(img_test, model)
        if conf["model_path_2"] != "none":
            pred_2 = eval_model(img_test, model_2)
            pred = torch.cat((pred, pred_2), dim=1)

        ifft_truth = get_ifft(img_true, amp_phase=conf["amp_phase"])
        ifft_pred = get_ifft(pred, amp_phase=conf["amp_phase"])

        fluxes_pred, fluxes_truth, length = flux_comparison(
            ifft_pred, ifft_truth, source_list)
        val = ((fluxes_pred - fluxes_truth) / fluxes_truth) * 100
        vals += list(val)
        lengths += list(length)

    vals = np.concatenate(vals).ravel()
    lengths = np.array(lengths, dtype="object")
    mask = lengths < 10

    click.echo("\nCreating pointsources histogram.\n")
    hist_point(vals, mask, out_path, plot_format=conf["format"])
    click.echo(f"\nThe mean flux difference is {vals.mean()}.\n")
    click.echo("\nCreating linear extent-mean flux diff plot.\n")
    plot_length_point(lengths,
                      vals,
                      mask,
                      out_path,
                      plot_format=conf["format"])
def evaluate_mean_diff(conf):
    # create DataLoader
    loader = create_databunch(conf["data_path"], conf["fourier"],
                              conf["source_list"], conf["batch_size"])
    model_path = conf["model_path"]
    out_path = Path(model_path).parent / "evaluation"
    out_path.mkdir(parents=True, exist_ok=True)

    img_size = loader.dataset[0][0][0].shape[-1]
    model = load_pretrained_model(conf["arch_name"], conf["model_path"],
                                  img_size)
    if conf["model_path_2"] != "none":
        model_2 = load_pretrained_model(conf["arch_name_2"],
                                        conf["model_path_2"], img_size)

    vals = []

    # iterate trough DataLoader
    for i, (img_test, img_true) in enumerate(tqdm(loader)):

        pred = eval_model(img_test, model)
        if conf["model_path_2"] != "none":
            pred_2 = eval_model(img_test, model_2)
            pred = torch.cat((pred, pred_2), dim=1)

        ifft_truth = get_ifft(img_true, amp_phase=conf["amp_phase"])
        ifft_pred = get_ifft(pred, amp_phase=conf["amp_phase"])

        for pred, truth in zip(ifft_pred, ifft_truth):
            blobs_pred, blobs_truth = calc_blobs(pred, truth)
            flux_pred, flux_truth = crop_first_component(
                pred, truth, blobs_truth[0])
            vals.extend([
                (flux_pred.mean() - flux_truth.mean()) / flux_truth.mean()
            ])

    click.echo("\nCreating mean_diff histogram.\n")
    vals = torch.tensor(vals) * 100
    histogram_mean_diff(
        vals,
        out_path,
        plot_format=conf["format"],
    )

    click.echo(f"\nThe mean difference is {vals.mean()}.\n")
def evaluate_dynamic_range(conf):
    # create Dataloader
    loader = create_databunch(conf["data_path"], conf["fourier"],
                              conf["source_list"], conf["batch_size"])
    model_path = conf["model_path"]
    out_path = Path(model_path).parent / "evaluation"
    out_path.mkdir(parents=True, exist_ok=True)

    img_size = loader.dataset[0][0][0].shape[-1]
    model = load_pretrained_model(conf["arch_name"], conf["model_path"],
                                  img_size)
    if conf["model_path_2"] != "none":
        model_2 = load_pretrained_model(conf["arch_name_2"],
                                        conf["model_path_2"], img_size)

    dr_truths = np.array([])
    dr_preds = np.array([])

    # iterate trough DataLoader
    for i, (img_test, img_true) in enumerate(tqdm(loader)):

        pred = eval_model(img_test, model)
        if conf["model_path_2"] != "none":
            pred_2 = eval_model(img_test, model_2)
            pred = torch.cat((pred, pred_2), dim=1)

        ifft_truth = get_ifft(img_true, amp_phase=conf["amp_phase"])
        ifft_pred = get_ifft(pred, amp_phase=conf["amp_phase"])

        dr_truth, dr_pred, _, _ = calc_dr(ifft_truth, ifft_pred)
        dr_truths = np.append(dr_truths, dr_truth)
        dr_preds = np.append(dr_preds, dr_pred)

    click.echo(f"\nMean dynamic range for true source distributions:\
            {round(dr_truths.mean())}\n")
    click.echo(f"\nMean dynamic range for predicted source distributions:\
            {round(dr_preds.mean())}\n")

    click.echo("\nCreating histogram of dynamic ranges.\n")
    histogram_dynamic_ranges(
        dr_truths,
        dr_preds,
        out_path,
        plot_format=conf["format"],
    )
def evaluate_viewing_angle(conf):
    # create DataLoader
    loader = create_databunch(conf["data_path"], conf["fourier"],
                              conf["source_list"], conf["batch_size"])
    model_path = conf["model_path"]
    out_path = Path(model_path).parent / "evaluation"
    out_path.mkdir(parents=True, exist_ok=True)

    img_size = loader.dataset[0][0][0].shape[-1]
    model = load_pretrained_model(conf["arch_name"], conf["model_path"],
                                  img_size)
    if conf["model_path_2"] != "none":
        model_2 = load_pretrained_model(conf["arch_name_2"],
                                        conf["model_path_2"], img_size)

    alpha_truths = []
    alpha_preds = []

    # iterate trough DataLoader
    for i, (img_test, img_true) in enumerate(tqdm(loader)):
        pred = eval_model(img_test, model)
        if conf["model_path_2"] != "none":
            pred_2 = eval_model(img_test, model_2)
            pred = torch.cat((pred, pred_2), dim=1)

        ifft_truth = get_ifft(img_true, amp_phase=conf["amp_phase"])
        ifft_pred = get_ifft(pred, amp_phase=conf["amp_phase"])

        m_truth, n_truth, alpha_truth = calc_jet_angle(
            torch.tensor(ifft_truth))
        m_pred, n_pred, alpha_pred = calc_jet_angle(torch.tensor(ifft_pred))

        alpha_truths.extend(abs(alpha_truth))
        alpha_preds.extend(abs(alpha_pred))

    alpha_truths = torch.tensor(alpha_truths)
    alpha_preds = torch.tensor(alpha_preds)

    click.echo("\nCreating histogram of jet angles.\n")
    histogram_jet_angles(
        alpha_truths,
        alpha_preds,
        out_path,
        plot_format=conf["format"],
    )
def evaluate_area(conf):
    # create DataLoader
    loader = create_databunch(conf["data_path"], conf["fourier"],
                              conf["source_list"], conf["batch_size"])
    model_path = conf["model_path"]
    out_path = Path(model_path).parent / "evaluation"
    out_path.mkdir(parents=True, exist_ok=True)

    img_size = loader.dataset[0][0][0].shape[-1]
    model = load_pretrained_model(conf["arch_name"], conf["model_path"],
                                  img_size)
    if conf["model_path_2"] != "none":
        model_2 = load_pretrained_model(conf["arch_name_2"],
                                        conf["model_path_2"], img_size)

    vals = []

    # iterate trough DataLoader
    for i, (img_test, img_true) in enumerate(tqdm(loader)):

        pred = eval_model(img_test, model)
        if conf["model_path_2"] != "none":
            pred_2 = eval_model(img_test, model_2)
            pred = torch.cat((pred, pred_2), dim=1)

        ifft_truth = get_ifft(img_true, amp_phase=conf["amp_phase"])
        ifft_pred = get_ifft(pred, amp_phase=conf["amp_phase"])

        for pred, truth in zip(ifft_pred, ifft_truth):
            val = area_of_contour(pred, truth)
            vals.extend([val])

    click.echo("\nCreating eval_area histogram.\n")
    vals = torch.tensor(vals)
    histogram_area(
        vals,
        out_path,
        plot_format=conf["format"],
    )

    click.echo(f"\nThe mean area ratio is {vals.mean()}.\n")