コード例 #1
0
    def test_depth_init(self):
        # valid data
        random_valid_data = np.random.rand(IM_HEIGHT,
                                           IM_WIDTH).astype(np.float32)
        im = DepthImage(random_valid_data)
        self.assertEqual(im.height, IM_HEIGHT)
        self.assertEqual(im.width, IM_WIDTH)
        self.assertEqual(im.channels, 1)
        self.assertEqual(im.type, np.float32)
        self.assertTrue(np.allclose(im.data, random_valid_data))

        # invalid channels
        random_data = np.random.rand(IM_HEIGHT, IM_WIDTH, 3).astype(np.float32)
        caught_bad_channels = False
        try:
            im = DepthImage(random_data)
        except:
            caught_bad_channels = True
        self.assertTrue(caught_bad_channels)

        # invalid type
        random_data = np.random.rand(IM_HEIGHT, IM_WIDTH).astype(np.uint8)
        caught_bad_dtype = False
        try:
            im = DepthImage(random_data)
        except:
            caught_bad_dtype = True
        self.assertTrue(caught_bad_dtype)
コード例 #2
0
    def _get_actions(self, preds, ind, images, depths, camera_intr,
                     num_actions):
        """Generate the actions to be returned."""
        depth_im = DepthImage(images[0], frame=camera_intr.frame)
        point_cloud_im = camera_intr.deproject_to_image(depth_im)
        normal_cloud_im = point_cloud_im.normal_cloud_im()

        actions = []
        for i in range(num_actions):
            im_idx = ind[i, 0]
            h_idx = ind[i, 1]
            w_idx = ind[i, 2]
            center = Point(
                np.asarray([
                    w_idx * self._gqcnn_stride + self._gqcnn_recep_w // 2,
                    h_idx * self._gqcnn_stride + self._gqcnn_recep_h // 2
                ]))
            axis = -normal_cloud_im[center.y, center.x]
            if np.linalg.norm(axis) == 0:
                continue
            depth = depth_im[center.y, center.x, 0]
            if depth == 0.0:
                continue
            grasp = SuctionPoint2D(center,
                                   axis=axis,
                                   depth=depth,
                                   camera_intr=camera_intr)
            grasp_action = GraspAction(grasp, preds[im_idx, h_idx, w_idx, 0],
                                       DepthImage(images[im_idx]))
            actions.append(grasp_action)
        return actions
コード例 #3
0
def inpaint_depth_image(d_img, ix=0, iy=0):
    """Inpaint depth image on raw depth values.

    Only import code here to avoid making them required if we're not inpainting.

    Also, inpainting is slow, so crop some irrelevant values. But BE CAREFUL!
    Make sure any cropping here will lead to logical consistency with the
    processing in `camera.process_img_for_net` later. For now we crop the 'later
    part' of each dimension, which still leads to > 2x speed-up. The
    window size is 3 which I think means we can get away with a pixel difference
    of 3 when cropping but to be safe let's add a bit more, 50 pix to each side.

    For `ix` and `iy` see `camera.process_img_for_net`, makes inpainting faster.
    """
    d_img = d_img[ix:685, iy:1130]
    from perception import (ColorImage, DepthImage)
    print('now in-painting the depth image (shape {}), ix, iy = {}, {}...'.
          format(d_img.shape, ix, iy))
    start_t = time.time()
    d_img = DepthImage(d_img)
    d_img = d_img.inpaint()  # inpaint, then get d_img right away
    d_img = d_img.data  # get raw data back from the class
    cum_t = time.time() - start_t
    print('finished in-painting in {:.2f} seconds'.format(cum_t))
    return d_img
コード例 #4
0
ファイル: analyzer.py プロジェクト: marctuscher/cv_pipeline
 def _plot_grasp(self, datapoint, image_field_name, pose_field_name, gripper_mode, angular_preds=None):
     """ Plots a single grasp represented as a datapoint. """
     image = DepthImage(datapoint[image_field_name][:,:,0])
     depth = datapoint[pose_field_name][2]
     width = 0
     grasps = []
     if gripper_mode == GripperMode.PARALLEL_JAW or \
        gripper_mode == GripperMode.LEGACY_PARALLEL_JAW:
         if angular_preds is not None:
             num_bins = angular_preds.shape[0] / 2
             bin_width = GeneralConstants.PI / num_bins
             for i in range(num_bins):
                 bin_cent_ang = i * bin_width + bin_width / 2
                 grasps.append(Grasp2D(center=image.center, angle=GeneralConstants.PI / 2 - bin_cent_ang, depth=depth, width=0.0))
             grasps.append(Grasp2D(center=image.center, angle=datapoint[pose_field_name][3], depth=depth, width=0.0))
         else:
             grasps.append(Grasp2D(center=image.center,
                         angle=0,
                         depth=depth,
                         width=0.0))
         width = datapoint[pose_field_name][-1]
     else:
         grasps.append(SuctionPoint2D(center=image.center,
                                axis=[1,0,0],
                                depth=depth))                
     vis2d.imshow(image)
     for i, grasp in enumerate(grasps[:-1]):
         vis2d.grasp(grasp, width=width, color=plt.cm.RdYlGn(angular_preds[i * 2 + 1]))
     vis2d.grasp(grasps[-1], width=width, color='b')
コード例 #5
0
    def _read_color_and_depth_image(self):
        """Read a color and depth image from the device.
        """
        frames = self._pipe.wait_for_frames()
        if self._depth_align:
            frames = self._align.process(frames)

        depth_frame = frames.get_depth_frame()
        color_frame = frames.get_color_frame()

        if not depth_frame or not color_frame:
            logging.warning('Could not retrieve frames.')
            return None, None

        if self._filter_depth:
            depth_frame = self._filter_depth_frame(depth_frame)

        # convert to numpy arrays
        depth_image = self._to_numpy(depth_frame, np.float32)
        color_image = self._to_numpy(color_frame, np.uint8)

        # convert depth to meters
        depth_image *= self._depth_scale

        # bgr to rgb
        color_image = color_image[..., ::-1]

        depth = DepthImage(depth_image, frame=self._frame)
        color = ColorImage(color_image, frame=self._frame)
        return color, depth
コード例 #6
0
 def _get_actions(self, preds, ind, images, depths, camera_intr,
                  num_actions):
     """Generate the actions to be returned."""
     actions = []
     # TODO(vsatish): These should use the max angle instead.
     ang_bin_width = GeneralConstants.PI / preds.shape[-1]
     for i in range(num_actions):
         im_idx = ind[i, 0]
         h_idx = ind[i, 1]
         w_idx = ind[i, 2]
         ang_idx = ind[i, 3]
         center = Point(
             np.asarray([
                 w_idx * self._gqcnn_stride + self._gqcnn_recep_w // 2,
                 h_idx * self._gqcnn_stride + self._gqcnn_recep_h // 2
             ]))
         ang = GeneralConstants.PI / 2 - (ang_idx * ang_bin_width +
                                          ang_bin_width / 2)
         depth = depths[im_idx, 0]
         grasp = Grasp2D(center,
                         ang,
                         depth,
                         width=self._gripper_width,
                         camera_intr=camera_intr)
         grasp_action = GraspAction(grasp, preds[im_idx, h_idx, w_idx,
                                                 ang_idx],
                                    DepthImage(images[im_idx]))
         actions.append(grasp_action)
     return actions
コード例 #7
0
def analyze_image_depths(path, bbox, out_name):
    """
    path should lead to a .npy file
    """
    img = np.load(path)
    img = np.reshape(img, img.shape[:2])
    img_slice = img[bbox[0]:bbox[2], bbox[1]:bbox[3]]
    vec = np.ndarray.flatten(img_slice)
    # vec = reject_outliers(vec)

    var = np.var(vec)
    mean = np.mean(vec)
    print("State for {}: Mean: {}, Standard Deviation: {}\n".format(
        out_name, mean, np.sqrt(var)))

    n, bins, patches = plt.hist(vec, vec.size // 10, facecolor="blue")

    plt.xlabel("depth value")
    plt.ylabel("count")
    plt.title("depth within region")
    plt.grid(True)
    plt.show()

    plt.savefig(os.path.join(out_path, "graph_" + out_name),
                bbox_inches="tight")
    plt.close()
    depth_img = DepthImage(img_slice)
    depth_img.save(os.path.join(out_path, out_name))
コード例 #8
0
 def rendered_images(data, render_mode=RenderMode.SEGMASK):
     rendered_images = []
     num_images = data.attrs[NUM_IMAGES_KEY]
     
     for i in range(num_images):
         # get the image data y'all
         image_key = IMAGE_KEY + '_' + str(i)
         image_data = data[image_key]
         image_arr = np.array(image_data[IMAGE_DATA_KEY])
         frame = image_data.attrs[IMAGE_FRAME_KEY]
         if render_mode == RenderMode.SEGMASK:
             image = BinaryImage(image_arr, frame)
         elif render_mode == RenderMode.DEPTH:
             image = DepthImage(image_arr, frame)
         elif render_mode == RenderMode.SCALED_DEPTH:
             image = ColorImage(image_arr, frame)
         R_camera_table =  image_data.attrs[CAM_ROT_KEY]
         t_camera_table =  image_data.attrs[CAM_POS_KEY]
         frame          =  image_data.attrs[CAM_FRAME_KEY]
         T_camera_world = RigidTransform(R_camera_table, t_camera_table,
                                         from_frame=frame,
                                         to_frame='world')
         
         rendered_images.append(ObjectRender(image, T_camera_world))
     return rendered_images
コード例 #9
0
    def plan_grasp(self,
                   depth,
                   rgb,
                   resetting=False,
                   camera_intr=None,
                   segmask=None):
        """
        Computes possible grasps.
        Parameters
        ----------
        depth: type `numpy`
            depth image
        rgb: type `numpy`
            rgb image
        camera_intr: type `perception.CameraIntrinsics`
            Intrinsic camera object.
        segmask: type `perception.BinaryImage`
            Binary segmask of detected object
        Returns
        -------
        type `GQCNNGrasp`
            Computed optimal grasp.
        """
        if "FC" in self.model:
            assert not (segmask is
                        None), "Fully-Convolutional policy expects a segmask."
        if camera_intr is None:
            camera_intr_filename = os.path.join(
                os.path.dirname(os.path.realpath(__file__)),
                "gqcnn/data/calib/primesense/primesense.intr")
            camera_intr = CameraIntrinsics.load(camera_intr_filename)

        depth_im = DepthImage(depth, frame=camera_intr.frame)
        color_im = ColorImage(rgb, frame=camera_intr.frame)

        valid_px_mask = depth_im.invalid_pixel_mask().inverse()
        if segmask is None:
            segmask = valid_px_mask
        else:
            segmask = segmask.mask_binary(valid_px_mask)

        # Inpaint.
        depth_im = depth_im.inpaint(
            rescale_factor=self.config["inpaint_rescale_factor"])
        # Aggregate color and depth images into a single
        # BerkeleyAutomation/perception `RgbdImage`.
        self.rgbd_im = RgbdImage.from_color_and_depth(color_im, depth_im)
        # Create an `RgbdImageState` with the `RgbdImage` and `CameraIntrinsics`.
        state = RgbdImageState(self.rgbd_im, camera_intr, segmask=segmask)

        # Set input sizes for fully-convolutional policy.
        if "FC" in self.model:
            self.policy_config["metric"]["fully_conv_gqcnn_config"][
                "im_height"] = depth_im.shape[0]
            self.policy_config["metric"]["fully_conv_gqcnn_config"][
                "im_width"] = depth_im.shape[1]

        return self.execute_policy(state, resetting)
コード例 #10
0
ファイル: analyzer.py プロジェクト: anmakon/gqcnn
 def _read_single_image(self, dataset, dataset_dir, datapoint):
     tensor = datapoint // dataset.datapoints_per_file
     array = datapoint % dataset.datapoints_per_file
     pointer = np.load(
         dataset_dir +
         '/tensors/image_files_{:05d}.npz'.format(tensor))['arr_0'][array]
     image = DepthImage(
         np.load(dataset_dir +
                 '/images/depth_im_{:07d}.npy'.format(pointer))[:, :, 0])
     return image
コード例 #11
0
 def _read_depth_images(self, num_images):
     """ Reads depth images from the device """
     depth_images = self._ros_read_images(self._depth_image_buffer, num_images, self.staleness_limit)
     for i in range(0, num_images):
         depth_images[i] = depth_images[i] * MM_TO_METERS # convert to meters
         if self._flip_images:
             depth_images[i] = np.flipud(depth_images[i])
             depth_images[i] = np.fliplr(depth_images[i])
         depth_images[i] = DepthImage(depth_images[i], frame=self._frame) 
     return depth_images
コード例 #12
0
ファイル: grasp_planner_node.py プロジェクト: nimpsch/gqcnn
    def read_images(self, req):
        """Reads images from a ROS service request.

        Parameters
        ---------
        req: :obj:`ROS ServiceRequest`
            ROS ServiceRequest for grasp planner service.
        """
        # Get the raw depth and color images as ROS `Image` objects.
        raw_color = req.color_image
        raw_depth = req.depth_image

        # Get the raw camera info as ROS `CameraInfo`.
        raw_camera_info = req.camera_info

        # Wrap the camera info in a BerkeleyAutomation/perception
        # `CameraIntrinsics` object.
        camera_intr = CameraIntrinsics(
            raw_camera_info.header.frame_id, raw_camera_info.K[0],
            raw_camera_info.K[4], raw_camera_info.K[2], raw_camera_info.K[5],
            raw_camera_info.K[1], raw_camera_info.height,
            raw_camera_info.width)

        # Create wrapped BerkeleyAutomation/perception RGB and depth images by
        # unpacking the ROS images using ROS `CvBridge`
        try:
            color_im = ColorImage(self.cv_bridge.imgmsg_to_cv2(
                raw_color, "rgb8"),
                                  frame=camera_intr.frame)
            depth_im = DepthImage(self.cv_bridge.imgmsg_to_cv2(
                raw_depth, desired_encoding="passthrough"),
                                  frame=camera_intr.frame)
        except CvBridgeError as cv_bridge_exception:
            rospy.logerr(cv_bridge_exception)

        # Check image sizes.
        if color_im.height != depth_im.height or \
           color_im.width != depth_im.width:
            msg = ("Color image and depth image must be the same shape! Color"
                   " is %d x %d but depth is %d x %d") % (
                       color_im.height, color_im.width, depth_im.height,
                       depth_im.width)
            rospy.logerr(msg)
            raise rospy.ServiceException(msg)

        if (color_im.height < self.min_height
                or color_im.width < self.min_width):
            msg = ("Color image is too small! Must be at least %d x %d"
                   " resolution but the requested image is only %d x %d") % (
                       self.min_height, self.min_width, color_im.height,
                       color_im.width)
            rospy.logerr(msg)
            raise rospy.ServiceException(msg)

        return color_im, depth_im, camera_intr
コード例 #13
0
def visualize_predictions(run_dir,
                          test_config,
                          pred_mask_dir,
                          pred_info_dir,
                          show_bbox=True,
                          show_class=True):
    """Visualizes predictions."""
    # Create subdirectory for prediction visualizations
    vis_dir = os.path.join(run_dir, 'vis')
    depth_dir = os.path.join(test_config['path'], test_config['images'])
    if not os.path.exists(vis_dir):
        os.makedirs(vis_dir)

    indices_arr = np.load(
        os.path.join(test_config['path'], test_config['indices']))
    image_ids = np.arange(indices_arr.size)
    ##################################################################
    # Process each image
    ##################################################################
    print('VISUALIZING PREDICTIONS')
    for image_id in tqdm(image_ids):
        depth_image_fn = 'image_{:06d}.npy'.format(indices_arr[image_id])
        # depth_image_fn = 'image_{:06d}.png'.format(indices_arr[image_id])

        # Load image and ground truth data and resize for net
        depth_data = np.load(os.path.join(depth_dir, depth_image_fn))
        # image = ColorImage.open(os.path.join(depth_dir, '..', 'depth_ims', depth_image_fn)).data
        image = DepthImage(depth_data).to_color().data

        # load mask and info
        r = np.load(
            os.path.join(pred_info_dir,
                         'image_{:06}.npy'.format(image_id))).item()
        r_masks = np.load(
            os.path.join(pred_mask_dir, 'image_{:06}.npy'.format(image_id)))
        # Must transpose from (n, h, w) to (h, w, n)
        if r_masks.any():
            r['masks'] = np.transpose(r_masks, (1, 2, 0))
        else:
            r['masks'] = r_masks
        # Visualize
        fig = plt.figure(figsize=(1.7067, 1.7067), dpi=300, frameon=False)
        ax = plt.Axes(fig, [0., 0., 1., 1.])
        fig.add_axes(ax)
        visualize.display_instances(image,
                                    r['rois'],
                                    r['masks'],
                                    r['class_ids'], ['bg', 'obj'],
                                    ax=ax,
                                    show_bbox=show_bbox,
                                    show_class=show_class)
        file_name = os.path.join(vis_dir, 'vis_{:06d}'.format(image_id))
        fig.savefig(file_name, transparent=True, dpi=300)
        plt.close()
コード例 #14
0
def depth_to_bin(img_file):
    # Load the image as a numpy array and the camera intrinsics
    image = np.load(img_file)
    # Create and deproject a depth image of the data using the camera intrinsics
    di = DepthImage(image, frame=ci.frame)
    di = di.inpaint()
    bi = di.to_binary(threshold=0.85)
    #vis2d.figure()
    #vis2d.imshow(di)
    #vis2d.imshow(bi)
    #vis2d.show()
    return bi
コード例 #15
0
ファイル: augment.py プロジェクト: jayef0/cv_pipeline
def inpaint(img):
    """
    Inpaint the image
    """
    # create DepthImage from gray version of img
    gray_img = skimage.color.rgb2gray(img)
    depth_img = DepthImage(gray_img)

    # zero out high-gradient areas and inpaint
    thresh_img = depth_img.threshold_gradients_pctile(0.95)
    inpaint_img = thresh_img.inpaint()
    return inpaint_img.data
コード例 #16
0
    def quality(self, state, actions, params):
        """Evaluate the quality of a set of actions according to a GQ-CNN.

        Parameters
        ----------
        state : :obj:`RgbdImageState`
            State of the world described by an RGB-D image.
        actions: :obj:`object`
            Set of grasping actions to evaluate.
        params: dict
            Optional parameters for quality evaluation.

        Returns
        -------
        :obj:`list` of float
            Real-valued grasp quality predictions for each action, between 0
            and 1.
        """
        # Form tensors.
        image_tensor, pose_tensor = self.grasps_to_tensors(actions, state)
        if params is not None and params["vis"]["tf_images"]:
            # Read vis params.
            k = params["vis"]["k"]
            d = utils.sqrt_ceil(k)

            # Display grasp transformed images.
            from visualization import Visualizer2D as vis2d
            vis2d.figure(size=(GeneralConstants.FIGSIZE,
                               GeneralConstants.FIGSIZE))
            for i, image_tf in enumerate(image_tensor[:k, ...]):
                depth = pose_tensor[i][0]
                vis2d.subplot(d, d, i + 1)
                vis2d.imshow(DepthImage(image_tf))
                vis2d.title("Image %d: d=%.3f" % (i, depth))
            vis2d.show()

        # Predict grasps.
        num_actions = len(actions)
        null_arr = -1 * np.ones(self._batch_size)
        predict_start = time()
        output_arr = np.zeros([num_actions, 2])
        cur_i = 0
        end_i = cur_i + min(self._batch_size, num_actions - cur_i)
        while cur_i < num_actions:
            output_arr[cur_i:end_i, :] = self.gqcnn(
                image_tensor[cur_i:end_i, :, :, 0],
                pose_tensor[cur_i:end_i, 0], null_arr)[0]
            cur_i = end_i
            end_i = cur_i + min(self._batch_size, num_actions - cur_i)
        q_values = output_arr[:, -1]
        self._logger.debug("Prediction took %.3f sec" %
                           (time() - predict_start))
        return q_values.tolist()
コード例 #17
0
    def read_images(self, req):
        """ Reads images from a ROS service request
        
        Parameters
        ---------
        req: :obj:`ROS ServiceRequest`
            ROS ServiceRequest for grasp planner service
        """
        # get the raw depth and color images as ROS Image objects
        raw_color = req.color_image
        raw_depth = req.depth_image

        # get the raw camera info as ROS CameraInfo object
        raw_camera_info = req.camera_info

        # wrap the camera info in a perception CameraIntrinsics object
        camera_intr = CameraIntrinsics(
            raw_camera_info.header.frame_id, raw_camera_info.K[0],
            raw_camera_info.K[4], raw_camera_info.K[2], raw_camera_info.K[5],
            raw_camera_info.K[1], raw_camera_info.height,
            raw_camera_info.width)

        # create wrapped Perception RGB and Depth Images by unpacking the ROS Images using CVBridge ###
        try:
            color_im = ColorImage(self.cv_bridge.imgmsg_to_cv2(
                raw_color, "rgb8"),
                                  frame=camera_intr.frame)
            depth_im = DepthImage(self.cv_bridge.imgmsg_to_cv2(
                raw_depth, desired_encoding="passthrough"),
                                  frame=camera_intr.frame)
        except CvBridgeError as cv_bridge_exception:
            rospy.logerr(cv_bridge_exception)

        # check image sizes
        if color_im.height != depth_im.height or \
           color_im.width != depth_im.width:
            msg = 'Color image and depth image must be the same shape! Color is %d x %d but depth is %d x %d' % (
                color_im.height, color_im.width, depth_im.height,
                depth_im.width)
            rospy.logerr(msg)
            raise rospy.ServiceException(msg)

        if color_im.height < self.min_height or color_im.width < self.min_width:
            msg = 'Color image is too small! Must be at least %d x %d resolution but the requested image is only %d x %d' % (
                self.min_height, self.min_width, color_im.height,
                color_im.width)
            rospy.logerr(msg)
            raise rospy.ServiceException(msg)

        return color_im, depth_im, camera_intr
コード例 #18
0
    def get_state(self, depth, segmask):
        # Read images.
        depth_im = DepthImage(depth, frame=self.camera_intr.frame)
        color_im = ColorImage(np.zeros([depth_im.height, depth_im.width,
                                        3]).astype(np.uint8),
                              frame=self.camera_intr.frame)
        segmask = BinaryImage(segmask.astype(np.uint8) * 255,
                              frame=self.camera_intr.frame)

        # Inpaint.
        depth_im = depth_im.inpaint(rescale_factor=self.inpaint_rescale_factor)

        # Create state.
        rgbd_im = RgbdImage.from_color_and_depth(color_im, depth_im)
        state = RgbdImageState(rgbd_im, self.camera_intr, segmask=segmask)
        return state, rgbd_im
コード例 #19
0
def largest_planar_surface(filename):
    # Load the image as a numpy array and the camera intrinsics
    image = np.load(filename)
    # Create and deproject a depth image of the data using the camera intrinsics
    di = DepthImage(image, frame=ci.frame)
    di = di.inpaint()
    pc = ci.deproject(di)
    # Make a PCL type point cloud from the image
    p = pcl.PointCloud(pc.data.T.astype(np.float32))
    # Make a segmenter and segment the point cloud.
    seg = p.make_segmenter()
    seg.set_model_type(pcl.SACMODEL_PARALLEL_PLANE)
    seg.set_method_type(pcl.SAC_RANSAC)
    seg.set_distance_threshold(0.005)
    indices, model = seg.segment()
    return indices, model, image, pc
コード例 #20
0
    def _read_depth_image(self):
        """ Reads a depth image from the device """
        # read raw uint16 buffer
        im_arr = self._depth_stream.read_frame()
        raw_buf = im_arr.get_buffer_as_uint16()
        buf_array = np.array([raw_buf[i] for i in range(PrimesenseSensor.DEPTH_IM_WIDTH * PrimesenseSensor.DEPTH_IM_HEIGHT)])

        # convert to image in meters
        depth_image = buf_array.reshape(PrimesenseSensor.DEPTH_IM_HEIGHT,
                                        PrimesenseSensor.DEPTH_IM_WIDTH)
        depth_image = depth_image * MM_TO_METERS # convert to meters
        if self._flip_images:
            depth_image = np.flipud(depth_image)
        else:
            depth_image = np.fliplr(depth_image)
        return DepthImage(depth_image, frame=self._frame)
コード例 #21
0
    def quality(self, state, actions, params):
        """Evaluate the quality of a set of actions according to a GQ-CNN.

        Parameters
        ----------
        state : :obj:`RgbdImageState`
            State of the world described by an RGB-D image.
        actions: :obj:`object`
            Set of grasping actions to evaluate.
        params: dict
            Optional parameters for quality evaluation.

        Returns
        -------
        :obj:`list` of float
            Real-valued grasp quality predictions for each
            action, between 0 and 1.
        """
        # Form tensors.
        tensor_start = time()
        image_tensor, pose_tensor = self.grasps_to_tensors(actions, state)
        self._logger.info("Image transformation took %.3f sec" %
                          (time() - tensor_start))
        if params is not None and params["vis"]["tf_images"]:
            # Read vis params.
            k = params["vis"]["k"]
            d = utils.sqrt_ceil(k)

            # Display grasp transformed images.
            from visualization import Visualizer2D as vis2d
            vis2d.figure(size=(GeneralConstants.FIGSIZE,
                               GeneralConstants.FIGSIZE))
            for i, image_tf in enumerate(image_tensor[:k, ...]):
                depth = pose_tensor[i][0]
                vis2d.subplot(d, d, i + 1)
                vis2d.imshow(DepthImage(image_tf))
                vis2d.title("Image %d: d=%.3f" % (i, depth))
            vis2d.show()

        # Predict grasps.
        predict_start = time()
        output_arr = self.gqcnn.predict(image_tensor, pose_tensor)
        q_values = output_arr[:, -1]
        self._logger.info("Inference took %.3f sec" % (time() - predict_start))
        return q_values.tolist()
コード例 #22
0
    def quality(self, state, actions, params):
        """ Evaluate the quality of a set of actions according to a GQ-CNN.

        Parameters
        ----------
        state : :obj:`RgbdImageState`
            state of the world described by an RGB-D image
        actions: :obj:`object`
            set of grasping actions to evaluate
        params: dict
            optional parameters for quality evaluation

        Returns
        -------
        :obj:`list` of float
            real-valued grasp quality predictions for each action, between 0 and 1
        """
        # form tensors
        tensor_start = time()
        image_tensor, pose_tensor = self.grasps_to_tensors(actions, state)
        logging.info('Image transformation took %.3f sec' %
                     (time() - tensor_start))
        if params is not None and params['vis']['tf_images']:
            # read vis params
            k = params['vis']['k']
            d = utils.sqrt_ceil(k)

            # display grasp transformed images
            from visualization import Visualizer2D as vis2d
            vis2d.figure(size=(FIGSIZE, FIGSIZE))
            for i, image_tf in enumerate(image_tensor[:k, ...]):
                depth = pose_tensor[i][0]
                vis2d.subplot(d, d, i + 1)
                vis2d.imshow(DepthImage(image_tf))
                vis2d.title('Image %d: d=%.3f' % (i, depth))
            vis2d.show()

        # predict grasps
        predict_start = time()
        output_arr = self.gqcnn.predict(image_tensor, pose_tensor)
        q_values = output_arr[:, -1]
        logging.info('Inference took %.3f sec' % (time() - predict_start))
        return q_values.tolist()
コード例 #23
0
 def _plot_grasp(self, datapoint, image_field_name, pose_field_name,
                 gripper_mode):
     """ Plots a single grasp represented as a datapoint. """
     image = DepthImage(datapoint[image_field_name][:, :, 0])
     depth = datapoint[pose_field_name][2]
     width = 0
     if gripper_mode == GripperMode.PARALLEL_JAW or \
        gripper_mode == GripperMode.LEGACY_PARALLEL_JAW:
         grasp = Grasp2D(center=image.center,
                         angle=0,
                         depth=depth,
                         width=0.0)
         width = datapoint[pose_field_name][-1]
     else:
         grasp = SuctionPoint2D(center=image.center,
                                axis=[1, 0, 0],
                                depth=depth)
     vis2d.imshow(image)
     vis2d.grasp(grasp, width=width)
コード例 #24
0
def run_gqcnn(depth,seg_mask):
	best_angle = 0;
	best_point = [0,0];
	best_dist = 0;
	depth_im =DepthImage(depth.astype("float32"), frame=camera_intr.frame)
	color_im = ColorImage(np.zeros([imageWidth, imageHeight,3]).astype(np.uint8),
                          frame=camera_intr.frame)
	print(seg_mask)
	segmask = BinaryImage(seg_mask)
	print(segmask)
	rgbd_im = RgbdImage.from_color_and_depth(color_im, depth_im)
	state_gqcnn = RgbdImageState(rgbd_im, camera_intr, segmask=segmask) 
	policy_start = time.time()
	try:
		action = policy(state_gqcnn)
		logger.info("Planning took %.3f sec" % (time.time() - policy_start))
		best_point = [action.grasp.center[0],action.grasp.center[1]];
		best_point = [action.grasp.center[0],action.grasp.center[1]];
		best_angle = float(action.grasp.angle)*180/3.141592
	except Exception as inst:
		print(inst)

	return best_angle,best_point,best_dist
コード例 #25
0
ファイル: find_depth.py プロジェクト: tjdalsckd/gqcnnddddd
def run_gqcnn(depth, seg_mask):
    best_angle = 0
    best_point = [0, 0]
    best_dist = 0
    depth_im = DepthImage(depth.astype("float32"), frame=camera_intr.frame)
    color_im = ColorImage(np.zeros([imageWidth, imageHeight,
                                    3]).astype(np.uint8),
                          frame=camera_intr.frame)
    segmask = BinaryImage(seg_mask)
    rgbd_im = RgbdImage.from_color_and_depth(color_im, depth_im)
    state_gqcnn = RgbdImageState(rgbd_im, camera_intr, segmask=segmask)
    policy_start = time.time()
    q_value = -1
    try:
        action = policy(state_gqcnn)
        best_point = [action.grasp.center[0], action.grasp.center[1]]
        best_angle = float(action.grasp.angle) * 180 / 3.141592
        q_value = action.q_value
        print("inner :       ", action.q_value)

    except Exception as inst:
        print(inst)

    return q_value
コード例 #26
0
def visualize_predictions(run_dir, test_config, pred_mask_dir, pred_info_dir, show_bbox=True, show_class=True):
    """Visualizes predictions."""
    # Create subdirectory for prediction visualizations
    vis_dir = os.path.join(run_dir, 'vis')
    depth_dir = os.path.join(test_config['path'], test_config['images'])
    if not os.path.exists(vis_dir):
        os.makedirs(vis_dir)

    indices_arr = np.load(os.path.join(test_config['path'], test_config['indices']))
    image_ids = np.arange(indices_arr.size)
    ##################################################################
    # Process each image
    ##################################################################
    print('VISUALIZING PREDICTIONS')
    for image_id in tqdm(image_ids):
        base_name = 'image_{:06d}'.format(indices_arr[image_id])
        depth_image_fn = base_name + '.npy'

        # Load image and ground truth data and resize for net
        depth_data = np.load(os.path.join(depth_dir, depth_image_fn))
        image = DepthImage(depth_data).to_color().data

        # load mask and info
        r = np.load(os.path.join(pred_info_dir, 'image_{:06}.npy'.format(image_id))).item()
        r_masks = np.load(os.path.join(pred_mask_dir, 'image_{:06}.npy'.format(image_id)))
        # Must transpose from (n, h, w) to (h, w, n)
        if r_masks.any():
            r['masks'] = np.transpose(r_masks, (1, 2, 0))
        else:
            r['masks'] = r_masks     
        # Visualize
        visualize.display_instances(image, r['rois'], r['masks'], r['class_ids'],
                                    ['bg', 'obj'], show_bbox=show_bbox, show_class=show_class)
        file_name = os.path.join(vis_dir, 'vis_{:06d}'.format(image_id))
        plt.savefig(file_name, bbox_inches='tight', pad_inches=0)
        plt.close()
コード例 #27
0
    policy_config = config["policy"]

    # Make relative paths absolute.
    if "gqcnn_model" in policy_config["metric"]:
        policy_config["metric"]["gqcnn_model"] = model_path
        if not os.path.isabs(policy_config["metric"]["gqcnn_model"]):
            policy_config["metric"]["gqcnn_model"] = os.path.join(
                os.path.dirname(os.path.realpath(__file__)), "..",
                policy_config["metric"]["gqcnn_model"])

    # Setup sensor.
    camera_intr = CameraIntrinsics.load(camera_intr_filename)

    # Read images.
    depth_data = np.load(depth_im_filename)
    depth_im = DepthImage(depth_data, frame=camera_intr.frame)
    color_im = ColorImage(np.zeros([depth_im.height, depth_im.width,
                                    3]).astype(np.uint8),
                          frame=camera_intr.frame)

    # Optionally read a segmask.
    segmask = None
    if segmask_filename is not None:
        segmask = BinaryImage.open(segmask_filename)
    valid_px_mask = depth_im.invalid_pixel_mask().inverse()
    if segmask is None:
        segmask = valid_px_mask
    else:
        segmask = segmask.mask_binary(valid_px_mask)

    # Inpaint.
コード例 #28
0
    def start_rendering(self):
        self._load_file_ids()

        for object_id in self.all_objects:
            self._load_data(object_id)

            for i, stable_pose in enumerate(self.stable_poses):
                try:
                    candidate_grasp_info = self.candidate_grasps_dict[
                        stable_pose.id]
                except KeyError:
                    continue

                if not candidate_grasp_info:
                    Warning("Candidate grasp info of object id %s empty" %
                            object_id)
                    Warning("Continue.")
                    continue
                T_obj_stp = stable_pose.T_obj_table.as_frames('obj', 'stp')
                T_obj_stp = self.object_mesh.get_T_surface_obj(T_obj_stp)

                T_table_obj = RigidTransform(from_frame='table',
                                             to_frame='obj')
                scene_objs = {
                    'table': SceneObject(self.table_mesh, T_table_obj)
                }

                urv = UniformPlanarWorksurfaceImageRandomVariable(
                    self.object_mesh,
                    [RenderMode.DEPTH_SCENE, RenderMode.SEGMASK],
                    'camera',
                    self.config['env_rv_params'],
                    scene_objs=scene_objs,
                    stable_pose=stable_pose)
                render_sample = urv.rvs(size=self.random_positions)
                # for render_sample in render_samples:

                binary_im = render_sample.renders[RenderMode.SEGMASK].image
                depth_im = render_sample.renders[
                    RenderMode.DEPTH_SCENE].image.crop(300, 300)
                orig_im = Image.fromarray(self._scale_image(depth_im.data))
                if self.show_images:
                    orig_im.show()
                orig_im.convert('RGB').save(self.output_dir + '/images/' +
                                            object_id + '_elev_' +
                                            str(self.elev) + '_original.png')
                print("Saved original")

                T_stp_camera = render_sample.camera.object_to_camera_pose
                shifted_camera_intr = render_sample.camera.camera_intr.crop(
                    300, 300, 240, 320)
                depth_points = self._reproject_to_3D(depth_im,
                                                     shifted_camera_intr)

                transformed_points, T_camera = self._transformation(
                    depth_points)

                camera_dir = np.dot(T_camera.rotation,
                                    np.array([0.0, 0.0, -1.0]))

                pcd = o3d.geometry.PointCloud()
                # print(camera_dir)
                pcd.points = o3d.utility.Vector3dVector(transformed_points.T)
                # TODO check normals!!
                #  pcd.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.1, max_nn=30))
                #  pcd.normals = o3d.utility.Vector3dVector(-np.asarray(pcd.normals))
                normals = np.repeat([camera_dir],
                                    len(transformed_points.T),
                                    axis=0)
                pcd.normals = o3d.utility.Vector3dVector(normals)

                # cs_points = [[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]]
                # cs_lines = [[0, 1], [0, 2], [0, 3]]
                # colors = [[1, 0, 0], [0, 1, 0], [0, 0, 1]]
                # cs = o3d.geometry.LineSet(points=o3d.utility.Vector3dVector(cs_points),
                #                           lines=o3d.utility.Vector2iVector(cs_lines))
                # cs.colors = o3d.utility.Vector3dVector(colors)
                # o3d.visualization.draw_geometries([pcd])

                depth = self._o3d_meshing(pcd)

                # projected_depth_im,new_camera_intr,table_height = self._projection(new_points,shifted_camera_intr)
                new_camera_intr = shifted_camera_intr
                new_camera_intr.cx = 150
                new_camera_intr.cy = 150
                projected_depth_im = np.asarray(depth)
                projected_depth_im[projected_depth_im == 0.0] = -1.0
                table_height = np.median(
                    projected_depth_im[projected_depth_im != -1.0].flatten())
                print("Minimum depth:", min(projected_depth_im.flatten()))
                print("Maximum depth:", max(projected_depth_im.flatten()))

                im = Image.fromarray(self._scale_image(projected_depth_im))

                projected_depth_im = DepthImage(projected_depth_im,
                                                frame='new_camera')

                cx = projected_depth_im.center[1]
                cy = projected_depth_im.center[0]

                # Grasp conversion
                T_obj_old_camera = T_stp_camera * T_obj_stp.as_frames(
                    'obj', T_stp_camera.from_frame)
                T_obj_camera = T_camera.dot(T_obj_old_camera)
                for grasp_info in candidate_grasp_info:
                    grasp = grasp_info.grasp
                    collision_free = grasp_info.collision_free

                    grasp_2d = grasp.project_camera(T_obj_camera,
                                                    new_camera_intr)
                    dx = cx - grasp_2d.center.x
                    dy = cy - grasp_2d.center.y
                    translation = np.array([dy, dx])

                    # Project 3D old_camera_cs contact points into new camera cs

                    contact_points = np.append(grasp_info.contact_point1, 1).T
                    new_cam = np.dot(T_obj_camera.matrix, contact_points)
                    c1 = new_camera_intr.project(
                        Point(new_cam[0:3], frame=new_camera_intr.frame))
                    contact_points = np.append(grasp_info.contact_point2, 1).T
                    new_cam = np.dot(T_obj_camera.matrix, contact_points)
                    c2 = new_camera_intr.project(
                        Point(new_cam[0:3], frame=new_camera_intr.frame))

                    # Check if there are occlusions at contact points
                    if projected_depth_im.data[
                            c1.x, c1.y] == -1.0 or projected_depth_im.data[
                                c2.x, c2.y] == -1.0:
                        print("Contact point at occlusion")
                        contact_occlusion = True
                    else:
                        contact_occlusion = False
                    # Mark contact points in image
                    im = im.convert('RGB')
                    if False:
                        im_draw = ImageDraw.Draw(im)
                        im_draw.line([(c1[0], c1[1] - 10),
                                      (c1[0], c1[1] + 10)],
                                     fill=(255, 0, 0, 255))
                        im_draw.line([(c1[0] - 10, c1[1]),
                                      (c1[0] + 10, c1[1])],
                                     fill=(255, 0, 0, 255))
                        im_draw.line([(c2[0], c2[1] - 10),
                                      (c2[0], c2[1] + 10)],
                                     fill=(255, 0, 0, 255))
                        im_draw.line([(c2[0] - 10, c2[1]),
                                      (c2[0] + 10, c2[1])],
                                     fill=(255, 0, 0, 255))
                    if self.show_images:
                        im.show()
                    im.save(self.output_dir + '/images/' + object_id +
                            '_elev_' + str(self.elev) + '_reprojected.png')

                    # Transform and crop image

                    depth_im_tf = projected_depth_im.transform(
                        translation, grasp_2d.angle)
                    depth_im_tf = depth_im_tf.crop(96, 96)

                    # Apply transformation to contact points
                    trans_map = np.array([[1, 0, dx], [0, 1, dy]])
                    rot_map = cv2.getRotationMatrix2D(
                        (cx, cy), np.rad2deg(grasp_2d.angle), 1)
                    trans_map_aff = np.r_[trans_map, [[0, 0, 1]]]
                    rot_map_aff = np.r_[rot_map, [[0, 0, 1]]]
                    full_map = rot_map_aff.dot(trans_map_aff)
                    # print("Full map",full_map)
                    c1_rotated = (np.dot(full_map, np.r_[c1.vector, [1]]) -
                                  np.array([150 - 48, 150 - 48, 0])) / 3
                    c2_rotated = (np.dot(full_map, np.r_[c2.vector, [1]]) -
                                  np.array([150 - 48, 150 - 48, 0])) / 3

                    grasp_line = depth_im_tf.data[48]
                    occlusions = len(np.where(np.squeeze(grasp_line) == -1)[0])

                    # Set occlusions to table height for resizing image
                    depth_im_tf.data[depth_im_tf.data == -1.0] = table_height

                    depth_image = Image.fromarray(np.asarray(depth_im_tf.data))\
                        .resize((32, 32), resample=Image.BILINEAR)
                    depth_im_tf_table = np.asarray(depth_image).reshape(
                        32, 32, 1)

                    # depth_im_tf_table = depth_im_tf.resize((32, 32,), interp='bilinear')

                    im = Image.fromarray(
                        self._scale_image(depth_im_tf_table.reshape(
                            32, 32))).convert('RGB')
                    draw = ImageDraw.Draw(im)
                    draw.line([(c1_rotated[0], c1_rotated[1] - 3),
                               (c1_rotated[0], c1_rotated[1] + 3)],
                              fill=(255, 0, 0, 255))
                    draw.line([(c1_rotated[0] - 3, c1_rotated[1]),
                               (c1_rotated[0] + 3, c1_rotated[1])],
                              fill=(255, 0, 0, 255))
                    draw.line([(c2_rotated[0], c2_rotated[1] - 3),
                               (c2_rotated[0], c2_rotated[1] + 3)],
                              fill=(255, 0, 0, 255))
                    draw.line([(c2_rotated[0] - 3, c2_rotated[1]),
                               (c2_rotated[0] + 3, c2_rotated[1])],
                              fill=(255, 0, 0, 255))
                    if self.show_images:
                        im.show()
                    im.save(self.output_dir + '/images/' + object_id +
                            '_elev_' + str(self.elev) + '_transformed.png')

                    hand_pose = np.r_[grasp_2d.center.y, grasp_2d.center.x,
                                      grasp_2d.depth, grasp_2d.angle,
                                      grasp_2d.center.y - new_camera_intr.cy,
                                      grasp_2d.center.x - new_camera_intr.cx,
                                      grasp_2d.width_px / 3]

                    self.tensor_datapoint[
                        'depth_ims_tf_table'] = depth_im_tf_table
                    self.tensor_datapoint['hand_poses'] = hand_pose
                    self.tensor_datapoint['obj_labels'] = self.cur_obj_label
                    self.tensor_datapoint['collision_free'] = collision_free
                    self.tensor_datapoint['pose_labels'] = self.cur_pose_label
                    self.tensor_datapoint[
                        'image_labels'] = self.cur_image_label
                    self.tensor_datapoint['files'] = [self.tensor, self.array]
                    self.tensor_datapoint['occlusions'] = occlusions
                    self.tensor_datapoint[
                        'contact_occlusion'] = contact_occlusion

                    for metric_name, metric_val in self.grasp_metrics[str(
                            grasp.id)].iteritems():
                        coll_free_metric = (1 * collision_free) * metric_val
                        self.tensor_datapoint[metric_name] = coll_free_metric
                    self.tensor_dataset.add(self.tensor_datapoint)
                    print("Saved dataset point")
                    self.cur_image_label += 1
                self.cur_pose_label += 1
                gc.collect()
            self.cur_obj_label += 1

        self.tensor_dataset.flush()
コード例 #29
0
    def visualize(self):
        """ Visualize predictions """

        logging.info('Visualizing ' + self.datapoint_type)

        # iterate through shuffled file indices
        for i in self.indices:
            im_filename = self.im_filenames[i]
            pose_filename = self.pose_filenames[i]
            label_filename = self.label_filenames[i]

            logging.info('Loading Image File: ' + im_filename +
                         ' Pose File: ' + pose_filename + ' Label File: ' +
                         label_filename)

            # load tensors from files
            metric_tensor = np.load(os.path.join(self.data_dir,
                                                 label_filename))['arr_0']
            label_tensor = 1 * (metric_tensor > self.metric_thresh)
            image_tensor = np.load(os.path.join(self.data_dir,
                                                im_filename))['arr_0']
            hand_poses_tensor = np.load(
                os.path.join(self.data_dir, pose_filename))['arr_0']

            pose_tensor = self._read_pose_data(hand_poses_tensor,
                                               self.input_data_mode)

            # score with neural network
            pred_p_success_tensor = self._gqcnn.predict(
                image_tensor, pose_tensor)

            # compute results
            classification_result = ClassificationResult(
                [pred_p_success_tensor], [label_tensor])

            logging.info('Error rate on files: %.3f' %
                         (classification_result.error_rate))
            logging.info('Precision on files: %.3f' %
                         (classification_result.precision))
            logging.info('Recall on files: %.3f' %
                         (classification_result.recall))
            mispred_ind = classification_result.mispredicted_indices()
            correct_ind = classification_result.correct_indices()
            # IPython.embed()

            if self.datapoint_type == 'true_positive' or self.datapoint_type == 'true_negative':
                vis_ind = correct_ind
            else:
                vis_ind = mispred_ind
            num_visualized = 0
            # visualize
            for ind in vis_ind:
                # limit the number of sampled datapoints displayed per object
                if num_visualized >= self.samples_per_object:
                    break
                num_visualized += 1

                # don't visualize the datapoints that we don't want
                if self.datapoint_type == 'true_positive':
                    if classification_result.labels[ind] == 0:
                        continue
                elif self.datapoint_type == 'true_negative':
                    if classification_result.labels[ind] == 1:
                        continue
                elif self.datapoint_type == 'false_positive':
                    if classification_result.labels[ind] == 0:
                        continue
                elif self.datapoint_type == 'false_negative':
                    if classification_result.labels[ind] == 1:
                        continue

                logging.info('Datapoint %d of files for %s' %
                             (ind, im_filename))
                logging.info('Depth: %.3f' % (hand_poses_tensor[ind, 2]))

                data = image_tensor[ind, ...]
                if self.display_image_type == RenderMode.SEGMASK:
                    image = BinaryImage(data)
                elif self.display_image_type == RenderMode.GRAYSCALE:
                    image = GrayscaleImage(data)
                elif self.display_image_type == RenderMode.COLOR:
                    image = ColorImage(data)
                elif self.display_image_type == RenderMode.DEPTH:
                    image = DepthImage(data)
                elif self.display_image_type == RenderMode.RGBD:
                    image = RgbdImage(data)
                elif self.display_image_type == RenderMode.GD:
                    image = GdImage(data)

                vis2d.figure()

                if self.display_image_type == RenderMode.RGBD:
                    vis2d.subplot(1, 2, 1)
                    vis2d.imshow(image.color)
                    grasp = Grasp2D(Point(image.center,
                                          'img'), 0, hand_poses_tensor[ind, 2],
                                    self.gripper_width_m)
                    grasp.camera_intr = grasp.camera_intr.resize(1.0 / 3.0)
                    vis2d.grasp(grasp)
                    vis2d.subplot(1, 2, 2)
                    vis2d.imshow(image.depth)
                    vis2d.grasp(grasp)
                elif self.display_image_type == RenderMode.GD:
                    vis2d.subplot(1, 2, 1)
                    vis2d.imshow(image.gray)
                    grasp = Grasp2D(Point(image.center,
                                          'img'), 0, hand_poses_tensor[ind, 2],
                                    self.gripper_width_m)
                    grasp.camera_intr = grasp.camera_intr.resize(1.0 / 3.0)
                    vis2d.grasp(grasp)
                    vis2d.subplot(1, 2, 2)
                    vis2d.imshow(image.depth)
                    vis2d.grasp(grasp)
                else:
                    vis2d.imshow(image)
                    grasp = Grasp2D(Point(image.center,
                                          'img'), 0, hand_poses_tensor[ind, 2],
                                    self.gripper_width_m)
                    grasp.camera_intr = grasp.camera_intr.resize(1.0 / 3.0)
                    vis2d.grasp(grasp)
                vis2d.title('Datapoint %d: Pred: %.3f Label: %.3f' %
                            (ind, classification_result.pred_probs[ind, 1],
                             classification_result.labels[ind]))
                vis2d.show()

        # cleanup
        self._cleanup()
コード例 #30
0
ファイル: analyzer.py プロジェクト: wenlongli/gqcnn
    def _run_prediction_single_model(self, model_dir, model_output_dir,
                                     dataset_config):
        """Analyze the performance of a single model."""
        # Read in model config.
        model_config_filename = os.path.join(model_dir,
                                             GQCNNFilenames.SAVED_CFG)
        with open(model_config_filename) as data_file:
            model_config = json.load(data_file)

        # Load model.
        self.logger.info("Loading model %s" % (model_dir))
        log_file = None
        for handler in self.logger.handlers:
            if isinstance(handler, logging.FileHandler):
                log_file = handler.baseFilename
        gqcnn = get_gqcnn_model(verbose=self.verbose).load(
            model_dir, verbose=self.verbose, log_file=log_file)
        gqcnn.open_session()
        gripper_mode = gqcnn.gripper_mode
        angular_bins = gqcnn.angular_bins

        # Read params from the config.
        if dataset_config is None:
            dataset_dir = model_config["dataset_dir"]
            split_name = model_config["split_name"]
            image_field_name = model_config["image_field_name"]
            pose_field_name = model_config["pose_field_name"]
            metric_name = model_config["target_metric_name"]
            metric_thresh = model_config["metric_thresh"]
        else:
            dataset_dir = dataset_config["dataset_dir"]
            split_name = dataset_config["split_name"]
            image_field_name = dataset_config["image_field_name"]
            pose_field_name = dataset_config["pose_field_name"]
            metric_name = dataset_config["target_metric_name"]
            metric_thresh = dataset_config["metric_thresh"]
            gripper_mode = dataset_config["gripper_mode"]

        self.logger.info("Loading dataset %s" % (dataset_dir))
        dataset = TensorDataset.open(dataset_dir)
        train_indices, val_indices, _ = dataset.split(split_name)

        # Visualize conv filters.
        conv1_filters = gqcnn.filters
        num_filt = conv1_filters.shape[3]
        d = utils.sqrt_ceil(num_filt)
        vis2d.clf()
        for k in range(num_filt):
            filt = conv1_filters[:, :, 0, k]
            vis2d.subplot(d, d, k + 1)
            vis2d.imshow(DepthImage(filt))
            figname = os.path.join(model_output_dir, "conv1_filters.pdf")
        vis2d.savefig(figname, dpi=self.dpi)

        # Aggregate training and validation true labels and predicted
        # probabilities.
        all_predictions = []
        if angular_bins > 0:
            all_predictions_raw = []
        all_labels = []
        for i in range(dataset.num_tensors):
            # Log progress.
            if i % self.log_rate == 0:
                self.logger.info("Predicting tensor %d of %d" %
                                 (i + 1, dataset.num_tensors))

            # Read in data.
            image_arr = dataset.tensor(image_field_name, i).arr
            pose_arr = read_pose_data(
                dataset.tensor(pose_field_name, i).arr, gripper_mode)
            metric_arr = dataset.tensor(metric_name, i).arr
            label_arr = 1 * (metric_arr > metric_thresh)
            label_arr = label_arr.astype(np.uint8)
            if angular_bins > 0:
                # Form mask to extract predictions from ground-truth angular
                # bins.
                raw_poses = dataset.tensor(pose_field_name, i).arr
                angles = raw_poses[:, 3]
                neg_ind = np.where(angles < 0)
                # TODO(vsatish): These should use the max angle instead.
                angles = np.abs(angles) % GeneralConstants.PI
                angles[neg_ind] *= -1
                g_90 = np.where(angles > (GeneralConstants.PI / 2))
                l_neg_90 = np.where(angles < (-1 * (GeneralConstants.PI / 2)))
                angles[g_90] -= GeneralConstants.PI
                angles[l_neg_90] += GeneralConstants.PI
                # TODO(vsatish): Fix this along with the others.
                angles *= -1  # Hack to fix reverse angle convention.
                angles += (GeneralConstants.PI / 2)
                pred_mask = np.zeros((raw_poses.shape[0], angular_bins * 2),
                                     dtype=bool)
                bin_width = GeneralConstants.PI / angular_bins
                for i in range(angles.shape[0]):
                    pred_mask[i, int((angles[i] // bin_width) * 2)] = True
                    pred_mask[i, int((angles[i] // bin_width) * 2 + 1)] = True

            # Predict with GQ-CNN.
            predictions = gqcnn.predict(image_arr, pose_arr)
            if angular_bins > 0:
                raw_predictions = np.array(predictions)
                predictions = predictions[pred_mask].reshape((-1, 2))

            # Aggregate.
            all_predictions.extend(predictions[:, 1].tolist())
            if angular_bins > 0:
                all_predictions_raw.extend(raw_predictions.tolist())
            all_labels.extend(label_arr.tolist())

        # Close session.
        gqcnn.close_session()

        # Create arrays.
        all_predictions = np.array(all_predictions)
        all_labels = np.array(all_labels)
        train_predictions = all_predictions[train_indices]
        val_predictions = all_predictions[val_indices]
        train_labels = all_labels[train_indices]
        val_labels = all_labels[val_indices]
        if angular_bins > 0:
            all_predictions_raw = np.array(all_predictions_raw)
            train_predictions_raw = all_predictions_raw[train_indices]
            val_predictions_raw = all_predictions_raw[val_indices]

        # Aggregate results.
        train_result = BinaryClassificationResult(train_predictions,
                                                  train_labels)
        val_result = BinaryClassificationResult(val_predictions, val_labels)
        train_result.save(os.path.join(model_output_dir, "train_result.cres"))
        val_result.save(os.path.join(model_output_dir, "val_result.cres"))

        # Get stats, plot curves.
        self.logger.info("Model %s training error rate: %.3f" %
                         (model_dir, train_result.error_rate))
        self.logger.info("Model %s validation error rate: %.3f" %
                         (model_dir, val_result.error_rate))

        self.logger.info("Model %s training loss: %.3f" %
                         (model_dir, train_result.cross_entropy_loss))
        self.logger.info("Model %s validation loss: %.3f" %
                         (model_dir, val_result.cross_entropy_loss))

        # Save images.
        vis2d.figure()
        example_dir = os.path.join(model_output_dir, "examples")
        if not os.path.exists(example_dir):
            os.mkdir(example_dir)

        # Train.
        self.logger.info("Saving training examples")
        train_example_dir = os.path.join(example_dir, "train")
        if not os.path.exists(train_example_dir):
            os.mkdir(train_example_dir)

        # Train TP.
        true_positive_indices = train_result.true_positive_indices
        np.random.shuffle(true_positive_indices)
        true_positive_indices = true_positive_indices[:self.num_vis]
        for i, j in enumerate(true_positive_indices):
            k = train_indices[j]
            datapoint = dataset.datapoint(
                k, field_names=[image_field_name, pose_field_name])
            vis2d.clf()
            if angular_bins > 0:
                self._plot_grasp(datapoint,
                                 image_field_name,
                                 pose_field_name,
                                 gripper_mode,
                                 angular_preds=train_predictions_raw[j])
            else:
                self._plot_grasp(datapoint, image_field_name, pose_field_name,
                                 gripper_mode)
            vis2d.title(
                "Datapoint %d: Pred: %.3f Label: %.3f" %
                (k, train_result.pred_probs[j], train_result.labels[j]),
                fontsize=self.font_size)
            vis2d.savefig(
                os.path.join(train_example_dir,
                             "true_positive_%03d.png" % (i)))

        # Train FP.
        false_positive_indices = train_result.false_positive_indices
        np.random.shuffle(false_positive_indices)
        false_positive_indices = false_positive_indices[:self.num_vis]
        for i, j in enumerate(false_positive_indices):
            k = train_indices[j]
            datapoint = dataset.datapoint(
                k, field_names=[image_field_name, pose_field_name])
            vis2d.clf()
            if angular_bins > 0:
                self._plot_grasp(datapoint,
                                 image_field_name,
                                 pose_field_name,
                                 gripper_mode,
                                 angular_preds=train_predictions_raw[j])
            else:
                self._plot_grasp(datapoint, image_field_name, pose_field_name,
                                 gripper_mode)
            vis2d.title(
                "Datapoint %d: Pred: %.3f Label: %.3f" %
                (k, train_result.pred_probs[j], train_result.labels[j]),
                fontsize=self.font_size)
            vis2d.savefig(
                os.path.join(train_example_dir,
                             "false_positive_%03d.png" % (i)))

        # Train TN.
        true_negative_indices = train_result.true_negative_indices
        np.random.shuffle(true_negative_indices)
        true_negative_indices = true_negative_indices[:self.num_vis]
        for i, j in enumerate(true_negative_indices):
            k = train_indices[j]
            datapoint = dataset.datapoint(
                k, field_names=[image_field_name, pose_field_name])
            vis2d.clf()
            if angular_bins > 0:
                self._plot_grasp(datapoint,
                                 image_field_name,
                                 pose_field_name,
                                 gripper_mode,
                                 angular_preds=train_predictions_raw[j])
            else:
                self._plot_grasp(datapoint, image_field_name, pose_field_name,
                                 gripper_mode)
            vis2d.title(
                "Datapoint %d: Pred: %.3f Label: %.3f" %
                (k, train_result.pred_probs[j], train_result.labels[j]),
                fontsize=self.font_size)
            vis2d.savefig(
                os.path.join(train_example_dir,
                             "true_negative_%03d.png" % (i)))

        # Train TP.
        false_negative_indices = train_result.false_negative_indices
        np.random.shuffle(false_negative_indices)
        false_negative_indices = false_negative_indices[:self.num_vis]
        for i, j in enumerate(false_negative_indices):
            k = train_indices[j]
            datapoint = dataset.datapoint(
                k, field_names=[image_field_name, pose_field_name])
            vis2d.clf()
            if angular_bins > 0:
                self._plot_grasp(datapoint,
                                 image_field_name,
                                 pose_field_name,
                                 gripper_mode,
                                 angular_preds=train_predictions_raw[j])
            else:
                self._plot_grasp(datapoint, image_field_name, pose_field_name,
                                 gripper_mode)
            vis2d.title(
                "Datapoint %d: Pred: %.3f Label: %.3f" %
                (k, train_result.pred_probs[j], train_result.labels[j]),
                fontsize=self.font_size)
            vis2d.savefig(
                os.path.join(train_example_dir,
                             "false_negative_%03d.png" % (i)))

        # Val.
        self.logger.info("Saving validation examples")
        val_example_dir = os.path.join(example_dir, "val")
        if not os.path.exists(val_example_dir):
            os.mkdir(val_example_dir)

        # Val TP.
        true_positive_indices = val_result.true_positive_indices
        np.random.shuffle(true_positive_indices)
        true_positive_indices = true_positive_indices[:self.num_vis]
        for i, j in enumerate(true_positive_indices):
            k = val_indices[j]
            datapoint = dataset.datapoint(
                k, field_names=[image_field_name, pose_field_name])
            vis2d.clf()
            if angular_bins > 0:
                self._plot_grasp(datapoint,
                                 image_field_name,
                                 pose_field_name,
                                 gripper_mode,
                                 angular_preds=val_predictions_raw[j])
            else:
                self._plot_grasp(datapoint, image_field_name, pose_field_name,
                                 gripper_mode)
            vis2d.title("Datapoint %d: Pred: %.3f Label: %.3f" %
                        (k, val_result.pred_probs[j], val_result.labels[j]),
                        fontsize=self.font_size)
            vis2d.savefig(
                os.path.join(val_example_dir, "true_positive_%03d.png" % (i)))

        # Val FP.
        false_positive_indices = val_result.false_positive_indices
        np.random.shuffle(false_positive_indices)
        false_positive_indices = false_positive_indices[:self.num_vis]
        for i, j in enumerate(false_positive_indices):
            k = val_indices[j]
            datapoint = dataset.datapoint(
                k, field_names=[image_field_name, pose_field_name])
            vis2d.clf()
            if angular_bins > 0:
                self._plot_grasp(datapoint,
                                 image_field_name,
                                 pose_field_name,
                                 gripper_mode,
                                 angular_preds=val_predictions_raw[j])
            else:
                self._plot_grasp(datapoint, image_field_name, pose_field_name,
                                 gripper_mode)
            vis2d.title("Datapoint %d: Pred: %.3f Label: %.3f" %
                        (k, val_result.pred_probs[j], val_result.labels[j]),
                        fontsize=self.font_size)
            vis2d.savefig(
                os.path.join(val_example_dir, "false_positive_%03d.png" % (i)))

        # Val TN.
        true_negative_indices = val_result.true_negative_indices
        np.random.shuffle(true_negative_indices)
        true_negative_indices = true_negative_indices[:self.num_vis]
        for i, j in enumerate(true_negative_indices):
            k = val_indices[j]
            datapoint = dataset.datapoint(
                k, field_names=[image_field_name, pose_field_name])
            vis2d.clf()
            if angular_bins > 0:
                self._plot_grasp(datapoint,
                                 image_field_name,
                                 pose_field_name,
                                 gripper_mode,
                                 angular_preds=val_predictions_raw[j])
            else:
                self._plot_grasp(datapoint, image_field_name, pose_field_name,
                                 gripper_mode)
            vis2d.title("Datapoint %d: Pred: %.3f Label: %.3f" %
                        (k, val_result.pred_probs[j], val_result.labels[j]),
                        fontsize=self.font_size)
            vis2d.savefig(
                os.path.join(val_example_dir, "true_negative_%03d.png" % (i)))

        # Val TP.
        false_negative_indices = val_result.false_negative_indices
        np.random.shuffle(false_negative_indices)
        false_negative_indices = false_negative_indices[:self.num_vis]
        for i, j in enumerate(false_negative_indices):
            k = val_indices[j]
            datapoint = dataset.datapoint(
                k, field_names=[image_field_name, pose_field_name])
            vis2d.clf()
            if angular_bins > 0:
                self._plot_grasp(datapoint,
                                 image_field_name,
                                 pose_field_name,
                                 gripper_mode,
                                 angular_preds=val_predictions_raw[j])
            else:
                self._plot_grasp(datapoint, image_field_name, pose_field_name,
                                 gripper_mode)
            vis2d.title("Datapoint %d: Pred: %.3f Label: %.3f" %
                        (k, val_result.pred_probs[j], val_result.labels[j]),
                        fontsize=self.font_size)
            vis2d.savefig(
                os.path.join(val_example_dir, "false_negative_%03d.png" % (i)))

        # Save summary stats.
        train_summary_stats = {
            "error_rate": train_result.error_rate,
            "ap_score": train_result.ap_score,
            "auc_score": train_result.auc_score,
            "loss": train_result.cross_entropy_loss
        }
        train_stats_filename = os.path.join(model_output_dir,
                                            "train_stats.json")
        json.dump(train_summary_stats,
                  open(train_stats_filename, "w"),
                  indent=JSON_INDENT,
                  sort_keys=True)

        val_summary_stats = {
            "error_rate": val_result.error_rate,
            "ap_score": val_result.ap_score,
            "auc_score": val_result.auc_score,
            "loss": val_result.cross_entropy_loss
        }
        val_stats_filename = os.path.join(model_output_dir, "val_stats.json")
        json.dump(val_summary_stats,
                  open(val_stats_filename, "w"),
                  indent=JSON_INDENT,
                  sort_keys=True)

        return train_result, val_result