def test_open(self):
     # Try opening nonexistent dataset (should raise error)
     open_successful = True
     try:
         TensorDataset.open(TEST_TENSOR_DATASET_NAME)
     except FileNotFoundError:
         open_successful = False
     self.assertFalse(open_successful)
    def __init__(self, dataset_path, frame=None, loop=True):
        self._dataset_path = dataset_path
        self._frame = frame
        self._color_frame = frame
        self._ir_frame = frame
        self._im_index = 0
        self._running = False

        from autolab_core import TensorDataset

        self._dataset = TensorDataset.open(self._dataset_path)
        self._num_images = self._dataset.num_datapoints
        self._image_rescale_factor = 1.0
        if "image_rescale_factor" in self._dataset.metadata.keys():
            self._image_rescale_factor = (
                1.0 / self._dataset.metadata["image_rescale_factor"])

        datapoint = self._dataset.datapoint(
            0, [TensorDatasetVirtualSensor.CAMERA_INTR_FIELD])
        camera_intr_vec = datapoint[
            TensorDatasetVirtualSensor.CAMERA_INTR_FIELD]
        self._color_intr = CameraIntrinsics.from_vec(
            camera_intr_vec,
            frame=self._color_frame).resize(self._image_rescale_factor)
        self._ir_intr = CameraIntrinsics.from_vec(
            camera_intr_vec,
            frame=self._ir_frame).resize(self._image_rescale_factor)
def get_dataset(config, args):
    # print args
    # tensor_config = config['dataset']['tensors']
    # for field_name in ['vertices', 'normals', 'vertex_probs', 'min_required_forces']:
    #     tensor_config['fields'][field_name]['height'] = args.num_samples
    # tensor_config['fields']['final_poses']['height']

    # return TensorDataset(args.output, tensor_config)
    return TensorDataset.open(args.output, access_mode='READ_WRITE')
Example #4
0
 def __init__(self, datasets):
     self.datasets = {
         dataset_name: TensorDataset.open(datasets[dataset_name])
         for dataset_name in datasets.keys()
     }
     self.obj_ids, self.id_to_datapoint = {}, []
     for dataset_name in datasets.keys():
         with open(datasets[dataset_name] + "/obj_ids.json",
                   "r") as read_file:
             obj_ids = json.load(read_file)
         dataset = self.datasets[dataset_name]
         for i in range(dataset.num_datapoints):
             datapoint = dataset.datapoint(i)
             self.id_to_datapoint.append((datapoint['obj_id'], i))
         self.obj_ids[dataset_name] = obj_ids
     self.id_to_datapoint = {
         k: list(v)
         for k, v in groupby(self.id_to_datapoint, key=lambda x: x[0])
     }
Example #5
0
    input_dataset_names = cfg["input_datasets"]
    output_dataset_name = cfg["output_dataset"]
    display_rate = cfg["display_rate"]

    # modify list of dataset names
    all_input_dataset_names = []
    for dataset_name in input_dataset_names:
        tensor_dir = os.path.join(dataset_name, "tensors")
        if os.path.exists(tensor_dir):
            all_input_dataset_names.append(dataset_name)
        else:
            dataset_subdirs = utils.filenames(dataset_name, tag="dataset_")
            all_input_dataset_names.extend(dataset_subdirs)

    # open tensor dataset
    dataset = TensorDataset.open(all_input_dataset_names[0])
    tensor_config = copy.deepcopy(dataset.config)
    for field_name in cfg["exclude_fields"]:
        if field_name in tensor_config["fields"].keys():
            del tensor_config["fields"][field_name]
    field_names = tensor_config["fields"].keys()
    alt_field_names = [
        f if f != "rewards" else "grasp_metrics" for f in field_names
    ]

    # init tensor dataset
    output_dataset = TensorDataset(output_dataset_name, tensor_config)

    # copy config
    out_config_filename = os.path.join(output_dataset_name,
                                       "merge_config.yaml")
Example #6
0
def generate_segmask_dataset(output_dataset_path,
                             config,
                             save_tensors=True,
                             warm_start=False):
    """ Generate a segmentation training dataset

    Parameters
    ----------
    dataset_path : str
        path to store the dataset
    config : dict
        dictionary-like objects containing parameters of the simulator and visualization
    save_tensors : bool
        save tensor datasets (for recreating state)
    warm_start : bool
        restart dataset generation from a previous state
    """

    # read subconfigs
    dataset_config = config['dataset']
    image_config = config['images']
    vis_config = config['vis']

    # debugging
    debug = config['debug']
    if debug:
        np.random.seed(SEED)

    # read general parameters
    num_states = config['num_states']
    num_images_per_state = config['num_images_per_state']

    states_per_flush = config['states_per_flush']
    states_per_garbage_collect = config['states_per_garbage_collect']

    # set max obj per state
    max_objs_per_state = config['state_space']['heap']['max_objs']

    # read image parameters
    im_height = config['state_space']['camera']['im_height']
    im_width = config['state_space']['camera']['im_width']
    segmask_channels = max_objs_per_state + 1

    # create the dataset path and all subfolders if they don't exist
    if not os.path.exists(output_dataset_path):
        os.mkdir(output_dataset_path)

    image_dir = os.path.join(output_dataset_path, 'images')
    if not os.path.exists(image_dir):
        os.mkdir(image_dir)
    color_dir = os.path.join(image_dir, 'color_ims')
    if image_config['color'] and not os.path.exists(color_dir):
        os.mkdir(color_dir)
    depth_dir = os.path.join(image_dir, 'depth_ims')
    if image_config['depth'] and not os.path.exists(depth_dir):
        os.mkdir(depth_dir)
    amodal_dir = os.path.join(image_dir, 'amodal_masks')
    if image_config['amodal'] and not os.path.exists(amodal_dir):
        os.mkdir(amodal_dir)
    modal_dir = os.path.join(image_dir, 'modal_masks')
    if image_config['modal'] and not os.path.exists(modal_dir):
        os.mkdir(modal_dir)
    semantic_dir = os.path.join(image_dir, 'semantic_masks')
    if image_config['semantic'] and not os.path.exists(semantic_dir):
        os.mkdir(semantic_dir)

    # setup logging
    experiment_log_filename = os.path.join(output_dataset_path,
                                           'dataset_generation.log')
    if os.path.exists(experiment_log_filename) and not warm_start:
        os.remove(experiment_log_filename)
    Logger.add_log_file(logger, experiment_log_filename, global_log_file=True)
    config.save(
        os.path.join(output_dataset_path, 'dataset_generation_params.yaml'))
    metadata = {}
    num_prev_states = 0

    # set dataset params
    if save_tensors:

        # read dataset subconfigs
        state_dataset_config = dataset_config['states']
        image_dataset_config = dataset_config['images']
        state_tensor_config = state_dataset_config['tensors']
        image_tensor_config = image_dataset_config['tensors']

        obj_pose_dim = POSE_DIM * max_objs_per_state
        obj_com_dim = POINT_DIM * max_objs_per_state
        state_tensor_config['fields']['obj_poses']['height'] = obj_pose_dim
        state_tensor_config['fields']['obj_coms']['height'] = obj_com_dim
        state_tensor_config['fields']['obj_ids']['height'] = max_objs_per_state

        image_tensor_config['fields']['camera_pose']['height'] = POSE_DIM

        if image_config['color']:
            image_tensor_config['fields']['color_im'] = {
                'dtype': 'uint8',
                'channels': 3,
                'height': im_height,
                'width': im_width
            }

        if image_config['depth']:
            image_tensor_config['fields']['depth_im'] = {
                'dtype': 'float32',
                'channels': 1,
                'height': im_height,
                'width': im_width
            }

        if image_config['modal']:
            image_tensor_config['fields']['modal_segmasks'] = {
                'dtype': 'uint8',
                'channels': segmask_channels,
                'height': im_height,
                'width': im_width
            }

        if image_config['amodal']:
            image_tensor_config['fields']['amodal_segmasks'] = {
                'dtype': 'uint8',
                'channels': segmask_channels,
                'height': im_height,
                'width': im_width
            }

        if image_config['semantic']:
            image_tensor_config['fields']['semantic_segmasks'] = {
                'dtype': 'uint8',
                'channels': 1,
                'height': im_height,
                'width': im_width
            }

        # create dataset filenames
        state_dataset_path = os.path.join(output_dataset_path, 'state_tensors')
        image_dataset_path = os.path.join(output_dataset_path, 'image_tensors')

        if warm_start:

            if not os.path.exists(state_dataset_path) or not os.path.exists(
                    image_dataset_path):
                logger.error(
                    'Attempting to warm start without saved tensor dataset')
                exit(1)

            # open datasets
            logger.info('Opening state dataset')
            state_dataset = TensorDataset.open(state_dataset_path,
                                               access_mode='READ_WRITE')
            logger.info('Opening image dataset')
            image_dataset = TensorDataset.open(image_dataset_path,
                                               access_mode='READ_WRITE')

            # read configs
            state_tensor_config = state_dataset.config
            image_tensor_config = image_dataset.config

            # clean up datasets (there may be datapoints with indices corresponding to non-existent data)
            num_state_datapoints = state_dataset.num_datapoints
            num_image_datapoints = image_dataset.num_datapoints
            num_prev_states = num_state_datapoints

            # clean up images
            image_ind = num_image_datapoints - 1
            image_datapoint = image_dataset[image_ind]
            while image_ind > 0 and image_datapoint[
                    'state_ind'] >= num_state_datapoints:
                image_ind -= 1
                image_datapoint = image_dataset[image_ind]
            images_to_remove = num_image_datapoints - 1 - image_ind
            logger.info('Deleting last %d image tensors' % (images_to_remove))
            if images_to_remove > 0:
                image_dataset.delete_last(images_to_remove)
                num_image_datapoints = image_dataset.num_datapoints
        else:
            # create datasets from scratch
            logger.info('Creating datasets')

            state_dataset = TensorDataset(state_dataset_path,
                                          state_tensor_config)
            image_dataset = TensorDataset(image_dataset_path,
                                          image_tensor_config)

        # read templates
        state_datapoint = state_dataset.datapoint_template
        image_datapoint = image_dataset.datapoint_template

    if warm_start:

        if not os.path.exists(
                os.path.join(output_dataset_path, 'metadata.json')):
            logger.error(
                'Attempting to warm start without previously created dataset')
            exit(1)

        # Read metadata and indices
        metadata = json.load(
            open(os.path.join(output_dataset_path, 'metadata.json'), 'r'))
        test_inds = np.load(os.path.join(image_dir,
                                         'test_indices.npy')).tolist()
        train_inds = np.load(os.path.join(image_dir,
                                          'train_indices.npy')).tolist()

        # set obj ids and splits
        reverse_obj_ids = metadata['obj_ids']
        obj_id_map = utils.reverse_dictionary(reverse_obj_ids)
        obj_splits = metadata['obj_splits']
        obj_keys = obj_splits.keys()
        mesh_filenames = metadata['meshes']

        # Get list of images generated so far
        generated_images = sorted(
            os.listdir(color_dir)) if image_config['color'] else sorted(
                os.listdir(depth_dir))
        num_total_images = len(generated_images)

        # Do our own calculation if no saved tensors
        if num_prev_states == 0:
            num_prev_states = num_total_images // num_images_per_state

        # Find images to remove and remove them from all relevant places if they exist
        num_images_to_remove = num_total_images - (num_prev_states *
                                                   num_images_per_state)
        logger.info(
            'Deleting last {} invalid images'.format(num_images_to_remove))
        for k in range(num_images_to_remove):
            im_name = generated_images[-(k + 1)]
            im_basename = os.path.splitext(im_name)[0]
            im_ind = int(im_basename.split('_')[1])
            if os.path.exists(os.path.join(depth_dir, im_name)):
                os.remove(os.path.join(depth_dir, im_name))
            if os.path.exists(os.path.join(color_dir, im_name)):
                os.remove(os.path.join(color_dir, im_name))
            if os.path.exists(os.path.join(semantic_dir, im_name)):
                os.remove(os.path.join(semantic_dir, im_name))
            if os.path.exists(os.path.join(modal_dir, im_basename)):
                shutil.rmtree(os.path.join(modal_dir, im_basename))
            if os.path.exists(os.path.join(amodal_dir, im_basename)):
                shutil.rmtree(os.path.join(amodal_dir, im_basename))
            if im_ind in train_inds:
                train_inds.remove(im_ind)
            elif im_ind in test_inds:
                test_inds.remove(im_ind)

    else:

        # Create initial env to generate metadata
        env = BinHeapEnv(config)
        obj_id_map = env.state_space.obj_id_map
        obj_keys = env.state_space.obj_keys
        obj_splits = env.state_space.obj_splits
        mesh_filenames = env.state_space.mesh_filenames
        save_obj_id_map = obj_id_map.copy()
        save_obj_id_map[ENVIRONMENT_KEY] = np.iinfo(np.uint32).max
        reverse_obj_ids = utils.reverse_dictionary(save_obj_id_map)
        metadata['obj_ids'] = reverse_obj_ids
        metadata['obj_splits'] = obj_splits
        metadata['meshes'] = mesh_filenames
        json.dump(metadata,
                  open(os.path.join(output_dataset_path, 'metadata.json'),
                       'w'),
                  indent=JSON_INDENT,
                  sort_keys=True)
        train_inds = []
        test_inds = []

    # generate states and images
    state_id = num_prev_states
    while state_id < num_states:

        # create env and set objects
        create_start = time.time()
        env = BinHeapEnv(config)
        env.state_space.obj_id_map = obj_id_map
        env.state_space.obj_keys = obj_keys
        env.state_space.set_splits(obj_splits)
        env.state_space.mesh_filenames = mesh_filenames
        create_stop = time.time()
        logger.info('Creating env took %.3f sec' %
                    (create_stop - create_start))

        # sample states
        states_remaining = num_states - state_id
        for i in range(min(states_per_garbage_collect, states_remaining)):

            # log current rollout
            if state_id % config['log_rate'] == 0:
                logger.info('State: %06d' % (state_id))

            try:
                # reset env
                env.reset()
                state = env.state
                split = state.metadata['split']

                # render state
                if vis_config['state']:
                    env.view_3d_scene()

                # Save state if desired
                if save_tensors:

                    # set obj state variables
                    obj_pose_vec = np.zeros(obj_pose_dim)
                    obj_com_vec = np.zeros(obj_com_dim)
                    obj_id_vec = np.iinfo(
                        np.uint32).max * np.ones(max_objs_per_state)
                    j = 0
                    for obj_state in state.obj_states:
                        obj_pose_vec[j * POSE_DIM:(j + 1) *
                                     POSE_DIM] = obj_state.pose.vec
                        obj_com_vec[j * POINT_DIM:(j + 1) *
                                    POINT_DIM] = obj_state.center_of_mass
                        obj_id_vec[j] = int(obj_id_map[obj_state.key])
                        j += 1

                    # store datapoint env params
                    state_datapoint['state_id'] = state_id
                    state_datapoint['obj_poses'] = obj_pose_vec
                    state_datapoint['obj_coms'] = obj_com_vec
                    state_datapoint['obj_ids'] = obj_id_vec
                    state_datapoint['split'] = split

                    # store state datapoint
                    image_start_ind = image_dataset.num_datapoints
                    image_end_ind = image_start_ind + num_images_per_state
                    state_datapoint['image_start_ind'] = image_start_ind
                    state_datapoint['image_end_ind'] = image_end_ind

                    # clean up
                    del obj_pose_vec
                    del obj_com_vec
                    del obj_id_vec

                    # add state
                    state_dataset.add(state_datapoint)

                # render images
                for k in range(num_images_per_state):

                    # reset the camera
                    if num_images_per_state > 1:
                        env.reset_camera()

                    obs = env.render_camera_image(color=image_config['color'])
                    if image_config['color']:
                        color_obs, depth_obs = obs
                    else:
                        depth_obs = obs

                    # vis obs
                    if vis_config['obs']:
                        if image_config['depth']:
                            plt.figure()
                            plt.imshow(depth_obs)
                            plt.title('Depth Observation')
                        if image_config['color']:
                            plt.figure()
                            plt.imshow(color_obs)
                            plt.title('Color Observation')
                        plt.show()

                    if image_config['modal'] or image_config[
                            'amodal'] or image_config['semantic']:

                        # render segmasks
                        amodal_segmasks, modal_segmasks = env.render_segmentation_images(
                        )

                        # retrieve segmask data
                        modal_segmask_arr = np.iinfo(np.uint8).max * np.ones(
                            [im_height, im_width, segmask_channels],
                            dtype=np.uint8)
                        amodal_segmask_arr = np.iinfo(np.uint8).max * np.ones(
                            [im_height, im_width, segmask_channels],
                            dtype=np.uint8)
                        stacked_segmask_arr = np.zeros(
                            [im_height, im_width, 1], dtype=np.uint8)

                        modal_segmask_arr[:, :, :env.
                                          num_objects] = modal_segmasks
                        amodal_segmask_arr[:, :, :env.
                                           num_objects] = amodal_segmasks

                        if image_config['semantic']:
                            for j in range(env.num_objects):
                                this_obj_px = np.where(
                                    modal_segmasks[:, :, j] > 0)
                                stacked_segmask_arr[this_obj_px[0],
                                                    this_obj_px[1], 0] = j + 1

                    # visualize
                    if vis_config['semantic']:
                        plt.figure()
                        plt.imshow(stacked_segmask_arr.squeeze())
                        plt.show()

                    if save_tensors:
                        # save image data as tensors
                        if image_config['color']:
                            image_datapoint['color_im'] = color_obs
                        if image_config['depth']:
                            image_datapoint['depth_im'] = depth_obs[:, :, None]
                        if image_config['modal']:
                            image_datapoint[
                                'modal_segmasks'] = modal_segmask_arr
                        if image_config['amodal']:
                            image_datapoint[
                                'amodal_segmasks'] = amodal_segmask_arr
                        if image_config['semantic']:
                            image_datapoint[
                                'semantic_segmasks'] = stacked_segmask_arr

                        image_datapoint['camera_pose'] = env.camera.pose.vec
                        image_datapoint[
                            'camera_intrs'] = env.camera.intrinsics.vec
                        image_datapoint['state_ind'] = state_id
                        image_datapoint['split'] = split

                        # add image
                        image_dataset.add(image_datapoint)

                    # Save depth image and semantic masks
                    if image_config['color']:
                        ColorImage(color_obs).save(
                            os.path.join(
                                color_dir, 'image_{:06d}.png'.format(
                                    num_images_per_state * state_id + k)))
                    if image_config['depth']:
                        DepthImage(depth_obs).save(
                            os.path.join(
                                depth_dir, 'image_{:06d}.png'.format(
                                    num_images_per_state * state_id + k)))
                    if image_config['modal']:
                        modal_id_dir = os.path.join(
                            modal_dir,
                            'image_{:06d}'.format(num_images_per_state *
                                                  state_id + k))
                        if not os.path.exists(modal_id_dir):
                            os.mkdir(modal_id_dir)
                        for i in range(env.num_objects):
                            BinaryImage(modal_segmask_arr[:, :, i]).save(
                                os.path.join(modal_id_dir,
                                             'channel_{:03d}.png'.format(i)))
                    if image_config['amodal']:
                        amodal_id_dir = os.path.join(
                            amodal_dir,
                            'image_{:06d}'.format(num_images_per_state *
                                                  state_id + k))
                        if not os.path.exists(amodal_id_dir):
                            os.mkdir(amodal_id_dir)
                        for i in range(env.num_objects):
                            BinaryImage(amodal_segmask_arr[:, :, i]).save(
                                os.path.join(amodal_id_dir,
                                             'channel_{:03d}.png'.format(i)))
                    if image_config['semantic']:
                        GrayscaleImage(stacked_segmask_arr.squeeze()).save(
                            os.path.join(
                                semantic_dir, 'image_{:06d}.png'.format(
                                    num_images_per_state * state_id + k)))

                    # Save split
                    if split == TRAIN_ID:
                        train_inds.append(num_images_per_state * state_id + k)
                    else:
                        test_inds.append(num_images_per_state * state_id + k)

                # auto-flush after every so many timesteps
                if state_id % states_per_flush == 0:
                    np.save(os.path.join(image_dir, 'train_indices.npy'),
                            train_inds)
                    np.save(os.path.join(image_dir, 'test_indices.npy'),
                            test_inds)
                    if save_tensors:
                        state_dataset.flush()
                        image_dataset.flush()

                # delete action objects
                for obj_state in state.obj_states:
                    del obj_state
                del state
                gc.collect()

                # update state id
                state_id += 1

            except Exception as e:
                # log an error
                logger.warning('Heap failed!')
                logger.warning('%s' % (str(e)))
                logger.warning(traceback.print_exc())
                if debug:
                    raise

                del env
                gc.collect()
                env = BinHeapEnv(config)
                env.state_space.obj_id_map = obj_id_map
                env.state_space.obj_keys = obj_keys
                env.state_space.set_splits(obj_splits)
                env.state_space.mesh_filenames = mesh_filenames

        # garbage collect
        del env
        gc.collect()

    # write all datasets to file, save indices
    np.save(os.path.join(image_dir, 'train_indices.npy'), train_inds)
    np.save(os.path.join(image_dir, 'test_indices.npy'), test_inds)
    if save_tensors:
        state_dataset.flush()
        image_dataset.flush()

    logger.info('Generated %d image datapoints' %
                (state_id * num_images_per_state))
Example #7
0
def run_parallel_bin_picking_benchmark(input_dataset_path,
                                       heap_ids,
                                       timesteps,
                                       output_dataset_path,
                                       config_filename):
    raise NotImplementedError('Cannot run in parallel. Need to split up the heap ids and timesteps')

    # load config
    config = YamlConfig(config_filename)

    # init ray
    ray_config = config['ray']
    num_cpus = ray_config['num_cpus']
    ray.init(num_cpus=num_cpus,
             redirect_output=ray_config['redirect_output'])
    
    # rollouts
    num_rollouts = config['num_rollouts'] // num_cpus
    dataset_ids = [rollout_bin_picking_policy_in_parallel.remote(dataset_path, config_filename, num_rollouts) for i in range(num_cpus)]
    dataset_filenames = ray.get(dataset_ids)
    if len(dataset_filenames) == 0:
        return
    
    # merge datasets    
    subproc_dataset = TensorDataset.open(dataset_filenames[0])
    tensor_config = subproc_dataset.config


    # open dataset
    dataset = TensorDataset(dataset_path, tensor_config)
    dataset.add_metadata('action_ids', subproc_dataset.metadata['action_ids'])

    # add datapoints
    obj_id = 0
    heap_id = 0
    obj_ids = {}
    for dataset_filename in dataset_filenames:
        logging.info('Aggregating data from %s' %(dataset_filename))
        j = 0
        subproc_dataset = TensorDataset.open(dataset_filename)
        subproc_obj_ids = subproc_dataset.metadata['obj_ids']
        for datapoint in subproc_dataset:
            if j > 0 and datapoint['timesteps'] == 0:
                heap_id += 1
                
            # modify object ids
            for i in range(datapoint['obj_ids'].shape[0]):
                subproc_obj_id = datapoint['obj_ids'][i]
                if subproc_obj_id != np.uint32(-1):
                    subproc_obj_key = subproc_obj_ids[str(subproc_obj_id)]
                    if subproc_obj_key not in obj_ids.keys():
                        obj_ids[subproc_obj_key] = obj_id
                        obj_id += 1
                    datapoint['obj_ids'][i] = obj_ids[subproc_obj_key]

            # modify grasped obj id
            subproc_grasped_obj_id = datapoint['grasped_obj_ids']
            grasped_obj_key = subproc_obj_ids[str(subproc_grasped_obj_id)]
            datapoint['grasped_obj_ids'] = obj_ids[grasped_obj_key]

            # modify heap id
            datapoint['heap_ids'] = heap_id
                
            # add datapoint to dataset
            dataset.add(datapoint)
            j += 1
            
    # write to disk        
    obj_ids = utils.reverse_dictionary(obj_ids)
    dataset.add_metadata('obj_ids', obj_ids)
    dataset.flush()
Example #8
0
    def _run_prediction_single_model(self, model_dir,
                                     model_output_dir,
                                     dataset_config):
        """ Analyze the performance of a single model. """
        # read in model config
        model_config_filename = os.path.join(model_dir, 'config.json')
        with open(model_config_filename) as data_file:
            model_config = json.load(data_file)

        # load model
        self.logger.info('Loading model %s' %(model_dir))
        log_file = None
        for handler in self.logger.handlers:
            if isinstance(handler, logging.FileHandler):
                log_file = handler.baseFilename
        gqcnn = get_gqcnn_model(verbose=self.verbose).load(model_dir, verbose=self.verbose, log_file=log_file)
        gqcnn.open_session()
        gripper_mode = gqcnn.gripper_mode
        angular_bins = gqcnn.angular_bins
        
        # read params from the config
        if dataset_config is None:
            dataset_dir = model_config['dataset_dir']
            split_name = model_config['split_name']
            image_field_name = model_config['image_field_name']
            pose_field_name = model_config['pose_field_name']
            metric_name = model_config['target_metric_name']
            metric_thresh = model_config['metric_thresh']
        else:
            dataset_dir = dataset_config['dataset_dir']
            split_name = dataset_config['split_name']
            image_field_name = dataset_config['image_field_name']
            pose_field_name = dataset_config['pose_field_name']
            metric_name = dataset_config['target_metric_name']
            metric_thresh = dataset_config['metric_thresh']
            gripper_mode = dataset_config['gripper_mode']
            
        self.logger.info('Loading dataset %s' %(dataset_dir))
        dataset = TensorDataset.open(dataset_dir)
        train_indices, val_indices, _ = dataset.split(split_name)
        
        # visualize conv filters
        conv1_filters = gqcnn.filters
        num_filt = conv1_filters.shape[3]
        d = utils.sqrt_ceil(num_filt)
        vis2d.clf()
        for k in range(num_filt):
            filt = conv1_filters[:,:,0,k]
            vis2d.subplot(d,d,k+1)
            vis2d.imshow(DepthImage(filt))
            figname = os.path.join(model_output_dir, 'conv1_filters.pdf')
        vis2d.savefig(figname, dpi=self.dpi)
        
        # aggregate training and validation true labels and predicted probabilities
        all_predictions = []
        if angular_bins > 0:
            all_predictions_raw = []
        all_labels = []
        for i in range(dataset.num_tensors):
            # log progress
            if i % self.log_rate == 0:
                self.logger.info('Predicting tensor %d of %d' %(i+1, dataset.num_tensors))

            # read in data
            image_arr = dataset.tensor(image_field_name, i).arr
            pose_arr = read_pose_data(dataset.tensor(pose_field_name, i).arr,
                                      gripper_mode)
            metric_arr = dataset.tensor(metric_name, i).arr
            label_arr = 1 * (metric_arr > metric_thresh)
            label_arr = label_arr.astype(np.uint8)
            if angular_bins > 0:
                # form mask to extract predictions from ground-truth angular bins
                raw_poses = dataset.tensor(pose_field_name, i).arr
                angles = raw_poses[:, 3]
                neg_ind = np.where(angles < 0)
                angles = np.abs(angles) % GeneralConstants.PI
                angles[neg_ind] *= -1
                g_90 = np.where(angles > (GeneralConstants.PI / 2))
                l_neg_90 = np.where(angles < (-1 * (GeneralConstants.PI / 2)))
                angles[g_90] -= GeneralConstants.PI
                angles[l_neg_90] += GeneralConstants.PI
                angles *= -1 # hack to fix reverse angle convention
                angles += (GeneralConstants.PI / 2)
                pred_mask = np.zeros((raw_poses.shape[0], angular_bins*2), dtype=bool)
                bin_width = GeneralConstants.PI / angular_bins
                for i in range(angles.shape[0]):
                    pred_mask[i, int((angles[i] // bin_width)*2)] = True
                    pred_mask[i, int((angles[i] // bin_width)*2 + 1)] = True

            # predict with GQ-CNN
            predictions = gqcnn.predict(image_arr, pose_arr)
            if angular_bins > 0:
                raw_predictions = np.array(predictions)
                predictions = predictions[pred_mask].reshape((-1, 2))
            
            # aggregate
            all_predictions.extend(predictions[:,1].tolist())
            if angular_bins > 0:
                all_predictions_raw.extend(raw_predictions.tolist())
            all_labels.extend(label_arr.tolist())
            
        # close session
        gqcnn.close_session()            

        # create arrays
        all_predictions = np.array(all_predictions)
        all_labels = np.array(all_labels)
        train_predictions = all_predictions[train_indices]
        val_predictions = all_predictions[val_indices]
        train_labels = all_labels[train_indices]
        val_labels = all_labels[val_indices]
        if angular_bins > 0:
            all_predictions_raw = np.array(all_predictions_raw)
            train_predictions_raw = all_predictions_raw[train_indices]
            val_predictions_raw = all_predictions_raw[val_indices]        

        # aggregate results
        train_result = BinaryClassificationResult(train_predictions, train_labels)
        val_result = BinaryClassificationResult(val_predictions, val_labels)
        train_result.save(os.path.join(model_output_dir, 'train_result.cres'))
        val_result.save(os.path.join(model_output_dir, 'val_result.cres'))

        # get stats, plot curves
        self.logger.info('Model %s training error rate: %.3f' %(model_dir, train_result.error_rate))
        self.logger.info('Model %s validation error rate: %.3f' %(model_dir, val_result.error_rate))

        self.logger.info('Model %s training loss: %.3f' %(model_dir, train_result.cross_entropy_loss))
        self.logger.info('Model %s validation loss: %.3f' %(model_dir, val_result.cross_entropy_loss))

        # save images
        vis2d.figure()
        example_dir = os.path.join(model_output_dir, 'examples')
        if not os.path.exists(example_dir):
            os.mkdir(example_dir)

        # train
        self.logger.info('Saving training examples')
        train_example_dir = os.path.join(example_dir, 'train')
        if not os.path.exists(train_example_dir):
            os.mkdir(train_example_dir)
            
        # train TP
        true_positive_indices = train_result.true_positive_indices
        np.random.shuffle(true_positive_indices)
        true_positive_indices = true_positive_indices[:self.num_vis]
        for i, j in enumerate(true_positive_indices):
            k = train_indices[j]
            datapoint = dataset.datapoint(k, field_names=[image_field_name,
                                                          pose_field_name])
            vis2d.clf()
            if angular_bins > 0:
                self._plot_grasp(datapoint, image_field_name, pose_field_name, gripper_mode, angular_preds=train_predictions_raw[j])
            else:
                self._plot_grasp(datapoint, image_field_name, pose_field_name, gripper_mode)
            vis2d.title('Datapoint %d: Pred: %.3f Label: %.3f' %(k,
                                                                 train_result.pred_probs[j],
                                                                 train_result.labels[j]),
                        fontsize=self.font_size)
            vis2d.savefig(os.path.join(train_example_dir, 'true_positive_%03d.png' %(i)))

        # train FP
        false_positive_indices = train_result.false_positive_indices
        np.random.shuffle(false_positive_indices)
        false_positive_indices = false_positive_indices[:self.num_vis]
        for i, j in enumerate(false_positive_indices):
            k = train_indices[j]
            datapoint = dataset.datapoint(k, field_names=[image_field_name,
                                                          pose_field_name])
            vis2d.clf()
            if angular_bins > 0:
                self._plot_grasp(datapoint, image_field_name, pose_field_name, gripper_mode, angular_preds=train_predictions_raw[j])
            else: 
                self._plot_grasp(datapoint, image_field_name, pose_field_name, gripper_mode)
            vis2d.title('Datapoint %d: Pred: %.3f Label: %.3f' %(k,
                                                                 train_result.pred_probs[j],
                                                                 train_result.labels[j]),
                        fontsize=self.font_size)
            vis2d.savefig(os.path.join(train_example_dir, 'false_positive_%03d.png' %(i)))

        # train TN
        true_negative_indices = train_result.true_negative_indices
        np.random.shuffle(true_negative_indices)
        true_negative_indices = true_negative_indices[:self.num_vis]
        for i, j in enumerate(true_negative_indices):
            k = train_indices[j]
            datapoint = dataset.datapoint(k, field_names=[image_field_name,
                                                          pose_field_name])
            vis2d.clf()
            if angular_bins > 0:
                self._plot_grasp(datapoint, image_field_name, pose_field_name, gripper_mode, angular_preds=train_predictions_raw[j])
            else: 
                self._plot_grasp(datapoint, image_field_name, pose_field_name, gripper_mode)
            vis2d.title('Datapoint %d: Pred: %.3f Label: %.3f' %(k,
                                                                 train_result.pred_probs[j],
                                                                 train_result.labels[j]),
                        fontsize=self.font_size)
            vis2d.savefig(os.path.join(train_example_dir, 'true_negative_%03d.png' %(i)))

        # train TP
        false_negative_indices = train_result.false_negative_indices
        np.random.shuffle(false_negative_indices)
        false_negative_indices = false_negative_indices[:self.num_vis]
        for i, j in enumerate(false_negative_indices):
            k = train_indices[j]
            datapoint = dataset.datapoint(k, field_names=[image_field_name,
                                                          pose_field_name])
            vis2d.clf()
            if angular_bins > 0:
                self._plot_grasp(datapoint, image_field_name, pose_field_name, gripper_mode, angular_preds=train_predictions_raw[j])
            else: 
                self._plot_grasp(datapoint, image_field_name, pose_field_name, gripper_mode)
            vis2d.title('Datapoint %d: Pred: %.3f Label: %.3f' %(k,
                                                                 train_result.pred_probs[j],
                                                                 train_result.labels[j]),
                        fontsize=self.font_size)
            vis2d.savefig(os.path.join(train_example_dir, 'false_negative_%03d.png' %(i)))

        # val
        self.logger.info('Saving validation examples')
        val_example_dir = os.path.join(example_dir, 'val')
        if not os.path.exists(val_example_dir):
            os.mkdir(val_example_dir)

        # val TP
        true_positive_indices = val_result.true_positive_indices
        np.random.shuffle(true_positive_indices)
        true_positive_indices = true_positive_indices[:self.num_vis]
        for i, j in enumerate(true_positive_indices):
            k = val_indices[j]
            datapoint = dataset.datapoint(k, field_names=[image_field_name,
                                                          pose_field_name])
            vis2d.clf()
            if angular_bins > 0:
                self._plot_grasp(datapoint, image_field_name, pose_field_name, gripper_mode, angular_preds=val_predictions_raw[j])
            else: 
                self._plot_grasp(datapoint, image_field_name, pose_field_name, gripper_mode)
            vis2d.title('Datapoint %d: Pred: %.3f Label: %.3f' %(k,
                                                                 val_result.pred_probs[j],
                                                                 val_result.labels[j]),
                        fontsize=self.font_size)
            vis2d.savefig(os.path.join(val_example_dir, 'true_positive_%03d.png' %(i)))

        # val FP
        false_positive_indices = val_result.false_positive_indices
        np.random.shuffle(false_positive_indices)
        false_positive_indices = false_positive_indices[:self.num_vis]
        for i, j in enumerate(false_positive_indices):
            k = val_indices[j]
            datapoint = dataset.datapoint(k, field_names=[image_field_name,
                                                          pose_field_name])
            vis2d.clf()
            if angular_bins > 0:
                self._plot_grasp(datapoint, image_field_name, pose_field_name, gripper_mode, angular_preds=val_predictions_raw[j])
            else: 
                self._plot_grasp(datapoint, image_field_name, pose_field_name, gripper_mode)
            vis2d.title('Datapoint %d: Pred: %.3f Label: %.3f' %(k,
                                                                 val_result.pred_probs[j],
                                                                 val_result.labels[j]),
                        fontsize=self.font_size)
            vis2d.savefig(os.path.join(val_example_dir, 'false_positive_%03d.png' %(i)))

        # val TN
        true_negative_indices = val_result.true_negative_indices
        np.random.shuffle(true_negative_indices)
        true_negative_indices = true_negative_indices[:self.num_vis]
        for i, j in enumerate(true_negative_indices):
            k = val_indices[j]
            datapoint = dataset.datapoint(k, field_names=[image_field_name,
                                                          pose_field_name])
            vis2d.clf()
            if angular_bins > 0:
                self._plot_grasp(datapoint, image_field_name, pose_field_name, gripper_mode, angular_preds=val_predictions_raw[j])
            else: 
                self._plot_grasp(datapoint, image_field_name, pose_field_name, gripper_mode)
            vis2d.title('Datapoint %d: Pred: %.3f Label: %.3f' %(k,
                                                                 val_result.pred_probs[j],
                                                                 val_result.labels[j]),
                        fontsize=self.font_size)
            vis2d.savefig(os.path.join(val_example_dir, 'true_negative_%03d.png' %(i)))

        # val TP
        false_negative_indices = val_result.false_negative_indices
        np.random.shuffle(false_negative_indices)
        false_negative_indices = false_negative_indices[:self.num_vis]
        for i, j in enumerate(false_negative_indices):
            k = val_indices[j]
            datapoint = dataset.datapoint(k, field_names=[image_field_name,
                                                          pose_field_name])
            vis2d.clf()
            if angular_bins > 0:
                self._plot_grasp(datapoint, image_field_name, pose_field_name, gripper_mode, angular_preds=val_predictions_raw[j])
            else: 
                self._plot_grasp(datapoint, image_field_name, pose_field_name, gripper_mode)
            vis2d.title('Datapoint %d: Pred: %.3f Label: %.3f' %(k,
                                                                 val_result.pred_probs[j],
                                                                 val_result.labels[j]),
                        fontsize=self.font_size)
            vis2d.savefig(os.path.join(val_example_dir, 'false_negative_%03d.png' %(i)))
            
        # save summary stats
        train_summary_stats = {
            'error_rate': train_result.error_rate,
            'ap_score': train_result.ap_score,
            'auc_score': train_result.auc_score,
            'loss': train_result.cross_entropy_loss
        }
        train_stats_filename = os.path.join(model_output_dir, 'train_stats.json')
        json.dump(train_summary_stats, open(train_stats_filename, 'w'),
                  indent=JSON_INDENT,
                  sort_keys=True)

        val_summary_stats = {
            'error_rate': val_result.error_rate,
            'ap_score': val_result.ap_score,
            'auc_score': val_result.auc_score,
            'loss': val_result.cross_entropy_loss            
        }
        val_stats_filename = os.path.join(model_output_dir, 'val_stats.json')
        json.dump(val_summary_stats, open(val_stats_filename, 'w'),
                  indent=JSON_INDENT,
                  sort_keys=True)        
        
        return train_result, val_result
Example #9
0
    def _run_prediction_single_model(self, model_dir, model_output_dir,
                                     dataset_config):
        """ Analyze the performance of a single model. """
        # read in model config
        model_config_filename = os.path.join(model_dir, 'config.json')
        with open(model_config_filename) as data_file:
            model_config = json.load(data_file)

        # load model
        logging.info('Loading model %s' % (model_dir))
        gqcnn = GQCNN.load(model_dir)
        gqcnn.open_session()
        gripper_mode = gqcnn.gripper_mode

        # read params from the config
        if dataset_config is None:
            dataset_dir = model_config['dataset_dir']
            split_name = model_config['split_name']
            image_field_name = model_config['image_field_name']
            pose_field_name = model_config['pose_field_name']
            metric_name = model_config['target_metric_name']
            metric_thresh = model_config['metric_thresh']
        else:
            dataset_dir = dataset_config['dataset_dir']
            split_name = dataset_config['split_name']
            image_field_name = dataset_config['image_field_name']
            pose_field_name = dataset_config['pose_field_name']
            metric_name = dataset_config['target_metric_name']
            metric_thresh = dataset_config['metric_thresh']
            gripper_mode = dataset_config['gripper_mode']

        logging.info('Loading dataset %s' % (dataset_dir))
        dataset = TensorDataset.open(dataset_dir)
        train_indices, val_indices, _ = dataset.split(split_name)

        # visualize conv filters
        conv1_filters = gqcnn.filters
        num_filt = conv1_filters.shape[3]
        d = utils.sqrt_ceil(num_filt)
        vis2d.clf()
        for k in range(num_filt):
            filt = conv1_filters[:, :, 0, k]
            vis2d.subplot(d, d, k + 1)
            vis2d.imshow(DepthImage(filt))
            figname = os.path.join(model_output_dir, 'conv1_filters.pdf')
        vis2d.savefig(figname, dpi=self.dpi)

        # aggregate training and validation true labels and predicted probabilities
        all_predictions = []
        all_labels = []
        for i in range(dataset.num_tensors):
            # log progress
            if i % self.log_rate == 0:
                logging.info('Predicting tensor %d of %d' %
                             (i + 1, dataset.num_tensors))

            # read in data
            image_arr = dataset.tensor(image_field_name, i).arr
            pose_arr = read_pose_data(
                dataset.tensor(pose_field_name, i).arr, gripper_mode)
            metric_arr = dataset.tensor(metric_name, i).arr
            label_arr = 1 * (metric_arr > metric_thresh)
            label_arr = label_arr.astype(np.uint8)

            # predict with GQ-CNN
            predictions = gqcnn.predict(image_arr, pose_arr)

            # aggregate
            all_predictions.extend(predictions[:, 1].tolist())
            all_labels.extend(label_arr.tolist())

        # close session
        gqcnn.close_session()

        # create arrays
        all_predictions = np.array(all_predictions)
        all_labels = np.array(all_labels)
        train_predictions = all_predictions[train_indices]
        val_predictions = all_predictions[val_indices]
        train_labels = all_labels[train_indices]
        val_labels = all_labels[val_indices]

        # aggregate results
        train_result = BinaryClassificationResult(train_predictions,
                                                  train_labels)
        val_result = BinaryClassificationResult(val_predictions, val_labels)
        train_result.save(os.path.join(model_output_dir, 'train_result.cres'))
        val_result.save(os.path.join(model_output_dir, 'val_result.cres'))

        # get stats, plot curves
        logging.info('Model %s training error rate: %.3f' %
                     (model_dir, train_result.error_rate))
        logging.info('Model %s validation error rate: %.3f' %
                     (model_dir, val_result.error_rate))

        # save images
        vis2d.figure()
        example_dir = os.path.join(model_output_dir, 'examples')
        if not os.path.exists(example_dir):
            os.mkdir(example_dir)

        # train
        logging.info('Saving training examples')
        train_example_dir = os.path.join(example_dir, 'train')
        if not os.path.exists(train_example_dir):
            os.mkdir(train_example_dir)

        # train TP
        true_positive_indices = train_result.true_positive_indices
        np.random.shuffle(true_positive_indices)
        true_positive_indices = true_positive_indices[:self.num_vis]
        for i, j in enumerate(true_positive_indices):
            k = train_indices[j]
            datapoint = dataset.datapoint(
                k, field_names=[image_field_name, pose_field_name])
            vis2d.clf()
            self._plot_grasp(datapoint, image_field_name, pose_field_name,
                             gripper_mode)
            vis2d.title(
                'Datapoint %d: Pred: %.3f Label: %.3f' %
                (k, train_result.pred_probs[j], train_result.labels[j]),
                fontsize=self.font_size)
            vis2d.savefig(
                os.path.join(train_example_dir,
                             'true_positive_%03d.png' % (i)))

        # train FP
        false_positive_indices = train_result.false_positive_indices
        np.random.shuffle(false_positive_indices)
        false_positive_indices = false_positive_indices[:self.num_vis]
        for i, j in enumerate(false_positive_indices):
            k = train_indices[j]
            datapoint = dataset.datapoint(
                k, field_names=[image_field_name, pose_field_name])
            vis2d.clf()
            self._plot_grasp(datapoint, image_field_name, pose_field_name,
                             gripper_mode)
            vis2d.title(
                'Datapoint %d: Pred: %.3f Label: %.3f' %
                (k, train_result.pred_probs[j], train_result.labels[j]),
                fontsize=self.font_size)
            vis2d.savefig(
                os.path.join(train_example_dir,
                             'false_positive_%03d.png' % (i)))

        # train TN
        true_negative_indices = train_result.true_negative_indices
        np.random.shuffle(true_negative_indices)
        true_negative_indices = true_negative_indices[:self.num_vis]
        for i, j in enumerate(true_negative_indices):
            k = train_indices[j]
            datapoint = dataset.datapoint(
                k, field_names=[image_field_name, pose_field_name])
            vis2d.clf()
            self._plot_grasp(datapoint, image_field_name, pose_field_name,
                             gripper_mode)
            vis2d.title(
                'Datapoint %d: Pred: %.3f Label: %.3f' %
                (k, train_result.pred_probs[j], train_result.labels[j]),
                fontsize=self.font_size)
            vis2d.savefig(
                os.path.join(train_example_dir,
                             'true_negative_%03d.png' % (i)))

        # train TP
        false_negative_indices = train_result.false_negative_indices
        np.random.shuffle(false_negative_indices)
        false_negative_indices = false_negative_indices[:self.num_vis]
        for i, j in enumerate(false_negative_indices):
            k = train_indices[j]
            datapoint = dataset.datapoint(
                k, field_names=[image_field_name, pose_field_name])
            vis2d.clf()
            self._plot_grasp(datapoint, image_field_name, pose_field_name,
                             gripper_mode)
            vis2d.title(
                'Datapoint %d: Pred: %.3f Label: %.3f' %
                (k, train_result.pred_probs[j], train_result.labels[j]),
                fontsize=self.font_size)
            vis2d.savefig(
                os.path.join(train_example_dir,
                             'false_negative_%03d.png' % (i)))

        # val
        logging.info('Saving validation examples')
        val_example_dir = os.path.join(example_dir, 'val')
        if not os.path.exists(val_example_dir):
            os.mkdir(val_example_dir)

        # val TP
        true_positive_indices = val_result.true_positive_indices
        np.random.shuffle(true_positive_indices)
        true_positive_indices = true_positive_indices[:self.num_vis]
        for i, j in enumerate(true_positive_indices):
            k = val_indices[j]
            datapoint = dataset.datapoint(
                k, field_names=[image_field_name, pose_field_name])
            vis2d.clf()
            self._plot_grasp(datapoint, image_field_name, pose_field_name,
                             gripper_mode)
            vis2d.title('Datapoint %d: Pred: %.3f Label: %.3f' %
                        (k, val_result.pred_probs[j], val_result.labels[j]),
                        fontsize=self.font_size)
            vis2d.savefig(
                os.path.join(val_example_dir, 'true_positive_%03d.png' % (i)))

        # val FP
        false_positive_indices = val_result.false_positive_indices
        np.random.shuffle(false_positive_indices)
        false_positive_indices = false_positive_indices[:self.num_vis]
        for i, j in enumerate(false_positive_indices):
            k = val_indices[j]
            datapoint = dataset.datapoint(
                k, field_names=[image_field_name, pose_field_name])
            vis2d.clf()
            self._plot_grasp(datapoint, image_field_name, pose_field_name,
                             gripper_mode)
            vis2d.title('Datapoint %d: Pred: %.3f Label: %.3f' %
                        (k, val_result.pred_probs[j], val_result.labels[j]),
                        fontsize=self.font_size)
            vis2d.savefig(
                os.path.join(val_example_dir, 'false_positive_%03d.png' % (i)))

        # val TN
        true_negative_indices = val_result.true_negative_indices
        np.random.shuffle(true_negative_indices)
        true_negative_indices = true_negative_indices[:self.num_vis]
        for i, j in enumerate(true_negative_indices):
            k = val_indices[j]
            datapoint = dataset.datapoint(
                k, field_names=[image_field_name, pose_field_name])
            vis2d.clf()
            self._plot_grasp(datapoint, image_field_name, pose_field_name,
                             gripper_mode)
            vis2d.title('Datapoint %d: Pred: %.3f Label: %.3f' %
                        (k, val_result.pred_probs[j], val_result.labels[j]),
                        fontsize=self.font_size)
            vis2d.savefig(
                os.path.join(val_example_dir, 'true_negative_%03d.png' % (i)))

        # val TP
        false_negative_indices = val_result.false_negative_indices
        np.random.shuffle(false_negative_indices)
        false_negative_indices = false_negative_indices[:self.num_vis]
        for i, j in enumerate(false_negative_indices):
            k = val_indices[j]
            datapoint = dataset.datapoint(
                k, field_names=[image_field_name, pose_field_name])
            vis2d.clf()
            self._plot_grasp(datapoint, image_field_name, pose_field_name,
                             gripper_mode)
            vis2d.title('Datapoint %d: Pred: %.3f Label: %.3f' %
                        (k, val_result.pred_probs[j], val_result.labels[j]),
                        fontsize=self.font_size)
            vis2d.savefig(
                os.path.join(val_example_dir, 'false_negative_%03d.png' % (i)))

        # save summary stats
        train_summary_stats = {
            'error_rate': train_result.error_rate,
            'ap_score': train_result.ap_score,
            'auc_score': train_result.auc_score
        }
        train_stats_filename = os.path.join(model_output_dir,
                                            'train_stats.json')
        json.dump(train_summary_stats,
                  open(train_stats_filename, 'w'),
                  indent=JSON_INDENT,
                  sort_keys=True)

        val_summary_stats = {
            'error_rate': val_result.error_rate,
            'ap_score': val_result.ap_score,
            'auc_score': val_result.auc_score
        }
        val_stats_filename = os.path.join(model_output_dir, 'val_stats.json')
        json.dump(val_summary_stats,
                  open(val_stats_filename, 'w'),
                  indent=JSON_INDENT,
                  sort_keys=True)

        return train_result, val_result
def compute_dataset_statistics(dataset_path,
                               output_path,
                               config):
    """
    Compute the statistics of fields of a TensorDataset

    Parameters
    ----------
    dataset_path : str
        path to the dataset
    output_dir : str
        where to save the data
    config : :obj:`YamlConfig`
        parameters for the analysis
    """
    # parse config
    analysis_fields = config['analysis_fields']
    num_percentiles = config['num_percentiles']
    thresholds = config['thresholds']
    log_rate = config['log_rate']

    num_bins = config['num_bins']
    font_size = config['font_size']
    line_width = config['line_width']
    dpi = config['dpi']
    
    # create dataset for the aggregated results
    dataset = TensorDataset.open(dataset_path)
    num_datapoints = dataset.num_datapoints

    # allocate buffers
    analysis_data = {}
    for field in analysis_fields:
        analysis_data[field] = []

    # loop through dataset
    for i in range(num_datapoints):
        if i % log_rate == 0:
            logging.info('Reading datapoint %d of %d' %(i+1, num_datapoints))

        # read datapoint
        datapoint = dataset.datapoint(i, analysis_fields)
        for key, value in datapoint.iteritems():
            analysis_data[key].append(value)

    # create output CSV
    stats_headers = {
        'name': 'str',
        'mean': 'float',
        'median': 'float',
        'std': 'float'
    }    
    for i in range(num_percentiles):
        pctile = int((100.0 / num_percentiles) * i)
        field = '%d_pctile' %(pctile)
        stats_headers[field] = 'float'
    for t in thresholds:
        field = 'pct_above_%.3f' %(t)
        stats_headers[field] = 'float'
    
    # analyze statistics
    for field, data in analysis_data.iteritems():
        # init arrays
        data = np.array(data)

        # init filename
        stats_filename = os.path.join(output_path, '%s_stats.json' %(field))
        if os.path.exists(stats_filename):
            logging.warning('Statistics file %s exists!' %(stats_filename))
        
        # stats
        mean = np.mean(data)
        median = np.median(data)
        std = np.std(data)
        stats = {
            'name': str(field),
            'mean': float(mean),
            'median': float(median),
            'std': float(std),
        }
        for i in range(num_percentiles):
            pctile = int((100.0 / num_percentiles) * i)
            pctile_field = '%d_pctile' %(pctile)
            stats[pctile_field] = float(np.percentile(data, pctile))
        for t in thresholds:
            t_field = 'pct_above_%.3f' %(t)
            stats[t_field] = float(np.mean(1 * (data > t)))
        json.dump(stats,
                  open(stats_filename, 'w'),
                  indent=2,
                  sort_keys=True)
                  
        # histogram
        num_unique = np.unique(data).shape[0]
        nb = min(num_bins, data.shape[0], num_unique)
        bounds = (np.min(data), np.max(data))
        vis2d.figure()
        utils.histogram(data,
                        nb,
                        bounds,
                        normalized=False,
                        plot=True)
        vis2d.xlabel(field, fontsize=font_size)
        vis2d.ylabel('Count', fontsize=font_size)
        data_filename = os.path.join(output_path, 'histogram_%s.pdf' %(field))
        vis2d.show(data_filename, dpi=dpi)
Example #11
0
        default=None,
        help="path to the dataset to use for training and validation",
    )
    parser.add_argument("split_name",
                        type=str,
                        default=None,
                        help="name to use for the split")
    parser.add_argument(
        "--train_pct",
        type=float,
        default=0.8,
        help="percent of data to use for training",
    )
    parser.add_argument(
        "--field_name",
        type=str,
        default=None,
        help="name of the field to split on",
    )
    args = parser.parse_args()
    dataset_dir = args.dataset_dir
    split_name = args.split_name
    train_pct = args.train_pct
    field_name = args.field_name

    # create split
    dataset = TensorDataset.open(dataset_dir)
    train_indices, val_indices = dataset.make_split(split_name,
                                                    train_pct=train_pct,
                                                    field_name=field_name)
Permission to use, copy, modify, and distribute this software and its documentation for educational,
research, and not-for-profit purposes, without fee and without a signed licensing agreement, is
hereby granted, provided that the above copyright notice, this paragraph and the following two
paragraphs appear in all copies, modifications, and distributions. Contact The Office of Technology
Licensing, UC Berkeley, 2150 Shattuck Avenue, Suite 510, Berkeley, CA 94720-1620, (510) 643-
7201, [email protected], http://ipira.berkeley.edu/industry-info for commercial licensing opportunities.

IN NO EVENT SHALL REGENTS BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT, SPECIAL,
INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING LOST PROFITS, ARISING OUT OF
THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF REGENTS HAS BEEN
ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

REGENTS SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
PURPOSE. THE SOFTWARE AND ACCOMPANYING DOCUMENTATION, IF ANY, PROVIDED
HEREUNDER IS PROVIDED "AS IS". REGENTS HAS NO OBLIGATION TO PROVIDE
MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS.
"""
"""
Print the size of a tensor dataset
Author: Jeff Mahler
"""
import os
import sys

from autolab_core import TensorDataset

if __name__ == '__main__':
    dataset = TensorDataset.open(sys.argv[1])
    print 'Num datapoints', dataset.num_datapoints
Example #13
0
    def test_single_read_write(self):
        # seed
        np.random.seed(SEED)
        random.seed(SEED)

        # open dataset
        create_successful = True
        try:
            dataset = TensorDataset(TEST_TENSOR_DATASET_NAME, TENSOR_CONFIG)
        except:
            create_successful = False
        self.assertTrue(create_successful)

        # check field names
        write_datapoint = dataset.datapoint_template
        for field_name in write_datapoint.keys():
            self.assertTrue(field_name in dataset.field_names)

        # add the datapoint
        write_datapoint['float_value'] = np.random.rand()
        write_datapoint['int_value'] = int(100 * np.random.rand())
        write_datapoint['str_value'] = utils.gen_experiment_id()
        write_datapoint['vector_value'] = np.random.rand(HEIGHT)
        write_datapoint['matrix_value'] = np.random.rand(HEIGHT, WIDTH)
        write_datapoint['image_value'] = np.random.rand(
            HEIGHT, WIDTH, CHANNELS)
        dataset.add(write_datapoint)

        # check num datapoints
        self.assertTrue(dataset.num_datapoints == 1)

        # add metadata
        metadata_num = np.random.rand()
        dataset.add_metadata('test', metadata_num)

        # check written arrays
        dataset.flush()
        for field_name in dataset.field_names:
            filename = os.path.join(TEST_TENSOR_DATASET_NAME, 'tensors',
                                    '%s_00000.npz' % (field_name))
            value = np.load(filename)['arr_0']
            if isinstance(value[0], str):
                self.assertTrue(value[0] == write_datapoint[field_name])
            else:
                self.assertTrue(
                    np.allclose(value[0], write_datapoint[field_name]))

        # re-open the dataset
        del dataset
        dataset = TensorDataset.open(TEST_TENSOR_DATASET_NAME)

        # read metadata
        self.assertTrue(np.allclose(dataset.metadata['test'], metadata_num))

        # read datapoint
        read_datapoint = dataset.datapoint(0)
        for field_name in dataset.field_names:
            if isinstance(read_datapoint[field_name], str):
                self.assertTrue(
                    read_datapoint[field_name] == write_datapoint[field_name])
            else:
                self.assertTrue(
                    np.allclose(read_datapoint[field_name],
                                write_datapoint[field_name]))

        # check iterator
        for read_datapoint in dataset:
            for field_name in dataset.field_names:
                if isinstance(read_datapoint[field_name], str):
                    self.assertTrue(read_datapoint[field_name] ==
                                    write_datapoint[field_name])
                else:
                    self.assertTrue(
                        np.allclose(read_datapoint[field_name],
                                    write_datapoint[field_name]))

        # read individual fields
        for field_name in dataset.field_names:
            read_datapoint = dataset.datapoint(0, field_names=[field_name])
            if isinstance(read_datapoint[field_name], str):
                self.assertTrue(
                    read_datapoint[field_name] == write_datapoint[field_name])
            else:
                self.assertTrue(
                    np.allclose(read_datapoint[field_name],
                                write_datapoint[field_name]))

        # re-open the dataset in write-only
        del dataset
        dataset = TensorDataset.open(TEST_TENSOR_DATASET_NAME,
                                     access_mode=READ_WRITE_ACCESS)

        # delete datapoint
        dataset.delete_last()

        # check that the dataset is correct
        self.assertTrue(dataset.num_datapoints == 0)
        self.assertTrue(dataset.num_tensors == 0)
        for field_name in dataset.field_names:
            filename = os.path.join(TEST_TENSOR_DATASET_NAME, 'tensors',
                                    '%s_00000.npz' % (field_name))
            self.assertFalse(os.path.exists(filename))

        # remove dataset
        if os.path.exists(TEST_TENSOR_DATASET_NAME):
            shutil.rmtree(TEST_TENSOR_DATASET_NAME)
Example #14
0
    def test_multi_tensor_read_write(self):
        # seed
        np.random.seed(SEED)
        random.seed(SEED)

        # open dataset
        dataset = TensorDataset(TEST_TENSOR_DATASET_NAME, TENSOR_CONFIG)

        write_datapoints = []
        for i in range(DATAPOINTS_PER_FILE + 1):
            write_datapoint = {}
            write_datapoint['float_value'] = np.random.rand()
            write_datapoint['int_value'] = int(100 * np.random.rand())
            write_datapoint['str_value'] = utils.gen_experiment_id()
            write_datapoint['vector_value'] = np.random.rand(HEIGHT)
            write_datapoint['matrix_value'] = np.random.rand(HEIGHT, WIDTH)
            write_datapoint['image_value'] = np.random.rand(
                HEIGHT, WIDTH, CHANNELS)
            dataset.add(write_datapoint)
            write_datapoints.append(write_datapoint)

        # check num datapoints
        self.assertTrue(dataset.num_datapoints == DATAPOINTS_PER_FILE + 1)
        self.assertTrue(dataset.num_tensors == 2)

        # check read
        dataset.flush()
        del dataset
        dataset = TensorDataset.open(TEST_TENSOR_DATASET_NAME,
                                     access_mode=READ_WRITE_ACCESS)
        for i, read_datapoint in enumerate(dataset):
            write_datapoint = write_datapoints[i]
            for field_name in dataset.field_names:
                if isinstance(read_datapoint[field_name], str):
                    self.assertTrue(read_datapoint[field_name] ==
                                    write_datapoint[field_name])
                else:
                    self.assertTrue(
                        np.allclose(read_datapoint[field_name],
                                    write_datapoint[field_name]))

        for i, read_datapoint in enumerate(dataset):
            # check iterator item
            write_datapoint = write_datapoints[i]
            for field_name in dataset.field_names:
                if isinstance(read_datapoint[field_name], str):
                    self.assertTrue(read_datapoint[field_name] ==
                                    write_datapoint[field_name])
                else:
                    self.assertTrue(
                        np.allclose(read_datapoint[field_name],
                                    write_datapoint[field_name]))

            # check random item
            ind = np.random.choice(dataset.num_datapoints)
            write_datapoint = write_datapoints[ind]
            read_datapoint = dataset.datapoint(ind)
            for field_name in dataset.field_names:
                if isinstance(read_datapoint[field_name], str):
                    self.assertTrue(read_datapoint[field_name] ==
                                    write_datapoint[field_name])
                else:
                    self.assertTrue(
                        np.allclose(read_datapoint[field_name],
                                    write_datapoint[field_name]))

        # check deletion
        dataset.delete_last()
        self.assertTrue(dataset.num_datapoints == DATAPOINTS_PER_FILE)
        self.assertTrue(dataset.num_tensors == 1)
        for field_name in dataset.field_names:
            filename = os.path.join(TEST_TENSOR_DATASET_NAME, 'tensors',
                                    '%s_00001.npz' % (field_name))
        dataset.add(write_datapoints[-1])
        for write_datapoint in write_datapoints:
            dataset.add(write_datapoint)
        self.assertTrue(dataset.num_datapoints == 2 *
                        (DATAPOINTS_PER_FILE + 1))
        self.assertTrue(dataset.num_tensors == 3)

        # check valid
        for i in range(dataset.num_datapoints):
            read_datapoint = dataset.datapoint(i)
            write_datapoint = write_datapoints[i % (len(write_datapoints))]
            for field_name in dataset.field_names:
                if isinstance(read_datapoint[field_name], str):
                    self.assertTrue(read_datapoint[field_name] ==
                                    write_datapoint[field_name])
                else:
                    self.assertTrue(
                        np.allclose(read_datapoint[field_name],
                                    write_datapoint[field_name]))

        # check read then write out of order
        ind = np.random.choice(DATAPOINTS_PER_FILE)
        write_datapoint = write_datapoints[ind]
        read_datapoint = dataset.datapoint(ind)
        for field_name in dataset.field_names:
            if isinstance(read_datapoint[field_name], str):
                self.assertTrue(
                    read_datapoint[field_name] == write_datapoint[field_name])
            else:
                self.assertTrue(
                    np.allclose(read_datapoint[field_name],
                                write_datapoint[field_name]))

        write_datapoint = write_datapoints[0]
        dataset.add(write_datapoint)
        read_datapoint = dataset.datapoint(dataset.num_datapoints - 1)
        for field_name in dataset.field_names:
            if isinstance(read_datapoint[field_name], str):
                self.assertTrue(
                    read_datapoint[field_name] == write_datapoint[field_name])
            else:
                self.assertTrue(
                    np.allclose(read_datapoint[field_name],
                                write_datapoint[field_name]))
        dataset.delete_last()

        # check data integrity
        for i, read_datapoint in enumerate(dataset):
            write_datapoint = write_datapoints[i % len(write_datapoints)]
            for field_name in dataset.field_names:
                if isinstance(read_datapoint[field_name], str):
                    self.assertTrue(read_datapoint[field_name] ==
                                    write_datapoint[field_name])
                else:
                    self.assertTrue(
                        np.allclose(read_datapoint[field_name],
                                    write_datapoint[field_name]))

        # delete last
        dataset.delete_last(len(write_datapoints))
        self.assertTrue(dataset.num_datapoints == DATAPOINTS_PER_FILE + 1)
        self.assertTrue(dataset.num_tensors == 2)
        for i, read_datapoint in enumerate(dataset):
            write_datapoint = write_datapoints[i]
            for field_name in dataset.field_names:
                if isinstance(read_datapoint[field_name], str):
                    self.assertTrue(read_datapoint[field_name] ==
                                    write_datapoint[field_name])
                else:
                    self.assertTrue(
                        np.allclose(read_datapoint[field_name],
                                    write_datapoint[field_name]))

        # remove dataset
        if os.path.exists(TEST_TENSOR_DATASET_NAME):
            shutil.rmtree(TEST_TENSOR_DATASET_NAME)
    # parse args
    parser = argparse.ArgumentParser(description='Subsamples a dataset')
    parser.add_argument('dataset_path',
                        type=str,
                        default=None,
                        help='directory of the dataset to subsample')
    parser.add_argument('output_path',
                        type=str,
                        default=None,
                        help='directory to store the subsampled dataset')
    args = parser.parse_args()
    dataset_path = args.dataset_path
    output_path = args.output_path

    dataset = TensorDataset.open(dataset_path)
    out_dataset = TensorDataset(output_path, dataset.config)

    ind = np.arange(dataset.num_datapoints)
    np.random.shuffle(ind)

    for i, j in enumerate(ind):
        logging.info('Saving datapoint %d' % (i))
        datapoint = dataset[j]
        out_dataset.add(datapoint)
    out_dataset.flush()

    for split_name in dataset.split_names:
        _, val_indices, _ = dataset.split(split_name)
        new_val_indices = []
        for i in range(ind.shape[0]):
Example #16
0
    def _run_prediction_single_model(self, model_dir, model_output_dir,
                                     dataset_config):
        """Analyze the performance of a single model."""
        # Read in model config.
        model_config_filename = os.path.join(model_dir,
                                             GQCNNFilenames.SAVED_CFG)
        with open(model_config_filename) as data_file:
            model_config = json.load(data_file)

        # Load model.
        self.logger.info("Loading model %s" % (model_dir))
        log_file = None
        for handler in self.logger.handlers:
            if isinstance(handler, logging.FileHandler):
                log_file = handler.baseFilename
        gqcnn = get_gqcnn_model(verbose=self.verbose).load(
            model_dir, verbose=self.verbose, log_file=log_file)
        gqcnn.open_session()
        gripper_mode = gqcnn.gripper_mode
        angular_bins = gqcnn.angular_bins

        # Read params from the config.
        if dataset_config is None:
            dataset_dir = model_config["dataset_dir"]
            split_name = model_config["split_name"]
            image_field_name = model_config["image_field_name"]
            pose_field_name = model_config["pose_field_name"]
            metric_name = model_config["target_metric_name"]
            metric_thresh = model_config["metric_thresh"]
        else:
            dataset_dir = dataset_config["dataset_dir"]
            split_name = dataset_config["split_name"]
            image_field_name = dataset_config["image_field_name"]
            pose_field_name = dataset_config["pose_field_name"]
            metric_name = dataset_config["target_metric_name"]
            metric_thresh = dataset_config["metric_thresh"]
            gripper_mode = dataset_config["gripper_mode"]

        self.logger.info("Loading dataset %s" % (dataset_dir))
        dataset = TensorDataset.open(dataset_dir)
        train_indices, val_indices, _ = dataset.split(split_name)

        # Visualize conv filters.
        conv1_filters = gqcnn.filters
        num_filt = conv1_filters.shape[3]
        d = utils.sqrt_ceil(num_filt)
        vis2d.clf()
        for k in range(num_filt):
            filt = conv1_filters[:, :, 0, k]
            vis2d.subplot(d, d, k + 1)
            vis2d.imshow(DepthImage(filt))
            figname = os.path.join(model_output_dir, "conv1_filters.pdf")
        vis2d.savefig(figname, dpi=self.dpi)

        # Aggregate training and validation true labels and predicted
        # probabilities.
        all_predictions = []
        if angular_bins > 0:
            all_predictions_raw = []
        all_labels = []
        for i in range(dataset.num_tensors):
            # Log progress.
            if i % self.log_rate == 0:
                self.logger.info("Predicting tensor %d of %d" %
                                 (i + 1, dataset.num_tensors))

            # Read in data.
            image_arr = dataset.tensor(image_field_name, i).arr
            pose_arr = read_pose_data(
                dataset.tensor(pose_field_name, i).arr, gripper_mode)
            metric_arr = dataset.tensor(metric_name, i).arr
            label_arr = 1 * (metric_arr > metric_thresh)
            label_arr = label_arr.astype(np.uint8)
            if angular_bins > 0:
                # Form mask to extract predictions from ground-truth angular
                # bins.
                raw_poses = dataset.tensor(pose_field_name, i).arr
                angles = raw_poses[:, 3]
                neg_ind = np.where(angles < 0)
                # TODO(vsatish): These should use the max angle instead.
                angles = np.abs(angles) % GeneralConstants.PI
                angles[neg_ind] *= -1
                g_90 = np.where(angles > (GeneralConstants.PI / 2))
                l_neg_90 = np.where(angles < (-1 * (GeneralConstants.PI / 2)))
                angles[g_90] -= GeneralConstants.PI
                angles[l_neg_90] += GeneralConstants.PI
                # TODO(vsatish): Fix this along with the others.
                angles *= -1  # Hack to fix reverse angle convention.
                angles += (GeneralConstants.PI / 2)
                pred_mask = np.zeros((raw_poses.shape[0], angular_bins * 2),
                                     dtype=bool)
                bin_width = GeneralConstants.PI / angular_bins
                for i in range(angles.shape[0]):
                    pred_mask[i, int((angles[i] // bin_width) * 2)] = True
                    pred_mask[i, int((angles[i] // bin_width) * 2 + 1)] = True

            # Predict with GQ-CNN.
            predictions = gqcnn.predict(image_arr, pose_arr)
            if angular_bins > 0:
                raw_predictions = np.array(predictions)
                predictions = predictions[pred_mask].reshape((-1, 2))

            # Aggregate.
            all_predictions.extend(predictions[:, 1].tolist())
            if angular_bins > 0:
                all_predictions_raw.extend(raw_predictions.tolist())
            all_labels.extend(label_arr.tolist())

        # Close session.
        gqcnn.close_session()

        # Create arrays.
        all_predictions = np.array(all_predictions)
        all_labels = np.array(all_labels)
        train_predictions = all_predictions[train_indices]
        val_predictions = all_predictions[val_indices]
        train_labels = all_labels[train_indices]
        val_labels = all_labels[val_indices]
        if angular_bins > 0:
            all_predictions_raw = np.array(all_predictions_raw)
            train_predictions_raw = all_predictions_raw[train_indices]
            val_predictions_raw = all_predictions_raw[val_indices]

        # Aggregate results.
        train_result = BinaryClassificationResult(train_predictions,
                                                  train_labels)
        val_result = BinaryClassificationResult(val_predictions, val_labels)
        train_result.save(os.path.join(model_output_dir, "train_result.cres"))
        val_result.save(os.path.join(model_output_dir, "val_result.cres"))

        # Get stats, plot curves.
        self.logger.info("Model %s training error rate: %.3f" %
                         (model_dir, train_result.error_rate))
        self.logger.info("Model %s validation error rate: %.3f" %
                         (model_dir, val_result.error_rate))

        self.logger.info("Model %s training loss: %.3f" %
                         (model_dir, train_result.cross_entropy_loss))
        self.logger.info("Model %s validation loss: %.3f" %
                         (model_dir, val_result.cross_entropy_loss))

        # Save images.
        vis2d.figure()
        example_dir = os.path.join(model_output_dir, "examples")
        if not os.path.exists(example_dir):
            os.mkdir(example_dir)

        # Train.
        self.logger.info("Saving training examples")
        train_example_dir = os.path.join(example_dir, "train")
        if not os.path.exists(train_example_dir):
            os.mkdir(train_example_dir)

        # Train TP.
        true_positive_indices = train_result.true_positive_indices
        np.random.shuffle(true_positive_indices)
        true_positive_indices = true_positive_indices[:self.num_vis]
        for i, j in enumerate(true_positive_indices):
            k = train_indices[j]
            datapoint = dataset.datapoint(
                k, field_names=[image_field_name, pose_field_name])
            vis2d.clf()
            if angular_bins > 0:
                self._plot_grasp(datapoint,
                                 image_field_name,
                                 pose_field_name,
                                 gripper_mode,
                                 angular_preds=train_predictions_raw[j])
            else:
                self._plot_grasp(datapoint, image_field_name, pose_field_name,
                                 gripper_mode)
            vis2d.title(
                "Datapoint %d: Pred: %.3f Label: %.3f" %
                (k, train_result.pred_probs[j], train_result.labels[j]),
                fontsize=self.font_size)
            vis2d.savefig(
                os.path.join(train_example_dir,
                             "true_positive_%03d.png" % (i)))

        # Train FP.
        false_positive_indices = train_result.false_positive_indices
        np.random.shuffle(false_positive_indices)
        false_positive_indices = false_positive_indices[:self.num_vis]
        for i, j in enumerate(false_positive_indices):
            k = train_indices[j]
            datapoint = dataset.datapoint(
                k, field_names=[image_field_name, pose_field_name])
            vis2d.clf()
            if angular_bins > 0:
                self._plot_grasp(datapoint,
                                 image_field_name,
                                 pose_field_name,
                                 gripper_mode,
                                 angular_preds=train_predictions_raw[j])
            else:
                self._plot_grasp(datapoint, image_field_name, pose_field_name,
                                 gripper_mode)
            vis2d.title(
                "Datapoint %d: Pred: %.3f Label: %.3f" %
                (k, train_result.pred_probs[j], train_result.labels[j]),
                fontsize=self.font_size)
            vis2d.savefig(
                os.path.join(train_example_dir,
                             "false_positive_%03d.png" % (i)))

        # Train TN.
        true_negative_indices = train_result.true_negative_indices
        np.random.shuffle(true_negative_indices)
        true_negative_indices = true_negative_indices[:self.num_vis]
        for i, j in enumerate(true_negative_indices):
            k = train_indices[j]
            datapoint = dataset.datapoint(
                k, field_names=[image_field_name, pose_field_name])
            vis2d.clf()
            if angular_bins > 0:
                self._plot_grasp(datapoint,
                                 image_field_name,
                                 pose_field_name,
                                 gripper_mode,
                                 angular_preds=train_predictions_raw[j])
            else:
                self._plot_grasp(datapoint, image_field_name, pose_field_name,
                                 gripper_mode)
            vis2d.title(
                "Datapoint %d: Pred: %.3f Label: %.3f" %
                (k, train_result.pred_probs[j], train_result.labels[j]),
                fontsize=self.font_size)
            vis2d.savefig(
                os.path.join(train_example_dir,
                             "true_negative_%03d.png" % (i)))

        # Train TP.
        false_negative_indices = train_result.false_negative_indices
        np.random.shuffle(false_negative_indices)
        false_negative_indices = false_negative_indices[:self.num_vis]
        for i, j in enumerate(false_negative_indices):
            k = train_indices[j]
            datapoint = dataset.datapoint(
                k, field_names=[image_field_name, pose_field_name])
            vis2d.clf()
            if angular_bins > 0:
                self._plot_grasp(datapoint,
                                 image_field_name,
                                 pose_field_name,
                                 gripper_mode,
                                 angular_preds=train_predictions_raw[j])
            else:
                self._plot_grasp(datapoint, image_field_name, pose_field_name,
                                 gripper_mode)
            vis2d.title(
                "Datapoint %d: Pred: %.3f Label: %.3f" %
                (k, train_result.pred_probs[j], train_result.labels[j]),
                fontsize=self.font_size)
            vis2d.savefig(
                os.path.join(train_example_dir,
                             "false_negative_%03d.png" % (i)))

        # Val.
        self.logger.info("Saving validation examples")
        val_example_dir = os.path.join(example_dir, "val")
        if not os.path.exists(val_example_dir):
            os.mkdir(val_example_dir)

        # Val TP.
        true_positive_indices = val_result.true_positive_indices
        np.random.shuffle(true_positive_indices)
        true_positive_indices = true_positive_indices[:self.num_vis]
        for i, j in enumerate(true_positive_indices):
            k = val_indices[j]
            datapoint = dataset.datapoint(
                k, field_names=[image_field_name, pose_field_name])
            vis2d.clf()
            if angular_bins > 0:
                self._plot_grasp(datapoint,
                                 image_field_name,
                                 pose_field_name,
                                 gripper_mode,
                                 angular_preds=val_predictions_raw[j])
            else:
                self._plot_grasp(datapoint, image_field_name, pose_field_name,
                                 gripper_mode)
            vis2d.title("Datapoint %d: Pred: %.3f Label: %.3f" %
                        (k, val_result.pred_probs[j], val_result.labels[j]),
                        fontsize=self.font_size)
            vis2d.savefig(
                os.path.join(val_example_dir, "true_positive_%03d.png" % (i)))

        # Val FP.
        false_positive_indices = val_result.false_positive_indices
        np.random.shuffle(false_positive_indices)
        false_positive_indices = false_positive_indices[:self.num_vis]
        for i, j in enumerate(false_positive_indices):
            k = val_indices[j]
            datapoint = dataset.datapoint(
                k, field_names=[image_field_name, pose_field_name])
            vis2d.clf()
            if angular_bins > 0:
                self._plot_grasp(datapoint,
                                 image_field_name,
                                 pose_field_name,
                                 gripper_mode,
                                 angular_preds=val_predictions_raw[j])
            else:
                self._plot_grasp(datapoint, image_field_name, pose_field_name,
                                 gripper_mode)
            vis2d.title("Datapoint %d: Pred: %.3f Label: %.3f" %
                        (k, val_result.pred_probs[j], val_result.labels[j]),
                        fontsize=self.font_size)
            vis2d.savefig(
                os.path.join(val_example_dir, "false_positive_%03d.png" % (i)))

        # Val TN.
        true_negative_indices = val_result.true_negative_indices
        np.random.shuffle(true_negative_indices)
        true_negative_indices = true_negative_indices[:self.num_vis]
        for i, j in enumerate(true_negative_indices):
            k = val_indices[j]
            datapoint = dataset.datapoint(
                k, field_names=[image_field_name, pose_field_name])
            vis2d.clf()
            if angular_bins > 0:
                self._plot_grasp(datapoint,
                                 image_field_name,
                                 pose_field_name,
                                 gripper_mode,
                                 angular_preds=val_predictions_raw[j])
            else:
                self._plot_grasp(datapoint, image_field_name, pose_field_name,
                                 gripper_mode)
            vis2d.title("Datapoint %d: Pred: %.3f Label: %.3f" %
                        (k, val_result.pred_probs[j], val_result.labels[j]),
                        fontsize=self.font_size)
            vis2d.savefig(
                os.path.join(val_example_dir, "true_negative_%03d.png" % (i)))

        # Val TP.
        false_negative_indices = val_result.false_negative_indices
        np.random.shuffle(false_negative_indices)
        false_negative_indices = false_negative_indices[:self.num_vis]
        for i, j in enumerate(false_negative_indices):
            k = val_indices[j]
            datapoint = dataset.datapoint(
                k, field_names=[image_field_name, pose_field_name])
            vis2d.clf()
            if angular_bins > 0:
                self._plot_grasp(datapoint,
                                 image_field_name,
                                 pose_field_name,
                                 gripper_mode,
                                 angular_preds=val_predictions_raw[j])
            else:
                self._plot_grasp(datapoint, image_field_name, pose_field_name,
                                 gripper_mode)
            vis2d.title("Datapoint %d: Pred: %.3f Label: %.3f" %
                        (k, val_result.pred_probs[j], val_result.labels[j]),
                        fontsize=self.font_size)
            vis2d.savefig(
                os.path.join(val_example_dir, "false_negative_%03d.png" % (i)))

        # Save summary stats.
        train_summary_stats = {
            "error_rate": train_result.error_rate,
            "ap_score": train_result.ap_score,
            "auc_score": train_result.auc_score,
            "loss": train_result.cross_entropy_loss
        }
        train_stats_filename = os.path.join(model_output_dir,
                                            "train_stats.json")
        json.dump(train_summary_stats,
                  open(train_stats_filename, "w"),
                  indent=JSON_INDENT,
                  sort_keys=True)

        val_summary_stats = {
            "error_rate": val_result.error_rate,
            "ap_score": val_result.ap_score,
            "auc_score": val_result.auc_score,
            "loss": val_result.cross_entropy_loss
        }
        val_stats_filename = os.path.join(model_output_dir, "val_stats.json")
        json.dump(val_summary_stats,
                  open(val_stats_filename, "w"),
                  indent=JSON_INDENT,
                  sort_keys=True)

        return train_result, val_result
            get_model_config(model_config, .3, .9, 6.6e-04, 6.5e-02, 2.2e-01,
                             True)),
        TopplingModel(
            get_model_config(model_config, .3, .9, 6.6e-04, 6.5e-02, 2.2e-01,
                             False)),
        TopplingModel(
            get_model_config(model_config, .3, .9, 6.6e-04, 6.5e-02, 2.2e-01,
                             True)),
        TopplingModel(
            get_model_config(model_config, .3, .9, 6.6e-04, 6.5e-02, 2.2e-01,
                             False))
    ]
    use_sensitivities = [False, False, True, True]
    model_names = [
        'Baseline', 'Baseline+Rotations', 'Baseline+Robustness', 'Robust Model'
    ]

    env = GraspingEnv(config, config['vis'])
    env.reset()

    datasets, obj_id_to_keys = [], []
    for dataset_name in os.listdir(args.datasets):
        dataset_name = dataset_name.split(' ')[0]
        dataset_path = os.path.join(args.datasets, dataset_name)
        datasets.append(TensorDataset.open(dataset_path))
        with open(os.path.join(dataset_path, "obj_keys.json"),
                  "r") as read_file:
            obj_id_to_keys.append(json.load(read_file))
    visualize(env, datasets, obj_id_to_keys, models, model_names,
              use_sensitivities)