def main():
    """main function"""
    root_dir_default = osp.join(_init_paths.root_dir, 'data', 'kitti', 'KITTI-Object')
    splits_file_default = osp.join(_init_paths.root_dir, 'data', 'kitti', 'splits', 'trainval.txt')
    all_categories = ['Car', 'Van', 'Truck', 'Pedestrian', 'Person_sitting', 'Cyclist', 'Tram']

    parser = argparse.ArgumentParser()
    parser.add_argument("-r", "--root_dir", default=root_dir_default, help="Path to KITTI Object directory")
    parser.add_argument("-s", "--split_file", default=splits_file_default, help="Path to split file")
    parser.add_argument("-c", "--categories", type=str, nargs='+', default=['Car'], choices=all_categories, help="Object type (category)")
    args = parser.parse_args()

    print "------------- Config ------------------"
    for arg in vars(args):
        print "{} \t= {}".format(arg, getattr(args, arg))

    assert osp.exists(args.root_dir), 'KITTI Object dir "{}" does not exist'.format(args.root_dir)
    assert osp.exists(args.split_file), 'Path to split file does not exist: {}'.format(args.split_file)

    image_names = [x.rstrip() for x in open(args.split_file)]
    num_of_images = len(image_names)
    print 'Using Split {} with {} images'.format(osp.basename(args.split_file), num_of_images)

    root_dir = osp.join(args.root_dir, 'training')
    label_dir = osp.join(root_dir, 'label_2_updated')
    image_dir = osp.join(root_dir, 'image_2')
    calib_dir = osp.join(root_dir, 'calib')

    assert osp.exists(root_dir)
    assert osp.exists(label_dir)
    assert osp.exists(image_dir)
    assert osp.exists(calib_dir)

    dataset_name = 'kitti_' + osp.splitext(osp.basename(args.split_file))[0]
    dataset = ImageDataset(dataset_name)
    dataset.set_rootdir(root_dir)

    # Using a slight harder settings thank standard kitti hardness
    min_height = 20  # minimum height for evaluated groundtruth/detections
    max_occlusion = 2  # maximum occlusion level of the groundtruth used for evaluation
    max_truncation = 0.7  # maximum truncation level of the groundtruth used for evaluation

    total_num_of_objects = 0

    print 'Creating ImageDataset. May take long time'
    for image_name in tqdm(image_names):
        image_file_path = osp.join(image_dir, image_name + '.png')
        label_file_path = osp.join(label_dir, image_name + '.txt')
        calib_file_path = osp.join(calib_dir, image_name + '.txt')

        assert osp.exists(image_file_path)
        assert osp.exists(label_file_path)
        assert osp.exists(calib_file_path)

        objects = read_kitti_object_labels(label_file_path)

        # filter the objects based on kitti hardness criteria
        filtered_objects = {}
        for obj_id, obj in enumerate(objects):
            if obj['type'] not in args.categories:
                continue

            bbx = np.asarray(obj['bbox'])

            too_hard = False
            if (bbx[3] - bbx[1]) < min_height:
                too_hard = True
            if obj['occlusion'] > max_occlusion:
                too_hard = True
            if obj['truncation'] > max_truncation:
                too_hard = True

            if not too_hard:
                filtered_objects[obj_id] = obj

        if not filtered_objects:
            continue

        total_num_of_objects += len(filtered_objects)

        image = cv2.imread(image_file_path)
        W = image.shape[1]
        H = image.shape[0]
        calib_data = read_kitti_calib_file(calib_file_path)
        P0 = calib_data['P0'].reshape((3, 4))
        P2 = calib_data['P2'].reshape((3, 4))
        K = P0[:3, :3]
        assert np.all(P2[:3, :3] == K)

        cam2_center = -np.linalg.inv(K).dot(P2[:, 3])

        velo_T_cam0 = get_kitti_cam0_to_velo(calib_data)
        velo_T_cam2 = velo_T_cam0 * Pose(t=cam2_center)
        cam2_T_velo = get_kitti_velo_to_cam(calib_data, cam2_center)
        assert np.allclose(velo_T_cam2.inverse().matrix(), cam2_T_velo.matrix())

        annotation = OrderedDict()
        annotation['image_file'] = osp.relpath(image_file_path, root_dir)
        annotation['image_size'] = NoIndent([W, H])
        annotation['image_intrinsic'] = NoIndent(K.astype(np.float).tolist())

        obj_infos = []
        for obj_id in sorted(filtered_objects):
            obj = filtered_objects[obj_id]
            obj_pose_cam2 = get_kitti_object_pose(obj, velo_T_cam0, cam2_center)
            obj_pose_cam0 = get_kitti_object_pose(obj, velo_T_cam0, np.zeros(3))
            assert np.allclose(obj_pose_cam0.t - obj_pose_cam2.t, cam2_center)

            bbx_visible = np.array(obj['bbox'])
            bbx_amodal = get_kitti_amodal_bbx(obj, K, obj_pose_cam2)

            obj_origin_proj = project_point(K, obj_pose_cam2.t)
            distance = np.linalg.norm(obj_pose_cam2.t)

            delta_rot = rotation_from_two_vectors(obj_pose_cam2.t, np.array([0., 0., 1.]))
            obj_rel_rot = np.matmul(delta_rot, obj_pose_cam2.R)
            assert np.allclose(delta_rot.dot(obj_pose_cam2.t), np.array([0., 0., distance]))

            viewpoint = viewpoint_from_rotation(obj_rel_rot)

            R_vp = rotation_from_viewpoint(viewpoint)
            assert np.allclose(R_vp, obj_rel_rot, rtol=1e-03), "R_vp = \n{}\nobj_rel_rot = \n{}\n".format(R_vp, obj_rel_rot)
            assert np.allclose(np.matmul(delta_rot.T, R_vp), obj_pose_cam2.R, rtol=1e-04)

            pred_alpha = get_kitti_alpha_from_object_pose(obj_pose_cam2, velo_T_cam2)
            alpha_diff = wrap_to_pi(pred_alpha - obj['alpha'])
            assert np.abs(alpha_diff) < 0.011, "{} vs {}. alpha_diff={}".format(pred_alpha, obj['alpha'], alpha_diff)

            obj_info = OrderedDict()

            obj_info['id'] = obj_id
            obj_info['category'] = obj['type'].lower()
            obj_info['dimension'] = NoIndent(obj['dimension'][::-1])  # [length, width, height]
            obj_info['bbx_visible'] = NoIndent(bbx_visible.tolist())
            obj_info['bbx_amodal'] = NoIndent(np.around(bbx_amodal, decimals=6).tolist())
            obj_info['viewpoint'] = NoIndent(np.around(viewpoint, decimals=6).tolist())
            obj_info['center_proj'] = NoIndent(np.around(obj_origin_proj, decimals=6).tolist())
            obj_info['center_dist'] = round(float(distance), 6)

            obj_infos.append(obj_info)
        annotation['object_infos'] = obj_infos
        dataset.add_image_info(annotation)

    print 'Finished creating dataset with {} images and {} objects.'.format(dataset.num_of_images(), total_num_of_objects)

    metainfo = OrderedDict()
    metainfo['total_num_of_objects'] = total_num_of_objects
    metainfo['categories'] = NoIndent([x.lower() for x in args.categories])
    metainfo['min_height'] = min_height
    metainfo['max_occlusion'] = max_occlusion
    metainfo['max_truncation'] = max_truncation
    dataset.set_metainfo(metainfo)

    out_json_filename = dataset_name + '.json'
    print 'Saving annotations to {}'.format(out_json_filename)
    dataset.write_data_to_json(out_json_filename)
Пример #2
0
def main():
    """Main Function"""
    parser = argparse.ArgumentParser()
    parser.add_argument("-i",
                        "--image_dataset_file",
                        required=True,
                        type=str,
                        help="Path to image dataset file to split")
    parser.add_argument("-c",
                        "--category",
                        required=True,
                        type=str,
                        help="category to separate out")
    parser.add_argument("-s",
                        "--score_thresh",
                        default=0.0,
                        type=float,
                        help="minimum score")
    args = parser.parse_args()

    assert osp.isfile(args.image_dataset_file
                      ), '{} either do not exist or not a file'.format(
                          args.image_dataset_file)

    print('Loading image dataset from {}'.format(args.image_dataset_file))
    image_datset = ImageDataset.from_json(args.image_dataset_file)
    print(image_datset)
    num_of_objects = sum([
        len(img_info['object_infos'])
        for img_info in image_datset.image_infos()
    ])
    print("total number of objects = {}".format(num_of_objects))

    new_image_infos = []

    print('selecting object_infos with category {}'.format(args.category))
    for im_info in tqdm(image_datset.image_infos()):
        new_im_info = OrderedDict()

        for im_info_field in ['image_file', 'segm_file']:
            if im_info_field in im_info:
                new_im_info[im_info_field] = im_info[im_info_field]

        for im_info_field in ['image_size', 'image_intrinsic']:
            if im_info_field in im_info:
                new_im_info[im_info_field] = NoIndent(im_info[im_info_field])

        W = im_info['image_size'][0]
        H = im_info['image_size'][1]

        new_obj_infos = []
        for obj_id, obj_info in enumerate(im_info['object_infos']):
            if obj_info['category'] != args.category:
                continue

            if obj_info['score'] < args.score_thresh:
                continue

            new_obj_info = OrderedDict()

            if 'id' not in obj_info:
                obj_info['id'] = obj_id + 1

            vbbx = np.array(obj_info['bbx_visible'])
            assert_bbx(vbbx)
            vbbx = clip_bbx_by_image_size(vbbx, W, H)
            assert_bbx(vbbx)
            new_obj_info['bbx_visible'] = NoIndent(vbbx.tolist())

            for obj_info_field in ['id', 'category']:
                if obj_info_field in obj_info:
                    new_obj_info[obj_info_field] = obj_info[obj_info_field]

            for obj_info_field in [
                    'viewpoint', 'bbx_amodal', 'center_proj', 'dimension'
            ]:
                if obj_info_field in obj_info:
                    new_obj_info[obj_info_field] = NoIndent(
                        obj_info[obj_info_field])

            for obj_info_field in [
                    'center_dist', 'occlusion', 'truncation', 'shape_file',
                    'score'
            ]:
                if obj_info_field in obj_info:
                    new_obj_info[obj_info_field] = obj_info[obj_info_field]
            new_obj_infos.append(new_obj_info)

        if new_obj_infos:
            new_im_info['object_infos'] = new_obj_infos
            new_image_infos.append(new_im_info)

    new_dataset = ImageDataset(
        name="{}_{}".format(image_datset.name(), args.category))
    new_dataset.set_image_infos(new_image_infos)
    new_dataset.set_rootdir(image_datset.rootdir())
    num_of_objects = sum([
        len(img_info['object_infos'])
        for img_info in new_dataset.image_infos()
    ])

    metainfo = OrderedDict()
    metainfo['total_num_of_objects'] = num_of_objects
    metainfo['categories'] = NoIndent([args.category])
    metainfo['score_thresh'] = args.score_thresh
    new_dataset.set_metainfo(metainfo)

    print(new_dataset)
    print("new number of objects = {}".format(num_of_objects))

    new_dataset.write_data_to_json(new_dataset.name() + ".json")
Пример #3
0
def main():
    """main function"""
    root_dir_default = osp.join(_init_paths.root_dir, 'data', 'pascal3D',
                                'Pascal3D-Dataset')
    split_choices = ['train', 'val', 'trainval', 'test']
    sub_dataset_choices = ['imagenet', 'pascal']
    category_choices = ['car', 'motorbike', 'bicycle', 'bus']

    parser = argparse.ArgumentParser()
    parser.add_argument("-r",
                        "--root_dir",
                        default=root_dir_default,
                        help="Path to Pascal3d Object directory")
    parser.add_argument("-s",
                        "--split",
                        default='trainval',
                        choices=split_choices,
                        help="Split type")
    parser.add_argument("-d",
                        "--sub_dataset",
                        default='imagenet',
                        choices=sub_dataset_choices,
                        help="Sub dataset type")
    parser.add_argument("-c",
                        "--category",
                        type=str,
                        default='car',
                        choices=category_choices,
                        help="Object type (category)")
    parser.add_argument("-n",
                        "--dataset_name",
                        type=str,
                        help="Optional output dataset name")
    parser.add_argument('--no-truncated',
                        dest='keep_truncated',
                        action='store_false',
                        help="use this to remove truncated objects")
    parser.set_defaults(keep_truncated=True)
    parser.add_argument('--no-occluded',
                        dest='keep_occluded',
                        action='store_false',
                        help="use this to remove occluded objects")
    parser.set_defaults(keep_occluded=True)
    parser.add_argument('--no-difficult',
                        dest='keep_difficult',
                        action='store_false',
                        help="use this to remove difficult objects")
    parser.set_defaults(keep_difficult=True)
    args = parser.parse_args()

    assert osp.exists(args.root_dir), "Directory '{}' do not exist".format(
        args.root_dir)
    anno_dir = osp.join(args.root_dir, 'AnnotationsFixed',
                        '{}_{}'.format(args.category, args.sub_dataset))
    image_dir = osp.join(args.root_dir, 'Images',
                         '{}_{}'.format(args.category, args.sub_dataset))
    assert osp.exists(anno_dir), "Directory '{}' do not exist".format(anno_dir)
    assert osp.exists(image_dir), "Directory '{}' do not exist".format(
        image_dir)

    split_file = osp.join(_init_paths.root_dir, 'data', 'pascal3D', 'splits',
                          '{}_{}.txt'.format(args.sub_dataset, args.split))
    assert osp.exists(split_file), "Split file '{}' do not exist".format(
        split_file)

    print "split = {}".format(args.split)
    print "sub_dataset = {}".format(args.sub_dataset)
    print "category = {}".format(args.category)
    print "anno_dir = {}".format(anno_dir)
    print "image_dir = {}".format(image_dir)
    print "keep_truncated = {}".format(args.keep_truncated)
    print "keep_occluded = {}".format(args.keep_occluded)
    print "keep_difficult = {}".format(args.keep_difficult)

    image_names = [x.rstrip() for x in open(split_file)]
    num_of_images = len(image_names)
    print 'Using split {} with {} images'.format(args.split, num_of_images)

    # imagenet uses JPEG while pascal images are in jpg format
    image_ext = '.JPEG' if args.sub_dataset == 'imagenet' else '.jpg'

    if args.dataset_name:
        dataset_name = args.dataset_name
    else:
        dataset_name = 'pascal3d_{}_{}_{}'.format(args.sub_dataset, args.split,
                                                  args.category)
    dataset = ImageDataset(dataset_name)
    dataset.set_rootdir(args.root_dir)

    print "Importing dataset ..."
    for image_name in tqdm(image_names):
        anno_file = osp.join(anno_dir, image_name + '.mat')
        image_file = osp.join(image_dir, image_name + image_ext)

        if not osp.exists(anno_file):
            continue
        assert osp.exists(image_file), "Image file '{}' do not exist".format(
            image_file)

        image_info = OrderedDict()
        image_info['image_file'] = osp.relpath(image_file, args.root_dir)

        image = cv2.imread(image_file)
        assert image.size, "image loaded from '{}' is empty".format(image_file)

        W = image.shape[1]
        H = image.shape[0]
        image_info['image_size'] = NoIndent([W, H])

        record = sio.loadmat(anno_file)['record'].flatten()[0]
        assert record['filename'][
            0] == image_name + image_ext, "{} vs {}".format(
                record['filename'][0], image_name + image_ext)

        record_objects = record['objects'].flatten()
        obj_infos = []

        for obj_id in xrange(len(record_objects)):
            rec_obj = record_objects[obj_id]
            category = rec_obj['class'].flatten()[0]
            if category != args.category:
                continue

            occluded = bool(rec_obj['occluded'].flatten()[0])
            truncated = bool(rec_obj['truncated'].flatten()[0])
            difficult = bool(rec_obj['difficult'].flatten()[0])

            if not args.keep_truncated and truncated:
                continue
            if not args.keep_occluded and occluded:
                continue
            if not args.keep_difficult and difficult:
                continue

            rec_vp = rec_obj['viewpoint'].flatten()[0]
            distance = rec_vp['distance'].flatten()[0]
            if distance == 0.0:
                continue

            azimuth = math.radians(rec_vp['azimuth'][0, 0])
            elevation = math.radians(rec_vp['elevation'][0, 0])
            tilt = math.radians(rec_vp['theta'][0, 0])
            if azimuth == 0.0 and elevation == 0.0 and tilt == 0.0:
                continue

            viewpoint = np.around(np.array([azimuth, elevation, tilt],
                                           dtype=np.float),
                                  decimals=6)
            viewpoint = wrap_to_pi_array(viewpoint)

            assert_viewpoint(viewpoint)

            assert rec_vp['focal'][
                0,
                0] == 1, "rec_vp['focal'] is expected to be 1 but got {}".format(
                    rec_vp['focal'][0, 0])
            center_proj = np.array([rec_vp['px'][0, 0], rec_vp['py'][0, 0]],
                                   dtype=np.float)
            assert_coord2D(center_proj)

            vbbx = rec_obj['bbox'].flatten()
            assert_bbx(vbbx)
            vbbx = clip_bbx_by_image_size(vbbx, W, H)
            if np.any(vbbx[:2] >= vbbx[2:]):
                continue

            obj_info = OrderedDict()
            obj_info['id'] = obj_id
            obj_info['category'] = category

            # since we dont have precise measure, use an approximate measure
            obj_info['occlusion'] = 0.5 if occluded else 0.0
            obj_info['truncation'] = 0.5 if truncated else 0.0
            obj_info['difficulty'] = 0.5 if difficult else 0.0

            vbbx = np.around(vbbx, decimals=6)
            assert_bbx(vbbx)
            obj_info['bbx_visible'] = NoIndent(vbbx.tolist())

            if 'abbx' in rec_obj.dtype.names:
                abbx = rec_obj['abbx'].flatten()
                if abbx.shape == (4, ):
                    assert_bbx(abbx)
                    obj_info['bbx_amodal'] = NoIndent(
                        np.around(abbx, decimals=6).tolist())

            obj_info['viewpoint'] = NoIndent(viewpoint.tolist())
            obj_info['center_proj'] = NoIndent(
                np.around(center_proj, decimals=6).tolist())

            obj_infos.append(obj_info)

        # only add if we have atleast 1 object
        if obj_infos:
            image_info['object_infos'] = obj_infos
            dataset.add_image_info(image_info)

    total_num_of_objects = sum(
        [len(img_info['object_infos']) for img_info in dataset.image_infos()])
    print 'Finished creating dataset with {} images and {} objects.'.format(
        dataset.num_of_images(), total_num_of_objects)

    num_of_objects_with_abbx = sum([
        len([
            obj_info for obj_info in img_info['object_infos']
            if 'bbx_amodal' in obj_info
        ]) for img_info in dataset.image_infos()
    ])
    print "Number of objects with bbx_amodal information = {}".format(
        num_of_objects_with_abbx)
    metainfo = OrderedDict()
    metainfo['total_num_of_objects'] = total_num_of_objects
    metainfo['categories'] = NoIndent([args.category])
    dataset.set_metainfo(metainfo)

    out_json_filename = dataset_name + '.json'
    dataset.write_data_to_json(out_json_filename)
def test_single_weights_file(weights_file, net, input_dataset):
    """Test already initalized net with a new set of weights"""
    net.copy_from(weights_file)
    net.layers[0].generate_datum_ids()

    input_num_of_objects = sum([len(image_info['object_infos']) for image_info in input_dataset.image_infos()])
    assert net.layers[0].curr_data_ids_idx == 0
    assert net.layers[0].number_of_datapoints() == input_num_of_objects
    assert net.layers[0].data_ids == range(input_num_of_objects)

    data_samples = net.layers[0].data_samples
    num_of_data_samples = len(data_samples)
    batch_size = net.layers[0].batch_size
    num_of_batches = int(np.ceil(num_of_data_samples / float(batch_size)))

    assert len(net.layers[0].image_loader) == input_dataset.num_of_images()

    # Create Result dataset
    result_dataset = ImageDataset(input_dataset.name())
    result_dataset.set_rootdir(input_dataset.rootdir())
    result_dataset.set_metainfo(input_dataset.metainfo().copy())

    # Add weight file and its md5 checksum to metainfo
    result_dataset.metainfo()['weights_file'] = weights_file
    result_dataset.metainfo()['weights_file_md5'] = md5(open(weights_file, 'rb').read()).hexdigest()

    # Set the image level fields
    for input_im_info in input_dataset.image_infos():
        result_im_info = OrderedDict()
        result_im_info['image_file'] = input_im_info['image_file']
        result_im_info['image_size'] = NoIndent(input_im_info['image_size'])
        result_im_info['image_intrinsic'] = NoIndent(input_im_info['image_intrinsic'])
        result_im_info['object_infos'] = []
        result_dataset.add_image_info(result_im_info)

    assert result_dataset.num_of_images() == input_dataset.num_of_images()

    assert_funcs = {
        "viewpoint": assert_viewpoint,
        "bbx_visible": assert_bbx,
        "bbx_amodal": assert_bbx,
        "center_proj": assert_coord2D,
    }

    performance_metric = {}

    print 'Evaluating for {} batches with {} imaes per batch.'.format(num_of_batches, batch_size)
    for b in tqdm.trange(num_of_batches):
        start_idx = batch_size * b
        end_idx = min(batch_size * (b + 1), num_of_data_samples)
        # print 'Working on batch: %d/%d (Image# %d - %d)' % (b, num_of_batches, start_idx, end_idx)
        output = net.forward()

        # store all accuracy outputs
        for key in [key for key in output if any(x in key for x in ["accuracy", "iou", "error"])]:
            assert np.squeeze(output[key]).shape == (), "Expects {} output to be scalar but got {}".format(key, output[key].shape)
            current_batch_accuracy = float(np.squeeze(output[key]))
            if key in performance_metric:
                performance_metric[key].append(current_batch_accuracy)
            else:
                performance_metric[key] = [current_batch_accuracy]

        for i in xrange(start_idx, end_idx):
            image_id = data_samples[i]['image_id']
            image_info = result_dataset.image_infos()[image_id]

            object_info = OrderedDict()

            # since we are not changing cetegory orid it is directly copied
            object_info['id'] = data_samples[i]['id']
            object_info['category'] = data_samples[i]['category']

            # since we are not predicting bbx_visible, it is directly copied
            object_info['bbx_visible'] = NoIndent(data_samples[i]['bbx_visible'].tolist())

            for info in ["bbx_amodal", "viewpoint", "center_proj"]:
                pred_info = "pred_" + info
                if pred_info in net.blobs:
                    prediction = np.squeeze(net.blobs[pred_info].data[i - start_idx, ...])
                    assert_funcs[info](prediction)
                    object_info[info] = NoIndent(prediction.tolist())

            image_info['object_infos'].append(object_info)

    for key in sorted(performance_metric):
        performance_metric[key] = np.mean(performance_metric[key])
        print 'Test set {}: {:.4f}'.format(key, performance_metric[key])

    regex = re.compile('iter_([0-9]*).caffemodel')
    performance_metric['iter'] = int(regex.findall(weights_file)[0])

    result_num_of_objects = sum([len(image_info['object_infos']) for image_info in result_dataset.image_infos()])
    assert result_num_of_objects == num_of_data_samples
    return result_dataset, performance_metric
def run_inference(weights_file, net, input_dataset):
    """Run inference with already initalized net with a new set of weights"""
    net.copy_from(weights_file)
    net.layers[0].generate_datum_ids()

    num_of_images = input_dataset.num_of_images()
    assert net.layers[0].curr_data_ids_idx == 0
    assert net.layers[0].number_of_datapoints() == num_of_images
    assert net.layers[0].data_ids == range(num_of_images)

    assert len(net.layers[0].image_loader) == num_of_images
    assert len(net.layers[0].data_samples) == num_of_images
    assert net.layers[
        0].rois_per_image < 0, "rois_per_image need to be dynamic for testing"
    assert net.layers[
        0].imgs_per_batch == 1, "We only support one image per batch while testing"
    assert net.layers[0].flip_ratio < 0, "No flipping while testing"
    assert net.layers[0].jitter_iou_min > 1, "No jittering"

    # Create Result dataset
    result_dataset = ImageDataset(input_dataset.name())
    result_dataset.set_rootdir(input_dataset.rootdir())
    result_dataset.set_metainfo(input_dataset.metainfo().copy())

    # Add weight file and its md5 checksum to metainfo
    result_dataset.metainfo()['weights_file'] = weights_file
    result_dataset.metainfo()['weights_file_md5'] = md5(
        open(weights_file, 'rb').read()).hexdigest()

    # Set the image level fields
    for input_im_info in input_dataset.image_infos():
        result_im_info = OrderedDict()
        result_im_info['image_file'] = input_im_info['image_file']
        result_im_info['image_size'] = input_im_info['image_size']
        if 'image_intrinsic' in input_im_info:
            result_im_info['image_intrinsic'] = input_im_info[
                'image_intrinsic']
        obj_infos = []
        for input_obj_info in input_im_info['object_infos']:
            obj_info = OrderedDict()
            for field in ['id', 'category', 'score', 'bbx_visible']:
                if field in input_obj_info:
                    obj_info[field] = input_obj_info[field]
            obj_infos.append(obj_info)
        result_im_info['object_infos'] = obj_infos
        assert len(result_im_info['object_infos']) == len(
            input_im_info['object_infos'])
        result_dataset.add_image_info(result_im_info)

    assert result_dataset.num_of_images() == num_of_images
    assert len(net.layers[0].data_samples) == num_of_images
    for result_img_info, layer_img_info in zip(result_dataset.image_infos(),
                                               net.layers[0].data_samples):
        assert len(result_img_info['object_infos']) == len(
            layer_img_info['object_infos'])

    assert_funcs = {
        "viewpoint": assert_viewpoint,
        "bbx_visible": assert_bbx,
        "bbx_amodal": assert_bbx,
        "center_proj": assert_coord2D,
    }

    print 'Running inference for {} images.'.format(num_of_images)
    for image_id in tqdm.trange(num_of_images):
        # Run forward pass
        _ = net.forward()

        img_info = result_dataset.image_infos()[image_id]
        expected_num_of_rois = len(img_info['object_infos'])
        assert net.blobs['rois'].data.shape == (
            expected_num_of_rois,
            5), "{}_{}".format(net.blobs['rois'].data.shape,
                               expected_num_of_rois)

        for info in ["bbx_amodal", "viewpoint", "center_proj"]:
            pred_info = "pred_" + info
            if pred_info in net.blobs:
                assert net.blobs[pred_info].data.shape[
                    0] == expected_num_of_rois

        for i, obj_info in enumerate(img_info['object_infos']):
            for info in ["bbx_amodal", "viewpoint", "center_proj"]:
                pred_info = "pred_" + info
                if pred_info in net.blobs:
                    prediction = np.squeeze(net.blobs[pred_info].data[i, ...])
                    assert_funcs[info](prediction)
                    obj_info[info] = prediction.tolist()

    return result_dataset