Exemple #1
0
def main():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--pretrained_model',
                        '-p',
                        type=str,
                        help='Pretrained models',
                        required=True)  # noqa

    args = parser.parse_args()
    pretrained_model = args.pretrained_model

    config = {
        'output_channels': 1,
        'feature_extractor_name': 'resnet50',
        'confidence_thresh': 0.1,
        'depth_range': [100, 1500],
        'use_bgr': True,
        'use_bgr2gray': True,
        'roi_padding': 50
    }

    depth_range = config['depth_range']

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = HPNET(config).to(device)
    model.load_state_dict(torch.load(pretrained_model))
    model.eval()

    torch.save(model.state_dict(),
               pretrained_model,
               _use_new_zipfile_serialization=False)
Exemple #2
0
def main():
    current_dir = osp.dirname(osp.abspath(__file__))
    trained_model_dir = osp.join(osp.dirname(osp.dirname(current_dir)),
                                 'pretrained_model')

    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument('--input-dir',
                        '-i',
                        type=str,
                        help='input directory',
                        default=None)
    parser.add_argument('--color',
                        '-c',
                        type=str,
                        help='color image (.png)',
                        default=None)
    parser.add_argument('--depth',
                        '-d',
                        type=str,
                        help='depth image (.npy)',
                        default=None)
    parser.add_argument('--camera-info',
                        '-ci',
                        type=str,
                        help='camera info file (.yaml)',
                        default=None)

    parser.add_argument('--pretrained_model',
                        '-p',
                        type=str,
                        help='Pretrained models',
                        default=osp.join(trained_model_dir, 'hanging_1020.pt'))
    # default=osp.join(trained_model_dir, 'pouring.pt'))
    # default='/media/kosuke55/SANDISK-2/meshdata/random_shape_shapenet_hanging_render/1010/gan_2000per0-1000obj_1020.pt')  # gan hanging # noqa
    # default='/media/kosuke55/SANDISK-2/meshdata/random_shape_shapenet_pouring_render/1227/pouring_random_20201230_0215_5epoch.pt')  # gan pouring # noqa

    parser.add_argument('--predict-depth',
                        '-pd',
                        type=int,
                        help='predict-depth',
                        default=0)
    parser.add_argument('--task',
                        '-t',
                        type=str,
                        help='h(hanging) or p(pouring)'
                        'Not needed if roi size is the same in config.',
                        default='h')

    args = parser.parse_args()
    base_dir = args.input_dir
    pretrained_model = args.pretrained_model

    config_path = str(
        Path(osp.abspath(__file__)).parent.parent / 'learning_scripts' /
        'config' / 'gray_model.yaml')
    with open(config_path) as f:
        config = yaml.safe_load(f)

    task_type = args.task
    if task_type == 'p':
        task_type = 'pouring'
    else:
        task_type = 'hanging'

    target_size = tuple(config['target_size'])
    depth_range = config['depth_range']
    depth_roi_size = config['depth_roi_size'][task_type]

    print('task type: {}'.format(task_type))
    print('depth roi size: {}'.format(depth_roi_size))

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = HPNET(config).to(device)
    model.load_state_dict(torch.load(pretrained_model), strict=False)
    model.eval()

    viewer = skrobot.viewers.TrimeshSceneViewer(resolution=(640, 480))

    if base_dir is not None:
        color_paths = list(Path(base_dir).glob('**/color/*.png'))
    elif args.color is not None \
            and args.depth is not None \
            and args.camera_info is not None:
        color_paths = [args.color]

    else:
        return False

    is_first_loop = True
    try:
        for color_path in color_paths:
            if not is_first_loop:
                viewer.delete(pc)  # noqa
                for c in contact_point_sphere_list:  # noqa
                    viewer.delete(c)

            if base_dir is not None:
                camera_info_path = str(
                    (color_path.parent.parent / 'camera_info' /
                     color_path.stem).with_suffix('.yaml'))
                depth_path = str((color_path.parent.parent / 'depth' /
                                  color_path.stem).with_suffix('.npy'))
                color_path = str(color_path)
            else:
                camera_info_path = args.camera_info
                color_path = args.color
                depth_path = args.depth

            camera_model = cameramodels.PinholeCameraModel.from_yaml_file(
                camera_info_path)
            camera_model.target_size = target_size

            cv_bgr = cv2.imread(color_path)

            intrinsics = camera_model.open3d_intrinsic

            cv_bgr = cv2.resize(cv_bgr, target_size)
            cv_rgb = cv2.cvtColor(cv_bgr, cv2.COLOR_BGR2RGB)
            color = o3d.geometry.Image(cv_rgb)

            cv_depth = np.load(depth_path)
            cv_depth = cv2.resize(cv_depth,
                                  target_size,
                                  interpolation=cv2.INTER_NEAREST)
            depth = o3d.geometry.Image(cv_depth)

            rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth(
                color, depth, depth_trunc=4.0, convert_rgb_to_intensity=False)
            pcd = o3d.geometry.PointCloud.create_from_rgbd_image(
                rgbd, intrinsics)
            trimesh_pc = trimesh.PointCloud(np.asarray(pcd.points),
                                            np.asarray(pcd.colors))
            pc = skrobot.model.PointCloudLink(trimesh_pc)

            viewer.add(pc)

            if config['use_bgr2gray']:
                gray = cv2.cvtColor(cv_bgr, cv2.COLOR_BGR2GRAY)
                gray = cv2.resize(gray, target_size)[..., None] / 255.
                normalized_depth = normalize_depth(cv_depth, depth_range[0],
                                                   depth_range[1])[..., None]
                in_feature = np.concatenate((normalized_depth, gray),
                                            axis=2).astype(np.float32)
            else:
                raise NotImplementedError()

            transform = transforms.Compose([transforms.ToTensor()])
            in_feature = transform(in_feature)

            in_feature = in_feature.to(device)
            in_feature = in_feature.unsqueeze(0)

            confidence, depth, rotation = model(in_feature)

            confidence = confidence[0, 0:1, ...]
            confidence_np = confidence.cpu().detach().numpy().copy() * 255
            confidence_np = confidence_np.transpose(1, 2, 0)
            confidence_np[confidence_np <= 0] = 0
            confidence_np[confidence_np >= 255] = 255
            confidence_img = confidence_np.astype(np.uint8)

            print(model.rois_list)
            contact_point_sphere_list = []
            roi_image = cv_bgr.copy()
            for i, (roi, roi_center) in enumerate(
                    zip(model.rois_list[0], model.rois_center_list[0])):
                if roi.tolist() == [0, 0, 0, 0]:
                    continue
                roi = roi.cpu().detach().numpy().copy()
                roi_image = draw_roi(roi_image, roi)
                hanging_point_x = roi_center[0]
                hanging_point_y = roi_center[1]
                v = rotation[i].cpu().detach().numpy()
                v /= np.linalg.norm(v)
                rot = rotation_matrix_from_axis(v, [0, 1, 0], 'xy')
                q = matrix2quaternion(rot)

                hanging_point = np.array(
                    camera_model.project_pixel_to_3d_ray(
                        [int(hanging_point_x),
                         int(hanging_point_y)]))

                if args.predict_depth:
                    dep = depth[i].cpu().detach().numpy().copy()
                    dep = unnormalize_depth(dep, depth_range[0],
                                            depth_range[1]) * 0.001
                    length = float(dep) / hanging_point[2]
                else:
                    depth_roi = make_box(roi_center,
                                         width=depth_roi_size[1],
                                         height=depth_roi_size[0],
                                         img_shape=target_size,
                                         xywh=False)
                    depth_roi_clip = cv_depth[depth_roi[0]:depth_roi[2],
                                              depth_roi[1]:depth_roi[3]]

                    dep_roi_clip = depth_roi_clip[np.where(
                        np.logical_and(
                            config['depth_range'][0] < depth_roi_clip,
                            depth_roi_clip < config['depth_range'][1]))]

                    dep_roi_clip = np.median(dep_roi_clip) * 0.001

                    if dep_roi_clip == np.nan:
                        continue
                    length = float(dep_roi_clip) / hanging_point[2]

                hanging_point *= length

                contact_point_sphere = skrobot.model.Sphere(0.001,
                                                            color=[255, 0, 0])
                contact_point_sphere.newcoords(
                    skrobot.coordinates.Coordinates(pos=hanging_point, rot=q))
                viewer.add(contact_point_sphere)
                contact_point_sphere_list.append(contact_point_sphere)

            if is_first_loop:
                viewer.show()

            heatmap = overlay_heatmap(cv_bgr, confidence_img)
            cv2.imshow('heatmap', heatmap)
            cv2.imshow('roi', roi_image)
            print('Next data: [ENTER] on image window.\n'
                  'Quit: [q] on image window.')
            key = cv2.waitKey(0)
            cv2.destroyAllWindows()
            if key == ord('q'):
                break

            is_first_loop = False

    except KeyboardInterrupt:
        pass
    def __init__(self):
        self.bridge = CvBridge()
        self.pub = rospy.Publisher("~output", Image, queue_size=10)
        self.pub_confidence = rospy.Publisher("~output/confidence",
                                              Image,
                                              queue_size=10)
        self.pub_depth = rospy.Publisher("~colorized_depth",
                                         Image,
                                         queue_size=10)
        self.pub_axis = rospy.Publisher("~axis", Image, queue_size=10)
        self.pub_axis_raw = rospy.Publisher("~axis_raw", Image, queue_size=10)
        self.pub_hanging_points = rospy.Publisher("/hanging_points",
                                                  PoseArray,
                                                  queue_size=10)

        self.gpu_id = rospy.get_param('~gpu', 0)
        self.predict_depth = rospy.get_param('~predict_depth', True)
        print('self.predict_depth: ', self.predict_depth)

        os.environ['CUDA_VISIBLE_DEVICES'] = str(self.gpu_id)
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')

        pretrained_model = rospy.get_param(
            '~pretrained_model',
            '/media/kosuke/SANDISK/hanging_points_net/checkpoints/gray/hpnet_bestmodel_20201025_1542.pt'
        )  # noqa
        task_type = rospy.get_param('~task_type', 'hanging')
        config_path = rospy.get_param('~config', None)

        if config_path is None:
            rospack = rospkg.RosPack()
            pack_path = rospack.get_path('hanging_points_cnn')
            config_path = osp.join(pack_path, 'hanging_points_cnn',
                                   'learning_scripts', 'config',
                                   'gray_model.yaml')
        print('Load ' + config_path)
        with open(config_path) as f:
            self.config = yaml.safe_load(f)

        self.transform = transforms.Compose([transforms.ToTensor()])

        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        self.depth_range = self.config['depth_range']
        self.target_size = tuple(self.config['target_size'])
        self.depth_roi_size = self.config['depth_roi_size'][task_type]
        print('task type: {}'.format(task_type))
        print('depth roi size: {}'.format(self.depth_roi_size))

        self.model = HPNET(self.config).to(device)

        if osp.exists(pretrained_model):
            print('use pretrained model')
            self.model.load_state_dict(torch.load(pretrained_model),
                                       strict=False)

        self.model.eval()

        self.camera_model = None
        self.load_camera_info()

        self.use_coords = False

        self.subscribe()
class HangingPointsNet():
    def __init__(self):
        self.bridge = CvBridge()
        self.pub = rospy.Publisher("~output", Image, queue_size=10)
        self.pub_confidence = rospy.Publisher("~output/confidence",
                                              Image,
                                              queue_size=10)
        self.pub_depth = rospy.Publisher("~colorized_depth",
                                         Image,
                                         queue_size=10)
        self.pub_axis = rospy.Publisher("~axis", Image, queue_size=10)
        self.pub_axis_raw = rospy.Publisher("~axis_raw", Image, queue_size=10)
        self.pub_hanging_points = rospy.Publisher("/hanging_points",
                                                  PoseArray,
                                                  queue_size=10)

        self.gpu_id = rospy.get_param('~gpu', 0)
        self.predict_depth = rospy.get_param('~predict_depth', True)
        print('self.predict_depth: ', self.predict_depth)

        os.environ['CUDA_VISIBLE_DEVICES'] = str(self.gpu_id)
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')

        pretrained_model = rospy.get_param(
            '~pretrained_model',
            '/media/kosuke/SANDISK/hanging_points_net/checkpoints/gray/hpnet_bestmodel_20201025_1542.pt'
        )  # noqa
        task_type = rospy.get_param('~task_type', 'hanging')
        config_path = rospy.get_param('~config', None)

        if config_path is None:
            rospack = rospkg.RosPack()
            pack_path = rospack.get_path('hanging_points_cnn')
            config_path = osp.join(pack_path, 'hanging_points_cnn',
                                   'learning_scripts', 'config',
                                   'gray_model.yaml')
        print('Load ' + config_path)
        with open(config_path) as f:
            self.config = yaml.safe_load(f)

        self.transform = transforms.Compose([transforms.ToTensor()])

        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        self.depth_range = self.config['depth_range']
        self.target_size = tuple(self.config['target_size'])
        self.depth_roi_size = self.config['depth_roi_size'][task_type]
        print('task type: {}'.format(task_type))
        print('depth roi size: {}'.format(self.depth_roi_size))

        self.model = HPNET(self.config).to(device)

        if osp.exists(pretrained_model):
            print('use pretrained model')
            self.model.load_state_dict(torch.load(pretrained_model),
                                       strict=False)

        self.model.eval()

        self.camera_model = None
        self.load_camera_info()

        self.use_coords = False

        self.subscribe()

    def subscribe(self):
        self.sub_camera_info = message_filters.Subscriber('~camera_info',
                                                          CameraInfo,
                                                          queue_size=1,
                                                          buff_size=2**24)
        self.sub_rgb_raw = message_filters.Subscriber('~rgb_raw',
                                                      Image,
                                                      queue_size=1,
                                                      buff_size=2**24)
        self.sub_rgb = message_filters.Subscriber('~rgb',
                                                  Image,
                                                  queue_size=1,
                                                  buff_size=2**24)
        self.sub_depth = message_filters.Subscriber('~depth',
                                                    Image,
                                                    queue_size=1,
                                                    buff_size=2**24)
        sync = message_filters.ApproximateTimeSynchronizer([
            self.sub_camera_info, self.sub_rgb_raw, self.sub_rgb,
            self.sub_depth
        ],
                                                           queue_size=100,
                                                           slop=0.1)
        sync.registerCallback(self.callback)

    def get_full_camera_info(self, camera_info):
        full_camera_info = copy.copy(camera_info)
        full_camera_info.roi.x_offset = 0
        full_camera_info.roi.y_offset = 0
        full_camera_info.roi.height = 0
        full_camera_info.roi.width = 0

        return full_camera_info

    def load_camera_info(self):
        print('load camera info')
        self.camera_info = rospy.wait_for_message('~camera_info', CameraInfo)
        self.camera_model\
            = cameramodels.PinholeCameraModel.from_camera_info(
                self.camera_info)

    def callback(self, camera_info_msg, img_raw_msg, img_msg, depth_msg):
        ymin = camera_info_msg.roi.y_offset
        xmin = camera_info_msg.roi.x_offset
        ymax = camera_info_msg.roi.y_offset + camera_info_msg.roi.height
        xmax = camera_info_msg.roi.x_offset + camera_info_msg.roi.width
        self.camera_model.roi = [ymin, xmin, ymax, xmax]
        self.camera_model.target_size = self.target_size

        bgr_raw = self.bridge.imgmsg_to_cv2(img_raw_msg, "bgr8")
        bgr = self.bridge.imgmsg_to_cv2(img_msg, "bgr8")
        cv_depth = self.bridge.imgmsg_to_cv2(depth_msg, "32FC1")
        if cv_depth is None or bgr is None:
            return
        remove_nan(cv_depth)
        cv_depth[cv_depth < self.depth_range[0]] = 0
        cv_depth[cv_depth > self.depth_range[1]] = 0
        bgr = cv2.resize(bgr, self.target_size)
        cv_depth = cv2.resize(cv_depth,
                              self.target_size,
                              interpolation=cv2.INTER_NEAREST)

        depth_bgr = colorize_depth(cv_depth, ignore_value=0)

        in_feature = cv_depth.copy().astype(np.float32) * 0.001

        if self.config['use_bgr2gray']:
            gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY)
            gray = cv2.resize(gray, self.target_size)[..., None] / 255.
            normalized_depth = normalize_depth(cv_depth, self.depth_range[0],
                                               self.depth_range[1])[..., None]
            in_feature = np.concatenate((normalized_depth, gray),
                                        axis=2).astype(np.float32)

        if self.transform:
            in_feature = self.transform(in_feature)

        in_feature = in_feature.to(self.device)
        in_feature = in_feature.unsqueeze(0)

        confidence, depth, rotation = self.model(in_feature)
        confidence = confidence[0, 0:1, ...]
        confidence_np = confidence.cpu().detach().numpy().copy() * 255
        confidence_np = confidence_np.transpose(1, 2, 0)
        confidence_np[confidence_np <= 0] = 0
        confidence_np[confidence_np >= 255] = 255
        confidence_img = confidence_np.astype(np.uint8)
        confidence_img = cv2.resize(confidence_img, self.target_size)
        heatmap = overlay_heatmap(bgr, confidence_img)

        axis_pred = bgr.copy()
        axis_pred_raw = bgr_raw.copy()

        hanging_points_pose_array = PoseArray()
        for i, (roi, roi_center) in enumerate(
                zip(self.model.rois_list[0], self.model.rois_center_list[0])):
            if roi.tolist() == [0, 0, 0, 0]:
                continue
            roi = roi.cpu().detach().numpy().copy()
            hanging_point_x = roi_center[0]
            hanging_point_y = roi_center[1]
            v = rotation[i].cpu().detach().numpy()
            v /= np.linalg.norm(v)
            rot = rotation_matrix_from_axis(v, [0, 1, 0], 'xy')
            q = matrix2quaternion(rot)

            hanging_point = np.array(
                self.camera_model.project_pixel_to_3d_ray(
                    [int(hanging_point_x),
                     int(hanging_point_y)]))

            if self.predict_depth:
                dep = depth[i].cpu().detach().numpy().copy()
                dep = unnormalize_depth(dep, self.depth_range[0],
                                        self.depth_range[1]) * 0.001
                length = float(dep) / hanging_point[2]
            else:
                depth_roi = make_box(roi_center,
                                     width=self.depth_roi_size[1],
                                     height=self.depth_roi_size[0],
                                     img_shape=self.target_size,
                                     xywh=False)
                depth_roi_clip = cv_depth[depth_roi[0]:depth_roi[2],
                                          depth_roi[1]:depth_roi[3]]
                dep_roi_clip = depth_roi_clip[np.where(
                    np.logical_and(self.depth_range[0] < depth_roi_clip,
                                   depth_roi_clip < self.depth_range[1]))]
                dep_roi_clip = np.median(dep_roi_clip) * 0.001
                if dep_roi_clip == np.nan:
                    continue
                length = float(dep_roi_clip) / hanging_point[2]

            hanging_point *= length
            hanging_point_pose = Pose()
            hanging_point_pose.position.x = hanging_point[0]
            hanging_point_pose.position.y = hanging_point[1]
            hanging_point_pose.position.z = hanging_point[2]
            hanging_point_pose.orientation.w = q[0]
            hanging_point_pose.orientation.x = q[1]
            hanging_point_pose.orientation.y = q[2]
            hanging_point_pose.orientation.z = q[3]
            hanging_points_pose_array.poses.append(hanging_point_pose)

            axis_pred_raw = cv2.rectangle(
                axis_pred_raw,
                (int(roi[0] * (xmax - xmin) / self.target_size[1] + xmin),
                 int(roi[1] * (ymax - ymin) / self.target_size[0] + ymin)),
                (int(roi[2] * (xmax - xmin) / self.target_size[1] + xmin),
                 int(roi[3] * (ymax - ymin) / self.target_size[0] + ymin)),
                (0, 255, 0), 1)
            try:
                axis_pred_raw = draw_axis(axis_pred_raw, quaternion2matrix(q),
                                          hanging_point,
                                          self.camera_model.full_K)
            except Exception:
                print('Fail to draw axis')

        axis_pred = self.camera_model.crop_image(axis_pred_raw,
                                                 copy=True).astype(np.uint8)

        msg_out = self.bridge.cv2_to_imgmsg(heatmap, "bgr8")
        msg_out.header.stamp = depth_msg.header.stamp

        confidence_msg = self.bridge.cv2_to_imgmsg(confidence_img, "mono8")
        confidence_msg.header.stamp = depth_msg.header.stamp

        colorized_depth_msg = self.bridge.cv2_to_imgmsg(depth_bgr, "bgr8")
        colorized_depth_msg.header.stamp = depth_msg.header.stamp

        axis_pred_msg = self.bridge.cv2_to_imgmsg(axis_pred, "bgr8")
        axis_pred_msg.header.stamp = depth_msg.header.stamp

        axis_pred_raw_msg = self.bridge.cv2_to_imgmsg(axis_pred_raw, "bgr8")
        axis_pred_raw_msg.header.stamp = depth_msg.header.stamp

        hanging_points_pose_array.header = camera_info_msg.header
        self.pub.publish(msg_out)
        self.pub_confidence.publish(confidence_msg)
        self.pub_depth.publish(colorized_depth_msg)
        self.pub_axis.publish(axis_pred_msg)
        self.pub_axis_raw.publish(axis_pred_raw_msg)
        self.pub_hanging_points.publish(hanging_points_pose_array)
task_type = args.task
if task_type == 'p':
    task_type = 'pouring'
else:
    task_type = 'hanging'

target_size = tuple(config['target_size'])
depth_range = config['depth_range']
depth_roi_size = config['depth_roi_size'][task_type]

print('task type: {}'.format(task_type))
print('depth roi size: {}'.format(depth_roi_size))

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = HPNET(config).to(device)
model.load_state_dict(torch.load(pretrained_model), strict=False)
model.eval()

viewer = skrobot.viewers.TrimeshSceneViewer(resolution=(640, 480))
thresh_distance = 0.03

if is_sim_data:
    print('Inference with sim data')
    color_paths = list(sorted(Path(base_dir).glob('*/*/color/*.png')))
else:
    print('Inference with real data')
    color_paths = list(sorted(Path(base_dir).glob('*/color/*.png')))

gui = args.gui
save_dir = args.save_dir
    def __init__(self,
                 data_path,
                 test_data_path,
                 batch_size,
                 max_epoch,
                 pretrained_model,
                 train_data_num,
                 val_data_num,
                 save_dir,
                 lr,
                 config=None,
                 train_depth=False,
                 port=6006,
                 object_list=None,
                 data_augmentation=True):

        if config is None:
            warnings.warn('confing is not specified. use defalut confing.')
            config = {
                'output_channels': 1,
                'feature_extractor_name': 'resnet50',
                'confidence_thresh': 0.3,
                'depth_range': [100, 1500],
                'use_bgr': True,
                'use_bgr2gray': True,
                'roi_padding': 50
            }
        self.config = config
        self.depth_range = config['depth_range']
        self.train_dataloader, self.val_dataloader\
            = load_dataset(data_path, batch_size,
                           use_bgr=self.config['use_bgr'],
                           use_bgr2gray=self.config['use_bgr2gray'],
                           depth_range=self.depth_range,
                           object_list=object_list,
                           data_augmentation=data_augmentation)
        self.test_dataloader \
            = load_test_dataset(test_data_path,
                                use_bgr=self.config['use_bgr'],
                                use_bgr2gray=self.config['use_bgr2gray'],
                                depth_range=self.depth_range)

        self.train_data_num = train_data_num
        self.val_data_num = val_data_num
        self.save_dir = save_dir
        os.makedirs(self.save_dir, exist_ok=True)

        self.lr = lr
        self.max_epoch = max_epoch
        self.time_now = datetime.now().strftime('%Y%m%d_%H%M')
        self.best_loss = 1e10

        self.camramodel = None
        self.target_size = [256, 256]

        self.vis = visdom.Visdom(port=port)

        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        print('device is:{}'.format(self.device))

        self.model = HPNET(config).to(self.device)

        self.save_model_interval = 1
        if os.path.exists(pretrained_model):
            print('use pretrained model')
            self.model.load_state_dict(torch.load(pretrained_model),
                                       strict=False)

        self.prev_model = copy.deepcopy(self.model)

        self.best_loss = 1e10
        self.optimizer = torch.optim.Adam(self.model.parameters(),
                                          lr=self.lr,
                                          betas=(0.9, 0.999),
                                          eps=1e-10,
                                          weight_decay=0,
                                          amsgrad=False)
        self.prev_optimizer = copy.deepcopy(self.optimizer)

        self.scheduler = optim.lr_scheduler.LambdaLR(
            self.optimizer, lr_lambda=lambda epo: 0.9**epo)

        self.now = datetime.now().strftime('%Y%m%d_%H%M')
        self.use_coords = False
        self.train_depth = train_depth
class Trainer(object):
    def __init__(self,
                 data_path,
                 test_data_path,
                 batch_size,
                 max_epoch,
                 pretrained_model,
                 train_data_num,
                 val_data_num,
                 save_dir,
                 lr,
                 config=None,
                 train_depth=False,
                 port=6006,
                 object_list=None,
                 data_augmentation=True):

        if config is None:
            warnings.warn('confing is not specified. use defalut confing.')
            config = {
                'output_channels': 1,
                'feature_extractor_name': 'resnet50',
                'confidence_thresh': 0.3,
                'depth_range': [100, 1500],
                'use_bgr': True,
                'use_bgr2gray': True,
                'roi_padding': 50
            }
        self.config = config
        self.depth_range = config['depth_range']
        self.train_dataloader, self.val_dataloader\
            = load_dataset(data_path, batch_size,
                           use_bgr=self.config['use_bgr'],
                           use_bgr2gray=self.config['use_bgr2gray'],
                           depth_range=self.depth_range,
                           object_list=object_list,
                           data_augmentation=data_augmentation)
        self.test_dataloader \
            = load_test_dataset(test_data_path,
                                use_bgr=self.config['use_bgr'],
                                use_bgr2gray=self.config['use_bgr2gray'],
                                depth_range=self.depth_range)

        self.train_data_num = train_data_num
        self.val_data_num = val_data_num
        self.save_dir = save_dir
        os.makedirs(self.save_dir, exist_ok=True)

        self.lr = lr
        self.max_epoch = max_epoch
        self.time_now = datetime.now().strftime('%Y%m%d_%H%M')
        self.best_loss = 1e10

        self.camramodel = None
        self.target_size = [256, 256]

        self.vis = visdom.Visdom(port=port)

        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        print('device is:{}'.format(self.device))

        self.model = HPNET(config).to(self.device)

        self.save_model_interval = 1
        if os.path.exists(pretrained_model):
            print('use pretrained model')
            self.model.load_state_dict(torch.load(pretrained_model),
                                       strict=False)

        self.prev_model = copy.deepcopy(self.model)

        self.best_loss = 1e10
        self.optimizer = torch.optim.Adam(self.model.parameters(),
                                          lr=self.lr,
                                          betas=(0.9, 0.999),
                                          eps=1e-10,
                                          weight_decay=0,
                                          amsgrad=False)
        self.prev_optimizer = copy.deepcopy(self.optimizer)

        self.scheduler = optim.lr_scheduler.LambdaLR(
            self.optimizer, lr_lambda=lambda epo: 0.9**epo)

        self.now = datetime.now().strftime('%Y%m%d_%H%M')
        self.use_coords = False
        self.train_depth = train_depth

    def step(self, dataloader, mode):
        print('Start {}'.format(mode))
        # self.model = self.prev_model
        if mode == 'train':
            self.model.train()
        elif mode == 'val' or mode == 'test':
            self.model.eval()

        loss_sum = 0
        confidence_loss_sum = 0
        depth_loss_sum = 0
        rotation_loss_sum = 0
        rotation_loss_count = 0

        for index, (hp_data, depth_image, camera_info_path, hp_data_gt,
                    annotation_data) in tqdm.tqdm(enumerate(dataloader),
                                                  total=len(dataloader),
                                                  desc='{} epoch={}'.format(
                                                      mode, self.epo),
                                                  leave=False):

            # if index == 0:
            #     self.model = self.prev_model

            self.cameramodel\
                = cameramodels.PinholeCameraModel.from_yaml_file(
                    camera_info_path[0])
            self.cameramodel.target_size = self.target_size

            depth_image = hp_data.numpy().copy()[0, 0, ...]
            depth_image = np.nan_to_num(depth_image)
            depth_image = unnormalize_depth(depth_image, self.depth_range[0],
                                            self.depth_range[1])
            hp_data = hp_data.to(self.device)

            depth_image_bgr = colorize_depth(depth_image,
                                             ignore_value=self.depth_range[0])

            if mode == 'train':
                confidence, depth, rotation = self.model(hp_data)
            elif mode == 'val' or mode == 'test':
                with torch.no_grad():
                    confidence, depth, rotation = self.model(hp_data)

            confidence_np = confidence[0, ...].cpu().detach().numpy().copy()
            confidence_np[confidence_np >= 1] = 1.
            confidence_np[confidence_np <= 0] = 0.
            confidence_vis = cv2.cvtColor(confidence_np[0, ...] * 255,
                                          cv2.COLOR_GRAY2BGR)

            if mode != 'test':
                pos_weight = hp_data_gt.detach().numpy().copy()
                pos_weight = pos_weight[:, 0, ...]
                zeroidx = np.where(pos_weight < 0.5)
                nonzeroidx = np.where(pos_weight >= 0.5)
                pos_weight[zeroidx] = 0.5
                pos_weight[nonzeroidx] = 1.0
                pos_weight = torch.from_numpy(pos_weight)
                pos_weight = pos_weight.to(self.device)

                hp_data_gt = hp_data_gt.to(self.device)
                confidence_gt = hp_data_gt[:, 0:1, ...]
                rois_list_gt, rois_center_list_gt = find_rois(confidence_gt)

                criterion = HPNETLoss(self.use_coords).to(self.device)

                if self.model.rois_list is None or rois_list_gt is None:
                    return None, None

                annotated_rois = annotate_rois(self.model.rois_list,
                                               rois_list_gt, annotation_data)

                confidence_loss, depth_loss, rotation_loss = criterion(
                    confidence, hp_data_gt, pos_weight, depth, rotation,
                    annotated_rois)

                if self.train_depth:
                    loss = confidence_loss + rotation_loss + depth_loss
                else:
                    loss = confidence_loss + rotation_loss

                if torch.isnan(loss):
                    print('loss is nan!!')
                    self.model = self.prev_model
                    self.optimizer = torch.optim.Adam(self.model.parameters(),
                                                      lr=self.lr,
                                                      betas=(0.9, 0.999),
                                                      eps=1e-10,
                                                      weight_decay=0,
                                                      amsgrad=False)
                    self.optimizer.load_state_dict(
                        self.prev_optimizer.state_dict())
                    continue
                else:
                    self.prev_model = copy.deepcopy(self.model)
                    self.prev_optimizer = copy.deepcopy(self.optimizer)

                if mode == 'train':
                    self.optimizer.zero_grad()
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), 5)
                    self.optimizer.step()

                axis_gt = depth_image_bgr.copy()

                confidence_gt_vis = cv2.cvtColor(
                    confidence_gt[0, 0, ...].cpu().detach().numpy().copy() *
                    255, cv2.COLOR_GRAY2BGR)

                # Visualize gt axis and roi
                for roi, roi_c in zip(rois_list_gt[0], rois_center_list_gt[0]):
                    if roi.tolist() == [0, 0, 0, 0]:
                        continue
                    roi = roi.cpu().detach().numpy().copy()
                    cx = roi_c[0]
                    cy = roi_c[1]

                    depth_and_rotation_gt = get_value_gt([cx, cy],
                                                         annotation_data[0])
                    rotation_gt = depth_and_rotation_gt[1:]
                    depth_gt_val = depth_and_rotation_gt[0]
                    unnormalized_depth_gt_val = unnormalize_depth(
                        depth_gt_val, self.depth_range[0], self.depth_range[1])

                    hanging_point_pose = np.array(
                        self.cameramodel.project_pixel_to_3d_ray(
                            [int(cx), int(cy)])) \
                        * unnormalized_depth_gt_val * 0.001

                    if self.use_coords:
                        rot = quaternion2matrix(rotation_gt),

                    else:
                        v = np.matmul(quaternion2matrix(rotation_gt),
                                      [1, 0, 0])
                        rot = rotation_matrix_from_axis(v, [0, 1, 0], 'xy')
                    try:
                        draw_axis(axis_gt, rot, hanging_point_pose,
                                  self.cameramodel.K)
                    except Exception:
                        print('Fail to draw axis')

                    confidence_gt_vis = draw_roi(confidence_gt_vis,
                                                 roi,
                                                 val=depth_gt_val,
                                                 gt=True)
                    axis_gt = draw_roi(axis_gt, roi, val=depth_gt_val, gt=True)

            # Visualize pred axis and roi
            axis_pred = depth_image_bgr.copy()

            for i, (roi, roi_c) in enumerate(
                    zip(self.model.rois_list[0],
                        self.model.rois_center_list[0])):

                if roi.tolist() == [0, 0, 0, 0]:
                    continue
                roi = roi.cpu().detach().numpy().copy()
                cx = roi_c[0]
                cy = roi_c[1]

                dep = depth[i].cpu().detach().numpy().copy()
                normalized_dep_pred = float(dep)
                dep = unnormalize_depth(dep, self.depth_range[0],
                                        self.depth_range[1])

                confidence_vis = draw_roi(confidence_vis,
                                          roi,
                                          val=normalized_dep_pred)
                axis_pred = draw_roi(axis_pred, roi, val=normalized_dep_pred)

                if mode != 'test':
                    if annotated_rois[i][2]:
                        confidence_vis = draw_roi(confidence_vis,
                                                  annotated_rois[i][0],
                                                  val=annotated_rois[i][1][0],
                                                  gt=True)
                        axis_pred = draw_roi(axis_pred,
                                             annotated_rois[i][0],
                                             val=annotated_rois[i][1][0],
                                             gt=True)

                hanging_point_pose = np.array(
                    self.cameramodel.project_pixel_to_3d_ray(
                        [int(cx), int(cy)])) * float(dep * 0.001)

                if self.use_coords:
                    # have not check this yet
                    q = rotation[i].cpu().detach().numpy().copy()
                    q /= np.linalg.norm(q)
                    rot = quaternion2matrix(q)

                else:
                    v = rotation[i].cpu().detach().numpy()
                    v /= np.linalg.norm(v)
                    rot = rotation_matrix_from_axis(v, [0, 1, 0], 'xy')

                try:
                    draw_axis(axis_pred, rot, hanging_point_pose,
                              self.cameramodel.K)
                except Exception:
                    print('Fail to draw axis')

            axis_pred = cv2.cvtColor(axis_pred, cv2.COLOR_BGR2RGB)
            confidence_vis = cv2.cvtColor(confidence_vis, cv2.COLOR_BGR2RGB)

            if self.config['use_bgr']:
                if self.config['use_bgr2gray']:
                    in_gray = hp_data.cpu().detach().numpy().copy()[0, 1:2,
                                                                    ...] * 255
                    in_gray = in_gray.transpose(1, 2, 0).astype(np.uint8)
                    in_gray = cv2.cvtColor(in_gray, cv2.COLOR_GRAY2RGB)
                    in_gray = in_gray.transpose(2, 0, 1)
                    in_img = in_gray
                else:
                    in_bgr = hp_data.cpu().detach().numpy().copy()[
                        0, 3:, ...].transpose(1, 2, 0)
                    in_rgb = cv2.cvtColor(in_bgr, cv2.COLOR_BGR2RGB).transpose(
                        2, 0, 1)
                    in_img = in_rgb

            if mode != 'test':
                confidence_loss_sum += confidence_loss.item()

                axis_gt = cv2.cvtColor(axis_gt, cv2.COLOR_BGR2RGB)
                confidence_gt_vis = cv2.cvtColor(confidence_gt_vis,
                                                 cv2.COLOR_BGR2RGB)

                if rotation_loss.item() > 0:
                    depth_loss_sum += depth_loss.item()
                    rotation_loss_sum += rotation_loss.item()
                    loss_sum = loss_sum \
                        + confidence_loss.item() \
                        + rotation_loss.item()
                    rotation_loss_count += 1

                if np.mod(index, 1) == 0:
                    print(
                        'epoch {}, {}/{},{} loss is confidence:{} rotation:{} depth:{}'
                        .format(  # noqa
                            self.epo, index, len(dataloader), mode,
                            confidence_loss.item(), rotation_loss.item(),
                            depth_loss.item()))

                self.vis.images(
                    [axis_gt.transpose(2, 0, 1),
                     axis_pred.transpose(2, 0, 1)],
                    win='{} axis'.format(mode),
                    opts=dict(title='{} axis'.format(mode)))
                self.vis.images(
                    [
                        confidence_gt_vis.transpose(2, 0, 1),
                        confidence_vis.transpose(2, 0, 1)
                    ],
                    win='{}_confidence_roi'.format(mode),
                    opts=dict(title='{} confidence(GT, Pred)'.format(mode)))

                if self.config['use_bgr']:
                    self.vis.images([in_img],
                                    win='{} in_gray'.format(mode),
                                    opts=dict(title='{} in_gray'.format(mode)))
            else:
                if self.config['use_bgr']:
                    self.vis.images(
                        [
                            in_img,
                            confidence_vis.transpose(2, 0, 1),
                            axis_pred.transpose(2, 0, 1)
                        ],
                        win='{}-{}'.format(mode, index),
                        opts=dict(
                            title='{}-{} hanging_point_depth (pred)'.format(
                                mode, index)))
                else:
                    self.vis.images(
                        [
                            confidence_vis.transpose(2, 0, 1),
                            axis_pred.transpose(2, 0, 1)
                        ],
                        win='{}-{}'.format(mode, index),
                        opts=dict(
                            title='{}-{} hanging_point_depth (pred)'.format(
                                mode, index)))

            if np.mod(index, 1000) == 0:
                save_file = osp.join(
                    self.save_dir,
                    'hpnet_latestmodel_' + self.time_now + '.pt')
                print('save {}'.format(save_file))
                torch.save(self.model.state_dict(),
                           save_file,
                           _use_new_zipfile_serialization=False)

        if mode != 'test':
            if len(dataloader) > 0:
                avg_confidence_loss\
                    = confidence_loss_sum / len(dataloader)
                if rotation_loss_count > 0:
                    avg_rotation_loss\
                        = rotation_loss_sum / rotation_loss_count
                    avg_depth_loss\
                        = depth_loss_sum / rotation_loss_count
                    avg_loss\
                        = loss_sum / rotation_loss_count
                else:
                    avg_rotation_loss = 1e10
                    avg_depth_loss = 1e10
                    avg_loss = 1e10
            else:
                avg_loss = loss_sum
                avg_confidence_loss = confidence_loss_sum
                avg_rotation_loss = rotation_loss_sum
                avg_depth_loss = rotation_loss_sum

            self.vis.line(X=np.array([self.epo]),
                          Y=np.array([avg_confidence_loss]),
                          opts={'title': 'confidence'},
                          win='confidence loss',
                          name='{}_confidence_loss'.format(mode),
                          update='append')
            if rotation_loss_count > 0:
                self.vis.line(X=np.array([self.epo]),
                              Y=np.array([avg_rotation_loss]),
                              opts={'title': 'rotation loss'},
                              win='rotation loss',
                              name='{}_rotation_loss'.format(mode),
                              update='append')
                self.vis.line(X=np.array([self.epo]),
                              Y=np.array([avg_depth_loss]),
                              opts={'title': 'depth loss'},
                              win='depth loss',
                              name='{}_depth_loss'.format(mode),
                              update='append')
                self.vis.line(X=np.array([self.epo]),
                              Y=np.array([avg_loss]),
                              opts={'title': 'loss'},
                              win='loss',
                              name='{}_loss'.format(mode),
                              update='append')

            if mode == 'val':
                if np.mod(self.epo, self.save_model_interval) == 0:
                    save_file = osp.join(
                        self.save_dir,
                        'hpnet_latestmodel_' + self.time_now + '.pt')
                    print('save {}'.format(save_file))
                    torch.save(self.model.state_dict(),
                               save_file,
                               _use_new_zipfile_serialization=False)

                if self.best_loss > avg_loss:
                    print('update best model {} -> {}'.format(
                        self.best_loss, avg_loss))
                    self.best_loss = avg_loss
                    save_file = osp.join(
                        self.save_dir,
                        'hpnet_bestmodel_' + self.time_now + '.pt')
                    print('save {}'.format(save_file))
                    # For ros(python 2, torch 1.4)
                    torch.save(self.model.state_dict(),
                               save_file,
                               _use_new_zipfile_serialization=False)

    def train(self):
        for self.epo in range(self.max_epoch):
            self.step(self.train_dataloader, 'train')
            self.step(self.val_dataloader, 'val')
            self.step(self.test_dataloader, 'test')
            self.scheduler.step()