Exemplo n.º 1
0
def images(download_path):
    download_file_by_url(cmr_url, download_path, "SA_64x64.zip", "zip")
    img_path = os.path.join(download_path, "SA_64x64", "DICOM")
    cmr_images = read_dicom_images(img_path,
                                   sort_instance=True,
                                   sort_patient=True)

    return cmr_images[:5, ...]
Exemplo n.º 2
0
def test_download_file_by_url(param):
    url, output_file_name, file_format = param.split(";")

    # run twice to test the code when the file exist
    download_file_by_url(url, output_directory, output_file_name, file_format)
    download_file_by_url(url, output_directory, output_file_name, file_format)

    assert os.path.exists(output_directory.joinpath(output_file_name)) is True
Exemplo n.º 3
0
def images(download_path):
    download_file_by_url(cmr_url, download_path, "SA_64x64.zip", "zip")
    img_path = os.path.join(download_path, "SA_64x64", "DICOM")
    cmr_dcm_list = read_dicom_dir(img_path,
                                  sort_instance=True,
                                  sort_patient=True)
    cmr_images = dicom2arraylist(dicom_patient_list=cmr_dcm_list,
                                 return_patient_id=False)

    return cmr_images[:5]
Exemplo n.º 4
0
Arquivo: main.py Projeto: sz144/pykale
def main():
    args = arg_parse()
    # ---- setup device ----
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print("==> Using device " + device)

    # ---- setup configs ----
    cfg = get_cfg_defaults()
    cfg.merge_from_file(args.cfg)
    cfg.freeze()
    seed.set_seed(cfg.SOLVER.SEED)
    # ---- setup logger and output ----
    output_dir = os.path.join(cfg.OUTPUT_DIR, cfg.DATASET.NAME, args.output)
    os.makedirs(output_dir, exist_ok=True)
    logger = lu.construct_logger("gripnet", output_dir)
    logger.info("Using " + device)
    logger.info(cfg.dump())
    # ---- setup dataset ----
    download_file_by_url(cfg.DATASET.URL, cfg.DATASET.ROOT, "pose.pt", "pt")
    data = torch.load(os.path.join(cfg.DATASET.ROOT, "pose.pt"))
    device = torch.device(device)
    data = data.to(device)
    # ---- setup model ----
    print("==> Building model..")
    model = GripNet(
        cfg.GRIPN.GG_LAYERS, cfg.GRIPN.GD_LAYERS, cfg.GRIPN.DD_LAYERS, data.n_d_node, data.n_g_node, data.n_dd_edge_type
    ).to(device)
    # TODO Visualize model
    # ---- setup trainers ----
    optimizer = torch.optim.Adam(model.parameters(), lr=cfg.SOLVER.BASE_LR)
    # TODO
    trainer = Trainer(cfg, device, data, model, optimizer, logger, output_dir)

    if args.resume:
        # Load checkpoint
        print("==> Resuming from checkpoint..")
        cp = torch.load(args.resume)
        trainer.model.load_state_dict(cp["net"])
        trainer.optim.load_state_dict(cp["optim"])
        trainer.epochs = cp["epoch"]
        trainer.train_auprc = cp["train_auprc"]
        trainer.valid_auprc = cp["valid_auprc"]
        trainer.train_auroc = cp["train_auroc"]
        trainer.valid_auroc = cp["valid_auroc"]
        trainer.train_ap = cp["train_ap"]
        trainer.valid_ap = cp["valid_ap"]

    trainer.train()
Exemplo n.º 5
0
def images(download_path):
    download_file_by_url(cmr_url, download_path, "SA_64x64.zip", "zip")
    img_path = os.path.join(download_path, "SA_64x64_v2.0", "DICOM")
    cmr_dcm_list = read_dicom_dir(img_path,
                                  sort_instance=True,
                                  sort_patient=True,
                                  check_series_uid=True)
    dcms = []
    for i in range(5):
        for j in range(len(cmr_dcm_list[i])):
            dcms.append(cmr_dcm_list[i][j])
    dcm5_list = check_dicom_series_uid(dcms)
    cmr_images = dicom2arraylist(dicom_patient_list=dcm5_list,
                                 return_patient_id=False)

    return cmr_images
Exemplo n.º 6
0
    def download(path):
        """Download dataset.
            Office-31 source: https://www.cc.gatech.edu/~judy/domainadapt/#datasets_code
            Caltech-256 source: http://www.vision.caltech.edu/Image_Datasets/Caltech256/
            Data with this library is adapted from: http://www.stat.ucla.edu/~jxie/iFRAME/code/imageClassification.rar
        """
        url = "https://github.com/pykale/data/raw/main/images/office"

        if not os.path.exists(path):
            os.makedirs(path)
        for domain_ in OFFICE_DOMAINS:
            filename = "%s.zip" % domain_
            data_path = os.path.join(path, filename)
            if os.path.exists(data_path):
                logging.info(f"Data file {filename} already exists.")
                continue
            else:
                data_url = "%s/%s" % (url, filename)
                download_file_by_url(data_url, path, filename, "zip")
                logging.info(f"Download {data_url} to {data_path}")

        logging.info("[DONE]")
        return
Exemplo n.º 7
0
def baseline_model(download_path):
    download_file_by_url(baseline_url, download_path, "baseline.mat", "mat")
    return loadmat(os.path.join(download_path, "baseline.mat"))
Exemplo n.º 8
0
def pose_data(download_path):
    download_file_by_url(pose_url, download_path, "pose.pt", "pt")
    return torch.load(os.path.join(download_path, "pose.pt"))
Exemplo n.º 9
0
def main():
    args = arg_parse()

    # ---- setup configs ----
    cfg = get_cfg_defaults()
    cfg.merge_from_file(args.cfg)
    cfg.freeze()
    print(cfg)

    save_images = cfg.OUTPUT.SAVE_IMAGES
    print(f"Save Images: {save_images}")

    # ---- initialize folder to store images ----
    save_images_location = cfg.OUTPUT.ROOT
    print(f"Save Images: {save_images_location}")

    if not os.path.exists(save_images_location):
        os.makedirs(save_images_location)

    # ---- setup dataset ----
    base_dir = cfg.DATASET.BASE_DIR
    file_format = cfg.DATASET.FILE_FORAMT
    download_file_by_url(cfg.DATASET.SOURCE, cfg.DATASET.ROOT,
                         "%s.%s" % (base_dir, file_format), file_format)

    img_path = os.path.join(cfg.DATASET.ROOT, base_dir, cfg.DATASET.IMG_DIR)
    images = read_dicom_images(img_path, sort_instance=True, sort_patient=True)

    mask_path = os.path.join(cfg.DATASET.ROOT, base_dir, cfg.DATASET.MASK_DIR)
    mask = read_dicom_images(mask_path, sort_instance=True)

    landmark_path = os.path.join(cfg.DATASET.ROOT, base_dir,
                                 cfg.DATASET.LANDMARK_FILE)
    landmark_df = pd.read_csv(
        landmark_path, index_col="Subject")  # read .csv file as dataframe
    landmarks = landmark_df.iloc[:, :6].values
    y = landmark_df["Group"].values
    y[np.where(
        y != 0
    )] = 1  # convert to binary classification problem, i.e. no PH vs PAH

    # plot the first phase of images
    if save_images:
        visualize.plot_multi_images(
            images[:, 0, ...],
            marker_locs=landmarks,
            im_kwargs=dict(cfg.IM_KWARGS),
            marker_kwargs=dict(cfg.MARKER_KWARGS),
        ).savefig(str(save_images_location) + "/0)first_phase.png")

    # ---- data pre-processing ----
    # ----- image registration -----
    img_reg, max_dist = reg_img_stack(images.copy(), landmarks)
    if save_images:
        visualize.plot_multi_images(img_reg[:, 0, ...],
                                    im_kwargs=dict(cfg.IM_KWARGS)).savefig(
                                        str(save_images_location) +
                                        "/1)image_registration")

    # ----- masking -----
    img_masked = mask_img_stack(img_reg.copy(), mask[0, 0, ...])
    if save_images:
        visualize.plot_multi_images(img_masked[:, 0, ...],
                                    im_kwargs=dict(cfg.IM_KWARGS)).savefig(
                                        str(save_images_location) +
                                        "/2)masking")

    # ----- resize -----
    img_rescaled = rescale_img_stack(img_masked.copy(),
                                     scale=1 / cfg.PROC.SCALE)
    if save_images:
        visualize.plot_multi_images(img_rescaled[:, 0, ...],
                                    im_kwargs=dict(cfg.IM_KWARGS)).savefig(
                                        str(save_images_location) +
                                        "/3)resize")

    # ----- normalization -----
    img_norm = normalize_img_stack(img_rescaled.copy())
    if save_images:
        visualize.plot_multi_images(img_norm[:, 0, ...],
                                    im_kwargs=dict(cfg.IM_KWARGS)).savefig(
                                        str(save_images_location) +
                                        "/4)normalize")

    # ---- evaluating machine learning pipeline ----
    x = img_norm.copy()
    trainer = MPCATrainer(classifier=cfg.PIPELINE.CLASSIFIER, n_features=200)
    cv_results = cross_validate(trainer,
                                x,
                                y,
                                cv=10,
                                scoring=["accuracy", "roc_auc"],
                                n_jobs=1)

    print("Averaged training time: {:.4f} seconds".format(
        np.mean(cv_results["fit_time"])))
    print("Averaged testing time: {:.4f} seconds".format(
        np.mean(cv_results["score_time"])))
    print("Averaged Accuracy: {:.4f}".format(
        np.mean(cv_results["test_accuracy"])))
    print("Averaged AUC: {:.4f}".format(np.mean(cv_results["test_roc_auc"])))

    # ---- model weights interpretation ----
    trainer.fit(x, y)

    weights = trainer.mpca.inverse_transform(
        trainer.clf.coef_) - trainer.mpca.mean_
    weights = rescale_img_stack(
        weights, cfg.PROC.SCALE)  # rescale weights to original shape
    weights = mask_img_stack(weights, mask[0, 0, ...])  # masking weights
    top_weights = model_weights.select_top_weight(
        weights, select_ratio=0.02)  # select top 2% weights
    if save_images:
        visualize.plot_weights(
            top_weights[0][0],
            background_img=images[0][0],
            im_kwargs=dict(cfg.IM_KWARGS),
            marker_kwargs=dict(cfg.WEIGHT_KWARGS),
        ).savefig(str(save_images_location) + "/5)weights")
Exemplo n.º 10
0
def main():
    args = arg_parse()

    # ---- setup configs ----
    cfg = get_cfg_defaults()
    cfg.merge_from_file(args.cfg)
    cfg.freeze()
    print(cfg)

    save_figs = cfg.OUTPUT.SAVE_FIG
    fig_format = cfg.SAVE_FIG_KWARGS.format
    print(f"Save Figures: {save_figs}")

    # ---- initialize folder to store images ----
    save_figures_location = cfg.OUTPUT.ROOT
    print(f"Save Figures: {save_figures_location}")

    if not os.path.exists(save_figures_location):
        os.makedirs(save_figures_location)

    # ---- setup dataset ----
    base_dir = cfg.DATASET.BASE_DIR
    file_format = cfg.DATASET.FILE_FORAMT
    download_file_by_url(cfg.DATASET.SOURCE, cfg.DATASET.ROOT, "%s.%s" % (base_dir, file_format), file_format)

    img_path = os.path.join(cfg.DATASET.ROOT, base_dir, cfg.DATASET.IMG_DIR)
    patient_dcm_list = read_dicom_dir(img_path, sort_instance=True, sort_patient=True)
    images, patient_ids = dicom2arraylist(patient_dcm_list, return_patient_id=True)
    patient_ids = np.array(patient_ids, dtype=int)
    n_samples = len(images)

    mask_path = os.path.join(cfg.DATASET.ROOT, base_dir, cfg.DATASET.MASK_DIR)
    mask_dcm = read_dicom_dir(mask_path, sort_instance=True)
    mask = dicom2arraylist(mask_dcm, return_patient_id=False)[0][0, ...]

    landmark_path = os.path.join(cfg.DATASET.ROOT, base_dir, cfg.DATASET.LANDMARK_FILE)
    landmark_df = pd.read_csv(landmark_path, index_col="Subject").loc[patient_ids]  # read .csv file as dataframe
    landmarks = landmark_df.iloc[:, :-1].values
    y = landmark_df["Group"].values
    y[np.where(y != 0)] = 1  # convert to binary classification problem, i.e. no PH vs PAH

    # plot the first phase of images with landmarks
    marker_names = list(landmark_df.columns[1::2])
    markers = []
    for marker in marker_names:
        marker_name = marker.split(" ")
        marker_name.pop(-1)
        marker_name = " ".join(marker_name)
        markers.append(marker_name)

    if save_figs:
        n_img_per_fig = 45
        n_figures = int(n_samples / n_img_per_fig) + 1
        for k in range(n_figures):
            visualize.plot_multi_images(
                [images[i][0, ...] for i in range(k * n_img_per_fig, min((k + 1) * n_img_per_fig, n_samples))],
                marker_locs=landmarks[k * n_img_per_fig : min((k + 1) * n_img_per_fig, n_samples), :],
                im_kwargs=dict(cfg.PLT_KWS.IM),
                marker_cmap="Set1",
                marker_kwargs=dict(cfg.PLT_KWS.MARKER),
                marker_titles=markers,
                image_titles=list(patient_ids[k * n_img_per_fig : min((k + 1) * n_img_per_fig, n_samples)]),
                n_cols=5,
            ).savefig(
                str(save_figures_location) + "/0)landmark_visualization_%s_of_%s.%s" % (k + 1, n_figures, fig_format),
                **dict(cfg.SAVE_FIG_KWARGS),
            )

    # ---- data pre-processing ----
    # ----- image registration -----
    img_reg, max_dist = reg_img_stack(images.copy(), landmarks, landmarks[0])
    plt_kawargs = {**{"im_kwargs": dict(cfg.PLT_KWS.IM), "image_titles": list(patient_ids)}, **dict(cfg.PLT_KWS.PLT)}
    if save_figs:
        visualize.plot_multi_images([img_reg[i][0, ...] for i in range(n_samples)], **plt_kawargs).savefig(
            str(save_figures_location) + "/1)image_registration.%s" % fig_format, **dict(cfg.SAVE_FIG_KWARGS)
        )

    # ----- masking -----
    img_masked = mask_img_stack(img_reg.copy(), mask)
    if save_figs:
        visualize.plot_multi_images([img_masked[i][0, ...] for i in range(n_samples)], **plt_kawargs).savefig(
            str(save_figures_location) + "/2)masking.%s" % fig_format, **dict(cfg.SAVE_FIG_KWARGS)
        )

    # ----- resize -----
    img_rescaled = rescale_img_stack(img_masked.copy(), scale=1 / cfg.PROC.SCALE)
    if save_figs:
        visualize.plot_multi_images([img_rescaled[i][0, ...] for i in range(n_samples)], **plt_kawargs).savefig(
            str(save_figures_location) + "/3)resize.%s" % fig_format, **dict(cfg.SAVE_FIG_KWARGS)
        )

    # ----- normalization -----
    img_norm = normalize_img_stack(img_rescaled.copy())
    if save_figs:
        visualize.plot_multi_images([img_norm[i][0, ...] for i in range(n_samples)], **plt_kawargs).savefig(
            str(save_figures_location) + "/4)normalize.%s" % fig_format, **dict(cfg.SAVE_FIG_KWARGS)
        )

    # ---- evaluating machine learning pipeline ----
    x = np.concatenate([img_norm[i].reshape((1,) + img_norm[i].shape) for i in range(n_samples)], axis=0)
    trainer = MPCATrainer(classifier=cfg.PIPELINE.CLASSIFIER, n_features=200)
    cv_results = cross_validate(trainer, x, y, cv=10, scoring=["accuracy", "roc_auc"], n_jobs=1)

    print("Averaged training time: {:.4f} seconds".format(np.mean(cv_results["fit_time"])))
    print("Averaged testing time: {:.4f} seconds".format(np.mean(cv_results["score_time"])))
    print("Averaged Accuracy: {:.4f}".format(np.mean(cv_results["test_accuracy"])))
    print("Averaged AUC: {:.4f}".format(np.mean(cv_results["test_roc_auc"])))

    # ---- model weights interpretation ----
    trainer.fit(x, y)

    weights = trainer.mpca.inverse_transform(trainer.clf.coef_) - trainer.mpca.mean_
    weights = rescale_img_stack(weights, cfg.PROC.SCALE)  # rescale weights to original shape
    weights = mask_img_stack(weights, mask)  # masking weights
    top_weights = model_weights.select_top_weight(weights, select_ratio=0.02)  # select top 2% weights
    if save_figs:
        visualize.plot_weights(
            top_weights[0][0],
            background_img=images[0][0],
            im_kwargs=dict(cfg.PLT_KWS.IM),
            marker_kwargs=dict(cfg.PLT_KWS.WEIGHT),
        ).savefig(str(save_figures_location) + "/5)weights.%s" % fig_format, **dict(cfg.SAVE_FIG_KWARGS))
Exemplo n.º 11
0
def gait(download_path):
    download_file_by_url(gait_url, download_path, "gait.mat", "mat")
    return loadmat(os.path.join(download_path, "gait.mat"))
Exemplo n.º 12
0
def test_get_source_target(source_cfg, target_cfg, valid_ratio, weight_type,
                           datasize_type, testing_cfg, class_subset):
    source_name, source_n_class, source_trainlist, source_testlist = source_cfg.split(
        ";")
    target_name, target_n_class, target_trainlist, target_testlist = target_cfg.split(
        ";")
    n_class = eval(min(source_n_class, target_n_class))

    # get cfg parameters
    cfg = testing_cfg
    cfg.DATASET.SOURCE = source_name
    cfg.DATASET.SRC_TRAINLIST = source_trainlist
    cfg.DATASET.SRC_TESTLIST = source_testlist
    cfg.DATASET.TARGET = target_name
    cfg.DATASET.TGT_TRAINLIST = target_trainlist
    cfg.DATASET.TGT_TESTLIST = target_testlist
    cfg.DATASET.WEIGHT_TYPE = weight_type
    cfg.DATASET.SIZE_TYPE = datasize_type

    download_file_by_url(
        url=url,
        output_directory=str(Path(cfg.DATASET.ROOT).parent.absolute()),
        output_file_name="video_test_data.zip",
        file_format="zip",
    )

    # test get_source_target
    source, target, num_classes = VideoDataset.get_source_target(
        VideoDataset(source_name), VideoDataset(target_name), seed, cfg)

    assert num_classes == n_class
    assert isinstance(source, dict)
    assert isinstance(target, dict)
    assert isinstance(source["rgb"], VideoDatasetAccess)
    assert isinstance(target["rgb"], VideoDatasetAccess)
    assert isinstance(source["flow"], VideoDatasetAccess)
    assert isinstance(target["flow"], VideoDatasetAccess)

    # test get_train & get_test
    assert isinstance(source["rgb"].get_train(), torch.utils.data.Dataset)
    assert isinstance(source["rgb"].get_test(), torch.utils.data.Dataset)
    assert isinstance(source["flow"].get_train(), torch.utils.data.Dataset)
    assert isinstance(source["flow"].get_test(), torch.utils.data.Dataset)

    # test get_train_valid
    train_valid = source["rgb"].get_train_valid(valid_ratio)
    assert isinstance(train_valid, list)
    assert isinstance(train_valid[0], torch.utils.data.Dataset)
    assert isinstance(train_valid[1], torch.utils.data.Dataset)

    # test action_multi_domain_datasets
    dataset = VideoMultiDomainDatasets(
        source,
        target,
        image_modality=cfg.DATASET.IMAGE_MODALITY,
        seed=seed,
        config_weight_type=cfg.DATASET.WEIGHT_TYPE,
        config_size_type=cfg.DATASET.SIZE_TYPE,
    )
    assert isinstance(dataset, DomainsDatasetBase)

    # test class subsets
    if source_cfg == SOURCES[1] and target_cfg == TARGETS[0]:
        dataset_subset = VideoMultiDomainDatasets(
            source,
            target,
            image_modality="rgb",
            seed=seed,
            config_weight_type=cfg.DATASET.WEIGHT_TYPE,
            config_size_type=cfg.DATASET.SIZE_TYPE,
            class_ids=class_subset,
        )

        train, valid = source["rgb"].get_train_valid(valid_ratio)
        test = source["rgb"].get_test()
        dataset_subset._rgb_source_by_split = {}
        dataset_subset._rgb_target_by_split = {}
        dataset_subset._rgb_source_by_split["train"] = get_class_subset(
            train, class_subset)
        dataset_subset._rgb_target_by_split[
            "train"] = dataset_subset._rgb_source_by_split["train"]
        dataset_subset._rgb_source_by_split["valid"] = get_class_subset(
            valid, class_subset)
        dataset_subset._rgb_source_by_split["test"] = get_class_subset(
            test, class_subset)

        # Ground truth length of the subset dataset
        train_dataset_subset_length = len(
            [1 for data in train if data[1] in class_subset])
        valid_dataset_subset_length = len(
            [1 for data in valid if data[1] in class_subset])
        test_dataset_subset_length = len(
            [1 for data in test if data[1] in class_subset])
        assert len(dataset_subset._rgb_source_by_split["train"]
                   ) == train_dataset_subset_length
        assert len(dataset_subset._rgb_source_by_split["valid"]
                   ) == valid_dataset_subset_length
        assert len(dataset_subset._rgb_source_by_split["test"]
                   ) == test_dataset_subset_length
        assert len(dataset_subset) == train_dataset_subset_length
Exemplo n.º 13
0
def test_video_domain_adapter(source_cfg, target_cfg, image_modality,
                              da_method, testing_cfg, testing_training_cfg):
    source_name, source_n_class, source_trainlist, source_testlist = source_cfg.split(
        ";")
    target_name, target_n_class, target_trainlist, target_testlist = target_cfg.split(
        ";")

    # get cfg parameters
    cfg = testing_cfg
    cfg.DATASET.SOURCE = source_name
    cfg.DATASET.SRC_TRAINLIST = source_trainlist
    cfg.DATASET.SRC_TESTLIST = source_testlist
    cfg.DATASET.TARGET = target_name
    cfg.DATASET.TGT_TRAINLIST = target_trainlist
    cfg.DATASET.TGT_TESTLIST = target_testlist
    cfg.DATASET.IMAGE_MODALITY = image_modality
    cfg.DATASET.WEIGHT_TYPE = WEIGHT_TYPE
    cfg.DATASET.SIZE_TYPE = DATASIZE_TYPE
    cfg.DAN.USERANDOM = False

    # download example data
    download_file_by_url(
        url=url,
        output_directory=str(Path(cfg.DATASET.ROOT).parent.absolute()),
        output_file_name="video_test_data.zip",
        file_format="zip",
    )

    # build dataset
    source, target, num_classes = VideoDataset.get_source_target(
        VideoDataset(source_name), VideoDataset(target_name), seed, cfg)

    dataset = VideoMultiDomainDatasets(
        source,
        target,
        image_modality=cfg.DATASET.IMAGE_MODALITY,
        seed=seed,
        config_weight_type=cfg.DATASET.WEIGHT_TYPE,
        config_size_type=cfg.DATASET.SIZE_TYPE,
    )

    # setup feature extractor
    if cfg.DATASET.IMAGE_MODALITY in ["rgb", "flow"]:
        class_feature_dim = 1024
        domain_feature_dim = class_feature_dim
        if cfg.DATASET.IMAGE_MODALITY == "rgb":
            feature_network = {"rgb": VideoBoringModel(3), "flow": None}
        else:
            feature_network = {"rgb": None, "flow": VideoBoringModel(2)}
    else:
        class_feature_dim = 2048
        domain_feature_dim = int(class_feature_dim / 2)
        feature_network = {
            "rgb": VideoBoringModel(3),
            "flow": VideoBoringModel(2)
        }

    # setup classifier
    classifier_network = ClassNetVideo(input_size=class_feature_dim,
                                       n_class=num_classes)
    train_params = testing_training_cfg["train_params"]
    method_params = {}
    method = domain_adapter.Method(da_method)

    # setup DA method
    if method.is_mmd_method():
        model = video_domain_adapter.create_mmd_based_video(
            method=method,
            dataset=dataset,
            image_modality=cfg.DATASET.IMAGE_MODALITY,
            feature_extractor=feature_network,
            task_classifier=classifier_network,
            **method_params,
            **train_params,
        )
    else:
        critic_input_size = domain_feature_dim
        # setup critic network
        if method.is_cdan_method():
            if cfg.DAN.USERANDOM:
                critic_input_size = 1024
            else:
                critic_input_size = domain_feature_dim * num_classes
        critic_network = DomainNetVideo(input_size=critic_input_size)

        if da_method == "CDAN":
            method_params["use_random"] = cfg.DAN.USERANDOM

        model = video_domain_adapter.create_dann_like_video(
            method=method,
            dataset=dataset,
            image_modality=cfg.DATASET.IMAGE_MODALITY,
            feature_extractor=feature_network,
            task_classifier=classifier_network,
            critic=critic_network,
            **method_params,
            **train_params,
        )

    ModelTestHelper.test_model(model, train_params)