def run_test(folder_path, override_dict, test_path, snapshot_iter, is_large,
             save_img_data):

    print("Folder path: %s" % folder_path)

    with open(os.path.join(folder_path, "PARAM.p"), 'rb') as f:
        opt0 = pickle.load(f)

    # opt = {**opt0, **override_dict}
    opt = recursive_merge_dicts(opt0, override_dict)

    vp = Pipeline(None,
                  opt,
                  model_dir=folder_path,
                  auto_save_hyperparameters=False,
                  use_logging=False)

    print(vp.opt)
    with vp.graph.as_default():
        sess = vp.create_session()
        vp.run_full_test_from_checkpoint(sess,
                                         test_path=test_path,
                                         snapshot_iter=snapshot_iter,
                                         is_large=is_large,
                                         save_img_data=save_img_data)
Exemple #2
0
        "learning_rate": learning_rate,
        "max_epochs": 2000,
        "weight_decay": 1e-6,
        "test_steps": 5000,
        "test_limit": 200,
        "recon_weight": recon_weight,
    }
    opt["encoder_options"] = {
        "keypoint_num": num_keypoints,
        "patch_feature_dim": patch_feature_dim,
        "ae_recon_type": opt["recon_name"],
        "keypoint_concentration_loss_weight": 100.0,
        "keypoint_axis_balancing_loss_weight": 200.0,
        "keypoint_separation_loss_weight": keypoint_separation_loss_weight,
        "keypoint_separation_bandwidth": keypoint_separation_bandwidth,
        "keypoint_transform_loss_weight": kp_transform_loss,
        "keypoint_decoding_heatmap_levels": decoding_levels,
        "keypoint_decoding_heatmap_level_base": 0.5 ** (1 / 2),
        "image_channels": 3,
    }
    opt["decoder_options"] = copy(opt["encoder_options"])

    # -------------------------------------
    model_dir = os.path.join("results/exercise_25")
    vp = Pipeline(None, opt, model_dir=model_dir)
    print(vp.opt)
    with vp.graph.as_default():
        sess = vp.create_session()
        vp.run_full_train(sess, restore=True)
        vp.run_full_test(sess)