def save_image(image, fname=None, dir_path='debug/images/'):
    if fname is None:
        fname = 'image_{}.png'.format(save_image.image_counter)
        save_image.image_counter = save_image.image_counter + 1
    make_dir(dir_path)
    fpath = os.path.join(dir_path, fname)
    save_image_core(image, fpath)
def save_plot(plt, fname=None, save_dir='debug/plots/'):
    plt.tight_layout()
    plt.draw()
    if fname is None:
        fname = 'plot_{}.png'.format(save_plot.plt_counter)
        save_plot.plt_counter = save_plot.plt_counter + 1

    make_dir(save_dir)
    if not 'png' in fname and not 'pdf' in fname:
        fname = fname + '.png'

    save_path = os.path.join(save_dir, fname)
    plt.savefig(save_path)
    print('[-] Saved plot to {}'.format(save_path))
    plt.clf()
    plt.close()
    gc.collect()
Example #3
0
def validate(cfg, valid_loader, GENERATORS, cri_mse, device, epoch, save_path,
             is_visual):
    print("=" * 30)
    print("START VALIDATON")
    psnr_result = OrderedDict()
    # switch to evaluate mode
    for name, models in GENERATORS.items():
        if models is not None:
            models.eval()
            psnr_result['{}'.format(name)] = AverageMeter()

    for _, data in enumerate(valid_loader):
        gt_img, gt_filename, blur_img, blur_filename = data

        batch_size = gt_img.size(0)
        gt_img, blur_img = prepare([gt_img, blur_img], device)

        with torch.no_grad():
            outputs = GENERATORS['netG'](blur_img)

        mse = cri_mse(gt_img, outputs[-1])
        psnr = 10 * log10(1 / mse.item())
        psnr_result['{}'.format(name)].update(psnr, batch_size)

        if is_visual:
            gt_filename = gt_filename[0]
            blur_filename = blur_filename[0]

            save_out_path = os.path.join(save_path, 'output')
            make_dir(save_out_path)
            save_ep_path = os.path.join(save_out_path, 'ep_{}'.format(epoch))
            make_dir(save_ep_path)

            output_list = [o[0, :, :, :] for o in outputs]
            output_list = tensor2img_list(output_list)
            for i in range(len(output_list)):
                save_name = os.path.join(
                    save_ep_path, '{}_out{}.png'.format(blur_filename, i))
                cv2.imwrite(save_name, output_list[i])

    return psnr_result
Example #4
0
def set_dirs(cfg):
    # 1. Make Dirs
    save_dir = cfg.SAVE_DIR
    make_dir(save_dir)

    model_dir = os.path.join(save_dir, 'checkpoints')
    make_dir(model_dir)

    valid_dir = os.path.join(save_dir, 'valid')
    make_dir(valid_dir)

    log_dir = os.path.join(save_dir, 'logs')
    make_dir(log_dir)

    return save_dir, model_dir, valid_dir, log_dir
        "save_npy": True,  # 是否将评估结果到npy文件中,该文件可用来绘制pr和fm曲线
        # 保存曲线指标数据的文件路径
        "qualitative_npy_path": os.path.join(
            output_path, data_type + "_" + "qualitative_results.npy"
        ),
        "quantitative_npy_path": os.path.join(
            output_path, data_type + "_" + "quantitative_results.npy"
        ),
        "axes_setting": {  # 不同曲线的绘图配置
            "pr": {  # pr曲线的配置
                "x_label": "Recall",  # 横坐标标签
                "y_label": "Precision",  # 纵坐标标签
                "x_lim": (0.1, 1),  # 横坐标显示范围
                "y_lim": (0.1, 1),  # 纵坐标显示范围
            },
            "fm": {  # fm曲线的配置
                "x_label": "Threshold",  # 横坐标标签
                "y_label": r"F$_{\beta}$",  # 纵坐标标签
                "x_lim": (0, 1),  # 横坐标显示范围
                "y_lim": (0, 0.9),  # 纵坐标显示范围
            },
        },
        "bit_num": 3,  # 评估结果保留的小数点后数据的位数
        "resume_record": True,  # 是否保留之前的评估记录(针对record_path文件有效)
        "skipped_names": [],
    }

    make_dir(output_path)
    cal_all_metrics()
    # draw_pr_fm_curve(for_pr=True)
Example #6
0
def get_args():
    parser = argparse.ArgumentParser(
        description=textwrap.dedent(r"""
    INCLUDE:

    - F-measure-Threshold Curve
    - Precision-Recall Curve
    - MAE
    - weighted F-measure
    - S-measure
    - max/average/adaptive F-measure
    - max/average/adaptive E-measure
    - max/average Precision
    - max/average Sensitivity
    - max/average Specificity
    - max/average F-measure
    - max/average Dice
    - max/average IoU

    NOTE:

    - Our method automatically calculates the intersection of `pre` and `gt`.
    - Currently supported pre naming rules: `prefix + gt_name_wo_ext + suffix_w_ext`

    EXAMPLES:

    python eval_all.py \
        --dataset-json configs/datasets/json/rgbd_sod.json \
        --method-json configs/methods/json/rgbd_other_methods.json configs/methods/json/rgbd_our_method.json --metric-npy output/rgbd_metrics.npy \
        --curves-npy output/rgbd_curves.npy \
        --record-tex output/rgbd_results.txt
    """),
        formatter_class=argparse.RawTextHelpFormatter,
    )
    parser.add_argument("--dataset-json",
                        required=True,
                        type=str,
                        help="Json file for datasets.")
    parser.add_argument("--method-json",
                        required=True,
                        nargs="+",
                        type=str,
                        help="Json file for methods.")
    parser.add_argument("--metric-npy",
                        type=str,
                        help="Npy file for saving metric results.")
    parser.add_argument("--curves-npy",
                        type=str,
                        help="Npy file for saving curve results.")
    parser.add_argument("--record-txt",
                        type=str,
                        help="Txt file for saving metric results.")
    parser.add_argument("--to-overwrite",
                        action="store_true",
                        help="To overwrite the txt file.")
    parser.add_argument("--record-xlsx",
                        type=str,
                        help="Xlsx file for saving metric results.")
    parser.add_argument(
        "--include-methods",
        type=str,
        nargs="+",
        help="Names of only specific methods you want to evaluate.",
    )
    parser.add_argument(
        "--exclude-methods",
        type=str,
        nargs="+",
        help="Names of some specific methods you do not want to evaluate.",
    )
    parser.add_argument(
        "--include-datasets",
        type=str,
        nargs="+",
        help="Names of only specific datasets you want to evaluate.",
    )
    parser.add_argument(
        "--exclude-datasets",
        type=str,
        nargs="+",
        help="Names of some specific datasets you do not want to evaluate.",
    )
    parser.add_argument(
        "--num-workers",
        type=int,
        default=4,
        help=
        "Number of workers for multi-threading or multi-processing. Default: 4",
    )
    parser.add_argument(
        "--num-bits",
        type=int,
        default=3,
        help="Number of decimal places for showing results. Default: 3",
    )
    parser.add_argument(
        "--metric-names",
        type=str,
        nargs="+",
        default=["mae", "fm", "em", "sm", "wfm"],
        choices=METRIC_MAPPING.keys(),
        help="Names of metrics",
    )
    args = parser.parse_args()

    if args.metric_npy is not None:
        make_dir(os.path.dirname(args.metric_npy))
    if args.curves_npy is not None:
        make_dir(os.path.dirname(args.curves_npy))
    if args.record_txt is not None:
        make_dir(os.path.dirname(args.record_txt))
    if args.record_xlsx is not None:
        make_dir(os.path.dirname(args.record_xlsx))
    if args.to_overwrite and not args.record_txt:
        warnings.warn("--to-overwrite only works with a valid --record-txt")
    return args
def cal_cosod_matrics(
    data_type: str = "rgb_sod",
    txt_path: str = "",
    to_append: bool = True,
    xlsx_path: str = "",
    drawing_info: dict = None,
    dataset_info: dict = None,
    save_npy: bool = True,
    curves_npy_path: str = "./curves.npy",
    metrics_npy_path: str = "./metrics.npy",
    num_bits: int = 3,
):
    """
    Save the results of all models on different datasets in a `npy` file in the form of a
    dictionary.

    ::

        {
          dataset1:{
            method1:[fm, em, p, r],
            method2:[fm, em, p, r],
            .....
          },
          dataset2:{
            method1:[fm, em, p, r],
            method2:[fm, em, p, r],
            .....
          },
          ....
        }

    :param data_type: the type of data
    :param txt_path: the path of the txt for saving results
    :param to_append: whether to append results to the original record
    :param xlsx_path: the path of the xlsx file for saving results
    :param drawing_info: the method information for plotting figures
    :param dataset_info: the dataset information
    :param save_npy: whether to save results into npy files
    :param curves_npy_path: the npy file path for saving curve data
    :param metrics_npy_path: the npy file path for saving metric values
    :param num_bits: the number of bits used to format results
    """
    curves = defaultdict(dict)  # Two curve metrics
    metrics = defaultdict(dict)  # Six numerical metrics

    txt_recoder = TxtRecorder(
        txt_path=txt_path,
        to_append=to_append,
        max_method_name_width=max([len(x) for x in drawing_info.keys()]),  # 显示完整名字
    )
    excel_recorder = MetricExcelRecorder(
        xlsx_path=xlsx_path,
        sheet_name=data_type,
        row_header=["methods"],
        dataset_names=sorted(list(dataset_info.keys())),
        metric_names=["sm", "wfm", "mae", "adpf", "avgf", "maxf", "adpe", "avge", "maxe"],
    )

    for dataset_name, dataset_path in dataset_info.items():
        txt_recoder.add_row(row_name="Dataset", row_data=dataset_name, row_start_str="\n")

        # 获取真值图片信息
        gt_info = dataset_path["mask"]
        gt_root = gt_info["path"]
        gt_ext = gt_info["suffix"]
        # 真值名字列表
        gt_index_file = dataset_path.get("index_file")
        if gt_index_file:
            gt_name_list = get_name_with_group_list(data_path=gt_index_file, file_ext=gt_ext)
        else:
            gt_name_list = get_name_with_group_list(data_path=gt_root, file_ext=gt_ext)
        assert len(gt_name_list) > 0, "there is not ground truth."

        # ==>> test the intersection between pre and gt for each method <<==
        for method_name, method_info in drawing_info.items():
            method_root = method_info["path_dict"]
            method_dataset_info = method_root.get(dataset_name, None)
            if method_dataset_info is None:
                colored_print(
                    msg=f"{method_name} does not have results on {dataset_name}", mode="warning"
                )
                continue

            # 预测结果存放路径下的图片文件名字列表和扩展名称
            pre_ext = method_dataset_info["suffix"]
            pre_root = method_dataset_info["path"]
            pre_name_list = get_name_with_group_list(data_path=pre_root, file_ext=pre_ext)

            # get the intersection
            eval_name_list = sorted(list(set(gt_name_list).intersection(set(pre_name_list))))
            num_names = len(eval_name_list)

            if num_names == 0:
                colored_print(
                    msg=f"{method_name} does not have results on {dataset_name}", mode="warning"
                )
                continue

            grouped_data = group_names(names=eval_name_list)
            num_groups = len(grouped_data)

            colored_print(
                f"Evaluating {method_name} with {num_names} images and {num_groups} groups"
                f" (G:{len(gt_name_list)},P:{len(pre_name_list)}) images on dataset {dataset_name}"
            )

            group_metric_recorder = GroupedMetricRecorder()
            inter_group_bar = tqdm(
                grouped_data.items(),
                total=num_groups,
                leave=False,
                ncols=79,
                desc=f"[{dataset_name}]",
            )
            for group_name, names_in_group in inter_group_bar:
                intra_group_bar = tqdm(
                    names_in_group,
                    total=len(names_in_group),
                    leave=False,
                    ncols=79,
                    desc=f"({group_name})",
                )
                for img_name in intra_group_bar:
                    img_name_with_group = os.path.join(group_name, img_name)
                    gt, pre = get_gt_pre_with_name(
                        gt_root=gt_root,
                        pre_root=pre_root,
                        img_name=img_name_with_group,
                        pre_ext=pre_ext,
                        gt_ext=gt_ext,
                        to_normalize=False,
                    )
                    group_metric_recorder.update(group_name=group_name, pre=pre, gt=gt)
            method_results = group_metric_recorder.show(num_bits=num_bits, return_ndarray=False)
            method_curves = method_results["sequential"]
            method_metrics = method_results["numerical"]
            curves[dataset_name][method_name] = method_curves
            metrics[dataset_name][method_name] = method_metrics

            excel_recorder(
                row_data=method_metrics, dataset_name=dataset_name, method_name=method_name
            )
            txt_recoder(method_results=method_metrics, method_name=method_name)

    if save_npy:
        make_dir(os.path.basename(curves_npy_path))
        np.save(curves_npy_path, curves)
        np.save(metrics_npy_path, metrics)
        colored_print(f"all methods have been saved in {curves_npy_path} and {metrics_npy_path}")
    formatted_string = formatter_for_tabulate(metrics)
    colored_print(f"all methods have been tested:\n{formatted_string}")
Example #8
0
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
from torch.utils import data
from torchvision import transforms

from model.generator import MSPL_Generator
from utils.misc import make_dir, write_log
from utils.visualize import denorm, tensor2img

# 1. Set Paths
MODEL_DIR = ''
INPUT_IMG_DIR = ''
OUTPUT_IMG_DIR = ''
make_dir(OUTPUT_IMG_DIR)

# 2. Set GPU or CPU
device = torch.device('cuda')  # if USE_GPU else device = torch.device('cpu')

# 2. Model Load
netG = MSPL_Generator(3, 3, 128, [4, 4, 4, 4]).to(device)
netG = nn.DataParallel(netG)

rgb_to_tensor = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

if isinstance(netG, nn.DataParallel):
    netG = netG.module
def cal_sod_matrics(
    sheet_name: str = "results",
    txt_path: str = "",
    to_append: bool = True,
    xlsx_path: str = "",
    methods_info: dict = None,
    datasets_info: dict = None,
    curves_npy_path: str = "./curves.npy",
    metrics_npy_path: str = "./metrics.npy",
    num_bits: int = 3,
    num_workers: int = 2,
    use_mp: bool = False,
    metric_names: tuple = ("mae", "fm", "em", "sm", "wfm"),
    ncols_tqdm: int = 79,
):
    """
    Save the results of all models on different datasets in a `npy` file in the form of a
    dictionary.

    ::

        {
          dataset1:{
            method1:[fm, em, p, r],
            method2:[fm, em, p, r],
            .....
          },
          dataset2:{
            method1:[fm, em, p, r],
            method2:[fm, em, p, r],
            .....
          },
          ....
        }

    :param sheet_name: the type of the sheet in xlsx file
    :param txt_path: the path of the txt for saving results
    :param to_append: whether to append results to the original record
    :param xlsx_path: the path of the xlsx file for saving results
    :param methods_info: the method information
    :param datasets_info: the dataset information
    :param curves_npy_path: the npy file path for saving curve data
    :param metrics_npy_path: the npy file path for saving metric values
    :param num_bits: the number of bits used to format results
    :param num_workers: the number of workers of multiprocessing or multithreading
    :param use_mp: using multiprocessing or multithreading
    :param metric_names: names of metrics
    :param ncols_tqdm: number of columns for tqdm
    """
    recorder = Recorder(
        txt_path=txt_path,
        to_append=to_append,
        max_method_name_width=max([len(x)
                                   for x in methods_info.keys()]),  # 显示完整名字
        xlsx_path=xlsx_path,
        sheet_name=sheet_name,
        dataset_names=sorted(datasets_info.keys()),
        metric_names=[
            "sm", "wfm", "mae", "adpf", "avgf", "maxf", "adpe", "avge", "maxe"
        ],
    )

    for dataset_name, dataset_path in datasets_info.items():
        recorder.record_dataset_name(dataset_name)

        # 获取真值图片信息
        gt_info = dataset_path["mask"]
        gt_root = gt_info["path"]
        gt_ext = gt_info["suffix"]
        # 真值名字列表
        gt_index_file = dataset_path.get("index_file")
        if gt_index_file:
            gt_name_list = get_name_list(data_path=gt_index_file,
                                         name_suffix=gt_ext)
        else:
            gt_name_list = get_name_list(data_path=gt_root, name_suffix=gt_ext)
        assert len(gt_name_list) > 0, "there is not ground truth."

        # ==>> test the intersection between pre and gt for each method <<==
        tqdm.set_lock(RLock())
        pool_cls = pool.Pool if use_mp else pool.ThreadPool
        procs = pool_cls(processes=num_workers,
                         initializer=tqdm.set_lock,
                         initargs=(tqdm.get_lock(), ))
        procs_idx = 0
        for method_name, method_info in methods_info.items():
            method_root = method_info["path_dict"]
            method_dataset_info = method_root.get(dataset_name, None)
            if method_dataset_info is None:
                tqdm.write(
                    f"{method_name} does not have results on {dataset_name}")
                continue

            # 预测结果存放路径下的图片文件名字列表和扩展名称
            pre_prefix = method_dataset_info.get("prefix", "")
            pre_suffix = method_dataset_info["suffix"]
            pre_root = method_dataset_info["path"]
            pre_name_list = get_name_list(
                data_path=pre_root,
                name_prefix=pre_prefix,
                name_suffix=pre_suffix,
            )

            # get the intersection
            eval_name_list = sorted(
                list(set(gt_name_list).intersection(pre_name_list)))
            if len(eval_name_list) == 0:
                tqdm.write(
                    f"{method_name} does not have results on {dataset_name}")
                continue

            procs.apply_async(
                func=evaluate_data,
                kwds=dict(
                    names=eval_name_list,
                    num_bits=num_bits,
                    pre_root=pre_root,
                    pre_prefix=pre_prefix,
                    pre_suffix=pre_suffix,
                    gt_root=gt_root,
                    gt_ext=gt_ext,
                    desc=
                    f"[{dataset_name}({len(gt_name_list)}):{method_name}({len(pre_name_list)})]",
                    proc_idx=procs_idx,
                    blocking=use_mp,
                    metric_names=metric_names,
                    ncols_tqdm=ncols_tqdm,
                ),
                callback=partial(recorder.record, method_name=method_name),
            )
            procs_idx += 1
        procs.close()
        procs.join()

    if curves_npy_path:
        make_dir(os.path.dirname(curves_npy_path))
        np.save(curves_npy_path, recorder.curves)
        print(f"All curves has been saved in {curves_npy_path}")
    if metrics_npy_path:
        make_dir(os.path.dirname(metrics_npy_path))
        np.save(metrics_npy_path, recorder.metrics)
        print(f"All metrics has been saved in {metrics_npy_path}")
    formatted_string = formatter_for_tabulate(recorder.metrics)
    print(f"All methods have been evaluated:\n{formatted_string}")