Пример #1
0
 def test_transform(self):
     random_valid_data = (255.0 * np.random.rand(IM_HEIGHT, IM_WIDTH, 3)).astype(np.uint8)
     im = ColorImage(random_valid_data)
     
     translation = np.array([2,2])
     im_tf = im.transform(translation, 0.0)
     self.assertTrue(np.allclose(im[0,0], im_tf[2,2]))
Пример #2
0
    def test_color_init(self):
        # valid data
        random_valid_data = (255.0 *
                             np.random.rand(IM_HEIGHT, IM_WIDTH, 3)).astype(
                                 np.uint8)
        im = ColorImage(random_valid_data)
        self.assertEqual(im.height, IM_HEIGHT)
        self.assertEqual(im.width, IM_WIDTH)
        self.assertEqual(im.channels, 3)
        self.assertTrue(np.allclose(im.data, random_valid_data))

        # invalid channels
        random_data = np.random.rand(IM_HEIGHT, IM_WIDTH).astype(np.uint8)
        caught_bad_channels = False
        try:
            im = ColorImage(random_data)
        except:
            caught_bad_channels = True
        self.assertTrue(caught_bad_channels)

        # invalid type
        random_data = np.random.rand(IM_HEIGHT, IM_WIDTH, 3).astype(np.float32)
        caught_bad_dtype = False
        try:
            im = ColorImage(random_data)
        except:
            caught_bad_dtype = True
        self.assertTrue(caught_bad_dtype)
Пример #3
0
    def test_mask_by_ind(self):
        random_valid_data = (255.0 * np.random.rand(IM_HEIGHT, IM_WIDTH, 3)).astype(np.uint8)
        im = ColorImage(random_valid_data)        

        ind = np.array([[0,0]])
        im2 = im.mask_by_ind(ind)
        self.assertEqual(np.sum(im2[1,1]), 0.0)
Пример #4
0
    def test_resize(self):
        random_valid_data = (255.0 * np.random.rand(IM_HEIGHT, IM_WIDTH, 3)).astype(np.uint8)
        im = ColorImage(random_valid_data)

        big_scale = 2.0
        big_im = im.resize(big_scale)
        self.assertEqual(big_im.height, big_scale * IM_HEIGHT)
        self.assertEqual(big_im.width, big_scale * IM_WIDTH)

        small_scale = 0.5
        small_im = im.resize(small_scale)
        self.assertEqual(small_im.height, small_scale * IM_HEIGHT)
        self.assertEqual(small_im.width, small_scale * IM_WIDTH)
Пример #5
0
    def _action_callback(self, action):
        # Update image with binary mask
        binary_data = np.frombuffer(action.mask_data, dtype=np.uint8).reshape(action.height, action.width)
        binary_image = BinaryImage(binary_data)
        mask = binary_image.nonzero_pixels()
        self._current_image.data[mask[:,0], mask[:,1], :] = (0.3 * self._current_image.data[mask[:,0], mask[:,1], :] +
                                                             0.7 * np.array([200, 30, 30], dtype=np.uint8))

        # Paint the appropriate action type over the current image
        action_type = action.action_type
        action_data = action.action_data

        if action_type == 'parallel_jaw':
            location = np.array([action_data[0], action_data[1]], dtype=np.uint32)
            axis = np.array([action_data[2], action_data[3]])
            if axis[0] != 0:
                angle = np.rad2deg(np.arctan((axis[1] / axis[0])))
            else:
                angle = 90
            jaw_image = ColorImage(imutils.rotate_bound(self._jaw_image.data, angle))
            self._current_image = self._overlay_image(self._current_image, jaw_image, location)
        elif action_type == 'suction':
            location = np.array([action_data[0], action_data[1]], dtype=np.uint32)
            self._current_image = self._overlay_image(self._current_image, self._suction_image, location)
        elif action_type == 'push':
            start = np.array([action_data[0], action_data[1]], dtype=np.int32)
            end = np.array([action_data[2], action_data[3]], dtype=np.int32)
            axis = (end - start).astype(np.float32)
            axis_len = np.linalg.norm(axis)
            if axis_len == 0:
                axis = np.array([0.0, 1.0])
                axis_len = 1.0
            axis = axis / axis_len

            if axis[0] != 0:
                angle = np.rad2deg(np.arctan2(axis[1], axis[0]))
            else:
                angle = (90.0 if axis[1] > 0 else -90.0)
            start_image = ColorImage(imutils.rotate_bound(self._push_image.data, angle))
            end_image = ColorImage(imutils.rotate_bound(self._push_end_image.data, angle))

            start = np.array(start - axis * 0.5 * self._push_image.width, dtype=np.int32)
            end = np.array(end + axis * 0.5 * self._push_end_image.width, dtype=np.int32)

            self._current_image = self._overlay_image(self._current_image, start_image, start)
            self._current_image = self._overlay_image(self._current_image, end_image, end)

        # Update display
        self.image_signal.emit()
        self.conf_signal.emit(action.confidence)
Пример #6
0
 def rendered_images(data, render_mode=RenderMode.SEGMASK):
     rendered_images = []
     num_images = data.attrs[NUM_IMAGES_KEY]
     
     for i in range(num_images):
         # get the image data y'all
         image_key = IMAGE_KEY + '_' + str(i)
         image_data = data[image_key]
         image_arr = np.array(image_data[IMAGE_DATA_KEY])
         frame = image_data.attrs[IMAGE_FRAME_KEY]
         if render_mode == RenderMode.SEGMASK:
             image = BinaryImage(image_arr, frame)
         elif render_mode == RenderMode.DEPTH:
             image = DepthImage(image_arr, frame)
         elif render_mode == RenderMode.SCALED_DEPTH:
             image = ColorImage(image_arr, frame)
         R_camera_table =  image_data.attrs[CAM_ROT_KEY]
         t_camera_table =  image_data.attrs[CAM_POS_KEY]
         frame          =  image_data.attrs[CAM_FRAME_KEY]
         T_camera_world = RigidTransform(R_camera_table, t_camera_table,
                                         from_frame=frame,
                                         to_frame='world')
         
         rendered_images.append(ObjectRender(image, T_camera_world))
     return rendered_images
Пример #7
0
    def frames(self):
        """Retrieve the next frame from the image directory and convert it to a ColorImage,
        a DepthImage, and an IrImage.

        Parameters
        ----------
        skip_registration : bool
            If True, the registration step is skipped.

        Returns
        -------
        :obj:`tuple` of :obj:`ColorImage`, :obj:`DepthImage`, :obj:`IrImage`, :obj:`numpy.ndarray`
            The ColorImage, DepthImage, and IrImage of the current frame.

        Raises
        ------
        RuntimeError
            If the Kinect stream is not running or if all images in the
            directory have been used.
        """
        if not self._running:
            raise RuntimeError('Primesense device pointing to %s not runnning. Cannot read frames' %(self._path_to_images))

        if self._im_index > self._num_images:
            raise RuntimeError('Primesense device is out of images')

        # read images
        color_filename = os.path.join(self._path_to_images, 'color_%d.png' %(self._im_index))
        color_im = ColorImage.open(color_filename, frame=self._frame)
        depth_filename = os.path.join(self._path_to_images, 'depth_%d.npy' %(self._im_index))
        depth_im = DepthImage.open(depth_filename, frame=self._frame)
        self._im_index += 1
        return color_im, depth_im, None
Пример #8
0
    def _read_color_and_depth_image(self):
        """Read a color and depth image from the device.
        """
        frames = self._pipe.wait_for_frames()
        if self._depth_align:
            frames = self._align.process(frames)

        depth_frame = frames.get_depth_frame()
        color_frame = frames.get_color_frame()

        if not depth_frame or not color_frame:
            logging.warning('Could not retrieve frames.')
            return None, None

        if self._filter_depth:
            depth_frame = self._filter_depth_frame(depth_frame)

        # convert to numpy arrays
        depth_image = self._to_numpy(depth_frame, np.float32)
        color_image = self._to_numpy(color_frame, np.uint8)

        # convert depth to meters
        depth_image *= self._depth_scale

        # bgr to rgb
        color_image = color_image[..., ::-1]

        depth = DepthImage(depth_image, frame=self._frame)
        color = ColorImage(color_image, frame=self._frame)
        return color, depth
Пример #9
0
 def load(save_dir):
     if not os.path.exists(save_dir):
         raise ValueError('Directory %s does not exist!' % (save_dir))
     color_image_filename = os.path.join(save_dir, 'color.png')
     depth_image_filename = os.path.join(save_dir, 'depth.npy')
     camera_intr_filename = os.path.join(save_dir, 'camera.intr')
     segmask_filename = os.path.join(save_dir, 'segmask.npy')
     obj_segmask_filename = os.path.join(save_dir, 'obj_segmask.npy')
     state_filename = os.path.join(save_dir, 'state.pkl')
     camera_intr = CameraIntrinsics.load(camera_intr_filename)
     color = ColorImage.open(color_image_filename, frame=camera_intr.frame)
     depth = DepthImage.open(depth_image_filename, frame=camera_intr.frame)
     segmask = None
     if os.path.exists(segmask_filename):
         segmask = BinaryImage.open(segmask_filename,
                                    frame=camera_intr.frame)
     obj_segmask = None
     if os.path.exists(obj_segmask_filename):
         obj_segmask = SegmentationImage.open(obj_segmask_filename,
                                              frame=camera_intr.frame)
     fully_observed = None
     if os.path.exists(state_filename):
         fully_observed = pkl.load(open(state_filename, 'rb'))
     return RgbdImageState(RgbdImage.from_color_and_depth(color, depth),
                           camera_intr,
                           segmask=segmask,
                           obj_segmask=obj_segmask,
                           fully_observed=fully_observed)
Пример #10
0
    def plan_grasp(self,
                   depth,
                   rgb,
                   resetting=False,
                   camera_intr=None,
                   segmask=None):
        """
        Computes possible grasps.
        Parameters
        ----------
        depth: type `numpy`
            depth image
        rgb: type `numpy`
            rgb image
        camera_intr: type `perception.CameraIntrinsics`
            Intrinsic camera object.
        segmask: type `perception.BinaryImage`
            Binary segmask of detected object
        Returns
        -------
        type `GQCNNGrasp`
            Computed optimal grasp.
        """
        if "FC" in self.model:
            assert not (segmask is
                        None), "Fully-Convolutional policy expects a segmask."
        if camera_intr is None:
            camera_intr_filename = os.path.join(
                os.path.dirname(os.path.realpath(__file__)),
                "gqcnn/data/calib/primesense/primesense.intr")
            camera_intr = CameraIntrinsics.load(camera_intr_filename)

        depth_im = DepthImage(depth, frame=camera_intr.frame)
        color_im = ColorImage(rgb, frame=camera_intr.frame)

        valid_px_mask = depth_im.invalid_pixel_mask().inverse()
        if segmask is None:
            segmask = valid_px_mask
        else:
            segmask = segmask.mask_binary(valid_px_mask)

        # Inpaint.
        depth_im = depth_im.inpaint(
            rescale_factor=self.config["inpaint_rescale_factor"])
        # Aggregate color and depth images into a single
        # BerkeleyAutomation/perception `RgbdImage`.
        self.rgbd_im = RgbdImage.from_color_and_depth(color_im, depth_im)
        # Create an `RgbdImageState` with the `RgbdImage` and `CameraIntrinsics`.
        state = RgbdImageState(self.rgbd_im, camera_intr, segmask=segmask)

        # Set input sizes for fully-convolutional policy.
        if "FC" in self.model:
            self.policy_config["metric"]["fully_conv_gqcnn_config"][
                "im_height"] = depth_im.shape[0]
            self.policy_config["metric"]["fully_conv_gqcnn_config"][
                "im_width"] = depth_im.shape[1]

        return self.execute_policy(state, resetting)
Пример #11
0
    def __init__(self, screen):
        """Initialize the ROS GUI.
        """
        super(RosGUI, self).__init__()

        # Set constants
        self._total_time = 0
        self._n_attempts = 0
        self._n_picks = 0
        self._state = 'pause'
        self._img_dir = os.path.join(os.getcwd(), 'images')

        # Initialize the UI
        self._init_UI()
        self._current_image = None

        # Connect the signals to slots
        self.weight_signal.connect(self.set_weight)
        self.conf_signal.connect(self.set_confidence)
        self.pph_signal.connect(self.set_pph)
        self.state_signal.connect(self.set_state)
        self.image_signal.connect(self.update_central_image)
        self.success_signal.connect(self.set_success)

        # Load grasp images
        self._suction_image = ColorImage.open(self._image_fn('suction.png'))
        self._suction_image = self._suction_image.resize((40, 40))
        self._jaw_image = ColorImage.open(self._image_fn('gripper.png'))
        self._jaw_image = self._jaw_image.resize(0.06)
        self._push_image = ColorImage.open(self._image_fn('probe.png'))
        self._push_image = self._push_image.resize(0.2)
        self._push_end_image = ColorImage.open(self._image_fn('probe_end.png'))
        self._push_end_image = self._push_end_image.resize(0.06)

        # Initialize ROS subscribers
        self._init_subscribers()

        desktop = QApplication.desktop()
        if screen > desktop.screenCount():
            print 'Screen index is greater than number of available screens; using default screen'
        screenRect = desktop.screenGeometry(screen)
        self.move(screenRect.topLeft())
        self.setWindowFlags(Qt.X11BypassWindowManagerHint)
        self.setFixedSize(screenRect.size())
        self.show()
Пример #12
0
 def _read_color_images(self, num_images):
     """ Reads color images from the device """
     color_images = self._ros_read_images(self._color_image_buffer, num_images, self.staleness_limit)
     for i in range(0, num_images):
         if self._flip_images:
             color_images[i] = np.flipud(color_images[i].astype(np.uint8))
             color_images[i] = np.fliplr(color_images[i].astype(np.uint8))
         color_images[i] = ColorImage(color_images[i], frame=self._frame) 
     return color_images
Пример #13
0
    def read_images(self, req):
        """Reads images from a ROS service request.

        Parameters
        ---------
        req: :obj:`ROS ServiceRequest`
            ROS ServiceRequest for grasp planner service.
        """
        # Get the raw depth and color images as ROS `Image` objects.
        raw_color = req.color_image
        raw_depth = req.depth_image

        # Get the raw camera info as ROS `CameraInfo`.
        raw_camera_info = req.camera_info

        # Wrap the camera info in a BerkeleyAutomation/perception
        # `CameraIntrinsics` object.
        camera_intr = CameraIntrinsics(
            raw_camera_info.header.frame_id, raw_camera_info.K[0],
            raw_camera_info.K[4], raw_camera_info.K[2], raw_camera_info.K[5],
            raw_camera_info.K[1], raw_camera_info.height,
            raw_camera_info.width)

        # Create wrapped BerkeleyAutomation/perception RGB and depth images by
        # unpacking the ROS images using ROS `CvBridge`
        try:
            color_im = ColorImage(self.cv_bridge.imgmsg_to_cv2(
                raw_color, "rgb8"),
                                  frame=camera_intr.frame)
            depth_im = DepthImage(self.cv_bridge.imgmsg_to_cv2(
                raw_depth, desired_encoding="passthrough"),
                                  frame=camera_intr.frame)
        except CvBridgeError as cv_bridge_exception:
            rospy.logerr(cv_bridge_exception)

        # Check image sizes.
        if color_im.height != depth_im.height or \
           color_im.width != depth_im.width:
            msg = ("Color image and depth image must be the same shape! Color"
                   " is %d x %d but depth is %d x %d") % (
                       color_im.height, color_im.width, depth_im.height,
                       depth_im.width)
            rospy.logerr(msg)
            raise rospy.ServiceException(msg)

        if (color_im.height < self.min_height
                or color_im.width < self.min_width):
            msg = ("Color image is too small! Must be at least %d x %d"
                   " resolution but the requested image is only %d x %d") % (
                       self.min_height, self.min_width, color_im.height,
                       color_im.width)
            rospy.logerr(msg)
            raise rospy.ServiceException(msg)

        return color_im, depth_im, camera_intr
Пример #14
0
    def read_images(self, req):
        """ Reads images from a ROS service request
        
        Parameters
        ---------
        req: :obj:`ROS ServiceRequest`
            ROS ServiceRequest for grasp planner service
        """
        # get the raw depth and color images as ROS Image objects
        raw_color = req.color_image
        raw_depth = req.depth_image

        # get the raw camera info as ROS CameraInfo object
        raw_camera_info = req.camera_info

        # wrap the camera info in a perception CameraIntrinsics object
        camera_intr = CameraIntrinsics(
            raw_camera_info.header.frame_id, raw_camera_info.K[0],
            raw_camera_info.K[4], raw_camera_info.K[2], raw_camera_info.K[5],
            raw_camera_info.K[1], raw_camera_info.height,
            raw_camera_info.width)

        # create wrapped Perception RGB and Depth Images by unpacking the ROS Images using CVBridge ###
        try:
            color_im = ColorImage(self.cv_bridge.imgmsg_to_cv2(
                raw_color, "rgb8"),
                                  frame=camera_intr.frame)
            depth_im = DepthImage(self.cv_bridge.imgmsg_to_cv2(
                raw_depth, desired_encoding="passthrough"),
                                  frame=camera_intr.frame)
        except CvBridgeError as cv_bridge_exception:
            rospy.logerr(cv_bridge_exception)

        # check image sizes
        if color_im.height != depth_im.height or \
           color_im.width != depth_im.width:
            msg = 'Color image and depth image must be the same shape! Color is %d x %d but depth is %d x %d' % (
                color_im.height, color_im.width, depth_im.height,
                depth_im.width)
            rospy.logerr(msg)
            raise rospy.ServiceException(msg)

        if color_im.height < self.min_height or color_im.width < self.min_width:
            msg = 'Color image is too small! Must be at least %d x %d resolution but the requested image is only %d x %d' % (
                self.min_height, self.min_width, color_im.height,
                color_im.width)
            rospy.logerr(msg)
            raise rospy.ServiceException(msg)

        return color_im, depth_im, camera_intr
Пример #15
0
    def test_shape_comp(self):
        random_valid_data = (255.0 * np.random.rand(IM_HEIGHT, IM_WIDTH, 3)).astype(np.uint8)
        im1 = ColorImage(random_valid_data)
        random_valid_data = (255.0 * np.random.rand(IM_HEIGHT, IM_WIDTH, 3)).astype(np.uint8)
        im2 = ColorImage(random_valid_data)
        self.assertTrue(im1.is_same_shape(im2))

        random_valid_data = (255.0 * np.random.rand(2*IM_HEIGHT, 2*IM_WIDTH, 3)).astype(np.uint8)
        im3 = ColorImage(random_valid_data)        
        self.assertFalse(im1.is_same_shape(im3))        
    def get_state(self, depth, segmask):
        # Read images.
        depth_im = DepthImage(depth, frame=self.camera_intr.frame)
        color_im = ColorImage(np.zeros([depth_im.height, depth_im.width,
                                        3]).astype(np.uint8),
                              frame=self.camera_intr.frame)
        segmask = BinaryImage(segmask.astype(np.uint8) * 255,
                              frame=self.camera_intr.frame)

        # Inpaint.
        depth_im = depth_im.inpaint(rescale_factor=self.inpaint_rescale_factor)

        # Create state.
        rgbd_im = RgbdImage.from_color_and_depth(color_im, depth_im)
        state = RgbdImageState(rgbd_im, self.camera_intr, segmask=segmask)
        return state, rgbd_im
Пример #17
0
    def _overlay_image(self, base, overlay, location):
        ow, oh = overlay.width, overlay.height
        bw, bh = base.width, base.height
        xmin, ymin = (location[0] - ow / 2), (location[1] - oh / 2)
        xmax, ymax = (location[0] + ow / 2 + ow % 2), (location[1] + oh / 2 + oh % 2)

        padding_l = (0 if (xmin > 0) else abs(xmin))
        padding_r = (0 if (xmax < bw) else (xmax - bw))
        padding_t = (0 if (ymin > 0) else abs(ymin))
        padding_b = (0 if (ymax < bh) else (ymax - bh))

        big_img = np.pad(base.data, ((padding_t, padding_b), (padding_l, padding_r), (0,0)), mode='constant')
        insert = big_img[ymin:ymax, xmin:xmax, :].copy()
        nzp = overlay.nonzero_pixels()
        insert[nzp[:,0], nzp[:,1], :] = overlay.data[nzp[:,0], nzp[:,1], :]
        big_img[ymin:ymax, xmin:xmax, :] = insert
        big_img = big_img[:bh, :bw, :]

        return ColorImage(big_img)
Пример #18
0
    def _read_color_image(self):
        """ Reads a color image from the device """
        # read raw buffer
        im_arr = self._color_stream.read_frame()
        raw_buf = im_arr.get_buffer_as_triplet()
        r_array = np.array([raw_buf[i][0] for i in range(PrimesenseSensor.COLOR_IM_WIDTH * PrimesenseSensor.COLOR_IM_HEIGHT)])        
        g_array = np.array([raw_buf[i][1] for i in range(PrimesenseSensor.COLOR_IM_WIDTH * PrimesenseSensor.COLOR_IM_HEIGHT)])        
        b_array = np.array([raw_buf[i][2] for i in range(PrimesenseSensor.COLOR_IM_WIDTH * PrimesenseSensor.COLOR_IM_HEIGHT)])        

        # convert to uint8 image
        color_image = np.zeros([PrimesenseSensor.COLOR_IM_HEIGHT, PrimesenseSensor.COLOR_IM_WIDTH, 3])
        color_image[:,:,0] = r_array.reshape(PrimesenseSensor.COLOR_IM_HEIGHT,
                                             PrimesenseSensor.COLOR_IM_WIDTH)
        color_image[:,:,1] = g_array.reshape(PrimesenseSensor.COLOR_IM_HEIGHT,
                                             PrimesenseSensor.COLOR_IM_WIDTH)
        color_image[:,:,2] = b_array.reshape(PrimesenseSensor.COLOR_IM_HEIGHT,
                                             PrimesenseSensor.COLOR_IM_WIDTH)
        if self._flip_images:
            color_image = np.flipud(color_image.astype(np.uint8))
        else:
            color_image = np.fliplr(color_image.astype(np.uint8))
        return ColorImage(color_image, frame=self._frame)
Пример #19
0
def run_gqcnn(depth,seg_mask):
	best_angle = 0;
	best_point = [0,0];
	best_dist = 0;
	depth_im =DepthImage(depth.astype("float32"), frame=camera_intr.frame)
	color_im = ColorImage(np.zeros([imageWidth, imageHeight,3]).astype(np.uint8),
                          frame=camera_intr.frame)
	print(seg_mask)
	segmask = BinaryImage(seg_mask)
	print(segmask)
	rgbd_im = RgbdImage.from_color_and_depth(color_im, depth_im)
	state_gqcnn = RgbdImageState(rgbd_im, camera_intr, segmask=segmask) 
	policy_start = time.time()
	try:
		action = policy(state_gqcnn)
		logger.info("Planning took %.3f sec" % (time.time() - policy_start))
		best_point = [action.grasp.center[0],action.grasp.center[1]];
		best_point = [action.grasp.center[0],action.grasp.center[1]];
		best_angle = float(action.grasp.angle)*180/3.141592
	except Exception as inst:
		print(inst)

	return best_angle,best_point,best_dist
Пример #20
0
def run_gqcnn(depth, seg_mask):
    best_angle = 0
    best_point = [0, 0]
    best_dist = 0
    depth_im = DepthImage(depth.astype("float32"), frame=camera_intr.frame)
    color_im = ColorImage(np.zeros([imageWidth, imageHeight,
                                    3]).astype(np.uint8),
                          frame=camera_intr.frame)
    segmask = BinaryImage(seg_mask)
    rgbd_im = RgbdImage.from_color_and_depth(color_im, depth_im)
    state_gqcnn = RgbdImageState(rgbd_im, camera_intr, segmask=segmask)
    policy_start = time.time()
    q_value = -1
    try:
        action = policy(state_gqcnn)
        best_point = [action.grasp.center[0], action.grasp.center[1]]
        best_angle = float(action.grasp.angle) * 180 / 3.141592
        q_value = action.q_value
        print("inner :       ", action.q_value)

    except Exception as inst:
        print(inst)

    return q_value
Пример #21
0
import IPython
import matplotlib.pyplot as plt
import numpy as np
import os
import sys

from perception import ColorImage
from perception.models import ClassificationCNN

if __name__ == '__main__':
    model_dir = sys.argv[1]
    model_type = sys.argv[2]
    image_filename = sys.argv[3]

    #with open('data/images/imagenet.json', 'r') as f:
    #    label_to_category = eval(f.read())

    im = ColorImage.open(image_filename)
    cnn = ClassificationCNN.open(model_dir, model_typename=model_type)
    out = cnn.predict(im)
    label = cnn.top_prediction(im)
    #category = label_to_category[label]

    plt.figure()
    plt.imshow(im.bgr2rgb().data)
    plt.title('Pred: %d' % (label))
    plt.axis('off')
    plt.show()

    #IPython.embed()
Пример #22
0
    # Make relative paths absolute.
    if "gqcnn_model" in policy_config["metric"]:
        policy_config["metric"]["gqcnn_model"] = model_path
        if not os.path.isabs(policy_config["metric"]["gqcnn_model"]):
            policy_config["metric"]["gqcnn_model"] = os.path.join(
                os.path.dirname(os.path.realpath(__file__)), "..",
                policy_config["metric"]["gqcnn_model"])

    # Setup sensor.
    camera_intr = CameraIntrinsics.load(camera_intr_filename)

    # Read images.
    depth_data = np.load(depth_im_filename)
    depth_im = DepthImage(depth_data, frame=camera_intr.frame)
    color_im = ColorImage(np.zeros([depth_im.height, depth_im.width,
                                    3]).astype(np.uint8),
                          frame=camera_intr.frame)

    # Optionally read a segmask.
    segmask = None
    if segmask_filename is not None:
        segmask = BinaryImage.open(segmask_filename)
    valid_px_mask = depth_im.invalid_pixel_mask().inverse()
    if segmask is None:
        segmask = valid_px_mask
    else:
        segmask = segmask.mask_binary(valid_px_mask)

    # Inpaint.
    depth_im = depth_im.inpaint(rescale_factor=inpaint_rescale_factor)
Пример #23
0
 def saved_frames(self):
     """list of perception.ColorImage : Any color images that have been saved
     due to recording or the max_frames argument.
     """
     return [ColorImage(f) for f in self._saved_frames]
Пример #24
0
def analyze_classification_performance(model_dir, config, dataset_path=None):
    # read params
    predict_batch_size = config['batch_size']
    randomize = config['randomize']
    
    plotting_config = config['plotting']
    figsize = plotting_config['figsize']
    font_size = plotting_config['font_size']
    legend_font_size = plotting_config['legend_font_size']
    line_width = plotting_config['line_width']
    colors = plotting_config['colors']
    dpi = plotting_config['dpi']
    style = '-'

    class_remapping = None
    if 'class_remapping' in config.keys():
        class_remapping = config['class_remapping']

    # read training config
    training_config_filename = os.path.join(model_dir, 'training_config.yaml')
    training_config = YamlConfig(training_config_filename)

    # read training params
    indices_filename = None
    if dataset_path is None:
        dataset_path = training_config['dataset']
        indices_filename = os.path.join(model_dir, 'splits.npz')
    dataset_prefix, dataset_name = os.path.split(dataset_path)
    if dataset_name == '':
        _, dataset_name = os.path.split(dataset_prefix)
    x_names = training_config['x_names']
    y_name = training_config['y_name']
    batch_size = training_config['training']['batch_size']
    iterator_config = training_config['data_iteration']
    x_name = x_names[0]

    # set analysis dir
    analysis_dir = os.path.join(model_dir, 'analysis')
    if not os.path.exists(analysis_dir):
        os.mkdir(analysis_dir)

    # setup log file
    experiment_log_filename = os.path.join(analysis_dir, '%s_analysis.log' %(dataset_name))
    if os.path.exists(experiment_log_filename):
        os.remove(experiment_log_filename)
    formatter = logging.Formatter('%(asctime)s %(levelname)s: %(message)s')
    hdlr = logging.FileHandler(experiment_log_filename)
    hdlr.setFormatter(formatter)
    logging.getLogger().addHandler(hdlr)

    # setup plotting
    plt.figure(figsize=(figsize, figsize))

    # read dataset
    dataset = TensorDataset.open(dataset_path)

    # read dataset splits
    if indices_filename is None:
        splits = {dataset_name: np.arange(dataset.num_datapoints)}
    else:
        splits = np.load(indices_filename)['arr_0'].tolist()

    # load cnn
    logging.info('Loading model %s' %(model_dir))
    cnn = ClassificationCNN.open(model_dir)

    # save examples
    logging.info('Saving examples of each class')
    all_labels = np.arange(cnn.num_classes)
    label_counts = {}
    [label_counts.update({l:0}) for l in all_labels]
    for tensor_ind in range(dataset.num_tensors):
        tensor = dataset.tensor(y_name, tensor_ind)
        for label in tensor:
            label_counts[label] += 1

    d = utils.sqrt_ceil(cnn.num_classes)
    plt.clf()
    for i, label in enumerate(all_labels):
        tensor_ind = 0
        label_found = False
        while not label_found and tensor_ind < dataset.num_tensors:
            tensor = dataset.tensor(y_name, tensor_ind)
            ind = np.where(tensor.arr == label)[0]
            if ind.shape[0] > 0:
                ind = ind[0] + dataset.datapoints_per_tensor * (tensor_ind)
                label_found = True

            tensor_ind += 1
        
        if not label_found:
            continue
        datapoint = dataset[ind]
        example_im = datapoint[x_name]
        
        plt.subplot(d,d,i+1)
        plt.imshow(example_im[:,:,:3].astype(np.uint8))
        plt.title('Class %03d: %.3f%%' %(label, float(label_counts[label]) / dataset.num_datapoints), fontsize=3)
        plt.axis('off')
    plt.savefig(os.path.join(analysis_dir, '%s_classes.pdf' %(dataset_name)))

    # evaluate on dataset
    results = {}
    for split_name, indices in splits.iteritems():
        logging.info('Evaluating performance on split: %s' %(split_name))

        # predict
        if randomize:
            pred_probs, true_labels = cnn.evaluate_on_dataset(dataset, indices=indices, batch_size=predict_batch_size)
            pred_labels = np.argmax(pred_probs, axis=1)
        else:
            true_labels = []
            pred_labels = []
            pred_probs = []
            for datapoint in dataset:
                im = ColorImage(datapoint['color_ims'].astype(np.uint8)[:,:,:3])
                true_label = datapoint['stp_labels']
                pred_prob = cnn.predict(im)
                pred_label = np.argmax(pred_prob, axis=1)
                true_labels.append(true_label)
                pred_labels.append(pred_label)
                pred_probs.append(pred_prob.ravel())
                
                """
                if class_remapping is not None:
                    true_label = class_remapping[true_label]
                plt.figure()
                plt.imshow(im.raw_data)
                plt.title('T: %d, P: %d' %(true_label, pred_label))
                plt.show()
                """
            true_labels = np.array(true_labels)
            pred_labels = np.array(pred_labels)
            pred_probs = np.array(pred_probs)
                
        # apply optional class re-mapping
        if class_remapping is not None:
            new_true_labels = np.zeros(true_labels.shape)
            for orig_label, new_label in class_remapping.iteritems():
                new_true_labels[true_labels==orig_label] = new_label
            true_labels = new_true_labels

        # compute classification results
        result = ClassificationResult([pred_probs], [true_labels])
        results[split_name] = result

        # print stats
        logging.info('SPLIT: %s' %(split_name))
        logging.info('Acc: %.3f' %(result.accuracy))
        logging.info('AP: %.3f' %(result.ap_score))
        logging.info('AUC: %.3f' %(result.auc_score))

        # save confusion matrix
        confusion = result.confusion_matrix.data
        plt.clf()
        plt.imshow(confusion, cmap=plt.cm.gray, interpolation='none')
        plt.locator_params(nticks=cnn.num_classes)
        plt.xlabel('Predicted', fontsize=font_size)
        plt.ylabel('Actual', fontsize=font_size)
        plt.savefig(os.path.join(analysis_dir, '%s_confusion.pdf' %(split_name)), dpi=dpi)

        # save analysis
        result_filename = os.path.join(analysis_dir, '%s.cres' %(split_name))
        result.save(result_filename)

    # plot
    colormap = plt.get_cmap('tab10')
    num_colors = 9

    plt.clf()
    for i, split_name in enumerate(splits.keys()):
        result = results[split_name]
        precision, recall, taus = result.precision_recall_curve()
        color = colormap(float(colors[i%num_colors]) / num_colors)
        plt.plot(recall, precision, linewidth=line_width, color=color, linestyle=style, label=split_name)
    plt.xlabel('Recall', fontsize=font_size)
    plt.ylabel('Precision', fontsize=font_size)
    plt.title('Precision-Recall Curve', fontsize=font_size)
    handles, plt_labels = plt.gca().get_legend_handles_labels()
    plt.legend(handles, plt_labels, loc='best', fontsize=legend_font_size)
    plt.savefig(os.path.join(analysis_dir, '%s_precision_recall.pdf' %(dataset_name)), dpi=dpi)

    plt.clf()
    for i, split_name in enumerate(splits.keys()):
        result = results[split_name]
        fpr, tpr, taus = result.roc_curve()
        color = colormap(float(colors[i%num_colors]) / num_colors)
        plt.plot(fpr, tpr, linewidth=line_width, color=color, linestyle=style, label=split_name)
    plt.xlabel('FPR', fontsize=font_size)
    plt.ylabel('TPR', fontsize=font_size)
    plt.title('Receiver Operating Characteristic', fontsize=font_size)
    handles, plt_labels = plt.gca().get_legend_handles_labels()
    plt.legend(handles, plt_labels, loc='best', fontsize=legend_font_size)
    plt.savefig(os.path.join(analysis_dir, '%s_roc.pdf' %(dataset_name)), dpi=dpi)
Пример #25
0
        if not os.path.isabs(policy_config['metric']['gqcnn_model']):
            policy_config['metric']['gqcnn_model'] = os.path.join(
                os.path.dirname(os.path.realpath(__file__)), '..',
                policy_config['metric']['gqcnn_model'])

    # setup sensor
    camera_intr = CameraIntrinsics.load(camera_intr_filename)

    # read images
    depth_data = np.load(depth_im_filename)
    c_image = np.array(cv2.imread(last_color_path))
    #print depth_data
    depth_im = DepthImage(depth_data, frame=camera_intr.frame)
    #color_im = ColorImage(np.zeros([depth_im.height, depth_im.width, 3]).astype(np.uint8),
    #                      frame=camera_intr.frame)
    color_im = ColorImage(c_image, frame=camera_intr.frame)
    # optionally read a segmask
    segmask = None
    if segmask_filename is not None:
        segmask = BinaryImage.open(segmask_filename)
        print "SegMask Opening __________________"
    valid_px_mask = depth_im.invalid_pixel_mask().inverse()
    if segmask is None:
        segmask = valid_px_mask
    else:
        segmask = segmask.mask_binary(valid_px_mask)

    # inpaint
    depth_im = depth_im.inpaint(rescale_factor=inpaint_rescale_factor)

    if True:  #'input_images' in policy_config['vis'].keys() and policy_config['vis']['input_images']:
Пример #26
0
    def visualize(self):
        """ Visualize predictions """

        logging.info('Visualizing ' + self.datapoint_type)

        # iterate through shuffled file indices
        for i in self.indices:
            im_filename = self.im_filenames[i]
            pose_filename = self.pose_filenames[i]
            label_filename = self.label_filenames[i]

            logging.info('Loading Image File: ' + im_filename +
                         ' Pose File: ' + pose_filename + ' Label File: ' +
                         label_filename)

            # load tensors from files
            metric_tensor = np.load(os.path.join(self.data_dir,
                                                 label_filename))['arr_0']
            label_tensor = 1 * (metric_tensor > self.metric_thresh)
            image_tensor = np.load(os.path.join(self.data_dir,
                                                im_filename))['arr_0']
            hand_poses_tensor = np.load(
                os.path.join(self.data_dir, pose_filename))['arr_0']

            pose_tensor = self._read_pose_data(hand_poses_tensor,
                                               self.input_data_mode)

            # score with neural network
            pred_p_success_tensor = self._gqcnn.predict(
                image_tensor, pose_tensor)

            # compute results
            classification_result = ClassificationResult(
                [pred_p_success_tensor], [label_tensor])

            logging.info('Error rate on files: %.3f' %
                         (classification_result.error_rate))
            logging.info('Precision on files: %.3f' %
                         (classification_result.precision))
            logging.info('Recall on files: %.3f' %
                         (classification_result.recall))
            mispred_ind = classification_result.mispredicted_indices()
            correct_ind = classification_result.correct_indices()
            # IPython.embed()

            if self.datapoint_type == 'true_positive' or self.datapoint_type == 'true_negative':
                vis_ind = correct_ind
            else:
                vis_ind = mispred_ind
            num_visualized = 0
            # visualize
            for ind in vis_ind:
                # limit the number of sampled datapoints displayed per object
                if num_visualized >= self.samples_per_object:
                    break
                num_visualized += 1

                # don't visualize the datapoints that we don't want
                if self.datapoint_type == 'true_positive':
                    if classification_result.labels[ind] == 0:
                        continue
                elif self.datapoint_type == 'true_negative':
                    if classification_result.labels[ind] == 1:
                        continue
                elif self.datapoint_type == 'false_positive':
                    if classification_result.labels[ind] == 0:
                        continue
                elif self.datapoint_type == 'false_negative':
                    if classification_result.labels[ind] == 1:
                        continue

                logging.info('Datapoint %d of files for %s' %
                             (ind, im_filename))
                logging.info('Depth: %.3f' % (hand_poses_tensor[ind, 2]))

                data = image_tensor[ind, ...]
                if self.display_image_type == RenderMode.SEGMASK:
                    image = BinaryImage(data)
                elif self.display_image_type == RenderMode.GRAYSCALE:
                    image = GrayscaleImage(data)
                elif self.display_image_type == RenderMode.COLOR:
                    image = ColorImage(data)
                elif self.display_image_type == RenderMode.DEPTH:
                    image = DepthImage(data)
                elif self.display_image_type == RenderMode.RGBD:
                    image = RgbdImage(data)
                elif self.display_image_type == RenderMode.GD:
                    image = GdImage(data)

                vis2d.figure()

                if self.display_image_type == RenderMode.RGBD:
                    vis2d.subplot(1, 2, 1)
                    vis2d.imshow(image.color)
                    grasp = Grasp2D(Point(image.center,
                                          'img'), 0, hand_poses_tensor[ind, 2],
                                    self.gripper_width_m)
                    grasp.camera_intr = grasp.camera_intr.resize(1.0 / 3.0)
                    vis2d.grasp(grasp)
                    vis2d.subplot(1, 2, 2)
                    vis2d.imshow(image.depth)
                    vis2d.grasp(grasp)
                elif self.display_image_type == RenderMode.GD:
                    vis2d.subplot(1, 2, 1)
                    vis2d.imshow(image.gray)
                    grasp = Grasp2D(Point(image.center,
                                          'img'), 0, hand_poses_tensor[ind, 2],
                                    self.gripper_width_m)
                    grasp.camera_intr = grasp.camera_intr.resize(1.0 / 3.0)
                    vis2d.grasp(grasp)
                    vis2d.subplot(1, 2, 2)
                    vis2d.imshow(image.depth)
                    vis2d.grasp(grasp)
                else:
                    vis2d.imshow(image)
                    grasp = Grasp2D(Point(image.center,
                                          'img'), 0, hand_poses_tensor[ind, 2],
                                    self.gripper_width_m)
                    grasp.camera_intr = grasp.camera_intr.resize(1.0 / 3.0)
                    vis2d.grasp(grasp)
                vis2d.title('Datapoint %d: Pred: %.3f Label: %.3f' %
                            (ind, classification_result.pred_probs[ind, 1],
                             classification_result.labels[ind]))
                vis2d.show()

        # cleanup
        self._cleanup()
    def read_images(self, req):
        """Reads images from a ROS service request.

        Parameters
        ---------
        req: :obj:`ROS ServiceRequest`
            ROS ServiceRequest for grasp planner service.
        """
        # Get the raw depth and color images as ROS `Image` objects.
        raw_color = req.color_image
        raw_depth = req.depth_image

        # Get the raw camera info as ROS `CameraInfo`.
        raw_camera_info = req.camera_info

        # Wrap the camera info in a BerkeleyAutomation/perception
        # `CameraIntrinsics` object.
        camera_intr = CameraIntrinsics(
            frame=raw_camera_info.header.frame_id,
            fx=raw_camera_info.K[0],
            fy=raw_camera_info.K[4],
            cx=raw_camera_info.K[2],
            cy=raw_camera_info.K[5],
            width=raw_camera_info.width,
            height=raw_camera_info.height,
        )

        # Create wrapped BerkeleyAutomation/perception RGB and depth images by
        # unpacking the ROS images using ROS `CvBridge`
        try:
            color_im = ColorImage(self.cv_bridge.imgmsg_to_cv2(
                raw_color, "rgb8"),
                                  frame=camera_intr.frame)

            cv2_depth = self.cv_bridge.imgmsg_to_cv2(
                raw_depth, desired_encoding="passthrough")
            cv2_depth = np.array(cv2_depth, dtype=np.float32)

            cv2_depth *= 0.001

            nan_idxs = np.isnan(cv2_depth)
            cv2_depth[nan_idxs] = 1000.0

            depth_im = DepthImage(cv2_depth, frame=camera_intr.frame)

        except CvBridgeError as cv_bridge_exception:
            print("except CvBridgeError")
            rospy.logerr(cv_bridge_exception)

        # Check image sizes.
        if color_im.height != depth_im.height or color_im.width != depth_im.width:
            msg = ("Color image and depth image must be the same shape! Color"
                   " is %d x %d but depth is %d x %d") % (
                       color_im.height, color_im.width, depth_im.height,
                       depth_im.width)

            rospy.logerr(msg)
            raise rospy.ServiceException(msg)

        if (color_im.height < self.min_height
                or color_im.width < self.min_width):
            msg = ("Color image is too small! Must be at least %d x %d"
                   " resolution but the requested image is only %d x %d") % (
                       self.min_height, self.min_width, color_im.height,
                       color_im.width)

            rospy.logerr(msg)
            raise rospy.ServiceException(msg)

        # --- create CameraData struct --- #
        camera_data = CameraData()
        camera_data.rgb_img = color_im
        camera_data.depth_img = depth_im
        camera_data.intrinsic_params = camera_intr

        return camera_data
Пример #28
0
    def wrapped_render(self, render_modes, front_and_back=False):
        """Render images of the scene and wrap them with Image wrapper classes
        from the Berkeley AUTOLab's perception module.

        Parameters
        ----------
        render_modes : list of perception.RenderMode 
            A list of the desired image types to return, from the perception
            module's RenderMode enum.

        front_and_back : bool
            If True, all surface normals are treated as if they're facing the camera.

        Returns
        -------
        list of perception.Image
            A list containing a corresponding Image sub-class for each type listed
            in render_modes.
        """

        # Render raw images
        render_color = False
        for mode in render_modes:
            if mode != RenderMode.DEPTH and mode != RenderMode.SCALED_DEPTH:
                render_color = True
                break

        color_im, depth_im = None, None
        if render_color:
            color_im, depth_im = self.render(render_color,
                                             front_and_back=front_and_back)
        else:
            depth_im = self.render(render_color)

        # For each specified render mode, add an Image object of the appropriate type
        images = []
        for render_mode in render_modes:
            # Then, convert them to an image wrapper class
            if render_mode == RenderMode.SEGMASK:
                images.append(
                    BinaryImage((depth_im > 0.0).astype(np.uint8),
                                frame=self.camera.intrinsics.frame,
                                threshold=0))

            elif render_mode == RenderMode.COLOR:
                images.append(
                    ColorImage(color_im, frame=self.camera.intrinsics.frame))

            elif render_mode == RenderMode.GRAYSCALE:
                images.append(
                    ColorImage(
                        color_im,
                        frame=self.camera.intrinsics.frame).to_grayscale())

            elif render_mode == RenderMode.DEPTH:
                images.append(
                    DepthImage(depth_im, frame=self.camera.intrinsics.frame))

            elif render_mode == RenderMode.SCALED_DEPTH:
                images.append(
                    DepthImage(depth_im,
                               frame=self.camera.intrinsics.frame).to_color())

            elif render_mode == RenderMode.RGBD:
                c = ColorImage(color_im, frame=self.camera.intrinsics.frame)
                d = DepthImage(depth_im, frame=self.camera.intrinsics.frame)
                images.append(RgbdImage.from_color_and_depth(c, d))

            elif render_mode == RenderMode.GD:
                g = ColorImage(
                    color_im,
                    frame=self.camera.intrinsics.frame).to_grayscale()
                d = DepthImage(depth_im, frame=self.camera.intrinsics.frame)
                images.append(GdImage.from_grayscale_and_depth(g, d))
            else:
                raise ValueError(
                    'Render mode {} not supported'.format(render_mode))

        return images
Пример #29
0
#%%
cfg = YamlConfig('cfg/gqcnn_pj_dbg.yaml')
#%%
grasp_policy = CrossEntropyRobustGraspingPolicy(cfg['policy'])
# grasp_policy = RobustGraspingPolicy(cfg['policy'])
#%%
img = rosco.rgb
d = rosco.depth
#%%
plt.imshow(img)
#%%
plt.imshow(d)
#print(img)
#print(d)
#%%
color_im = ColorImage(img.astype(np.uint8), encoding="bgr8", frame='pcl')
depth_im = DepthImage(d.astype(np.float32), frame='pcl')
color_im = color_im.inpaint(rescale_factor=cfg['inpaint_rescale_factor'])
depth_im = depth_im.inpaint(rescale_factor=cfg['inpaint_rescale_factor'])
rgbd_im = RgbdImage.from_color_and_depth(color_im, depth_im)
rgbd_state = RgbdImageState(rgbd_im, camera_int)
#%%
grasp = grasp_policy(rgbd_state)
#%%
img2 = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img2 = cv2.circle(img2,(int(grasp.grasp.center.x),int(grasp.grasp.center.y)),2,(255,0,0),3)
plt.imshow(img2)

#%%
img3 = cv2.cvtColor(d, cv2.COLOR_BGR2RGB)
img3 = cv2.circle(img3,(int(grasp.grasp.center.x),int(grasp.grasp.center.y)),2,(255,0,0),3)
Пример #30
0
def generate_segmask_dataset(output_dataset_path,
                             config,
                             save_tensors=True,
                             warm_start=False):
    """ Generate a segmentation training dataset

    Parameters
    ----------
    dataset_path : str
        path to store the dataset
    config : dict
        dictionary-like objects containing parameters of the simulator and visualization
    save_tensors : bool
        save tensor datasets (for recreating state)
    warm_start : bool
        restart dataset generation from a previous state
    """

    # read subconfigs
    dataset_config = config['dataset']
    image_config = config['images']
    vis_config = config['vis']

    # debugging
    debug = config['debug']
    if debug:
        np.random.seed(SEED)

    # read general parameters
    num_states = config['num_states']
    num_images_per_state = config['num_images_per_state']

    states_per_flush = config['states_per_flush']
    states_per_garbage_collect = config['states_per_garbage_collect']

    # set max obj per state
    max_objs_per_state = config['state_space']['heap']['max_objs']

    # read image parameters
    im_height = config['state_space']['camera']['im_height']
    im_width = config['state_space']['camera']['im_width']
    segmask_channels = max_objs_per_state + 1

    # create the dataset path and all subfolders if they don't exist
    if not os.path.exists(output_dataset_path):
        os.mkdir(output_dataset_path)

    image_dir = os.path.join(output_dataset_path, 'images')
    if not os.path.exists(image_dir):
        os.mkdir(image_dir)
    color_dir = os.path.join(image_dir, 'color_ims')
    if image_config['color'] and not os.path.exists(color_dir):
        os.mkdir(color_dir)
    depth_dir = os.path.join(image_dir, 'depth_ims')
    if image_config['depth'] and not os.path.exists(depth_dir):
        os.mkdir(depth_dir)
    amodal_dir = os.path.join(image_dir, 'amodal_masks')
    if image_config['amodal'] and not os.path.exists(amodal_dir):
        os.mkdir(amodal_dir)
    modal_dir = os.path.join(image_dir, 'modal_masks')
    if image_config['modal'] and not os.path.exists(modal_dir):
        os.mkdir(modal_dir)
    semantic_dir = os.path.join(image_dir, 'semantic_masks')
    if image_config['semantic'] and not os.path.exists(semantic_dir):
        os.mkdir(semantic_dir)

    # setup logging
    experiment_log_filename = os.path.join(output_dataset_path,
                                           'dataset_generation.log')
    if os.path.exists(experiment_log_filename) and not warm_start:
        os.remove(experiment_log_filename)
    Logger.add_log_file(logger, experiment_log_filename, global_log_file=True)
    config.save(
        os.path.join(output_dataset_path, 'dataset_generation_params.yaml'))
    metadata = {}
    num_prev_states = 0

    # set dataset params
    if save_tensors:

        # read dataset subconfigs
        state_dataset_config = dataset_config['states']
        image_dataset_config = dataset_config['images']
        state_tensor_config = state_dataset_config['tensors']
        image_tensor_config = image_dataset_config['tensors']

        obj_pose_dim = POSE_DIM * max_objs_per_state
        obj_com_dim = POINT_DIM * max_objs_per_state
        state_tensor_config['fields']['obj_poses']['height'] = obj_pose_dim
        state_tensor_config['fields']['obj_coms']['height'] = obj_com_dim
        state_tensor_config['fields']['obj_ids']['height'] = max_objs_per_state

        image_tensor_config['fields']['camera_pose']['height'] = POSE_DIM

        if image_config['color']:
            image_tensor_config['fields']['color_im'] = {
                'dtype': 'uint8',
                'channels': 3,
                'height': im_height,
                'width': im_width
            }

        if image_config['depth']:
            image_tensor_config['fields']['depth_im'] = {
                'dtype': 'float32',
                'channels': 1,
                'height': im_height,
                'width': im_width
            }

        if image_config['modal']:
            image_tensor_config['fields']['modal_segmasks'] = {
                'dtype': 'uint8',
                'channels': segmask_channels,
                'height': im_height,
                'width': im_width
            }

        if image_config['amodal']:
            image_tensor_config['fields']['amodal_segmasks'] = {
                'dtype': 'uint8',
                'channels': segmask_channels,
                'height': im_height,
                'width': im_width
            }

        if image_config['semantic']:
            image_tensor_config['fields']['semantic_segmasks'] = {
                'dtype': 'uint8',
                'channels': 1,
                'height': im_height,
                'width': im_width
            }

        # create dataset filenames
        state_dataset_path = os.path.join(output_dataset_path, 'state_tensors')
        image_dataset_path = os.path.join(output_dataset_path, 'image_tensors')

        if warm_start:

            if not os.path.exists(state_dataset_path) or not os.path.exists(
                    image_dataset_path):
                logger.error(
                    'Attempting to warm start without saved tensor dataset')
                exit(1)

            # open datasets
            logger.info('Opening state dataset')
            state_dataset = TensorDataset.open(state_dataset_path,
                                               access_mode='READ_WRITE')
            logger.info('Opening image dataset')
            image_dataset = TensorDataset.open(image_dataset_path,
                                               access_mode='READ_WRITE')

            # read configs
            state_tensor_config = state_dataset.config
            image_tensor_config = image_dataset.config

            # clean up datasets (there may be datapoints with indices corresponding to non-existent data)
            num_state_datapoints = state_dataset.num_datapoints
            num_image_datapoints = image_dataset.num_datapoints
            num_prev_states = num_state_datapoints

            # clean up images
            image_ind = num_image_datapoints - 1
            image_datapoint = image_dataset[image_ind]
            while image_ind > 0 and image_datapoint[
                    'state_ind'] >= num_state_datapoints:
                image_ind -= 1
                image_datapoint = image_dataset[image_ind]
            images_to_remove = num_image_datapoints - 1 - image_ind
            logger.info('Deleting last %d image tensors' % (images_to_remove))
            if images_to_remove > 0:
                image_dataset.delete_last(images_to_remove)
                num_image_datapoints = image_dataset.num_datapoints
        else:
            # create datasets from scratch
            logger.info('Creating datasets')

            state_dataset = TensorDataset(state_dataset_path,
                                          state_tensor_config)
            image_dataset = TensorDataset(image_dataset_path,
                                          image_tensor_config)

        # read templates
        state_datapoint = state_dataset.datapoint_template
        image_datapoint = image_dataset.datapoint_template

    if warm_start:

        if not os.path.exists(
                os.path.join(output_dataset_path, 'metadata.json')):
            logger.error(
                'Attempting to warm start without previously created dataset')
            exit(1)

        # Read metadata and indices
        metadata = json.load(
            open(os.path.join(output_dataset_path, 'metadata.json'), 'r'))
        test_inds = np.load(os.path.join(image_dir,
                                         'test_indices.npy')).tolist()
        train_inds = np.load(os.path.join(image_dir,
                                          'train_indices.npy')).tolist()

        # set obj ids and splits
        reverse_obj_ids = metadata['obj_ids']
        obj_id_map = utils.reverse_dictionary(reverse_obj_ids)
        obj_splits = metadata['obj_splits']
        obj_keys = obj_splits.keys()
        mesh_filenames = metadata['meshes']

        # Get list of images generated so far
        generated_images = sorted(
            os.listdir(color_dir)) if image_config['color'] else sorted(
                os.listdir(depth_dir))
        num_total_images = len(generated_images)

        # Do our own calculation if no saved tensors
        if num_prev_states == 0:
            num_prev_states = num_total_images // num_images_per_state

        # Find images to remove and remove them from all relevant places if they exist
        num_images_to_remove = num_total_images - (num_prev_states *
                                                   num_images_per_state)
        logger.info(
            'Deleting last {} invalid images'.format(num_images_to_remove))
        for k in range(num_images_to_remove):
            im_name = generated_images[-(k + 1)]
            im_basename = os.path.splitext(im_name)[0]
            im_ind = int(im_basename.split('_')[1])
            if os.path.exists(os.path.join(depth_dir, im_name)):
                os.remove(os.path.join(depth_dir, im_name))
            if os.path.exists(os.path.join(color_dir, im_name)):
                os.remove(os.path.join(color_dir, im_name))
            if os.path.exists(os.path.join(semantic_dir, im_name)):
                os.remove(os.path.join(semantic_dir, im_name))
            if os.path.exists(os.path.join(modal_dir, im_basename)):
                shutil.rmtree(os.path.join(modal_dir, im_basename))
            if os.path.exists(os.path.join(amodal_dir, im_basename)):
                shutil.rmtree(os.path.join(amodal_dir, im_basename))
            if im_ind in train_inds:
                train_inds.remove(im_ind)
            elif im_ind in test_inds:
                test_inds.remove(im_ind)

    else:

        # Create initial env to generate metadata
        env = BinHeapEnv(config)
        obj_id_map = env.state_space.obj_id_map
        obj_keys = env.state_space.obj_keys
        obj_splits = env.state_space.obj_splits
        mesh_filenames = env.state_space.mesh_filenames
        save_obj_id_map = obj_id_map.copy()
        save_obj_id_map[ENVIRONMENT_KEY] = np.iinfo(np.uint32).max
        reverse_obj_ids = utils.reverse_dictionary(save_obj_id_map)
        metadata['obj_ids'] = reverse_obj_ids
        metadata['obj_splits'] = obj_splits
        metadata['meshes'] = mesh_filenames
        json.dump(metadata,
                  open(os.path.join(output_dataset_path, 'metadata.json'),
                       'w'),
                  indent=JSON_INDENT,
                  sort_keys=True)
        train_inds = []
        test_inds = []

    # generate states and images
    state_id = num_prev_states
    while state_id < num_states:

        # create env and set objects
        create_start = time.time()
        env = BinHeapEnv(config)
        env.state_space.obj_id_map = obj_id_map
        env.state_space.obj_keys = obj_keys
        env.state_space.set_splits(obj_splits)
        env.state_space.mesh_filenames = mesh_filenames
        create_stop = time.time()
        logger.info('Creating env took %.3f sec' %
                    (create_stop - create_start))

        # sample states
        states_remaining = num_states - state_id
        for i in range(min(states_per_garbage_collect, states_remaining)):

            # log current rollout
            if state_id % config['log_rate'] == 0:
                logger.info('State: %06d' % (state_id))

            try:
                # reset env
                env.reset()
                state = env.state
                split = state.metadata['split']

                # render state
                if vis_config['state']:
                    env.view_3d_scene()

                # Save state if desired
                if save_tensors:

                    # set obj state variables
                    obj_pose_vec = np.zeros(obj_pose_dim)
                    obj_com_vec = np.zeros(obj_com_dim)
                    obj_id_vec = np.iinfo(
                        np.uint32).max * np.ones(max_objs_per_state)
                    j = 0
                    for obj_state in state.obj_states:
                        obj_pose_vec[j * POSE_DIM:(j + 1) *
                                     POSE_DIM] = obj_state.pose.vec
                        obj_com_vec[j * POINT_DIM:(j + 1) *
                                    POINT_DIM] = obj_state.center_of_mass
                        obj_id_vec[j] = int(obj_id_map[obj_state.key])
                        j += 1

                    # store datapoint env params
                    state_datapoint['state_id'] = state_id
                    state_datapoint['obj_poses'] = obj_pose_vec
                    state_datapoint['obj_coms'] = obj_com_vec
                    state_datapoint['obj_ids'] = obj_id_vec
                    state_datapoint['split'] = split

                    # store state datapoint
                    image_start_ind = image_dataset.num_datapoints
                    image_end_ind = image_start_ind + num_images_per_state
                    state_datapoint['image_start_ind'] = image_start_ind
                    state_datapoint['image_end_ind'] = image_end_ind

                    # clean up
                    del obj_pose_vec
                    del obj_com_vec
                    del obj_id_vec

                    # add state
                    state_dataset.add(state_datapoint)

                # render images
                for k in range(num_images_per_state):

                    # reset the camera
                    if num_images_per_state > 1:
                        env.reset_camera()

                    obs = env.render_camera_image(color=image_config['color'])
                    if image_config['color']:
                        color_obs, depth_obs = obs
                    else:
                        depth_obs = obs

                    # vis obs
                    if vis_config['obs']:
                        if image_config['depth']:
                            plt.figure()
                            plt.imshow(depth_obs)
                            plt.title('Depth Observation')
                        if image_config['color']:
                            plt.figure()
                            plt.imshow(color_obs)
                            plt.title('Color Observation')
                        plt.show()

                    if image_config['modal'] or image_config[
                            'amodal'] or image_config['semantic']:

                        # render segmasks
                        amodal_segmasks, modal_segmasks = env.render_segmentation_images(
                        )

                        # retrieve segmask data
                        modal_segmask_arr = np.iinfo(np.uint8).max * np.ones(
                            [im_height, im_width, segmask_channels],
                            dtype=np.uint8)
                        amodal_segmask_arr = np.iinfo(np.uint8).max * np.ones(
                            [im_height, im_width, segmask_channels],
                            dtype=np.uint8)
                        stacked_segmask_arr = np.zeros(
                            [im_height, im_width, 1], dtype=np.uint8)

                        modal_segmask_arr[:, :, :env.
                                          num_objects] = modal_segmasks
                        amodal_segmask_arr[:, :, :env.
                                           num_objects] = amodal_segmasks

                        if image_config['semantic']:
                            for j in range(env.num_objects):
                                this_obj_px = np.where(
                                    modal_segmasks[:, :, j] > 0)
                                stacked_segmask_arr[this_obj_px[0],
                                                    this_obj_px[1], 0] = j + 1

                    # visualize
                    if vis_config['semantic']:
                        plt.figure()
                        plt.imshow(stacked_segmask_arr.squeeze())
                        plt.show()

                    if save_tensors:
                        # save image data as tensors
                        if image_config['color']:
                            image_datapoint['color_im'] = color_obs
                        if image_config['depth']:
                            image_datapoint['depth_im'] = depth_obs[:, :, None]
                        if image_config['modal']:
                            image_datapoint[
                                'modal_segmasks'] = modal_segmask_arr
                        if image_config['amodal']:
                            image_datapoint[
                                'amodal_segmasks'] = amodal_segmask_arr
                        if image_config['semantic']:
                            image_datapoint[
                                'semantic_segmasks'] = stacked_segmask_arr

                        image_datapoint['camera_pose'] = env.camera.pose.vec
                        image_datapoint[
                            'camera_intrs'] = env.camera.intrinsics.vec
                        image_datapoint['state_ind'] = state_id
                        image_datapoint['split'] = split

                        # add image
                        image_dataset.add(image_datapoint)

                    # Save depth image and semantic masks
                    if image_config['color']:
                        ColorImage(color_obs).save(
                            os.path.join(
                                color_dir, 'image_{:06d}.png'.format(
                                    num_images_per_state * state_id + k)))
                    if image_config['depth']:
                        DepthImage(depth_obs).save(
                            os.path.join(
                                depth_dir, 'image_{:06d}.png'.format(
                                    num_images_per_state * state_id + k)))
                    if image_config['modal']:
                        modal_id_dir = os.path.join(
                            modal_dir,
                            'image_{:06d}'.format(num_images_per_state *
                                                  state_id + k))
                        if not os.path.exists(modal_id_dir):
                            os.mkdir(modal_id_dir)
                        for i in range(env.num_objects):
                            BinaryImage(modal_segmask_arr[:, :, i]).save(
                                os.path.join(modal_id_dir,
                                             'channel_{:03d}.png'.format(i)))
                    if image_config['amodal']:
                        amodal_id_dir = os.path.join(
                            amodal_dir,
                            'image_{:06d}'.format(num_images_per_state *
                                                  state_id + k))
                        if not os.path.exists(amodal_id_dir):
                            os.mkdir(amodal_id_dir)
                        for i in range(env.num_objects):
                            BinaryImage(amodal_segmask_arr[:, :, i]).save(
                                os.path.join(amodal_id_dir,
                                             'channel_{:03d}.png'.format(i)))
                    if image_config['semantic']:
                        GrayscaleImage(stacked_segmask_arr.squeeze()).save(
                            os.path.join(
                                semantic_dir, 'image_{:06d}.png'.format(
                                    num_images_per_state * state_id + k)))

                    # Save split
                    if split == TRAIN_ID:
                        train_inds.append(num_images_per_state * state_id + k)
                    else:
                        test_inds.append(num_images_per_state * state_id + k)

                # auto-flush after every so many timesteps
                if state_id % states_per_flush == 0:
                    np.save(os.path.join(image_dir, 'train_indices.npy'),
                            train_inds)
                    np.save(os.path.join(image_dir, 'test_indices.npy'),
                            test_inds)
                    if save_tensors:
                        state_dataset.flush()
                        image_dataset.flush()

                # delete action objects
                for obj_state in state.obj_states:
                    del obj_state
                del state
                gc.collect()

                # update state id
                state_id += 1

            except Exception as e:
                # log an error
                logger.warning('Heap failed!')
                logger.warning('%s' % (str(e)))
                logger.warning(traceback.print_exc())
                if debug:
                    raise

                del env
                gc.collect()
                env = BinHeapEnv(config)
                env.state_space.obj_id_map = obj_id_map
                env.state_space.obj_keys = obj_keys
                env.state_space.set_splits(obj_splits)
                env.state_space.mesh_filenames = mesh_filenames

        # garbage collect
        del env
        gc.collect()

    # write all datasets to file, save indices
    np.save(os.path.join(image_dir, 'train_indices.npy'), train_inds)
    np.save(os.path.join(image_dir, 'test_indices.npy'), test_inds)
    if save_tensors:
        state_dataset.flush()
        image_dataset.flush()

    logger.info('Generated %d image datapoints' %
                (state_id * num_images_per_state))