Beispiel #1
0
def run_on_validation(net_heat):
    # Run on validation set
    save_path = f.checkpoint_heat_path
    net_heat.load_state_dict(torch.load(save_path))
    test_data_loader = load_utils.test_loader(1, load_angle=True)
    avg_smape = []
    for step in range(128):
        test_imgs, test_labels, test_angles = next(test_data_loader)
        test_imgs_f = np.asarray(test_imgs, np.float32)[:, np.newaxis, :, :]
        test_imgs_01 = test_imgs_f / 255.0
        test_imgs_tensor = torch.from_numpy(test_imgs_01).cuda()
        with torch.no_grad():
            out_pcm, out_paf, _, _ = net_heat(test_imgs_tensor)  # NCHW
        np_pcm = out_pcm.detach().cpu().numpy()
        np_paf = out_paf.detach().cpu().numpy()

        cobb_dict = cap.cobb_angles(np_pcm[0, 4:6],
                                    np_paf[0],
                                    test_imgs[0],
                                    np_pcm[0, 6],
                                    use_filter=False)
        pred_angles, pairs_img, pairs_lr_value = cobb_dict[
            "angles"], cobb_dict["pairs_img"], cobb_dict["pair_lr_value"]
        smape = cap.SMAPE(pred_angles, test_angles[0])
        avg_smape.append(smape)
        print(step, smape)
        print(pred_angles - test_angles[0])
        cap.cvsave(pairs_img, "{}".format(step))
        print("end-----------------------------")
    print("SMAPE:", np.mean(avg_smape))
Beispiel #2
0
def parse_cobb_angle_by_annotated_points():
    # Use annotated corner points to parse cobb angle
    # so as to test cobb_angle_parser
    import confidence_map as cm
    test_data_loader = load_utils.test_loader(1, load_angle=True)
    avg_smape = []
    counter_isS = 0
    counter_notS = 0
    for step in range(128):
        test_imgs, test_labels, test_angles = next(test_data_loader)
        # gt_a1, gt_a2, gt_a3 = test_angles[0]
        # gt center points
        # [lr][N][17(joint)][xy]
        l_bcs, r_bcs = cm.ConfidenceMap()._find_LCenter_RCenter(test_labels)
        gt_lc, gt_rc = l_bcs[0], r_bcs[0]
        pair_lr_value = gt_lc, gt_rc

        # -----------------------------Use angle_parse from here
        # Sort pairs by y
        pair_lr_value = cap.sort_pairs_by_y(pair_lr_value)
        # Use sigma of x, interval, length to delete wrong pairs
        # pair_lr_value = rbf.simple_filter(pair_lr_value)
        # rbf_dict = rbf.filter(pair_lr_value)
        # pair_lr_value = rbf_dict["pair_lr_value"]
        # pair_lr_value = reduce_redundant_paris(pair_lr_value)
        # [p_len][xy] vector coordinates. (sorted by bone confidence, not up to bottom)
        bones = cap.bone_vectors(pair_lr_value)
        # Index1(higher), index2(lower) of max angle; a1: max angle value
        max_ind1, max_ind2, a1 = cap.max_angle_indices(bones, pair_lr_value)

        hmids = (pair_lr_value[0] + pair_lr_value[1]) / 2
        if not cap.isS(hmids):
            a2 = np.rad2deg(np.arccos(cap.cos_angle(
                bones[max_ind1], bones[0])))  # Use first bone
            a3 = np.rad2deg(
                np.arccos(cap.cos_angle(bones[max_ind2], bones[-1]))
            )  # Note: use last bone on submit test set gains better results

        # print(max_ind1, max_ind2)
        else:  # isS
            a2, a3 = cap.handle_isS_branch(pair_lr_value, max_ind1, max_ind2,
                                           test_imgs[0].shape[0])
        sub = np.array([a1, a2, a3]) - test_angles[0]
        print(step)
        print(sub)
        print("------------end---------------")
Beispiel #3
0
                        default=5,
                        type=int,
                        required=False,
                        help="batch size")
    parser.add_argument("--trainval", action='store_true', default=False)
    args = parser.parse_args()
    batch_size = args.s
    if args.trainval:  # Final training, use train and val set
        train_data_loader = load_utils.train_loader(batch_size,
                                                    load_angle=True,
                                                    use_trainval=True)
        print("--- Using [train, val] set as training set!")
    else:
        train_data_loader = load_utils.train_loader(batch_size,
                                                    load_angle=True)
    test_data_loader = load_utils.test_loader(batch_size, load_angle=True)

    net_heat = part_affinity_field_net.SpineModelPAF()
    net_heat.cuda()
    net_heat.eval()
    net_angle = part_affinity_field_net.CobbAngleModel()
    net_angle.cuda()

    # Load heatmap network checkpoint
    save_path_heat = f.checkpoint_heat_trainval_path if args.trainval else f.checkpoint_heat_path
    if path.exists(save_path_heat):
        net_heat.load_state_dict(torch.load(save_path_heat))
    else:
        raise FileNotFoundError(
            "Heatmap model checkpoint not found: {}.".format(save_path_heat))
Beispiel #4
0
    os.makedirs(f.train_results, exist_ok=True)
    os.makedirs(f.checkpoint, exist_ok=True)

    net = ladder_shufflenet.LadderModelAdd()

    if not torch.cuda.is_available():
        raise RuntimeError("GPU not available")
    batch_size = args.s
    print("Training with batch size: %d" % batch_size)
    if args.trainval:  # Final training, use train and val set
        train_data_loader = load_utils.train_loader(batch_size,
                                                    use_trainval=True)
        print("--- Using [train, val] set as training set!")
    else:
        train_data_loader = load_utils.train_loader(batch_size)
    test_data_loader = load_utils.test_loader(batch_size)
    device = torch.device("cuda")

    # Load checkpoint
    # If in trainval mode, no "trainval" checkpoint found,
    # and the checkpoint for "train" mode exists,
    # then load the "train" checkpoint for "trainval" training
    if not args.trainval:
        save_path = f.checkpoint_heat_path
        if path.exists(save_path):
            net.load_state_dict(torch.load(save_path))
            print("Model loaded")
        else:
            print("New model created")
    else:  # Trainval mode
        save_path = f.checkpoint_heat_trainval_path