示例#1
0
def _make_video(video_path, imgs):
    """Code used to generate a video using cv2.
    
    Parameters:
    video_path: a path ending with .mp4, for instance: "/results/pose2d.mp4"
    imgs: an iterable or generator with the images to turn into a video
    """

    first_frame = next(imgs)
    imgs = itertools.chain([first_frame], imgs)

    shape = int(first_frame.shape[1]), int(first_frame.shape[0])
    logger.debug('Saving video to: ' + video_path)
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    fps = 30
    output_shape = _resize(current_shape=shape, new_width=video_width)
    logger.debug('Video size is: {}'.format(output_shape))
    video_writer = cv2.VideoWriter(video_path, fourcc, fps, output_shape)

    progress_bar = tqdm if logger.info_enabled() else lambda x: x
    for img in progress_bar(imgs):
        resized = cv2.resize(img, output_shape)
        rgb = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB)
        video_writer.write(rgb)

    video_writer.release()
    logger.info('Video created at {}\n'.format(video_path))
示例#2
0
    def save_pose(self):
        """Saves the pose estimation results to a file in the output folder."""
        pts2d = self.corrected_points2d_matrix()
        dict_merge = self.camNetAll.save_network(path=None)
        pts2d_orig = self.camNetAll.get_points2d_matrix()

        # temporarily incorporate corrected values
        self.camNetAll.set_points2d_matrix(pts2d)
        self.post_process(pts2d)
        dict_merge["points2d"] = pts2d

        if self.camNetLeft.has_calibration() and self.camNetLeft.has_pose():
            self.camNetAll.triangulate()
            pts3d = self.camNetAll.points3d_m
            if config["procrustes_apply"]:
                print("Applying Procrustes on 3D Points")
                pts3d = procrustes_seperate(pts3d)
            dict_merge["points3d"] = pts3d
        else:
            logger.debug("Triangulation skipped.")

        # put uncorrected values back
        self.camNetAll.set_points2d_matrix(pts2d_orig)

        save_path = os.path.join(
            self.output_folder,
            "pose_result_{}.pkl".format(self.input_folder.replace("/", "_")),
        )
        pickle.dump(dict_merge, open(save_path, "wb"))
        print(f"Saved the pose at: {save_path}")
示例#3
0
    def _compute_mean(self):
        meanstd_file = config["mean"]
        if isfile(meanstd_file):
            meanstd = torch.load(meanstd_file)
        else:
            raise FileNotFoundError
            mean = torch.zeros(3)
            std = torch.zeros(3)
            for k in self.annotation_key:
                img_path = os.path.join(self.data_folder, k[FOLDER_NAME],
                                        k[IMAGE_NAME] + ".jpg")
                img = load_image(img_path)  # CxHxW
                mean += img.view(img.size(0), -1).mean(1)
                std += img.view(img.size(0), -1).std(1)
            mean /= len(self)
            std /= len(self)
            meanstd = {"mean": mean, "std": std}
            torch.save(meanstd, meanstd_file)
        if self.is_train:
            logger.debug(
                "    Mean: %.4f, %.4f, %.4f" %
                (meanstd["mean"][0], meanstd["mean"][1], meanstd["mean"][2]))
            logger.debug(
                "    Std:  %.4f, %.4f, %.4f" %
                (meanstd["std"][0], meanstd["std"][1], meanstd["std"][2]))

        return meanstd["mean"], meanstd["std"]
示例#4
0
def find_default_camera_ordering(input_folder):
    """Uses regexes to infer the correct camera ordering based on folder path.

    This is useful for Ramdya's Lab as a given data acquisition agent (say CLC)
    always uses the same camera ordering.

    Parameters:
    input_folder: the folder path on which to run the regexes.
    """

    known_users = [
        (r"/CLC/", [0, 6, 5, 4, 3, 2, 1]),
        (r"data/test", [0, 1, 2, 3, 4, 5, 6]),
    ]
    #
    input_folder = str(input_folder)  # use `str` in case pathlib.Path instance

    def match(regex):
        return re.search(regex, input_folder)

    candidates = [order for (regex, order) in known_users if match(regex)]
    if candidates:
        order = candidates[0]
        logger.debug(f"Default camera ordering found: {order}")
        return np.array(order)
示例#5
0
def load_heatmap(hm_path, shape):
    logger.debug("Heatmap shape: {}".format(shape))
    heatmap = np.memmap(filename=hm_path,
                        mode="r",
                        shape=shape,
                        dtype="float32")

    return heatmap
示例#6
0
def write_camera_order(folder, cidread2cid):
    assert os.path.isdir(
        folder
    ), "Trying to write_camera_order into {}, which is not a folder".format(
        folder)

    path = os.path.join(folder, "cam_order")
    logger.debug("Writing the camera ordering {} into folder {}".format(
        cidread2cid, folder))
    # print("Saving camera order {}: {}".format(path, cidread2cid))
    np.save(path, cidread2cid)
示例#7
0
 def calc_essential_matrix(points2d_1, points2d_2, intr):
     E, mask = cv2.findEssentialMat(
         points1=points2d_1,
         points2=points2d_2,
         cameraMatrix=intr,
         method=cv2.RANSAC,
         prob=0.9999,
         threshold=5,
     )
     logger.debug("Essential matrix inlier ratio: {}".format(
         np.sum(mask) / mask.shape[0]))
     return E, mask
def solve_belief_propagation(cam_list,
                             img_id,
                             bone_param,
                             num_peak=10,
                             prior=None):
    # find all the connected parts
    j_id_list_list = [[
        j for j in range(config["skeleton"].num_joints)
        if config["skeleton"].limb_id[j] == limb_id
    ] for limb_id in range(config["skeleton"].num_limbs)]

    chain_list = list()
    for j_id_l in j_id_list_list:
        visible = np.zeros(shape=(len(j_id_l), ), dtype=np.int)
        for cam in cam_list:
            visible += [
                config["skeleton"].camera_see_joint(cam.cam_id, j_id)
                for j_id in j_id_l
            ]
        if np.all(visible >= 2):
            chain_list.append(
                LegBP(
                    cam_list=cam_list,
                    img_id=img_id,
                    j_id_list=j_id_l,
                    bone_param=bone_param,
                    num_peak=num_peak,
                    prior=prior,
                ))
        else:
            pass
            # logger.debug("Joints {} is not visible from at least two cameras".format(j_id_l))

    logger.debug([[len(leg[i].candid_list) for i in range(len(leg.jointbp))]
                  for leg in chain_list])

    for chain in chain_list:
        chain.propagate()
        chain.solve()

    # read the best 2d locations
    points2d_list = [
        np.zeros((config["skeleton"].num_joints, 2), dtype=float)
        for _ in range(len(cam_list))
    ]
    for leg in chain_list:
        for cam_idx in range(len(cam_list)):
            for idx, j_id in enumerate(leg.j_id_list):
                points2d_list[cam_idx][j_id] = leg[idx][
                    leg[idx].argmin].p2d[cam_idx]

    return points2d_list.copy()
示例#9
0
    def load_network(self, calib):
        d = calib
        if calib is None:
            return None
        for cam in self.cam_list:
            if cam.cam_id in d and d[cam.cam_id]:
                cam.set_R(d[cam.cam_id]["R"])
                cam.set_tvec(d[cam.cam_id]["tvec"])
                cam.set_intrinsic(d[cam.cam_id]["intr"])
                cam.set_distort(d[cam.cam_id]["distort"])
            else:
                logger.debug("Camera {} is not on the calibration file".format(
                    cam.cam_id))

        return d["meta"]
示例#10
0
def process_folder(model, loader, unlabeled, output_folder, overwrite,
                   num_classes, acc_joints):
    save_path_pred, save_path_heatmap = (
        get_save_path_pred(unlabeled, output_folder),
        get_save_path_heatmap(unlabeled, output_folder),
    )

    if os.path.isfile(save_path_pred) and not overwrite:
        logger.info("Prediction file exists, skipping pose estimation")
        return None, None
    elif os.path.isfile(save_path_pred) and overwrite:
        logger.info("Overwriting existing predictions")

    save_path_heatmap.parent.mkdir(exist_ok=True, parents=True)
    save_path_pred.parent.mkdir(exist_ok=True, parents=True)

    logger.debug(f"creaint heatmap path: {save_path_heatmap}")
    heatmap = np.memmap(
        filename=save_path_heatmap,
        dtype="float32",
        mode="w+",
        shape=(
            config["num_cameras"] + 1,
            loader.dataset.greatest_image_id() + 1,
            config["num_predict"],
            config["heatmap_shape"][0],
            config["heatmap_shape"][1],
        ),
    )  # num_cameras+1 for the mirrored camera 3
    logger.debug(f"creating heatmap shape: {heatmap.shape}")

    pred, heatmap, _, _, _ = step(
        loader=loader,
        model=model,
        optimizer=None,
        mode=Mode.test,
        heatmap=heatmap,
        epoch=0,
        num_classes=num_classes,
        acc_joints=acc_joints,
    )

    _, cid2cidread = read_camera_order(
        get_output_path(unlabeled, output_folder))
    cid_to_reverse = config["flip_cameras"]
    cid_read_to_reverse = [cid2cidread[cid] for cid in cid_to_reverse]

    pred = flip_pred(pred, cid_read_to_reverse)
    logger.debug("Flipping heatmaps")
    heatmap = flip_heatmap(heatmap, cid_read_to_reverse)
    logger.debug("Flipping heatmaps")

    save_dict(pred, save_path_pred)
    if type(heatmap) != np.memmap:
        save_dict(heatmap, save_path_heatmap)

    print(pred.shape)
    return pred, heatmap
示例#11
0
    def calibrate(self, cam_id_list=None):
        assert self.cam_list
        ignore_joint_list = config["skeleton"].ignore_joint_id
        if cam_id_list is None:
            cam_id_list = range(self.num_cameras)

        self.reprojection_error()
        (
            x0,
            points_2d,
            n_cameras,
            n_points,
            camera_indices,
            point_indices,
        ) = self.prepare_bundle_adjust_param(cam_id_list)
        logger.debug(f"Number of points for calibration: {n_points}")
        A = bundle_adjustment_sparsity(
            n_cameras, n_points, camera_indices, point_indices
        )
        res = least_squares(
            residuals,
            x0,
            jac_sparsity=A,
            verbose=2 if logger.debug_enabled() else 0,
            x_scale="jac",
            ftol=1e-4,
            method="trf",
            args=(
                self.cam_list,
                n_cameras,
                n_points,
                camera_indices,
                point_indices,
                points_2d,
            ),
            max_nfev=1000,
        )

        logger.debug(
            "Bundle adjustment, Average reprojection error: {}".format(
                np.mean(np.abs(res.fun))
            )
        )

        self.triangulate()
        return res
示例#12
0
def get_max_img_id(path):
    print(path)
    bound_low = 0
    bound_high = 100000

    curr = (bound_high + bound_low) // 2
    while bound_high - bound_low > 1:
        if image_exists_img_id(path, curr):
            bound_low = curr
        else:
            bound_high = curr
        curr = (bound_low + bound_high) // 2

    if not image_exists_img_id(path, curr):
        logger.debug("Cannot find image at {} with img_id {}".format(path, curr))
        raise FileNotFoundError("No image found.")

    return curr
示例#13
0
    def reprojection_error(self, cam_indices=None, ignore_joint_list=None):
        if ignore_joint_list is None:
            ignore_joint_list = config["skeleton"].ignore_joint_id
        if cam_indices is None:
            cam_indices = range(len(self.cam_list))

        err_list = list()
        for (img_id, j_id, _), _ in np.ndenumerate(self.points3d_m):
            p3d = self.points3d_m[img_id, j_id].reshape(1, 3)
            if j_id in ignore_joint_list:
                continue
            for cam in self.cam_list:
                if not config["skeleton"].camera_see_joint(cam.cam_id, j_id):
                    continue
                err_list.append((cam.project(p3d) - cam[img_id, j_id]).ravel())

        err_mean = np.mean(np.abs(err_list))
        logger.debug("Ignore_list {}:  {:.4f}".format(ignore_joint_list,
                                                      err_mean))
        return err_list
示例#14
0
    def reprojection_error(self):
        ignore_joint_list = config["skeleton"].ignore_joint_id
        s = self.points3d.shape
        err = np.zeros((len(self.cam_list), s[0], s[1]))
        for (img_id, j_id, cam_idx) in product(
            range(s[0]), range(s[1]), range(len(self.cam_list))
        ):
            p3d = self.points3d[img_id, j_id].reshape(1, 3)
            if (
                config["skeleton"].camera_see_joint(self[cam_idx].cam_id, j_id)
                and not j_id in ignore_joint_list
                and not np.any(self.points3d[img_id, j_id, :] == 0)
                and not np.any(self[cam_idx][img_id, j_id, :] == 0)
            ):
                err[cam_idx, img_id, j_id] = np.sum(
                    np.abs(self[cam_idx].project(p3d) - self[cam_idx][img_id, j_id])
                )

        err_mean = np.mean(np.abs(err))
        logger.debug("Ignore_list {}:  {:.4f}".format(ignore_joint_list, err_mean))
        return err
示例#15
0
def read_camera_order(folder):
    assert os.path.isdir(
        folder
    ), "Trying to call read_camera_order on {}, which is not a folder".format(
        folder)

    path = os.path.join(folder, "./cam_order.npy")
    if os.path.isfile(path):
        order = np.load(file=path, allow_pickle=True)
    else:
        order = np.arange(config["num_cameras"])
        write_camera_order(folder, order)
        logger.debug(
            "Could not find camera order under {}. Writing the default ordering {}."
            .format(folder, order))

    cidread2cid = order.copy()
    cid2cidread = np.zeros(cidread2cid.size, dtype=int)
    for cidread, cid in enumerate(cidread2cid):
        cid2cidread[cid] = cidread

    return cidread2cid, cid2cidread
示例#16
0
    def __init__(
        self,
        image_folder,
        output_folder,
        num_images=900,
        cam_list=None,
        cam_id_list=range(config["num_cameras"]),
        cid2cidread=None,
        pred=None,
    ):
        self.image_folder = image_folder
        self.output_folder = output_folder
        self.points3d = None
        self.num_images = num_images
        self.num_cameras = len(cam_id_list)

        self.cid2cidread = (
            cid2cidread
            if cid2cidread is not None
            else read_camera_order(self.output_folder)[0]
        )

        if cam_list:
            logger.debug("Camera list is already given, skipping loading.")
            self.cam_list = cam_list
            return

        self.cam_list = list()
        pred_path = find_pred_path(self.output_folder)
        if pred_path is not None:
            logger.debug("no pred file under {}".format(self.output_folder))
            pred = np.load(file=pred_path, mmap_mode="r", allow_pickle=True)
            pred = pred[:, : self.num_images]
        else:
            pred = None

        for cam_id in cam_id_list:
            cam_id_read = self.cid2cidread[cam_id]
            self.cam_list.append(
                Camera(
                    cid=cam_id,
                    cid_read=cam_id_read,
                    image_folder=image_folder,
                    points2d=pred2points2d(
                        pred, cam_id, cam_id_read, config["image_shape"]
                    )
                    if pred is not None
                    else pred,
                    hm=None,
                )
            )

        calibration_path = find_calib_path(self.output_folder)
        if calibration_path is not None:
            calibration = np.load(file=calibration_path, allow_pickle=True)
            logger.debug("Reading calibration from {}".format(self.output_folder))
            _ = self.load_network(calibration)
示例#17
0
def read_manual_corrections(d, output_folder, manual_path_list,
                            cidread2cid_global, num_classes):
    pose_corr_path_list = []
    for root in manual_path_list:
        logger.debug("Searching recursively: {}".format(root))
        pose_corr_path_list.extend(find_pose_corr_recursively(root))
    logger.debug("Number of manual correction files: {}".format(
        len(pose_corr_path_list)))
    for path in pose_corr_path_list:
        d = pickle.load(open(path, "rb"))
        folder_name = d["folder"]
        key_folder_name = folder_name
        if folder_name not in cidread2cid_global:
            cam_folder = os.path.join(folder_name, output_folder)
            cidread2cid, cid2cidread = read_camera_order(cam_folder)
            cidread2cid_global[key_folder_name] = cidread2cid
        for cid in range(config["num_cameras"]):
            for img_id, points2d in d[cid].items():
                cid_read = cidread2cid[key_folder_name].tolist().index(cid)
                key = (key_folder_name, constr_img_name(cid_read, img_id))
                num_heatmaps = points2d.shape[0]

                pts = np.zeros((2 * num_classes, 2), dtype=np.float)
                if cid < 3:
                    pts[:num_heatmaps // 2, :] = points2d[:num_heatmaps //
                                                          2, :]
                elif 3 < cid < 7:
                    pts[num_classes:num_classes +
                        (num_heatmaps // 2), :] = points2d[num_heatmaps //
                                                           2:, :]
                elif cid == 3:
                    continue
                else:
                    raise NotImplementedError

                d[key] = pts
示例#18
0
def load_weights(model, resume: str):
    if isfile(resume):
        logger.debug("Loading checkpoint '{}'".format(resume))
        checkpoint = (torch.load(resume) if torch.cuda.is_available() else
                      torch.load(resume, map_location=torch.device("cpu")))

        if "mpii" in resume:  # weights for sh trained on mpii dataset
            logger.debug("Removing input/output layers")
            ignore_weight_list_template = [
                "module.score.{}.bias",
                "module.score.{}.weight",
                "module.score_.{}.weight",
            ]
            ignore_weight_list = list()
            for i in range(8):
                for template in ignore_weight_list_template:
                    ignore_weight_list.append(template.format(i))
            for k in ignore_weight_list:
                if k in checkpoint["state_dict"]:
                    checkpoint["state_dict"].pop(k)

            state = model.state_dict()
            state.update(checkpoint["state_dict"])
            logger.debug(model.state_dict())
            logger.debug(checkpoint["state_dict"])
            model.load_state_dict(state, strict=False)
        else:
            pretrained_dict = checkpoint["state_dict"]
            model.load_state_dict(pretrained_dict, strict=False)

        logger.debug("Loaded checkpoint '{}' (epoch {})".format(
            resume, checkpoint["epoch"]))
    else:
        logger.debug("=> no checkpoint found at '{}'".format(resume))
        raise FileNotFoundError
示例#19
0
def main(args):
    logger.debug("Creating model '{}', stacks={}, blocks={}".format(
        args.arch, args.stacks, args.blocks))
    model = hg(
        num_stacks=args.stacks,
        num_blocks=args.blocks,
        num_classes=args.num_classes,
        num_feats=args.features,
        inplanes=args.inplanes,
        init_stride=args.stride,
    )
    model = on_cuda(torch.nn.DataParallel(model))
    optimizer = torch.optim.RMSprop(
        model.parameters(),
        lr=args.lr,
        momentum=args.momentum,
        weight_decay=args.weight_decay,
    )
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           verbose=True,
                                                           patience=5)

    if args.resume:
        load_weights(model, args.resume)

    logger.debug("Total params: %.2fM" %
                 (sum(p.numel() for p in model.parameters()) / 1000000.0))

    if args.unlabeled:
        loader = DataLoader(
            DrosophilaDataset(
                data_folder=args.data_folder,
                train=False,
                sigma=args.sigma,
                session_id_train_list=None,
                folder_train_list=None,
                img_res=args.img_res,
                hm_res=args.hm_res,
                augmentation=False,
                evaluation=True,
                unlabeled=args.unlabeled,
                num_classes=args.num_classes,
                max_img_id=min(get_max_img_id(args.unlabeled),
                               args.max_img_id),
                output_folder=args.output_folder,
            ),
            batch_size=args.test_batch,
            shuffle=False,
            num_workers=args.workers,
            pin_memory=False,
            drop_last=False,
        )
        pred, heatmap = process_folder(
            model,
            loader,
            args.unlabeled,
            args.output_folder,
            args.overwrite,
            num_classes=args.num_classes,
            acc_joints=args.acc_joints,
        )
        return pred, heatmap
    else:
        train_loader, val_loader = create_dataloader()
        lr = args.lr
        best_acc = 0
        for epoch in range(args.start_epoch, args.epochs):
            logger.debug("\nEpoch: %d | LR: %.8f" % (epoch + 1, lr))

            _, _, _, _, _ = step(
                loader=train_loader,
                model=model,
                optimizer=optimizer,
                mode=Mode.train,
                heatmap=False,
                epoch=epoch,
                num_classes=args.num_classes,
                acc_joints=args.acc_joints,
            )
            val_pred, _, val_loss, val_acc, val_mse = step(
                loader=val_loader,
                model=model,
                optimizer=optimizer,
                mode=Mode.test,
                heatmap=False,
                epoch=epoch,
                num_classes=args.num_classes,
                acc_joints=args.acc_joints,
            )
            scheduler.step(val_loss)
            is_best = val_acc > best_acc
            best_acc = max(val_acc, best_acc)
            save_checkpoint(
                {
                    "epoch": epoch + 1,
                    "arch": args.arch,
                    "state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "best_acc": best_acc,
                    "optimizer": optimizer.state_dict(),
                    "image_shape": args.img_res,
                    "heatmap_shape": args.hm_res,
                },
                val_pred,
                is_best,
                checkpoint=args.checkpoint,
                snapshot=args.snapshot,
            )
示例#20
0
    def __init__(
        self,
        data_folder,
        data_corr_folder=None,
        img_res=None,
        hm_res=None,
        train=True,
        sigma=1,
        jsonfile="drosophilaimaging-export.json",
        session_id_train_list=None,
        folder_train_list=None,
        augmentation=False,
        evaluation=False,
        unlabeled=None,
        num_classes=config["num_predict"],
        max_img_id=None,
        output_folder=None,
    ):
        self.train = train
        self.data_folder = data_folder  # root image folders
        self.data_corr_folder = data_corr_folder
        self.json_file = os.path.join(jsonfile)
        self.is_train = train  # training set or test set
        self.img_res = img_res
        self.hm_res = hm_res
        self.sigma = sigma
        self.augmentation = augmentation
        self.evaluation = evaluation
        self.unlabeled = unlabeled
        self.num_classes = num_classes
        self.max_img_id = max_img_id
        self.cidread2cid = dict()
        self.output_folder = output_folder
        self.manual_path_list = []
        self.session_id_train_list = session_id_train_list
        self.folder_train_list = folder_train_list
        if self.output_folder is None:
            raise ValueError(
                "Please provide an output_folder relative to images folder")

        assert (not self.evaluation
                or not self.augmentation)  # self eval then not augmentation
        assert not self.unlabeled or evaluation  # if unlabeled then evaluation

        self.annotation_dict = dict()

        if isfile(self.json_file) and not self.unlabeled:
            logger.debug("Searching for json file")
            read_json(
                self.annotation_dict,
                self.json_file,
                self.folder_train_list,
                self.cidread2cid,
            )

        if not self.unlabeled and self.train and self.manual_path_list:
            logger.debug("Searching for manual corrections")
            read_manual_corrections(
                self.annotation_dict,
                self.output_folder,
                self.manual_path_list,
                self.cidread2cid,
                self.num_classes,
            )

        if self.unlabeled:
            logger.debug("Searching unlabeled")
            read_unlabeled_folder(
                self.annotation_dict,
                self.unlabeled,
                self.output_folder,
                self.cidread2cid,
                self.max_img_id,
            )

        # make sure data is in the folder
        for folder_name, image_name in self.annotation_dict.copy().keys():
            image_file = os.path.join(
                self.data_folder,
                folder_name.replace("_network", ""),
                image_name + ".jpg",
            )

            if not os.path.isfile(image_file):
                self.annotation_dict.pop((folder_name, image_name), None)
                print("FileNotFound: {}/{} ".format(folder_name, image_name))

        normalize_annotations(self.annotation_dict, self.num_classes,
                              self.cidread2cid)

        self.annotation_key = list(self.annotation_dict.keys())
        if self.evaluation:  # sort keys
            self.annotation_key.sort(key=lambda x: x[0] + "_" + x[1].split("_")
                                     [3] + "_" + x[1].split("_")[1])

        self.mean, self.std = self._compute_mean()

        logger.debug("Folders inside {} data: {}".format(
            "train" if self.train else "validation",
            set([k[0] for k in self.annotation_key]),
        ))
        logger.debug(
            "Successfully imported {} Images in Drosophila Dataset".format(
                len(self)))
示例#21
0
            mse_tarsus=am["tarsus_tip"].avg,
        )
        bar.next()

    bar.finish()
    return (predictions, heatmap, am["losses"].avg, am["acces"].avg,
            am["mse"].avg)


if __name__ == "__main__":
    # prepare
    parser = create_parser()
    args = parser.parse_args()

    args.train_joints = np.arange(args.num_classes)
    logger.debug(f"Training joints: {args.train_joints}")
    logger.debug(f"Acc joints: {args.acc_joints}")

    args.checkpoint = (args.checkpoint.replace(" ", "_").replace("(",
                                                                 "_").replace(
                                                                     ")", "_"))
    args.checkpoint = os.path.join(
        args.checkpoint,
        get_time() + "_{}_{}_{}_{}_{}_{}".format(
            "predict" if args.unlabeled else "training",
            args.arch,
            args.stacks,
            args.img_res,
            args.blocks,
            args.name,
        ),
示例#22
0
    def prepare_bundle_adjust_param(self,
                                    camera_id_list=None,
                                    ignore_joint_list=None,
                                    max_num_images=1000):
        logger.debug(
            "Calibration ignore joint list {}".format(ignore_joint_list))
        if ignore_joint_list is None:
            ignore_joint_list = config["skeleton"].ignore_joint_id
        if camera_id_list is None:
            camera_id_list = list(range(self.num_cameras))

        camera_params = np.zeros(shape=(len(camera_id_list), 13), dtype=float)
        cam_list = [self.cam_list[c] for c in camera_id_list]
        for i, cid in enumerate(camera_id_list):
            camera_params[i, 0:3] = np.squeeze(cam_list[cid].rvec)
            camera_params[i, 3:6] = np.squeeze(cam_list[cid].tvec)
            camera_params[i, 6] = cam_list[cid].focal_length_x
            camera_params[i, 7] = cam_list[cid].focal_length_y
            camera_params[i, 8:13] = np.squeeze(cam_list[cid].distort)

        point_indices = []
        camera_indices = []
        points2d_ba = []
        points3d_ba = []
        points3d_ba_source = dict()
        points3d_ba_source_inv = dict()
        point_index_counter = 0
        data_shape = self.points3d_m.shape

        if data_shape[0] > max_num_images:
            logger.debug(
                "There are too many ({}) images for calibration. Selecting {} randomly."
                .format(data_shape[0], max_num_images))
            img_id_list = np.random.randint(0,
                                            high=data_shape[0] - 1,
                                            size=(max_num_images))
        else:
            logger.debug("Using {} images for calibration".format(
                data_shape[0]))
            img_id_list = np.arange(data_shape[0] - 1)

        for img_id in img_id_list:
            for j_id in range(data_shape[1]):
                cam_list_iter = list()
                points2d_iter = list()
                for cam in cam_list:
                    if j_id in ignore_joint_list:
                        continue
                    if np.any(self.points3d_m[img_id, j_id, :] == 0):
                        continue
                    if np.any(cam[img_id, j_id, :] == 0):
                        continue
                    if not config["skeleton"].camera_see_joint(
                            cam.cam_id, j_id):
                        continue
                    if cam.cam_id == 3:
                        continue
                    cam_list_iter.append(cam)
                    points2d_iter.append(cam[img_id, j_id, :])

                # the point is seen by at least two cameras, add it to the bundle adjustment
                if len(cam_list_iter) >= 2:
                    points3d_iter = self.points3d_m[img_id, j_id, :]
                    points2d_ba.extend(points2d_iter)
                    points3d_ba.append(points3d_iter)
                    point_indices.extend([point_index_counter] *
                                         len(cam_list_iter))
                    points3d_ba_source[(img_id, j_id)] = point_index_counter
                    points3d_ba_source_inv[point_index_counter] = (img_id,
                                                                   j_id)
                    point_index_counter += 1
                    camera_indices.extend(
                        [cam.cam_id for cam in cam_list_iter])

        c = 0
        # make sure stripes from both sides share the same point id's
        # TODO move this into config file
        if "fly" in config["name"]:
            for idx, point_idx in enumerate(point_indices):
                img_id, j_id = points3d_ba_source_inv[point_idx]
                if (config["skeleton"].is_tracked_point(
                        j_id, config["skeleton"].Tracked.STRIPE)
                        and j_id > config["skeleton"].num_joints // 2):
                    if (
                            img_id,
                            j_id - config["skeleton"].num_joints // 2,
                    ) in points3d_ba_source:
                        point_indices[idx] = points3d_ba_source[(
                            img_id, j_id - config["skeleton"].num_joints // 2)]
                        c += 1

        logger.debug("Replaced {} points".format(c))
        points3d_ba = np.squeeze(np.array(points3d_ba))
        points2d_ba = np.squeeze(np.array(points2d_ba))
        cid2cidx = {
            v: k
            for (k, v) in enumerate(np.sort(np.unique(camera_indices)))
        }
        camera_indices = [cid2cidx[cid] for cid in camera_indices]
        camera_indices = np.array(camera_indices)
        point_indices = np.array(point_indices)

        n_cameras = camera_params.shape[0]
        n_points = points3d_ba.shape[0]

        x0 = np.hstack((camera_params.ravel(), points3d_ba.ravel()))

        return (
            x0.copy(),
            points2d_ba.copy(),
            n_cameras,
            n_points,
            camera_indices,
            point_indices,
        )
示例#23
0
    def __init__(
        self,
        image_folder,
        output_folder,
        cam_list=None,
        calibration=None,
        num_images=900,
        cam_id_list=range(config["num_cameras"]),
        cid2cidread=None,
        heatmap=None,
        pred=None,
        hm_path=None,
        pred_path=None,
    ):
        self.folder = image_folder
        self.folder_output = output_folder
        self.dict_name = image_folder
        self.points3d_m = None
        self.bone_param = None
        self.num_images = num_images
        self.num_joints = config["skeleton"].num_joints
        self.heatmap_shape = config["heatmap_shape"]
        self.image_shape = config["image_shape"]
        self.num_cameras = len(cam_id_list)

        if cid2cidread is not None:
            self.cid2cidread = cid2cidread
        else:
            _, cid2cidread = read_camera_order(self.folder_output)
            self.cid2cidread = cid2cidread

        if cam_list:
            logger.debug("Camera list is already given, skipping loading.")
            self.cam_list = cam_list
        else:
            self.cam_list = list()
            if pred_path is None:
                pred_path = find_pred_path(self.folder_output)
            if pred_path is None:
                logger.debug("no pred file under {}".format(
                    self.folder_output))
            if pred is None and pred_path is not None:
                logger.debug("loading pred path {}".format(pred_path))
                if pred_path.endswith(".json"):
                    pred = load_pred_from_json(pred_path,
                                               os.path.basename(image_folder),
                                               self.num_images)
                else:
                    pred = np.load(file=pred_path,
                                   mmap_mode="r",
                                   allow_pickle=True)

            num_images_in_pred = pred.shape[
                1] if pred is not None else num_images

            if hm_path is None:
                hm_path = find_hm_path(self.folder_output)
            if hm_path is None:
                logger.debug("no heatmap file under {}".format(
                    self.folder_output))
            if heatmap is None and hm_path is not None:
                hm_shape = (
                    config["num_cameras"] + 1,
                    num_images_in_pred,
                    config["num_predict"],
                    self.heatmap_shape[0],
                    self.heatmap_shape[1],
                )
                logger.debug("Heatmap shape: {}".format(hm_shape))
                logger.debug("Reading hm from {}".format(hm_path))
                heatmap = load_heatmap(hm_path, hm_shape)

            if self.num_images is not None and self.num_images < num_images_in_pred:
                if pred is not None:
                    pred = pred[:, :self.num_images]
                if heatmap is not None:
                    heatmap = heatmap[:, :self.num_images]

            for cam_id in cam_id_list:
                cam_id_read = cid2cidread[cam_id]
                pred_cam = pred2pred_cam(pred, cam_id, cam_id_read,
                                         self.image_shape, num_images_in_pred)
                self.cam_list.append(
                    Camera(
                        cid=cam_id,
                        cid_read=cam_id_read,
                        image_folder=image_folder,
                        hm=heatmap,
                        points2d=pred_cam,
                    ))

        if calibration is None:
            logger.debug("Reading calibration from {}".format(
                self.folder_output))
            calibration = read_calib(self.folder_output)
        if calibration is not None:
            _ = self.load_network(calibration)