示例#1
0
    def test_rotation_matrix_from_axis(self):
        x_axis = (1, 0, 0)
        y_axis = (0, 1, 0)
        rot = rotation_matrix_from_axis(x_axis, y_axis)
        _check_valid_rotation(rot)
        testing.assert_array_almost_equal(rot, np.eye(3))

        x_axis = (1, 1, 1)
        y_axis = (0, 0, 1)
        rot = rotation_matrix_from_axis(x_axis, y_axis)
        testing.assert_array_almost_equal(
            rot, [[0.57735027, -0.40824829, 0.70710678],
                  [0.57735027, -0.40824829, -0.70710678],
                  [0.57735027, 0.81649658, 0.0]])

        x_axis = (1, 1, 1)
        y_axis = (0, 0, -1)
        rot = rotation_matrix_from_axis(x_axis, y_axis)
        _check_valid_rotation(rot)
        testing.assert_array_almost_equal(
            rot, [[0.57735027, 0.40824829, -0.70710678],
                  [0.57735027, 0.40824829, 0.70710678],
                  [0.57735027, -0.81649658, 0.0]])

        rot = rotation_matrix_from_axis(y_axis, x_axis, axes='yx')
        _check_valid_rotation(rot)
示例#2
0
def main():
    current_dir = osp.dirname(osp.abspath(__file__))
    trained_model_dir = osp.join(osp.dirname(osp.dirname(current_dir)),
                                 'pretrained_model')

    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument('--input-dir',
                        '-i',
                        type=str,
                        help='input directory',
                        default=None)
    parser.add_argument('--color',
                        '-c',
                        type=str,
                        help='color image (.png)',
                        default=None)
    parser.add_argument('--depth',
                        '-d',
                        type=str,
                        help='depth image (.npy)',
                        default=None)
    parser.add_argument('--camera-info',
                        '-ci',
                        type=str,
                        help='camera info file (.yaml)',
                        default=None)

    parser.add_argument('--pretrained_model',
                        '-p',
                        type=str,
                        help='Pretrained models',
                        default=osp.join(trained_model_dir, 'hanging_1020.pt'))
    # default=osp.join(trained_model_dir, 'pouring.pt'))
    # default='/media/kosuke55/SANDISK-2/meshdata/random_shape_shapenet_hanging_render/1010/gan_2000per0-1000obj_1020.pt')  # gan hanging # noqa
    # default='/media/kosuke55/SANDISK-2/meshdata/random_shape_shapenet_pouring_render/1227/pouring_random_20201230_0215_5epoch.pt')  # gan pouring # noqa

    parser.add_argument('--predict-depth',
                        '-pd',
                        type=int,
                        help='predict-depth',
                        default=0)
    parser.add_argument('--task',
                        '-t',
                        type=str,
                        help='h(hanging) or p(pouring)'
                        'Not needed if roi size is the same in config.',
                        default='h')

    args = parser.parse_args()
    base_dir = args.input_dir
    pretrained_model = args.pretrained_model

    config_path = str(
        Path(osp.abspath(__file__)).parent.parent / 'learning_scripts' /
        'config' / 'gray_model.yaml')
    with open(config_path) as f:
        config = yaml.safe_load(f)

    task_type = args.task
    if task_type == 'p':
        task_type = 'pouring'
    else:
        task_type = 'hanging'

    target_size = tuple(config['target_size'])
    depth_range = config['depth_range']
    depth_roi_size = config['depth_roi_size'][task_type]

    print('task type: {}'.format(task_type))
    print('depth roi size: {}'.format(depth_roi_size))

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = HPNET(config).to(device)
    model.load_state_dict(torch.load(pretrained_model), strict=False)
    model.eval()

    viewer = skrobot.viewers.TrimeshSceneViewer(resolution=(640, 480))

    if base_dir is not None:
        color_paths = list(Path(base_dir).glob('**/color/*.png'))
    elif args.color is not None \
            and args.depth is not None \
            and args.camera_info is not None:
        color_paths = [args.color]

    else:
        return False

    is_first_loop = True
    try:
        for color_path in color_paths:
            if not is_first_loop:
                viewer.delete(pc)  # noqa
                for c in contact_point_sphere_list:  # noqa
                    viewer.delete(c)

            if base_dir is not None:
                camera_info_path = str(
                    (color_path.parent.parent / 'camera_info' /
                     color_path.stem).with_suffix('.yaml'))
                depth_path = str((color_path.parent.parent / 'depth' /
                                  color_path.stem).with_suffix('.npy'))
                color_path = str(color_path)
            else:
                camera_info_path = args.camera_info
                color_path = args.color
                depth_path = args.depth

            camera_model = cameramodels.PinholeCameraModel.from_yaml_file(
                camera_info_path)
            camera_model.target_size = target_size

            cv_bgr = cv2.imread(color_path)

            intrinsics = camera_model.open3d_intrinsic

            cv_bgr = cv2.resize(cv_bgr, target_size)
            cv_rgb = cv2.cvtColor(cv_bgr, cv2.COLOR_BGR2RGB)
            color = o3d.geometry.Image(cv_rgb)

            cv_depth = np.load(depth_path)
            cv_depth = cv2.resize(cv_depth,
                                  target_size,
                                  interpolation=cv2.INTER_NEAREST)
            depth = o3d.geometry.Image(cv_depth)

            rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth(
                color, depth, depth_trunc=4.0, convert_rgb_to_intensity=False)
            pcd = o3d.geometry.PointCloud.create_from_rgbd_image(
                rgbd, intrinsics)
            trimesh_pc = trimesh.PointCloud(np.asarray(pcd.points),
                                            np.asarray(pcd.colors))
            pc = skrobot.model.PointCloudLink(trimesh_pc)

            viewer.add(pc)

            if config['use_bgr2gray']:
                gray = cv2.cvtColor(cv_bgr, cv2.COLOR_BGR2GRAY)
                gray = cv2.resize(gray, target_size)[..., None] / 255.
                normalized_depth = normalize_depth(cv_depth, depth_range[0],
                                                   depth_range[1])[..., None]
                in_feature = np.concatenate((normalized_depth, gray),
                                            axis=2).astype(np.float32)
            else:
                raise NotImplementedError()

            transform = transforms.Compose([transforms.ToTensor()])
            in_feature = transform(in_feature)

            in_feature = in_feature.to(device)
            in_feature = in_feature.unsqueeze(0)

            confidence, depth, rotation = model(in_feature)

            confidence = confidence[0, 0:1, ...]
            confidence_np = confidence.cpu().detach().numpy().copy() * 255
            confidence_np = confidence_np.transpose(1, 2, 0)
            confidence_np[confidence_np <= 0] = 0
            confidence_np[confidence_np >= 255] = 255
            confidence_img = confidence_np.astype(np.uint8)

            print(model.rois_list)
            contact_point_sphere_list = []
            roi_image = cv_bgr.copy()
            for i, (roi, roi_center) in enumerate(
                    zip(model.rois_list[0], model.rois_center_list[0])):
                if roi.tolist() == [0, 0, 0, 0]:
                    continue
                roi = roi.cpu().detach().numpy().copy()
                roi_image = draw_roi(roi_image, roi)
                hanging_point_x = roi_center[0]
                hanging_point_y = roi_center[1]
                v = rotation[i].cpu().detach().numpy()
                v /= np.linalg.norm(v)
                rot = rotation_matrix_from_axis(v, [0, 1, 0], 'xy')
                q = matrix2quaternion(rot)

                hanging_point = np.array(
                    camera_model.project_pixel_to_3d_ray(
                        [int(hanging_point_x),
                         int(hanging_point_y)]))

                if args.predict_depth:
                    dep = depth[i].cpu().detach().numpy().copy()
                    dep = unnormalize_depth(dep, depth_range[0],
                                            depth_range[1]) * 0.001
                    length = float(dep) / hanging_point[2]
                else:
                    depth_roi = make_box(roi_center,
                                         width=depth_roi_size[1],
                                         height=depth_roi_size[0],
                                         img_shape=target_size,
                                         xywh=False)
                    depth_roi_clip = cv_depth[depth_roi[0]:depth_roi[2],
                                              depth_roi[1]:depth_roi[3]]

                    dep_roi_clip = depth_roi_clip[np.where(
                        np.logical_and(
                            config['depth_range'][0] < depth_roi_clip,
                            depth_roi_clip < config['depth_range'][1]))]

                    dep_roi_clip = np.median(dep_roi_clip) * 0.001

                    if dep_roi_clip == np.nan:
                        continue
                    length = float(dep_roi_clip) / hanging_point[2]

                hanging_point *= length

                contact_point_sphere = skrobot.model.Sphere(0.001,
                                                            color=[255, 0, 0])
                contact_point_sphere.newcoords(
                    skrobot.coordinates.Coordinates(pos=hanging_point, rot=q))
                viewer.add(contact_point_sphere)
                contact_point_sphere_list.append(contact_point_sphere)

            if is_first_loop:
                viewer.show()

            heatmap = overlay_heatmap(cv_bgr, confidence_img)
            cv2.imshow('heatmap', heatmap)
            cv2.imshow('roi', roi_image)
            print('Next data: [ENTER] on image window.\n'
                  'Quit: [q] on image window.')
            key = cv2.waitKey(0)
            cv2.destroyAllWindows()
            if key == ord('q'):
                break

            is_first_loop = False

    except KeyboardInterrupt:
        pass
    def callback(self, camera_info_msg, img_raw_msg, img_msg, depth_msg):
        ymin = camera_info_msg.roi.y_offset
        xmin = camera_info_msg.roi.x_offset
        ymax = camera_info_msg.roi.y_offset + camera_info_msg.roi.height
        xmax = camera_info_msg.roi.x_offset + camera_info_msg.roi.width
        self.camera_model.roi = [ymin, xmin, ymax, xmax]
        self.camera_model.target_size = self.target_size

        bgr_raw = self.bridge.imgmsg_to_cv2(img_raw_msg, "bgr8")
        bgr = self.bridge.imgmsg_to_cv2(img_msg, "bgr8")
        cv_depth = self.bridge.imgmsg_to_cv2(depth_msg, "32FC1")
        if cv_depth is None or bgr is None:
            return
        remove_nan(cv_depth)
        cv_depth[cv_depth < self.depth_range[0]] = 0
        cv_depth[cv_depth > self.depth_range[1]] = 0
        bgr = cv2.resize(bgr, self.target_size)
        cv_depth = cv2.resize(cv_depth,
                              self.target_size,
                              interpolation=cv2.INTER_NEAREST)

        depth_bgr = colorize_depth(cv_depth, ignore_value=0)

        in_feature = cv_depth.copy().astype(np.float32) * 0.001

        if self.config['use_bgr2gray']:
            gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY)
            gray = cv2.resize(gray, self.target_size)[..., None] / 255.
            normalized_depth = normalize_depth(cv_depth, self.depth_range[0],
                                               self.depth_range[1])[..., None]
            in_feature = np.concatenate((normalized_depth, gray),
                                        axis=2).astype(np.float32)

        if self.transform:
            in_feature = self.transform(in_feature)

        in_feature = in_feature.to(self.device)
        in_feature = in_feature.unsqueeze(0)

        confidence, depth, rotation = self.model(in_feature)
        confidence = confidence[0, 0:1, ...]
        confidence_np = confidence.cpu().detach().numpy().copy() * 255
        confidence_np = confidence_np.transpose(1, 2, 0)
        confidence_np[confidence_np <= 0] = 0
        confidence_np[confidence_np >= 255] = 255
        confidence_img = confidence_np.astype(np.uint8)
        confidence_img = cv2.resize(confidence_img, self.target_size)
        heatmap = overlay_heatmap(bgr, confidence_img)

        axis_pred = bgr.copy()
        axis_pred_raw = bgr_raw.copy()

        hanging_points_pose_array = PoseArray()
        for i, (roi, roi_center) in enumerate(
                zip(self.model.rois_list[0], self.model.rois_center_list[0])):
            if roi.tolist() == [0, 0, 0, 0]:
                continue
            roi = roi.cpu().detach().numpy().copy()
            hanging_point_x = roi_center[0]
            hanging_point_y = roi_center[1]
            v = rotation[i].cpu().detach().numpy()
            v /= np.linalg.norm(v)
            rot = rotation_matrix_from_axis(v, [0, 1, 0], 'xy')
            q = matrix2quaternion(rot)

            hanging_point = np.array(
                self.camera_model.project_pixel_to_3d_ray(
                    [int(hanging_point_x),
                     int(hanging_point_y)]))

            if self.predict_depth:
                dep = depth[i].cpu().detach().numpy().copy()
                dep = unnormalize_depth(dep, self.depth_range[0],
                                        self.depth_range[1]) * 0.001
                length = float(dep) / hanging_point[2]
            else:
                depth_roi = make_box(roi_center,
                                     width=self.depth_roi_size[1],
                                     height=self.depth_roi_size[0],
                                     img_shape=self.target_size,
                                     xywh=False)
                depth_roi_clip = cv_depth[depth_roi[0]:depth_roi[2],
                                          depth_roi[1]:depth_roi[3]]
                dep_roi_clip = depth_roi_clip[np.where(
                    np.logical_and(self.depth_range[0] < depth_roi_clip,
                                   depth_roi_clip < self.depth_range[1]))]
                dep_roi_clip = np.median(dep_roi_clip) * 0.001
                if dep_roi_clip == np.nan:
                    continue
                length = float(dep_roi_clip) / hanging_point[2]

            hanging_point *= length
            hanging_point_pose = Pose()
            hanging_point_pose.position.x = hanging_point[0]
            hanging_point_pose.position.y = hanging_point[1]
            hanging_point_pose.position.z = hanging_point[2]
            hanging_point_pose.orientation.w = q[0]
            hanging_point_pose.orientation.x = q[1]
            hanging_point_pose.orientation.y = q[2]
            hanging_point_pose.orientation.z = q[3]
            hanging_points_pose_array.poses.append(hanging_point_pose)

            axis_pred_raw = cv2.rectangle(
                axis_pred_raw,
                (int(roi[0] * (xmax - xmin) / self.target_size[1] + xmin),
                 int(roi[1] * (ymax - ymin) / self.target_size[0] + ymin)),
                (int(roi[2] * (xmax - xmin) / self.target_size[1] + xmin),
                 int(roi[3] * (ymax - ymin) / self.target_size[0] + ymin)),
                (0, 255, 0), 1)
            try:
                axis_pred_raw = draw_axis(axis_pred_raw, quaternion2matrix(q),
                                          hanging_point,
                                          self.camera_model.full_K)
            except Exception:
                print('Fail to draw axis')

        axis_pred = self.camera_model.crop_image(axis_pred_raw,
                                                 copy=True).astype(np.uint8)

        msg_out = self.bridge.cv2_to_imgmsg(heatmap, "bgr8")
        msg_out.header.stamp = depth_msg.header.stamp

        confidence_msg = self.bridge.cv2_to_imgmsg(confidence_img, "mono8")
        confidence_msg.header.stamp = depth_msg.header.stamp

        colorized_depth_msg = self.bridge.cv2_to_imgmsg(depth_bgr, "bgr8")
        colorized_depth_msg.header.stamp = depth_msg.header.stamp

        axis_pred_msg = self.bridge.cv2_to_imgmsg(axis_pred, "bgr8")
        axis_pred_msg.header.stamp = depth_msg.header.stamp

        axis_pred_raw_msg = self.bridge.cv2_to_imgmsg(axis_pred_raw, "bgr8")
        axis_pred_raw_msg.header.stamp = depth_msg.header.stamp

        hanging_points_pose_array.header = camera_info_msg.header
        self.pub.publish(msg_out)
        self.pub_confidence.publish(confidence_msg)
        self.pub_depth.publish(colorized_depth_msg)
        self.pub_axis.publish(axis_pred_msg)
        self.pub_axis_raw.publish(axis_pred_raw_msg)
        self.pub_hanging_points.publish(hanging_points_pose_array)
示例#4
0
        for i, (roi, roi_center) in enumerate(
                zip(model.rois_list[0], model.rois_center_list[0])):
            if roi.tolist() == [0, 0, 0, 0]:
                continue
            roi = roi.cpu().detach().numpy().copy()
            roi_image = draw_roi(roi_image, roi)
            roi_heatmap = draw_roi(roi_heatmap, roi)
            hanging_point_x = roi_center[0]
            hanging_point_y = roi_center[1]
            cv2.circle(roi_image, (hanging_point_x, hanging_point_y),
                       10, (19, 208, 251),
                       thickness=-1)
            v = rotation[i].cpu().detach().numpy()
            v /= np.linalg.norm(v)

            rot = rotation_matrix_from_axis(v, [0, 1, 0], 'xy')
            quaternion = matrix2quaternion(rot)

            camera_model_crop_resize \
                = camera_model.crop_resize_camera_info(target_size=target_size)

            hanging_point = np.array(
                camera_model_crop_resize.project_pixel_to_3d_ray(
                    [int(hanging_point_x),
                     int(hanging_point_y)]))

            draw_axis(axis_image, rot, hanging_point, camera_model.K)
            if args.predict_depth:
                dep = depth[i].cpu().detach().numpy().copy()
                dep = unnormalize_depth(dep, depth_range[0],
                                        depth_range[1]) * 0.001
    def step(self, dataloader, mode):
        print('Start {}'.format(mode))
        # self.model = self.prev_model
        if mode == 'train':
            self.model.train()
        elif mode == 'val' or mode == 'test':
            self.model.eval()

        loss_sum = 0
        confidence_loss_sum = 0
        depth_loss_sum = 0
        rotation_loss_sum = 0
        rotation_loss_count = 0

        for index, (hp_data, depth_image, camera_info_path, hp_data_gt,
                    annotation_data) in tqdm.tqdm(enumerate(dataloader),
                                                  total=len(dataloader),
                                                  desc='{} epoch={}'.format(
                                                      mode, self.epo),
                                                  leave=False):

            # if index == 0:
            #     self.model = self.prev_model

            self.cameramodel\
                = cameramodels.PinholeCameraModel.from_yaml_file(
                    camera_info_path[0])
            self.cameramodel.target_size = self.target_size

            depth_image = hp_data.numpy().copy()[0, 0, ...]
            depth_image = np.nan_to_num(depth_image)
            depth_image = unnormalize_depth(depth_image, self.depth_range[0],
                                            self.depth_range[1])
            hp_data = hp_data.to(self.device)

            depth_image_bgr = colorize_depth(depth_image,
                                             ignore_value=self.depth_range[0])

            if mode == 'train':
                confidence, depth, rotation = self.model(hp_data)
            elif mode == 'val' or mode == 'test':
                with torch.no_grad():
                    confidence, depth, rotation = self.model(hp_data)

            confidence_np = confidence[0, ...].cpu().detach().numpy().copy()
            confidence_np[confidence_np >= 1] = 1.
            confidence_np[confidence_np <= 0] = 0.
            confidence_vis = cv2.cvtColor(confidence_np[0, ...] * 255,
                                          cv2.COLOR_GRAY2BGR)

            if mode != 'test':
                pos_weight = hp_data_gt.detach().numpy().copy()
                pos_weight = pos_weight[:, 0, ...]
                zeroidx = np.where(pos_weight < 0.5)
                nonzeroidx = np.where(pos_weight >= 0.5)
                pos_weight[zeroidx] = 0.5
                pos_weight[nonzeroidx] = 1.0
                pos_weight = torch.from_numpy(pos_weight)
                pos_weight = pos_weight.to(self.device)

                hp_data_gt = hp_data_gt.to(self.device)
                confidence_gt = hp_data_gt[:, 0:1, ...]
                rois_list_gt, rois_center_list_gt = find_rois(confidence_gt)

                criterion = HPNETLoss(self.use_coords).to(self.device)

                if self.model.rois_list is None or rois_list_gt is None:
                    return None, None

                annotated_rois = annotate_rois(self.model.rois_list,
                                               rois_list_gt, annotation_data)

                confidence_loss, depth_loss, rotation_loss = criterion(
                    confidence, hp_data_gt, pos_weight, depth, rotation,
                    annotated_rois)

                if self.train_depth:
                    loss = confidence_loss + rotation_loss + depth_loss
                else:
                    loss = confidence_loss + rotation_loss

                if torch.isnan(loss):
                    print('loss is nan!!')
                    self.model = self.prev_model
                    self.optimizer = torch.optim.Adam(self.model.parameters(),
                                                      lr=self.lr,
                                                      betas=(0.9, 0.999),
                                                      eps=1e-10,
                                                      weight_decay=0,
                                                      amsgrad=False)
                    self.optimizer.load_state_dict(
                        self.prev_optimizer.state_dict())
                    continue
                else:
                    self.prev_model = copy.deepcopy(self.model)
                    self.prev_optimizer = copy.deepcopy(self.optimizer)

                if mode == 'train':
                    self.optimizer.zero_grad()
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), 5)
                    self.optimizer.step()

                axis_gt = depth_image_bgr.copy()

                confidence_gt_vis = cv2.cvtColor(
                    confidence_gt[0, 0, ...].cpu().detach().numpy().copy() *
                    255, cv2.COLOR_GRAY2BGR)

                # Visualize gt axis and roi
                for roi, roi_c in zip(rois_list_gt[0], rois_center_list_gt[0]):
                    if roi.tolist() == [0, 0, 0, 0]:
                        continue
                    roi = roi.cpu().detach().numpy().copy()
                    cx = roi_c[0]
                    cy = roi_c[1]

                    depth_and_rotation_gt = get_value_gt([cx, cy],
                                                         annotation_data[0])
                    rotation_gt = depth_and_rotation_gt[1:]
                    depth_gt_val = depth_and_rotation_gt[0]
                    unnormalized_depth_gt_val = unnormalize_depth(
                        depth_gt_val, self.depth_range[0], self.depth_range[1])

                    hanging_point_pose = np.array(
                        self.cameramodel.project_pixel_to_3d_ray(
                            [int(cx), int(cy)])) \
                        * unnormalized_depth_gt_val * 0.001

                    if self.use_coords:
                        rot = quaternion2matrix(rotation_gt),

                    else:
                        v = np.matmul(quaternion2matrix(rotation_gt),
                                      [1, 0, 0])
                        rot = rotation_matrix_from_axis(v, [0, 1, 0], 'xy')
                    try:
                        draw_axis(axis_gt, rot, hanging_point_pose,
                                  self.cameramodel.K)
                    except Exception:
                        print('Fail to draw axis')

                    confidence_gt_vis = draw_roi(confidence_gt_vis,
                                                 roi,
                                                 val=depth_gt_val,
                                                 gt=True)
                    axis_gt = draw_roi(axis_gt, roi, val=depth_gt_val, gt=True)

            # Visualize pred axis and roi
            axis_pred = depth_image_bgr.copy()

            for i, (roi, roi_c) in enumerate(
                    zip(self.model.rois_list[0],
                        self.model.rois_center_list[0])):

                if roi.tolist() == [0, 0, 0, 0]:
                    continue
                roi = roi.cpu().detach().numpy().copy()
                cx = roi_c[0]
                cy = roi_c[1]

                dep = depth[i].cpu().detach().numpy().copy()
                normalized_dep_pred = float(dep)
                dep = unnormalize_depth(dep, self.depth_range[0],
                                        self.depth_range[1])

                confidence_vis = draw_roi(confidence_vis,
                                          roi,
                                          val=normalized_dep_pred)
                axis_pred = draw_roi(axis_pred, roi, val=normalized_dep_pred)

                if mode != 'test':
                    if annotated_rois[i][2]:
                        confidence_vis = draw_roi(confidence_vis,
                                                  annotated_rois[i][0],
                                                  val=annotated_rois[i][1][0],
                                                  gt=True)
                        axis_pred = draw_roi(axis_pred,
                                             annotated_rois[i][0],
                                             val=annotated_rois[i][1][0],
                                             gt=True)

                hanging_point_pose = np.array(
                    self.cameramodel.project_pixel_to_3d_ray(
                        [int(cx), int(cy)])) * float(dep * 0.001)

                if self.use_coords:
                    # have not check this yet
                    q = rotation[i].cpu().detach().numpy().copy()
                    q /= np.linalg.norm(q)
                    rot = quaternion2matrix(q)

                else:
                    v = rotation[i].cpu().detach().numpy()
                    v /= np.linalg.norm(v)
                    rot = rotation_matrix_from_axis(v, [0, 1, 0], 'xy')

                try:
                    draw_axis(axis_pred, rot, hanging_point_pose,
                              self.cameramodel.K)
                except Exception:
                    print('Fail to draw axis')

            axis_pred = cv2.cvtColor(axis_pred, cv2.COLOR_BGR2RGB)
            confidence_vis = cv2.cvtColor(confidence_vis, cv2.COLOR_BGR2RGB)

            if self.config['use_bgr']:
                if self.config['use_bgr2gray']:
                    in_gray = hp_data.cpu().detach().numpy().copy()[0, 1:2,
                                                                    ...] * 255
                    in_gray = in_gray.transpose(1, 2, 0).astype(np.uint8)
                    in_gray = cv2.cvtColor(in_gray, cv2.COLOR_GRAY2RGB)
                    in_gray = in_gray.transpose(2, 0, 1)
                    in_img = in_gray
                else:
                    in_bgr = hp_data.cpu().detach().numpy().copy()[
                        0, 3:, ...].transpose(1, 2, 0)
                    in_rgb = cv2.cvtColor(in_bgr, cv2.COLOR_BGR2RGB).transpose(
                        2, 0, 1)
                    in_img = in_rgb

            if mode != 'test':
                confidence_loss_sum += confidence_loss.item()

                axis_gt = cv2.cvtColor(axis_gt, cv2.COLOR_BGR2RGB)
                confidence_gt_vis = cv2.cvtColor(confidence_gt_vis,
                                                 cv2.COLOR_BGR2RGB)

                if rotation_loss.item() > 0:
                    depth_loss_sum += depth_loss.item()
                    rotation_loss_sum += rotation_loss.item()
                    loss_sum = loss_sum \
                        + confidence_loss.item() \
                        + rotation_loss.item()
                    rotation_loss_count += 1

                if np.mod(index, 1) == 0:
                    print(
                        'epoch {}, {}/{},{} loss is confidence:{} rotation:{} depth:{}'
                        .format(  # noqa
                            self.epo, index, len(dataloader), mode,
                            confidence_loss.item(), rotation_loss.item(),
                            depth_loss.item()))

                self.vis.images(
                    [axis_gt.transpose(2, 0, 1),
                     axis_pred.transpose(2, 0, 1)],
                    win='{} axis'.format(mode),
                    opts=dict(title='{} axis'.format(mode)))
                self.vis.images(
                    [
                        confidence_gt_vis.transpose(2, 0, 1),
                        confidence_vis.transpose(2, 0, 1)
                    ],
                    win='{}_confidence_roi'.format(mode),
                    opts=dict(title='{} confidence(GT, Pred)'.format(mode)))

                if self.config['use_bgr']:
                    self.vis.images([in_img],
                                    win='{} in_gray'.format(mode),
                                    opts=dict(title='{} in_gray'.format(mode)))
            else:
                if self.config['use_bgr']:
                    self.vis.images(
                        [
                            in_img,
                            confidence_vis.transpose(2, 0, 1),
                            axis_pred.transpose(2, 0, 1)
                        ],
                        win='{}-{}'.format(mode, index),
                        opts=dict(
                            title='{}-{} hanging_point_depth (pred)'.format(
                                mode, index)))
                else:
                    self.vis.images(
                        [
                            confidence_vis.transpose(2, 0, 1),
                            axis_pred.transpose(2, 0, 1)
                        ],
                        win='{}-{}'.format(mode, index),
                        opts=dict(
                            title='{}-{} hanging_point_depth (pred)'.format(
                                mode, index)))

            if np.mod(index, 1000) == 0:
                save_file = osp.join(
                    self.save_dir,
                    'hpnet_latestmodel_' + self.time_now + '.pt')
                print('save {}'.format(save_file))
                torch.save(self.model.state_dict(),
                           save_file,
                           _use_new_zipfile_serialization=False)

        if mode != 'test':
            if len(dataloader) > 0:
                avg_confidence_loss\
                    = confidence_loss_sum / len(dataloader)
                if rotation_loss_count > 0:
                    avg_rotation_loss\
                        = rotation_loss_sum / rotation_loss_count
                    avg_depth_loss\
                        = depth_loss_sum / rotation_loss_count
                    avg_loss\
                        = loss_sum / rotation_loss_count
                else:
                    avg_rotation_loss = 1e10
                    avg_depth_loss = 1e10
                    avg_loss = 1e10
            else:
                avg_loss = loss_sum
                avg_confidence_loss = confidence_loss_sum
                avg_rotation_loss = rotation_loss_sum
                avg_depth_loss = rotation_loss_sum

            self.vis.line(X=np.array([self.epo]),
                          Y=np.array([avg_confidence_loss]),
                          opts={'title': 'confidence'},
                          win='confidence loss',
                          name='{}_confidence_loss'.format(mode),
                          update='append')
            if rotation_loss_count > 0:
                self.vis.line(X=np.array([self.epo]),
                              Y=np.array([avg_rotation_loss]),
                              opts={'title': 'rotation loss'},
                              win='rotation loss',
                              name='{}_rotation_loss'.format(mode),
                              update='append')
                self.vis.line(X=np.array([self.epo]),
                              Y=np.array([avg_depth_loss]),
                              opts={'title': 'depth loss'},
                              win='depth loss',
                              name='{}_depth_loss'.format(mode),
                              update='append')
                self.vis.line(X=np.array([self.epo]),
                              Y=np.array([avg_loss]),
                              opts={'title': 'loss'},
                              win='loss',
                              name='{}_loss'.format(mode),
                              update='append')

            if mode == 'val':
                if np.mod(self.epo, self.save_model_interval) == 0:
                    save_file = osp.join(
                        self.save_dir,
                        'hpnet_latestmodel_' + self.time_now + '.pt')
                    print('save {}'.format(save_file))
                    torch.save(self.model.state_dict(),
                               save_file,
                               _use_new_zipfile_serialization=False)

                if self.best_loss > avg_loss:
                    print('update best model {} -> {}'.format(
                        self.best_loss, avg_loss))
                    self.best_loss = avg_loss
                    save_file = osp.join(
                        self.save_dir,
                        'hpnet_bestmodel_' + self.time_now + '.pt')
                    print('save {}'.format(save_file))
                    # For ros(python 2, torch 1.4)
                    torch.save(self.model.state_dict(),
                               save_file,
                               _use_new_zipfile_serialization=False)
#!/usr/bin/env python

import numpy as np

import skrobot
from skrobot.coordinates.math import rotation_matrix_from_axis
from skrobot.model import Axis
from skrobot.model import Box
from skrobot.model import MeshLink
from skrobot.sdf import UnionSDF

b = Box(extents=[0.05, 0.1, 0.05], with_sdf=True)
m = MeshLink(visual_mesh=skrobot.data.bunny_objpath(), with_sdf=True)
b.translate([0, 0.1, 0])
u = UnionSDF([b.sdf, m.sdf])
axis = Axis(axis_radius=0.001, axis_length=0.3)
viewer = skrobot.viewers.TrimeshSceneViewer(resolution=(640, 480))
viewer.add(b)
viewer.add(m)
viewer.show()
pts, sd_vals = u.surface_points()

for _ in range(100):
    idx = np.random.randint(len(pts))
    rot = rotation_matrix_from_axis(np.random.random(3), np.random.random(3))
    ax = Axis(axis_radius=0.001, axis_length=0.01, pos=pts[idx], rot=rot)
    viewer.add(ax)