Пример #1
0
    def __init__(self, hyperparams):
        # Get configs, hyperparameters & initialize agent
        self.ptconf, self.dconf, self.camconf = hyperparams[
            'ptconf'], hyperparams['dconf'], hyperparams['camconf']
        self.env_name = hyperparams['env']
        self.gui_on = hyperparams['bullet_gui_on']
        self.hyperparams = hyperparams
        config = copy(AGENT_BULLET)
        config.update(hyperparams)
        Agent.__init__(self, config)

        # Setup bullet environment
        self._setup_conditions()
        self._setup_world(hyperparams['filename'])
        self.setup_bullet()
        self.setup_inference_camera()
        # Get demo data
        self.demo_vid = imageio.get_reader(
            join(self.dconf.DEMO_DIR, self.dconf.DEMO_NAME,
                 'rgb/{}.mp4'.format(self.dconf.SEQNAME)))
        self.demo_frames = []
        self.reset_condition = hyperparams['reset_condition']
        for im in self.demo_vid:
            self.demo_frames.append(im)
        # self.cost_tgt_mean = hyperparams['cost_tgt_mean']
        # self.cost_tgt_std = hyperparams['cost_tgt_std']
        self.objects_centroids_mrcnn = hyperparams['objects_centroids_mrcnn']

        # Setup feature embedding network if enabled
        self.tcn = None

        # Setup Mask RCNN
        inference_config = InferenceConfig()
        with tf.device('/device:GPU:1'):
            self.mrcnn = modellib.MaskRCNN(
                mode='inference',
                model_dir=join(self.ptconf.EXP_DIR, self.ptconf.EXP_NAME,
                               'mrcnn_logs'),
                config=inference_config)
            self.mrcnn.load_weights(self.ptconf.WEIGHTS_FILE_PATH_MRCNN,
                                    by_name=True)
        self.class_names = gconf.CLASS_NAMES_W_BG
        # self.target_ids = gconf.CLASS_IDS
        self.target_ids = [1, 2]
        self.colors = visualize.random_colors(7)
        self.plotting_on = hyperparams['plotting_on']
        if self.plotting_on:
            self.fig, self.ax = visualize.get_ax()
        self.mrcnn_centroids_last_known = {
            key: None
            for key in self.target_ids
        }
        st()

        # Run MRCNN on one image because first detection always takes long to initialize
        self.vid_seqname = 0
        rgb_crop, depth_crop = self.get_images(0)
        results = self.mrcnn.detect([rgb_crop], verbose=0)
        self.blob_detector = BlobDetector()
Пример #2
0
    def _get_rcnn_features(self, image, depth_rescaled):
        results = self.rcnn.detect([image], verbose=0)
        r = results[0]
        encountered_ids = []
        all_cropped_boxes = []
        all_centroids_unordered = []  # X Y Z
        all_centroids = dict()
        all_visual_features_unordered = []
        all_visual_features = dict()
        for i, box in enumerate(r['rois']):
            class_id = r['class_ids'][i]
            if class_id not in self.target_ids or class_id in encountered_ids:
                continue
            encountered_ids.append(class_id)
            cropped = utils.crop_box(image, box, y_offset=20, x_offset=20)
            # cropped = utils.resize_image(cropped, max_dim=299)[0]
            cropped = cv2.resize(cropped, (299, 299))
            all_cropped_boxes.append(cropped)

            masked_depth = depth_rescaled * r['masks'][:, :, i]
            masked_depth = masked_depth[np.where(masked_depth > 0)]
            z = np.median(np.sort(masked_depth.flatten()))

            x, y = utils.get_box_center(box)

            all_centroids_unordered.append([x, y, z])
            all_visual_features_unordered.append(r['roi_features'][i])

        all_cropped_boxes = np.asarray(all_cropped_boxes)
        all_centroids_unordered = np.asarray(all_centroids_unordered)

        for i in range(all_cropped_boxes.shape[0]):
            all_visual_features[
                encountered_ids[i]] = all_visual_features_unordered[i]
            all_centroids[encountered_ids[i]] = all_centroids_unordered[i]
        all_centroids = np.asarray([val for key, val in all_centroids.items()])
        all_visual_features = np.asarray(
            [val for key, val in all_visual_features.items()])
        if self.plot_mode:
            fig, ax = visualize.get_ax()
            ax = visualize.display_instances(image,
                                             r['rois'],
                                             r['masks'],
                                             r['class_ids'],
                                             self.class_names,
                                             r['scores'],
                                             ax=ax,
                                             colors=self.colors)
        else:
            fig = None
        return all_visual_features, all_centroids, fig
Пример #3
0
 def get_raw_rcnn_results(self, image):
     results = self.rcnn.detect([image], verbose=0)
     r = results[0]
     if self.plot_mode:
         fig, ax = visualize.get_ax()
         ax = visualize.display_instances(image,
                                          r['rois'],
                                          r['masks'],
                                          r['class_ids'],
                                          self.class_names,
                                          r['scores'],
                                          ax=ax,
                                          colors=self.colors)
     else:
         fig = None
     return r, fig
Пример #4
0
                                          step_offset_back:step_size]
imgs = imgs[TRUNCATE_FRONT + step_offset_front:-TRUNCATE_BACK - pose_offset -
            step_offset_back:step_size]
dimgs = dimgs[TRUNCATE_FRONT + step_offset_front:-TRUNCATE_BACK - pose_offset -
              step_offset_back:step_size]
print('length of truncated demo video: ', len(imgs))

##### PLOTTING & COST PARSING #####
### Compute pixel feature target and plot images ###

tgt_imgs = []
plot_imgs = []  # target imgs stacked horizontally for plotting purposes
points_tgt = []

if agent['visualize_mrcnn']:
    fig, ax = visualize.get_ax()
else:
    ax = None

# reader_w_robot = imageio.get_reader('/media/zhouxian/ed854110-6801-4dcd-9acf-c4f904955d71/iccv2019/gps_data/pour/lacan9_view0/data_files/2019-03-29_01-10-44/vids/rgb_sample_itr_9.mp4')
# imgs_r = []
# for img in reader_w_robot:
#     imgs_r.append(img)

# Compute MRCNN centroids & masks if not loaded from file

# Compute MRCNN centroids & masks if not loaded from file
for tt in range(agent['T']):
    dimg = dimgs[tt]
    img = imgs[tt]
    agent['demo_imgs'].append(img)
Пример #5
0
if FROM_DATASET:
    # # Test on a random image
    # Validation dataset

    image_ids = dataset.image_ids
    APs = []
    for image_id in image_ids:

        image = dataset.load_image(image_id)

        results = model.detect([image], verbose=1)

        r = results[0]
        visualize.display_instances(image, r['rois'], r['masks'], r['class_ids'], 
                                dataset.class_names, r['scores'], ax=visualize.get_ax()[1], colors=colors)

else:
    for seq_name in os.listdir('/home/msieb/projects/Mask_RCNN/datasets/baxter/test'):
        print("Processing ", seq_name)
    # seq_name = args.seqname
        dataset_dir = os.path.join(DATASET_DIR, "test", seq_name)
        filenames = os.listdir(dataset_dir)
        filenames = [file for file in filenames if '.jpg' in file]
        filenames = sorted(filenames, key=lambda x: x.split('.')[0])

        for ii, file in enumerate(filenames):
            if not ii % 1 == 0:
                continue
            # # Load image and ground truth data
            # image, image_meta, gt_class_id, gt_bbox, gt_mask =\