Beispiel #1
0
    def _run_prediction_single_model(self, model_dir, model_output_dir,
                                     dataset_config):
        """Analyze the performance of a single model."""
        # Read in model config.
        model_config_filename = os.path.join(model_dir,
                                             GQCNNFilenames.SAVED_CFG)
        with open(model_config_filename) as data_file:
            model_config = json.load(data_file)

        # Load model.
        self.logger.info("Loading model %s" % (model_dir))
        log_file = None
        for handler in self.logger.handlers:
            if isinstance(handler, logging.FileHandler):
                log_file = handler.baseFilename
        gqcnn = get_gqcnn_model(verbose=self.verbose).load(
            model_dir, verbose=self.verbose, log_file=log_file)
        gqcnn.open_session()
        gripper_mode = gqcnn.gripper_mode
        angular_bins = gqcnn.angular_bins

        # Read params from the config.
        if dataset_config is None:
            dataset_dir = model_config["dataset_dir"]
            split_name = model_config["split_name"]
            image_field_name = model_config["image_field_name"]
            pose_field_name = model_config["pose_field_name"]
            metric_name = model_config["target_metric_name"]
            metric_thresh = model_config["metric_thresh"]
        else:
            dataset_dir = dataset_config["dataset_dir"]
            split_name = dataset_config["split_name"]
            image_field_name = dataset_config["image_field_name"]
            pose_field_name = dataset_config["pose_field_name"]
            metric_name = dataset_config["target_metric_name"]
            metric_thresh = dataset_config["metric_thresh"]
            gripper_mode = dataset_config["gripper_mode"]

        self.logger.info("Loading dataset %s" % (dataset_dir))
        dataset = TensorDataset.open(dataset_dir)
        train_indices, val_indices, _ = dataset.split(split_name)

        # Visualize conv filters.
        conv1_filters = gqcnn.filters
        num_filt = conv1_filters.shape[3]
        d = utils.sqrt_ceil(num_filt)
        vis2d.clf()
        for k in range(num_filt):
            filt = conv1_filters[:, :, 0, k]
            vis2d.subplot(d, d, k + 1)
            vis2d.imshow(DepthImage(filt))
            figname = os.path.join(model_output_dir, "conv1_filters.pdf")
        vis2d.savefig(figname, dpi=self.dpi)

        # Aggregate training and validation true labels and predicted
        # probabilities.
        all_predictions = []
        if angular_bins > 0:
            all_predictions_raw = []
        all_labels = []
        for i in range(dataset.num_tensors):
            # Log progress.
            if i % self.log_rate == 0:
                self.logger.info("Predicting tensor %d of %d" %
                                 (i + 1, dataset.num_tensors))

            # Read in data.
            image_arr = dataset.tensor(image_field_name, i).arr
            pose_arr = read_pose_data(
                dataset.tensor(pose_field_name, i).arr, gripper_mode)
            metric_arr = dataset.tensor(metric_name, i).arr
            label_arr = 1 * (metric_arr > metric_thresh)
            label_arr = label_arr.astype(np.uint8)
            if angular_bins > 0:
                # Form mask to extract predictions from ground-truth angular
                # bins.
                raw_poses = dataset.tensor(pose_field_name, i).arr
                angles = raw_poses[:, 3]
                neg_ind = np.where(angles < 0)
                # TODO(vsatish): These should use the max angle instead.
                angles = np.abs(angles) % GeneralConstants.PI
                angles[neg_ind] *= -1
                g_90 = np.where(angles > (GeneralConstants.PI / 2))
                l_neg_90 = np.where(angles < (-1 * (GeneralConstants.PI / 2)))
                angles[g_90] -= GeneralConstants.PI
                angles[l_neg_90] += GeneralConstants.PI
                # TODO(vsatish): Fix this along with the others.
                angles *= -1  # Hack to fix reverse angle convention.
                angles += (GeneralConstants.PI / 2)
                pred_mask = np.zeros((raw_poses.shape[0], angular_bins * 2),
                                     dtype=bool)
                bin_width = GeneralConstants.PI / angular_bins
                for i in range(angles.shape[0]):
                    pred_mask[i, int((angles[i] // bin_width) * 2)] = True
                    pred_mask[i, int((angles[i] // bin_width) * 2 + 1)] = True

            # Predict with GQ-CNN.
            predictions = gqcnn.predict(image_arr, pose_arr)
            if angular_bins > 0:
                raw_predictions = np.array(predictions)
                predictions = predictions[pred_mask].reshape((-1, 2))

            # Aggregate.
            all_predictions.extend(predictions[:, 1].tolist())
            if angular_bins > 0:
                all_predictions_raw.extend(raw_predictions.tolist())
            all_labels.extend(label_arr.tolist())

        # Close session.
        gqcnn.close_session()

        # Create arrays.
        all_predictions = np.array(all_predictions)
        all_labels = np.array(all_labels)
        train_predictions = all_predictions[train_indices]
        val_predictions = all_predictions[val_indices]
        train_labels = all_labels[train_indices]
        val_labels = all_labels[val_indices]
        if angular_bins > 0:
            all_predictions_raw = np.array(all_predictions_raw)
            train_predictions_raw = all_predictions_raw[train_indices]
            val_predictions_raw = all_predictions_raw[val_indices]

        # Aggregate results.
        train_result = BinaryClassificationResult(train_predictions,
                                                  train_labels)
        val_result = BinaryClassificationResult(val_predictions, val_labels)
        train_result.save(os.path.join(model_output_dir, "train_result.cres"))
        val_result.save(os.path.join(model_output_dir, "val_result.cres"))

        # Get stats, plot curves.
        self.logger.info("Model %s training error rate: %.3f" %
                         (model_dir, train_result.error_rate))
        self.logger.info("Model %s validation error rate: %.3f" %
                         (model_dir, val_result.error_rate))

        self.logger.info("Model %s training loss: %.3f" %
                         (model_dir, train_result.cross_entropy_loss))
        self.logger.info("Model %s validation loss: %.3f" %
                         (model_dir, val_result.cross_entropy_loss))

        # Save images.
        vis2d.figure()
        example_dir = os.path.join(model_output_dir, "examples")
        if not os.path.exists(example_dir):
            os.mkdir(example_dir)

        # Train.
        self.logger.info("Saving training examples")
        train_example_dir = os.path.join(example_dir, "train")
        if not os.path.exists(train_example_dir):
            os.mkdir(train_example_dir)

        # Train TP.
        true_positive_indices = train_result.true_positive_indices
        np.random.shuffle(true_positive_indices)
        true_positive_indices = true_positive_indices[:self.num_vis]
        for i, j in enumerate(true_positive_indices):
            k = train_indices[j]
            datapoint = dataset.datapoint(
                k, field_names=[image_field_name, pose_field_name])
            vis2d.clf()
            if angular_bins > 0:
                self._plot_grasp(datapoint,
                                 image_field_name,
                                 pose_field_name,
                                 gripper_mode,
                                 angular_preds=train_predictions_raw[j])
            else:
                self._plot_grasp(datapoint, image_field_name, pose_field_name,
                                 gripper_mode)
            vis2d.title(
                "Datapoint %d: Pred: %.3f Label: %.3f" %
                (k, train_result.pred_probs[j], train_result.labels[j]),
                fontsize=self.font_size)
            vis2d.savefig(
                os.path.join(train_example_dir,
                             "true_positive_%03d.png" % (i)))

        # Train FP.
        false_positive_indices = train_result.false_positive_indices
        np.random.shuffle(false_positive_indices)
        false_positive_indices = false_positive_indices[:self.num_vis]
        for i, j in enumerate(false_positive_indices):
            k = train_indices[j]
            datapoint = dataset.datapoint(
                k, field_names=[image_field_name, pose_field_name])
            vis2d.clf()
            if angular_bins > 0:
                self._plot_grasp(datapoint,
                                 image_field_name,
                                 pose_field_name,
                                 gripper_mode,
                                 angular_preds=train_predictions_raw[j])
            else:
                self._plot_grasp(datapoint, image_field_name, pose_field_name,
                                 gripper_mode)
            vis2d.title(
                "Datapoint %d: Pred: %.3f Label: %.3f" %
                (k, train_result.pred_probs[j], train_result.labels[j]),
                fontsize=self.font_size)
            vis2d.savefig(
                os.path.join(train_example_dir,
                             "false_positive_%03d.png" % (i)))

        # Train TN.
        true_negative_indices = train_result.true_negative_indices
        np.random.shuffle(true_negative_indices)
        true_negative_indices = true_negative_indices[:self.num_vis]
        for i, j in enumerate(true_negative_indices):
            k = train_indices[j]
            datapoint = dataset.datapoint(
                k, field_names=[image_field_name, pose_field_name])
            vis2d.clf()
            if angular_bins > 0:
                self._plot_grasp(datapoint,
                                 image_field_name,
                                 pose_field_name,
                                 gripper_mode,
                                 angular_preds=train_predictions_raw[j])
            else:
                self._plot_grasp(datapoint, image_field_name, pose_field_name,
                                 gripper_mode)
            vis2d.title(
                "Datapoint %d: Pred: %.3f Label: %.3f" %
                (k, train_result.pred_probs[j], train_result.labels[j]),
                fontsize=self.font_size)
            vis2d.savefig(
                os.path.join(train_example_dir,
                             "true_negative_%03d.png" % (i)))

        # Train TP.
        false_negative_indices = train_result.false_negative_indices
        np.random.shuffle(false_negative_indices)
        false_negative_indices = false_negative_indices[:self.num_vis]
        for i, j in enumerate(false_negative_indices):
            k = train_indices[j]
            datapoint = dataset.datapoint(
                k, field_names=[image_field_name, pose_field_name])
            vis2d.clf()
            if angular_bins > 0:
                self._plot_grasp(datapoint,
                                 image_field_name,
                                 pose_field_name,
                                 gripper_mode,
                                 angular_preds=train_predictions_raw[j])
            else:
                self._plot_grasp(datapoint, image_field_name, pose_field_name,
                                 gripper_mode)
            vis2d.title(
                "Datapoint %d: Pred: %.3f Label: %.3f" %
                (k, train_result.pred_probs[j], train_result.labels[j]),
                fontsize=self.font_size)
            vis2d.savefig(
                os.path.join(train_example_dir,
                             "false_negative_%03d.png" % (i)))

        # Val.
        self.logger.info("Saving validation examples")
        val_example_dir = os.path.join(example_dir, "val")
        if not os.path.exists(val_example_dir):
            os.mkdir(val_example_dir)

        # Val TP.
        true_positive_indices = val_result.true_positive_indices
        np.random.shuffle(true_positive_indices)
        true_positive_indices = true_positive_indices[:self.num_vis]
        for i, j in enumerate(true_positive_indices):
            k = val_indices[j]
            datapoint = dataset.datapoint(
                k, field_names=[image_field_name, pose_field_name])
            vis2d.clf()
            if angular_bins > 0:
                self._plot_grasp(datapoint,
                                 image_field_name,
                                 pose_field_name,
                                 gripper_mode,
                                 angular_preds=val_predictions_raw[j])
            else:
                self._plot_grasp(datapoint, image_field_name, pose_field_name,
                                 gripper_mode)
            vis2d.title("Datapoint %d: Pred: %.3f Label: %.3f" %
                        (k, val_result.pred_probs[j], val_result.labels[j]),
                        fontsize=self.font_size)
            vis2d.savefig(
                os.path.join(val_example_dir, "true_positive_%03d.png" % (i)))

        # Val FP.
        false_positive_indices = val_result.false_positive_indices
        np.random.shuffle(false_positive_indices)
        false_positive_indices = false_positive_indices[:self.num_vis]
        for i, j in enumerate(false_positive_indices):
            k = val_indices[j]
            datapoint = dataset.datapoint(
                k, field_names=[image_field_name, pose_field_name])
            vis2d.clf()
            if angular_bins > 0:
                self._plot_grasp(datapoint,
                                 image_field_name,
                                 pose_field_name,
                                 gripper_mode,
                                 angular_preds=val_predictions_raw[j])
            else:
                self._plot_grasp(datapoint, image_field_name, pose_field_name,
                                 gripper_mode)
            vis2d.title("Datapoint %d: Pred: %.3f Label: %.3f" %
                        (k, val_result.pred_probs[j], val_result.labels[j]),
                        fontsize=self.font_size)
            vis2d.savefig(
                os.path.join(val_example_dir, "false_positive_%03d.png" % (i)))

        # Val TN.
        true_negative_indices = val_result.true_negative_indices
        np.random.shuffle(true_negative_indices)
        true_negative_indices = true_negative_indices[:self.num_vis]
        for i, j in enumerate(true_negative_indices):
            k = val_indices[j]
            datapoint = dataset.datapoint(
                k, field_names=[image_field_name, pose_field_name])
            vis2d.clf()
            if angular_bins > 0:
                self._plot_grasp(datapoint,
                                 image_field_name,
                                 pose_field_name,
                                 gripper_mode,
                                 angular_preds=val_predictions_raw[j])
            else:
                self._plot_grasp(datapoint, image_field_name, pose_field_name,
                                 gripper_mode)
            vis2d.title("Datapoint %d: Pred: %.3f Label: %.3f" %
                        (k, val_result.pred_probs[j], val_result.labels[j]),
                        fontsize=self.font_size)
            vis2d.savefig(
                os.path.join(val_example_dir, "true_negative_%03d.png" % (i)))

        # Val TP.
        false_negative_indices = val_result.false_negative_indices
        np.random.shuffle(false_negative_indices)
        false_negative_indices = false_negative_indices[:self.num_vis]
        for i, j in enumerate(false_negative_indices):
            k = val_indices[j]
            datapoint = dataset.datapoint(
                k, field_names=[image_field_name, pose_field_name])
            vis2d.clf()
            if angular_bins > 0:
                self._plot_grasp(datapoint,
                                 image_field_name,
                                 pose_field_name,
                                 gripper_mode,
                                 angular_preds=val_predictions_raw[j])
            else:
                self._plot_grasp(datapoint, image_field_name, pose_field_name,
                                 gripper_mode)
            vis2d.title("Datapoint %d: Pred: %.3f Label: %.3f" %
                        (k, val_result.pred_probs[j], val_result.labels[j]),
                        fontsize=self.font_size)
            vis2d.savefig(
                os.path.join(val_example_dir, "false_negative_%03d.png" % (i)))

        # Save summary stats.
        train_summary_stats = {
            "error_rate": train_result.error_rate,
            "ap_score": train_result.ap_score,
            "auc_score": train_result.auc_score,
            "loss": train_result.cross_entropy_loss
        }
        train_stats_filename = os.path.join(model_output_dir,
                                            "train_stats.json")
        json.dump(train_summary_stats,
                  open(train_stats_filename, "w"),
                  indent=JSON_INDENT,
                  sort_keys=True)

        val_summary_stats = {
            "error_rate": val_result.error_rate,
            "ap_score": val_result.ap_score,
            "auc_score": val_result.auc_score,
            "loss": val_result.cross_entropy_loss
        }
        val_stats_filename = os.path.join(model_output_dir, "val_stats.json")
        json.dump(val_summary_stats,
                  open(val_stats_filename, "w"),
                  indent=JSON_INDENT,
                  sort_keys=True)

        return train_result, val_result
Beispiel #2
0
    def _run_prediction_single_model(self, model_dir,
                                     model_output_dir,
                                     dataset_config):
        """ Analyze the performance of a single model. """
        # read in model config
        model_config_filename = os.path.join(model_dir, 'config.json')
        with open(model_config_filename) as data_file:
            model_config = json.load(data_file)

        # load model
        self.logger.info('Loading model %s' %(model_dir))
        log_file = None
        for handler in self.logger.handlers:
            if isinstance(handler, logging.FileHandler):
                log_file = handler.baseFilename
        gqcnn = get_gqcnn_model(verbose=self.verbose).load(model_dir, verbose=self.verbose, log_file=log_file)
        gqcnn.open_session()
        gripper_mode = gqcnn.gripper_mode
        angular_bins = gqcnn.angular_bins
        
        # read params from the config
        if dataset_config is None:
            dataset_dir = model_config['dataset_dir']
            split_name = model_config['split_name']
            image_field_name = model_config['image_field_name']
            pose_field_name = model_config['pose_field_name']
            metric_name = model_config['target_metric_name']
            metric_thresh = model_config['metric_thresh']
        else:
            dataset_dir = dataset_config['dataset_dir']
            split_name = dataset_config['split_name']
            image_field_name = dataset_config['image_field_name']
            pose_field_name = dataset_config['pose_field_name']
            metric_name = dataset_config['target_metric_name']
            metric_thresh = dataset_config['metric_thresh']
            gripper_mode = dataset_config['gripper_mode']
            
        self.logger.info('Loading dataset %s' %(dataset_dir))
        dataset = TensorDataset.open(dataset_dir)
        train_indices, val_indices, _ = dataset.split(split_name)
        
        # visualize conv filters
        conv1_filters = gqcnn.filters
        num_filt = conv1_filters.shape[3]
        d = utils.sqrt_ceil(num_filt)
        vis2d.clf()
        for k in range(num_filt):
            filt = conv1_filters[:,:,0,k]
            vis2d.subplot(d,d,k+1)
            vis2d.imshow(DepthImage(filt))
            figname = os.path.join(model_output_dir, 'conv1_filters.pdf')
        vis2d.savefig(figname, dpi=self.dpi)
        
        # aggregate training and validation true labels and predicted probabilities
        all_predictions = []
        if angular_bins > 0:
            all_predictions_raw = []
        all_labels = []
        for i in range(dataset.num_tensors):
            # log progress
            if i % self.log_rate == 0:
                self.logger.info('Predicting tensor %d of %d' %(i+1, dataset.num_tensors))

            # read in data
            image_arr = dataset.tensor(image_field_name, i).arr
            pose_arr = read_pose_data(dataset.tensor(pose_field_name, i).arr,
                                      gripper_mode)
            metric_arr = dataset.tensor(metric_name, i).arr
            label_arr = 1 * (metric_arr > metric_thresh)
            label_arr = label_arr.astype(np.uint8)
            if angular_bins > 0:
                # form mask to extract predictions from ground-truth angular bins
                raw_poses = dataset.tensor(pose_field_name, i).arr
                angles = raw_poses[:, 3]
                neg_ind = np.where(angles < 0)
                angles = np.abs(angles) % GeneralConstants.PI
                angles[neg_ind] *= -1
                g_90 = np.where(angles > (GeneralConstants.PI / 2))
                l_neg_90 = np.where(angles < (-1 * (GeneralConstants.PI / 2)))
                angles[g_90] -= GeneralConstants.PI
                angles[l_neg_90] += GeneralConstants.PI
                angles *= -1 # hack to fix reverse angle convention
                angles += (GeneralConstants.PI / 2)
                pred_mask = np.zeros((raw_poses.shape[0], angular_bins*2), dtype=bool)
                bin_width = GeneralConstants.PI / angular_bins
                for i in range(angles.shape[0]):
                    pred_mask[i, int((angles[i] // bin_width)*2)] = True
                    pred_mask[i, int((angles[i] // bin_width)*2 + 1)] = True

            # predict with GQ-CNN
            predictions = gqcnn.predict(image_arr, pose_arr)
            if angular_bins > 0:
                raw_predictions = np.array(predictions)
                predictions = predictions[pred_mask].reshape((-1, 2))
            
            # aggregate
            all_predictions.extend(predictions[:,1].tolist())
            if angular_bins > 0:
                all_predictions_raw.extend(raw_predictions.tolist())
            all_labels.extend(label_arr.tolist())
            
        # close session
        gqcnn.close_session()            

        # create arrays
        all_predictions = np.array(all_predictions)
        all_labels = np.array(all_labels)
        train_predictions = all_predictions[train_indices]
        val_predictions = all_predictions[val_indices]
        train_labels = all_labels[train_indices]
        val_labels = all_labels[val_indices]
        if angular_bins > 0:
            all_predictions_raw = np.array(all_predictions_raw)
            train_predictions_raw = all_predictions_raw[train_indices]
            val_predictions_raw = all_predictions_raw[val_indices]        

        # aggregate results
        train_result = BinaryClassificationResult(train_predictions, train_labels)
        val_result = BinaryClassificationResult(val_predictions, val_labels)
        train_result.save(os.path.join(model_output_dir, 'train_result.cres'))
        val_result.save(os.path.join(model_output_dir, 'val_result.cres'))

        # get stats, plot curves
        self.logger.info('Model %s training error rate: %.3f' %(model_dir, train_result.error_rate))
        self.logger.info('Model %s validation error rate: %.3f' %(model_dir, val_result.error_rate))

        self.logger.info('Model %s training loss: %.3f' %(model_dir, train_result.cross_entropy_loss))
        self.logger.info('Model %s validation loss: %.3f' %(model_dir, val_result.cross_entropy_loss))

        # save images
        vis2d.figure()
        example_dir = os.path.join(model_output_dir, 'examples')
        if not os.path.exists(example_dir):
            os.mkdir(example_dir)

        # train
        self.logger.info('Saving training examples')
        train_example_dir = os.path.join(example_dir, 'train')
        if not os.path.exists(train_example_dir):
            os.mkdir(train_example_dir)
            
        # train TP
        true_positive_indices = train_result.true_positive_indices
        np.random.shuffle(true_positive_indices)
        true_positive_indices = true_positive_indices[:self.num_vis]
        for i, j in enumerate(true_positive_indices):
            k = train_indices[j]
            datapoint = dataset.datapoint(k, field_names=[image_field_name,
                                                          pose_field_name])
            vis2d.clf()
            if angular_bins > 0:
                self._plot_grasp(datapoint, image_field_name, pose_field_name, gripper_mode, angular_preds=train_predictions_raw[j])
            else:
                self._plot_grasp(datapoint, image_field_name, pose_field_name, gripper_mode)
            vis2d.title('Datapoint %d: Pred: %.3f Label: %.3f' %(k,
                                                                 train_result.pred_probs[j],
                                                                 train_result.labels[j]),
                        fontsize=self.font_size)
            vis2d.savefig(os.path.join(train_example_dir, 'true_positive_%03d.png' %(i)))

        # train FP
        false_positive_indices = train_result.false_positive_indices
        np.random.shuffle(false_positive_indices)
        false_positive_indices = false_positive_indices[:self.num_vis]
        for i, j in enumerate(false_positive_indices):
            k = train_indices[j]
            datapoint = dataset.datapoint(k, field_names=[image_field_name,
                                                          pose_field_name])
            vis2d.clf()
            if angular_bins > 0:
                self._plot_grasp(datapoint, image_field_name, pose_field_name, gripper_mode, angular_preds=train_predictions_raw[j])
            else: 
                self._plot_grasp(datapoint, image_field_name, pose_field_name, gripper_mode)
            vis2d.title('Datapoint %d: Pred: %.3f Label: %.3f' %(k,
                                                                 train_result.pred_probs[j],
                                                                 train_result.labels[j]),
                        fontsize=self.font_size)
            vis2d.savefig(os.path.join(train_example_dir, 'false_positive_%03d.png' %(i)))

        # train TN
        true_negative_indices = train_result.true_negative_indices
        np.random.shuffle(true_negative_indices)
        true_negative_indices = true_negative_indices[:self.num_vis]
        for i, j in enumerate(true_negative_indices):
            k = train_indices[j]
            datapoint = dataset.datapoint(k, field_names=[image_field_name,
                                                          pose_field_name])
            vis2d.clf()
            if angular_bins > 0:
                self._plot_grasp(datapoint, image_field_name, pose_field_name, gripper_mode, angular_preds=train_predictions_raw[j])
            else: 
                self._plot_grasp(datapoint, image_field_name, pose_field_name, gripper_mode)
            vis2d.title('Datapoint %d: Pred: %.3f Label: %.3f' %(k,
                                                                 train_result.pred_probs[j],
                                                                 train_result.labels[j]),
                        fontsize=self.font_size)
            vis2d.savefig(os.path.join(train_example_dir, 'true_negative_%03d.png' %(i)))

        # train TP
        false_negative_indices = train_result.false_negative_indices
        np.random.shuffle(false_negative_indices)
        false_negative_indices = false_negative_indices[:self.num_vis]
        for i, j in enumerate(false_negative_indices):
            k = train_indices[j]
            datapoint = dataset.datapoint(k, field_names=[image_field_name,
                                                          pose_field_name])
            vis2d.clf()
            if angular_bins > 0:
                self._plot_grasp(datapoint, image_field_name, pose_field_name, gripper_mode, angular_preds=train_predictions_raw[j])
            else: 
                self._plot_grasp(datapoint, image_field_name, pose_field_name, gripper_mode)
            vis2d.title('Datapoint %d: Pred: %.3f Label: %.3f' %(k,
                                                                 train_result.pred_probs[j],
                                                                 train_result.labels[j]),
                        fontsize=self.font_size)
            vis2d.savefig(os.path.join(train_example_dir, 'false_negative_%03d.png' %(i)))

        # val
        self.logger.info('Saving validation examples')
        val_example_dir = os.path.join(example_dir, 'val')
        if not os.path.exists(val_example_dir):
            os.mkdir(val_example_dir)

        # val TP
        true_positive_indices = val_result.true_positive_indices
        np.random.shuffle(true_positive_indices)
        true_positive_indices = true_positive_indices[:self.num_vis]
        for i, j in enumerate(true_positive_indices):
            k = val_indices[j]
            datapoint = dataset.datapoint(k, field_names=[image_field_name,
                                                          pose_field_name])
            vis2d.clf()
            if angular_bins > 0:
                self._plot_grasp(datapoint, image_field_name, pose_field_name, gripper_mode, angular_preds=val_predictions_raw[j])
            else: 
                self._plot_grasp(datapoint, image_field_name, pose_field_name, gripper_mode)
            vis2d.title('Datapoint %d: Pred: %.3f Label: %.3f' %(k,
                                                                 val_result.pred_probs[j],
                                                                 val_result.labels[j]),
                        fontsize=self.font_size)
            vis2d.savefig(os.path.join(val_example_dir, 'true_positive_%03d.png' %(i)))

        # val FP
        false_positive_indices = val_result.false_positive_indices
        np.random.shuffle(false_positive_indices)
        false_positive_indices = false_positive_indices[:self.num_vis]
        for i, j in enumerate(false_positive_indices):
            k = val_indices[j]
            datapoint = dataset.datapoint(k, field_names=[image_field_name,
                                                          pose_field_name])
            vis2d.clf()
            if angular_bins > 0:
                self._plot_grasp(datapoint, image_field_name, pose_field_name, gripper_mode, angular_preds=val_predictions_raw[j])
            else: 
                self._plot_grasp(datapoint, image_field_name, pose_field_name, gripper_mode)
            vis2d.title('Datapoint %d: Pred: %.3f Label: %.3f' %(k,
                                                                 val_result.pred_probs[j],
                                                                 val_result.labels[j]),
                        fontsize=self.font_size)
            vis2d.savefig(os.path.join(val_example_dir, 'false_positive_%03d.png' %(i)))

        # val TN
        true_negative_indices = val_result.true_negative_indices
        np.random.shuffle(true_negative_indices)
        true_negative_indices = true_negative_indices[:self.num_vis]
        for i, j in enumerate(true_negative_indices):
            k = val_indices[j]
            datapoint = dataset.datapoint(k, field_names=[image_field_name,
                                                          pose_field_name])
            vis2d.clf()
            if angular_bins > 0:
                self._plot_grasp(datapoint, image_field_name, pose_field_name, gripper_mode, angular_preds=val_predictions_raw[j])
            else: 
                self._plot_grasp(datapoint, image_field_name, pose_field_name, gripper_mode)
            vis2d.title('Datapoint %d: Pred: %.3f Label: %.3f' %(k,
                                                                 val_result.pred_probs[j],
                                                                 val_result.labels[j]),
                        fontsize=self.font_size)
            vis2d.savefig(os.path.join(val_example_dir, 'true_negative_%03d.png' %(i)))

        # val TP
        false_negative_indices = val_result.false_negative_indices
        np.random.shuffle(false_negative_indices)
        false_negative_indices = false_negative_indices[:self.num_vis]
        for i, j in enumerate(false_negative_indices):
            k = val_indices[j]
            datapoint = dataset.datapoint(k, field_names=[image_field_name,
                                                          pose_field_name])
            vis2d.clf()
            if angular_bins > 0:
                self._plot_grasp(datapoint, image_field_name, pose_field_name, gripper_mode, angular_preds=val_predictions_raw[j])
            else: 
                self._plot_grasp(datapoint, image_field_name, pose_field_name, gripper_mode)
            vis2d.title('Datapoint %d: Pred: %.3f Label: %.3f' %(k,
                                                                 val_result.pred_probs[j],
                                                                 val_result.labels[j]),
                        fontsize=self.font_size)
            vis2d.savefig(os.path.join(val_example_dir, 'false_negative_%03d.png' %(i)))
            
        # save summary stats
        train_summary_stats = {
            'error_rate': train_result.error_rate,
            'ap_score': train_result.ap_score,
            'auc_score': train_result.auc_score,
            'loss': train_result.cross_entropy_loss
        }
        train_stats_filename = os.path.join(model_output_dir, 'train_stats.json')
        json.dump(train_summary_stats, open(train_stats_filename, 'w'),
                  indent=JSON_INDENT,
                  sort_keys=True)

        val_summary_stats = {
            'error_rate': val_result.error_rate,
            'ap_score': val_result.ap_score,
            'auc_score': val_result.auc_score,
            'loss': val_result.cross_entropy_loss            
        }
        val_stats_filename = os.path.join(model_output_dir, 'val_stats.json')
        json.dump(val_summary_stats, open(val_stats_filename, 'w'),
                  indent=JSON_INDENT,
                  sort_keys=True)        
        
        return train_result, val_result
Beispiel #3
0
    def visualise(self, model_dir, output_dir):
        """
        Evaluates the model on the dataset in self.datadir. Plots and saves the resulting classification accuracies.


        Parameters
        ----------
        model_dir (str): Path to the model.
        output_dir (str): Path to store the classification accuracies of the models

        Returns
        -------

        """
        # Create output dir if it doesn't exist yet
        if not os.path.exists(output_dir):
            os.mkdir(output_dir)

        # Set up logger
        self.logger = Logger.get_logger(self.__class__.__name__,
                                        log_file=os.path.join(
                                            output_dir, "analysis.log"),
                                        silence=(not self.verbose),
                                        global_log_file=self.verbose)
        self.logger.info("Saving output to %s" % output_dir)

        model_config = YamlConfig(model_dir + '/config.json')
        self.gripper_mode = model_config['gqcnn']['gripper_mode']
        if 'pose_input' in model_config['gqcnn']:
            self.pose_input = model_config['gqcnn']['pose_input']

        # Load models
        checkpoints = []
        if self.analyse_checkpoints:
            model_files = os.listdir(model_dir)
            for model_file in model_files:
                if 'model' in model_file:
                    if model_file[5] == '_':
                        checkpoints.append(int(model_file[6:].split('.')[0]))
        checkpoints.append('final')
        checkpoints = list(set(checkpoints))
        models = self._read_model(model_dir, checkpoints)

        # Initiate accuracy variables
        elev_bins = np.arange(2.5, 72.5, 5)

        accuracies = {}
        for checkpoint in checkpoints:
            accuracies[checkpoint] = {}
            for elev in elev_bins:
                accuracies[checkpoint][elev] = {'acc': [], 'tp': [], 'tn': [], 'num_p': [], 'num_n': []}
        stepsize = 50

        # Read and predict data with all models
        for steps in range(0, len(self.files), stepsize):
            self.logger.info("Read in tensors %d to %d" % (steps, steps+stepsize))
            image_arr, pose_arr, all_labels, elev_arr = self._read_data(steps, stepsize)
            for elev in elev_bins:
                mask = (elev_arr.squeeze() >= elev - 2.5) & (elev_arr.squeeze() < elev + 2.5)
                images = image_arr[mask]
                poses = pose_arr[mask]
                labels = all_labels[mask]
                for cnt, model in enumerate(models):
                    preds = model.predict(images, poses)
                    if preds is not None:
                        results = BinaryClassificationResult(preds[:, 1], labels)
                        accuracies[checkpoints[cnt]][elev]['acc'].append(100 * results.accuracy)
                        accuracies[checkpoints[cnt]][elev]['tp'].append(len(results.true_positive_indices))
                        accuracies[checkpoints[cnt]][elev]['tn'].append(len(results.true_negative_indices))
                        accuracies[checkpoints[cnt]][elev]['num_p'].append(len(labels[labels == 1]))
                        accuracies[checkpoints[cnt]][elev]['num_n'].append(len(labels[labels == 0]))

        # Calculate prediction accuracy for all models and all elevation (phi) angles
        for checkpoint in checkpoints:
            true_acc = []
            false_acc = []
            all_acc = []
            self.logger.info("Checkpoint: " + str(checkpoint))
            for elev in elev_bins:
                try:
                    tacc = sum(accuracies[checkpoint][elev]['tp']) / sum(accuracies[checkpoint][elev]['num_p']) * 100
                    facc = sum(accuracies[checkpoint][elev]['tn']) / sum(accuracies[checkpoint][elev]['num_n']) * 100
                    acc = (sum(accuracies[checkpoint][elev]['tn']) + sum(accuracies[checkpoint][elev]['tp']))/\
                          (sum(accuracies[checkpoint][elev]['num_p']) + sum(accuracies[checkpoint][elev]['num_n'])) * 100
                    true_acc.append(tacc)
                    false_acc.append(facc)
                    all_acc.append(acc)
                    self.logger.info("Elev: %.1f, Accuracy positive grasps: %.1f %%" % (elev, tacc))
                    self.logger.info("Elev: %.1f, Accuracy negative grasps: %.1f %%" % (elev, facc))
                    self.logger.info("Elev: %.1f, Accuracy all grasps: %.1f %%" % (elev, acc))
                except ZeroDivisionError:
                    self.logger.info("Elev: %.1f, no grasps" % elev)

            # Save output to txt file
            np.savetxt(output_dir + '/' + str(checkpoint) + '_tacc', true_acc, '%.1f')
            np.savetxt(output_dir + '/' + str(checkpoint) + '_facc', false_acc, '%.1f')
            np.savetxt(output_dir + '/' + str(checkpoint) + '_acc', all_acc, '%.1f')

            # Plot the outputs
            plt.figure()
            plt.plot(elev_bins, true_acc)
            plt.title("Prediction accuracy on positive grasps over varying elevation angles")
            plt.xlabel("Elevation angle [deg]")
            plt.ylabel("Accuracy [%]")
            plt.ylim((0, 100))
            plt.xlim((0, 60))
            plt.savefig(output_dir + '/' + str(checkpoint) + '_True_Accuracy.png')
            plt.close()

            plt.figure()
            plt.plot(elev_bins, false_acc)
            plt.title("Prediction accuracy on negative grasps over varying elevation angles")
            plt.xlabel("Elevation angle [deg]")
            plt.ylabel("Accuracy [%]")
            plt.ylim((0, 100))
            plt.xlim((0, 60))
            plt.savefig(output_dir + '/' + str(checkpoint) + '_Neg_Accuracy.png')
            plt.close()
Beispiel #4
0
    def _run_prediction_single_model(self, model_dir, model_output_dir,
                                     dataset_config):
        """ Analyze the performance of a single model. """
        # read in model config
        model_config_filename = os.path.join(model_dir, 'config.json')
        with open(model_config_filename) as data_file:
            model_config = json.load(data_file)

        # load model
        logging.info('Loading model %s' % (model_dir))
        gqcnn = GQCNN.load(model_dir)
        gqcnn.open_session()
        gripper_mode = gqcnn.gripper_mode

        # read params from the config
        if dataset_config is None:
            dataset_dir = model_config['dataset_dir']
            split_name = model_config['split_name']
            image_field_name = model_config['image_field_name']
            pose_field_name = model_config['pose_field_name']
            metric_name = model_config['target_metric_name']
            metric_thresh = model_config['metric_thresh']
        else:
            dataset_dir = dataset_config['dataset_dir']
            split_name = dataset_config['split_name']
            image_field_name = dataset_config['image_field_name']
            pose_field_name = dataset_config['pose_field_name']
            metric_name = dataset_config['target_metric_name']
            metric_thresh = dataset_config['metric_thresh']
            gripper_mode = dataset_config['gripper_mode']

        logging.info('Loading dataset %s' % (dataset_dir))
        dataset = TensorDataset.open(dataset_dir)
        train_indices, val_indices, _ = dataset.split(split_name)

        # visualize conv filters
        conv1_filters = gqcnn.filters
        num_filt = conv1_filters.shape[3]
        d = utils.sqrt_ceil(num_filt)
        vis2d.clf()
        for k in range(num_filt):
            filt = conv1_filters[:, :, 0, k]
            vis2d.subplot(d, d, k + 1)
            vis2d.imshow(DepthImage(filt))
            figname = os.path.join(model_output_dir, 'conv1_filters.pdf')
        vis2d.savefig(figname, dpi=self.dpi)

        # aggregate training and validation true labels and predicted probabilities
        all_predictions = []
        all_labels = []
        for i in range(dataset.num_tensors):
            # log progress
            if i % self.log_rate == 0:
                logging.info('Predicting tensor %d of %d' %
                             (i + 1, dataset.num_tensors))

            # read in data
            image_arr = dataset.tensor(image_field_name, i).arr
            pose_arr = read_pose_data(
                dataset.tensor(pose_field_name, i).arr, gripper_mode)
            metric_arr = dataset.tensor(metric_name, i).arr
            label_arr = 1 * (metric_arr > metric_thresh)
            label_arr = label_arr.astype(np.uint8)

            # predict with GQ-CNN
            predictions = gqcnn.predict(image_arr, pose_arr)

            # aggregate
            all_predictions.extend(predictions[:, 1].tolist())
            all_labels.extend(label_arr.tolist())

        # close session
        gqcnn.close_session()

        # create arrays
        all_predictions = np.array(all_predictions)
        all_labels = np.array(all_labels)
        train_predictions = all_predictions[train_indices]
        val_predictions = all_predictions[val_indices]
        train_labels = all_labels[train_indices]
        val_labels = all_labels[val_indices]

        # aggregate results
        train_result = BinaryClassificationResult(train_predictions,
                                                  train_labels)
        val_result = BinaryClassificationResult(val_predictions, val_labels)
        train_result.save(os.path.join(model_output_dir, 'train_result.cres'))
        val_result.save(os.path.join(model_output_dir, 'val_result.cres'))

        # get stats, plot curves
        logging.info('Model %s training error rate: %.3f' %
                     (model_dir, train_result.error_rate))
        logging.info('Model %s validation error rate: %.3f' %
                     (model_dir, val_result.error_rate))

        # save images
        vis2d.figure()
        example_dir = os.path.join(model_output_dir, 'examples')
        if not os.path.exists(example_dir):
            os.mkdir(example_dir)

        # train
        logging.info('Saving training examples')
        train_example_dir = os.path.join(example_dir, 'train')
        if not os.path.exists(train_example_dir):
            os.mkdir(train_example_dir)

        # train TP
        true_positive_indices = train_result.true_positive_indices
        np.random.shuffle(true_positive_indices)
        true_positive_indices = true_positive_indices[:self.num_vis]
        for i, j in enumerate(true_positive_indices):
            k = train_indices[j]
            datapoint = dataset.datapoint(
                k, field_names=[image_field_name, pose_field_name])
            vis2d.clf()
            self._plot_grasp(datapoint, image_field_name, pose_field_name,
                             gripper_mode)
            vis2d.title(
                'Datapoint %d: Pred: %.3f Label: %.3f' %
                (k, train_result.pred_probs[j], train_result.labels[j]),
                fontsize=self.font_size)
            vis2d.savefig(
                os.path.join(train_example_dir,
                             'true_positive_%03d.png' % (i)))

        # train FP
        false_positive_indices = train_result.false_positive_indices
        np.random.shuffle(false_positive_indices)
        false_positive_indices = false_positive_indices[:self.num_vis]
        for i, j in enumerate(false_positive_indices):
            k = train_indices[j]
            datapoint = dataset.datapoint(
                k, field_names=[image_field_name, pose_field_name])
            vis2d.clf()
            self._plot_grasp(datapoint, image_field_name, pose_field_name,
                             gripper_mode)
            vis2d.title(
                'Datapoint %d: Pred: %.3f Label: %.3f' %
                (k, train_result.pred_probs[j], train_result.labels[j]),
                fontsize=self.font_size)
            vis2d.savefig(
                os.path.join(train_example_dir,
                             'false_positive_%03d.png' % (i)))

        # train TN
        true_negative_indices = train_result.true_negative_indices
        np.random.shuffle(true_negative_indices)
        true_negative_indices = true_negative_indices[:self.num_vis]
        for i, j in enumerate(true_negative_indices):
            k = train_indices[j]
            datapoint = dataset.datapoint(
                k, field_names=[image_field_name, pose_field_name])
            vis2d.clf()
            self._plot_grasp(datapoint, image_field_name, pose_field_name,
                             gripper_mode)
            vis2d.title(
                'Datapoint %d: Pred: %.3f Label: %.3f' %
                (k, train_result.pred_probs[j], train_result.labels[j]),
                fontsize=self.font_size)
            vis2d.savefig(
                os.path.join(train_example_dir,
                             'true_negative_%03d.png' % (i)))

        # train TP
        false_negative_indices = train_result.false_negative_indices
        np.random.shuffle(false_negative_indices)
        false_negative_indices = false_negative_indices[:self.num_vis]
        for i, j in enumerate(false_negative_indices):
            k = train_indices[j]
            datapoint = dataset.datapoint(
                k, field_names=[image_field_name, pose_field_name])
            vis2d.clf()
            self._plot_grasp(datapoint, image_field_name, pose_field_name,
                             gripper_mode)
            vis2d.title(
                'Datapoint %d: Pred: %.3f Label: %.3f' %
                (k, train_result.pred_probs[j], train_result.labels[j]),
                fontsize=self.font_size)
            vis2d.savefig(
                os.path.join(train_example_dir,
                             'false_negative_%03d.png' % (i)))

        # val
        logging.info('Saving validation examples')
        val_example_dir = os.path.join(example_dir, 'val')
        if not os.path.exists(val_example_dir):
            os.mkdir(val_example_dir)

        # val TP
        true_positive_indices = val_result.true_positive_indices
        np.random.shuffle(true_positive_indices)
        true_positive_indices = true_positive_indices[:self.num_vis]
        for i, j in enumerate(true_positive_indices):
            k = val_indices[j]
            datapoint = dataset.datapoint(
                k, field_names=[image_field_name, pose_field_name])
            vis2d.clf()
            self._plot_grasp(datapoint, image_field_name, pose_field_name,
                             gripper_mode)
            vis2d.title('Datapoint %d: Pred: %.3f Label: %.3f' %
                        (k, val_result.pred_probs[j], val_result.labels[j]),
                        fontsize=self.font_size)
            vis2d.savefig(
                os.path.join(val_example_dir, 'true_positive_%03d.png' % (i)))

        # val FP
        false_positive_indices = val_result.false_positive_indices
        np.random.shuffle(false_positive_indices)
        false_positive_indices = false_positive_indices[:self.num_vis]
        for i, j in enumerate(false_positive_indices):
            k = val_indices[j]
            datapoint = dataset.datapoint(
                k, field_names=[image_field_name, pose_field_name])
            vis2d.clf()
            self._plot_grasp(datapoint, image_field_name, pose_field_name,
                             gripper_mode)
            vis2d.title('Datapoint %d: Pred: %.3f Label: %.3f' %
                        (k, val_result.pred_probs[j], val_result.labels[j]),
                        fontsize=self.font_size)
            vis2d.savefig(
                os.path.join(val_example_dir, 'false_positive_%03d.png' % (i)))

        # val TN
        true_negative_indices = val_result.true_negative_indices
        np.random.shuffle(true_negative_indices)
        true_negative_indices = true_negative_indices[:self.num_vis]
        for i, j in enumerate(true_negative_indices):
            k = val_indices[j]
            datapoint = dataset.datapoint(
                k, field_names=[image_field_name, pose_field_name])
            vis2d.clf()
            self._plot_grasp(datapoint, image_field_name, pose_field_name,
                             gripper_mode)
            vis2d.title('Datapoint %d: Pred: %.3f Label: %.3f' %
                        (k, val_result.pred_probs[j], val_result.labels[j]),
                        fontsize=self.font_size)
            vis2d.savefig(
                os.path.join(val_example_dir, 'true_negative_%03d.png' % (i)))

        # val TP
        false_negative_indices = val_result.false_negative_indices
        np.random.shuffle(false_negative_indices)
        false_negative_indices = false_negative_indices[:self.num_vis]
        for i, j in enumerate(false_negative_indices):
            k = val_indices[j]
            datapoint = dataset.datapoint(
                k, field_names=[image_field_name, pose_field_name])
            vis2d.clf()
            self._plot_grasp(datapoint, image_field_name, pose_field_name,
                             gripper_mode)
            vis2d.title('Datapoint %d: Pred: %.3f Label: %.3f' %
                        (k, val_result.pred_probs[j], val_result.labels[j]),
                        fontsize=self.font_size)
            vis2d.savefig(
                os.path.join(val_example_dir, 'false_negative_%03d.png' % (i)))

        # save summary stats
        train_summary_stats = {
            'error_rate': train_result.error_rate,
            'ap_score': train_result.ap_score,
            'auc_score': train_result.auc_score
        }
        train_stats_filename = os.path.join(model_output_dir,
                                            'train_stats.json')
        json.dump(train_summary_stats,
                  open(train_stats_filename, 'w'),
                  indent=JSON_INDENT,
                  sort_keys=True)

        val_summary_stats = {
            'error_rate': val_result.error_rate,
            'ap_score': val_result.ap_score,
            'auc_score': val_result.auc_score
        }
        val_stats_filename = os.path.join(model_output_dir, 'val_stats.json')
        json.dump(val_summary_stats,
                  open(val_stats_filename, 'w'),
                  indent=JSON_INDENT,
                  sort_keys=True)

        return train_result, val_result
Beispiel #5
0
    def _run_prediction(self, model_dir, model_output_dir, data_dir,
                        noise_analysis, depth_analysis, perturb_analysis,
                        single_analysis):
        """Predict the outcome of the file for a single model."""

        # Read in model config.
        model_config_filename = os.path.join(model_dir, "config.json")
        with open(model_config_filename) as data_file:
            model_config = json.load(data_file)

        # Load model.
        self.logger.info("Loading model %s" % (model_dir))
        log_file = None
        for handler in self.logger.handlers:
            if isinstance(handler, logging.FileHandler):
                log_file = handler.baseFilename
        gqcnn = get_gqcnn_model(verbose=self.verbose).load(
            model_dir, verbose=self.verbose, log_file=log_file)
        gqcnn.open_session()
        gripper_mode = gqcnn.gripper_mode
        angular_bins = gqcnn.angular_bins

        # Load data
        if noise_analysis:
            image_arr, pose_arr, labels, width_arr, file_arr, noise_arr = self._read_data(
                data_dir, noise=True)
        elif depth_analysis:
            image_arr, pose_arr, labels, width_arr, file_arr, depth_arr = self._read_data(
                data_dir, depth=True)
        elif perturb_analysis:
            image_arr, pose_arr, labels, width_arr, file_arr, perturb_arr = self._read_data(
                data_dir, perturb=True)
        elif single_analysis:
            image_arr, pose_arr, labels, width_arr, file_arr, perturb_arr = self._read_data(
                data_dir, perturb=True)
        else:
            image_arr, pose_arr, labels, width_arr, file_arr, obj_arr = self._read_data(
                data_dir)
        # Predict outcomes
        predictions = gqcnn.predict(image_arr, pose_arr)
        gqcnn.close_session()
        results = BinaryClassificationResult(predictions[:, 1], labels)

        # Log the results
        if noise_analysis:
            # Analyse the error rates in regard to the noise levels of the images
            noise_levels = np.unique(noise_arr)
            levels = len(noise_levels)
            for current_noise in noise_levels:
                pred = predictions[noise_arr[:, 0] == current_noise]
                lab = labels[noise_arr[:, 0] == current_noise]
                res = BinaryClassificationResult(pred[:, 1], lab)
                self._plot_histograms(pred[:, 1], lab, str(current_noise),
                                      model_output_dir)
                self.logger.info("Noise: %.4f Model %s error rate: %.3f" %
                                 (current_noise, model_dir, res.error_rate))
                self.logger.info(
                    "Noise: %.4f Model %s loss: %.3f" %
                    (current_noise, model_dir, res.cross_entropy_loss))
        elif depth_analysis:
            # Analyse the error rates in regard to the grasping depth in the images
            depth_levels = np.unique(depth_arr)
            levels = len(depth_levels)
            for current_depth in depth_levels:
                if current_depth == -1:
                    depth_mode = 'original'
                else:
                    depth_mode = 'relative %.2f' % (current_depth)
                pred = predictions[depth_arr == current_depth]
                lab = labels[depth_arr == current_depth]
                res = BinaryClassificationResult(pred[:, 1], lab)
                self._plot_histograms(pred[:, 1], lab, depth_mode,
                                      model_output_dir)
                self.logger.info("Depth %s Model %s error rate: %.3f" %
                                 (depth_mode, model_dir, res.error_rate))
                self.logger.info(
                    "Depth: %s Model %s loss: %.3f" %
                    (depth_mode, model_dir, res.cross_entropy_loss))
        elif perturb_analysis:
            # Analyse the error rates in regard to the grasping perturb in the images
            perturb_levels = np.unique(perturb_arr)
            print("Perturb levels: ", perturb_levels)
            _rot = len(np.unique(perturb_arr[:, 0]))
            _trans = len(np.unique(perturb_arr[:, 1]))
            try:
                _transy = len(np.unique(perturb_arr[:, 2]))
            except:
                _transy = 0
                print("No translation in y included")
            if _rot >= 2 and _trans <= 1 and _transy <= 1:
                perturbation = 'rotation'
                perturb_unit = 'deg'
                index = 0
            elif _rot <= 1 and _trans >= 2 and _transy <= 1:
                perturbation = 'translation'
                perturb_unit = 'pixel'
                index = 1
            elif _rot <= 1 and _trans <= 1 and _transy >= 2:
                perturbation = 'translationy'
                perturb_unit = 'pixel'
                index = 2
            else:
                raise ValueError(
                    "Perturbation array includes at least two different perturbation types. Can't be handled. Abort."
                )
                return None
            levels = len(perturb_levels)
            accuracies = []
            for current_perturb in perturb_levels:
                pred = predictions[perturb_arr[:, index] == current_perturb]
                lab = labels[perturb_arr[:, index] == current_perturb]
                res = BinaryClassificationResult(pred[:, 1], lab)
                perturb_mode = perturbation + ' %.0f ' % (
                    current_perturb) + perturb_unit
                self._plot_histograms(
                    pred[:, 1], lab,
                    perturbation + '_%.0f_' % (current_perturb) + perturb_unit,
                    model_output_dir)

                self.logger.info("Grasp %s Model %s error rate: %.3f" %
                                 (perturb_mode, model_dir, res.error_rate))
                accuracies.append(100 - res.error_rate)
                self.logger.info(
                    "Grasp %s Model %s loss: %.3f" %
                    (perturb_mode, model_dir, res.cross_entropy_loss))
            self._plot_grasp_perturbations(perturb_levels, accuracies,
                                           model_output_dir, perturbation)
        elif single_analysis:
            # Analyse the error rates in regard to the grasping perturb in the images
            perturb_levels = np.unique(perturb_arr)
            _rot = np.count_nonzero(perturb_arr[:, 0])
            _trans = np.count_nonzero(perturb_arr[:, 1])
            _transy = np.count_nonzero(perturb_arr[:, 2])
            _scalez = np.count_nonzero(perturb_arr[:, 3])
            _scalex = np.count_nonzero(perturb_arr[:, 4])
            if _rot >= 1 and _trans == 0 and _transy == 0 and _scalez == 0 and _scalex == 0:
                index = 0
                perturbation = 'rotation'
            elif _rot == 0 and _trans >= 1 and _transy == 0 and _scalez == 0 and _scalex == 0:
                perturbation = 'translation'
                index = 1
            elif _rot == 0 and _trans == 0 and _transy >= 1 and _scalez == 0 and _scalex == 0:
                perturbation = 'translationy'
                index = 2
            elif _rot == 0 and _trans == 0 and _transy == 0 and _scalez >= 1 and _scalex == 0:
                perturbation = 'scale_height'
                index = 3
            elif _rot == 0 and _trans == 0 and _transy == 0 and _scalez == 0 and _scalex >= 1:
                perturbation = 'scalex'
                index = 4
            else:
                perturbation = 'mixed'
                index = 5
            # Create new output dir for single file and perturbation mode
            print(len(perturb_arr))
            if len(perturb_arr) == 1:
                print("New output direction is: ", model_output_dir)
            else:
                model_output_dir = os.path.join(
                    model_output_dir,
                    str(file_arr[0][0]) + '_' + str(file_arr[0][1]) + '_' +
                    perturbation)
                print("New output direction is: ", model_output_dir)
            if not os.path.exists(model_output_dir):
                os.mkdir(model_output_dir)
            # Set up new logger.
            self.logger = Logger.get_logger(self.__class__.__name__,
                                            log_file=os.path.join(
                                                model_output_dir,
                                                "analysis.log"),
                                            silence=(not self.verbose),
                                            global_log_file=self.verbose)
            levels = len(perturb_arr)
            abs_pred_errors = []
            if levels == 1:
                self.logger.info(
                    "Mixed perturbation. Translationx %.1f, Translationy %.1f, "
                    "Rotation %.1f, Scale_height %.1f, Scale x %.1f" %
                    (perturb_arr[0][1], perturb_arr[0][2], perturb_arr[0][0],
                     perturb_arr[0][3], perturb_arr[0][4]))
                pred = predictions
                lab = labels
                res = BinaryClassificationResult(pred[:, 1], lab)
                self.logger.info("Grasp %s Model %s prediction: %.3f" %
                                 (perturbation, model_dir, pred[:, 1]))
                self.logger.info("Grasp %s Model %s error rate: %.3f" %
                                 (perturbation, model_dir, res.error_rate))
                self.logger.info(
                    "Grasp %s Model %s loss: %.3f" %
                    (perturbation, model_dir, res.cross_entropy_loss))

            else:
                for current_perturb in perturb_levels:
                    pred = predictions[perturb_arr[:,
                                                   index] == current_perturb]
                    lab = labels[perturb_arr[:, index] == current_perturb]
                    res = BinaryClassificationResult(pred[:, 1], lab)

                    if perturbation == 'rotation':
                        perturb_mode = 'rotation %.0f deg' % (current_perturb)
                    elif perturbation == 'translation':
                        perturb_mode = 'translation in x %.0f pixel' % (
                            current_perturb)
                    elif perturbation == 'translationy':
                        perturb_mode = 'translation in y %.0f pixel' % (
                            current_perturb)
                    elif perturbation == 'scale_height':
                        perturb_mode = 'scaling depth by %.0f' % (
                            current_perturb)
                    elif perturbation == 'scalex':
                        perturb_mode = 'scaling x by %.0f' % (current_perturb)
                    pos_errors, neg_errors = self._calculate_prediction_errors(
                        pred[:, 1], lab)
                    # Only append positive errors if grasp was positive.
                    if pos_errors:
                        abs_pred_errors.append(pos_errors)
                    self.logger.info("Grasp %s Model %s prediction: %.3f" %
                                     (perturb_mode, model_dir, pred[:, 1]))
                    self.logger.info("Grasp %s Model %s error rate: %.3f" %
                                     (perturb_mode, model_dir, res.error_rate))
                    self.logger.info(
                        "Grasp %s Model %s loss: %.3f" %
                        (perturb_mode, model_dir, res.cross_entropy_loss))
                if pos_errors:
                    self._plot_single_grasp_perturbations(
                        perturb_levels, abs_pred_errors, model_output_dir,
                        perturbation)
        else:
            levels = 1
            self._plot_histograms(predictions[:, 1], labels, '',
                                  model_output_dir)
            self.logger.info("Model %s error rate: %.3f" %
                             (model_dir, results.error_rate))
            self.logger.info("Model %s loss: %.3f" %
                             (model_dir, results.cross_entropy_loss))

            if obj_arr is not None and 'Cornell' in data_dir:
                unique = np.unique(obj_arr).tolist()
                object_label = pd.read_csv(
                    DATA_PATH + "Cornell/original/z.txt",
                    sep=" ",
                    header=None,
                    usecols=[1, 2]).drop_duplicates().to_numpy()
                true_pos = dict()
                false_neg = dict()
                false_pos = dict()
                true_neg = dict()
                for obj in unique:
                    obj = int(obj)
                    true_pos[object_label[obj, 1]] = 0
                    false_pos[object_label[obj, 1]] = 0
                    true_neg[object_label[obj, 1]] = 0
                    false_neg[object_label[obj, 1]] = 0

                for obj, pred, label in zip(obj_arr, predictions[:, 1],
                                            labels):
                    if label == 1 and pred >= 0.5:
                        true_pos[object_label[obj, 1]] += 1
                    elif label == 1 and pred < 0.5:
                        false_neg[object_label[obj, 1]] += 1
                    elif label == 0 and pred >= 0.5:
                        false_pos[object_label[obj, 1]] += 1
                    elif label == 0 and pred < 0.5:
                        true_neg[object_label[obj, 1]] += 1
                print(true_pos)
                self._export_object_analysis(true_pos, false_neg, false_pos,
                                             true_neg, model_output_dir)

        # Log the ratios
        pos_lab = len(labels[labels == 1])
        neg_lab = len(labels[labels == 0])

        true_pos = len(results.true_positive_indices)
        true_neg = len(results.true_negative_indices)
        false_pos = neg_lab - true_neg
        false_neg = pos_lab - true_pos

        self.logger.info("%d samples, %d grasps" %
                         (len(labels), len(labels) / levels))
        self.logger.info("%d positive grasps, %d negative grasps" %
                         (pos_lab / levels, neg_lab / levels))
        self.logger.info("Model overall accuracy %.2f %%" %
                         (100 * results.accuracy))
        self.logger.info("Accuracy on positive grasps: %.2f %%" %
                         (true_pos / pos_lab * 100))
        self.logger.info("Accuracy on negative grasps: %.2f %%" %
                         (true_neg / neg_lab * 100))
        self.logger.info("True positive samples: %d" % true_pos)
        self.logger.info("True negative samples: %d" % true_neg)
        self.logger.info("Correct predictions: %d" % (true_pos + true_neg))
        self.logger.info("False positive samples: %d" % false_pos)
        self.logger.info("False negative samples: %d" % false_neg)
        self.logger.info("False predictions: %d" % (false_pos + false_pos))

        cnt = 0  # Counter for grouping the same images with different noise/depth levels
        if self.num_images is None or self.num_images > len(width_arr):
            self.num_images = len(width_arr)
        steps = int(len(width_arr) / self.num_images)
        for j in range(0, len(width_arr), steps):
            try:
                if file_arr[j][1] != file_arr[j - 1][1]:
                    cnt = 0
                else:
                    cnt += 1
            except:
                cnt += 1
            if noise_analysis:
                image = self._plot_grasp(image_arr[j],
                                         width_arr[j],
                                         results,
                                         j,
                                         noise_arr=noise_arr)
            elif depth_analysis:
                image = self._plot_grasp(image_arr[j],
                                         width_arr[j],
                                         results,
                                         j,
                                         depth_arr=depth_arr)
            elif perturb_analysis or single_analysis:
                print("Plot grasp")
                image = self._plot_grasp(image_arr[j],
                                         width_arr[j],
                                         results,
                                         j,
                                         perturb_arr=perturb_arr)
            else:
                image = self._plot_grasp(image_arr[j],
                                         width_arr[j],
                                         results,
                                         j,
                                         plt_results=False)
            try:
                if noise_analysis or depth_analysis or perturb_analysis or single_analysis:
                    image.save(
                        os.path.join(
                            model_output_dir, "%05d_%03d_example_%03d.png" %
                            (file_arr[j][0], file_arr[j][1], cnt)))
                else:
                    image.save(
                        os.path.join(
                            model_output_dir, "%05d_%03d.png" %
                            (file_arr[j][0], file_arr[j][1])))
                    # data = self.scale(image_arr[j][:, :, 0])
                    # image = Image.fromarray(data).convert('RGB').resize((300, 300), resample=Image.NEAREST)
                    # image.save(os.path.join(model_output_dir, "%05d_%03d_orig.png" % (file_arr[j][0], file_arr[j][1])))
            except:
                image.save(
                    os.path.join(model_output_dir, "Example_%03d.png" % (cnt)))
        if single_analysis:
            print("Plotting depth image")
            j = int(len(image_arr) / 2)
            # Plot pure depth image without prediction labeling.
            image = self._plot_grasp(image_arr[j],
                                     width_arr[j],
                                     results,
                                     j,
                                     plt_results=False)
            image.save(os.path.join(model_output_dir, "Depth_image.png"))
        return results