Example #1
0
    def _run_prediction_single_model(self, model_dir,
                                     model_output_dir,
                                     dataset_config):
        """ Analyze the performance of a single model. """
        # read in model config
        model_config_filename = os.path.join(model_dir, 'config.json')
        with open(model_config_filename) as data_file:
            model_config = json.load(data_file)

        # load model
        self.logger.info('Loading model %s' %(model_dir))
        log_file = None
        for handler in self.logger.handlers:
            if isinstance(handler, logging.FileHandler):
                log_file = handler.baseFilename
        gqcnn = get_gqcnn_model(verbose=self.verbose).load(model_dir, verbose=self.verbose, log_file=log_file)
        gqcnn.open_session()
        gripper_mode = gqcnn.gripper_mode
        angular_bins = gqcnn.angular_bins
        
        # read params from the config
        if dataset_config is None:
            dataset_dir = model_config['dataset_dir']
            split_name = model_config['split_name']
            image_field_name = model_config['image_field_name']
            pose_field_name = model_config['pose_field_name']
            metric_name = model_config['target_metric_name']
            metric_thresh = model_config['metric_thresh']
        else:
            dataset_dir = dataset_config['dataset_dir']
            split_name = dataset_config['split_name']
            image_field_name = dataset_config['image_field_name']
            pose_field_name = dataset_config['pose_field_name']
            metric_name = dataset_config['target_metric_name']
            metric_thresh = dataset_config['metric_thresh']
            gripper_mode = dataset_config['gripper_mode']
            
        self.logger.info('Loading dataset %s' %(dataset_dir))
        dataset = TensorDataset.open(dataset_dir)
        train_indices, val_indices, _ = dataset.split(split_name)
        
        # visualize conv filters
        conv1_filters = gqcnn.filters
        num_filt = conv1_filters.shape[3]
        d = utils.sqrt_ceil(num_filt)
        vis2d.clf()
        for k in range(num_filt):
            filt = conv1_filters[:,:,0,k]
            vis2d.subplot(d,d,k+1)
            vis2d.imshow(DepthImage(filt))
            figname = os.path.join(model_output_dir, 'conv1_filters.pdf')
        vis2d.savefig(figname, dpi=self.dpi)
        
        # aggregate training and validation true labels and predicted probabilities
        all_predictions = []
        if angular_bins > 0:
            all_predictions_raw = []
        all_labels = []
        for i in range(dataset.num_tensors):
            # log progress
            if i % self.log_rate == 0:
                self.logger.info('Predicting tensor %d of %d' %(i+1, dataset.num_tensors))

            # read in data
            image_arr = dataset.tensor(image_field_name, i).arr
            pose_arr = read_pose_data(dataset.tensor(pose_field_name, i).arr,
                                      gripper_mode)
            metric_arr = dataset.tensor(metric_name, i).arr
            label_arr = 1 * (metric_arr > metric_thresh)
            label_arr = label_arr.astype(np.uint8)
            if angular_bins > 0:
                # form mask to extract predictions from ground-truth angular bins
                raw_poses = dataset.tensor(pose_field_name, i).arr
                angles = raw_poses[:, 3]
                neg_ind = np.where(angles < 0)
                angles = np.abs(angles) % GeneralConstants.PI
                angles[neg_ind] *= -1
                g_90 = np.where(angles > (GeneralConstants.PI / 2))
                l_neg_90 = np.where(angles < (-1 * (GeneralConstants.PI / 2)))
                angles[g_90] -= GeneralConstants.PI
                angles[l_neg_90] += GeneralConstants.PI
                angles *= -1 # hack to fix reverse angle convention
                angles += (GeneralConstants.PI / 2)
                pred_mask = np.zeros((raw_poses.shape[0], angular_bins*2), dtype=bool)
                bin_width = GeneralConstants.PI / angular_bins
                for i in range(angles.shape[0]):
                    pred_mask[i, int((angles[i] // bin_width)*2)] = True
                    pred_mask[i, int((angles[i] // bin_width)*2 + 1)] = True

            # predict with GQ-CNN
            predictions = gqcnn.predict(image_arr, pose_arr)
            if angular_bins > 0:
                raw_predictions = np.array(predictions)
                predictions = predictions[pred_mask].reshape((-1, 2))
            
            # aggregate
            all_predictions.extend(predictions[:,1].tolist())
            if angular_bins > 0:
                all_predictions_raw.extend(raw_predictions.tolist())
            all_labels.extend(label_arr.tolist())
            
        # close session
        gqcnn.close_session()            

        # create arrays
        all_predictions = np.array(all_predictions)
        all_labels = np.array(all_labels)
        train_predictions = all_predictions[train_indices]
        val_predictions = all_predictions[val_indices]
        train_labels = all_labels[train_indices]
        val_labels = all_labels[val_indices]
        if angular_bins > 0:
            all_predictions_raw = np.array(all_predictions_raw)
            train_predictions_raw = all_predictions_raw[train_indices]
            val_predictions_raw = all_predictions_raw[val_indices]        

        # aggregate results
        train_result = BinaryClassificationResult(train_predictions, train_labels)
        val_result = BinaryClassificationResult(val_predictions, val_labels)
        train_result.save(os.path.join(model_output_dir, 'train_result.cres'))
        val_result.save(os.path.join(model_output_dir, 'val_result.cres'))

        # get stats, plot curves
        self.logger.info('Model %s training error rate: %.3f' %(model_dir, train_result.error_rate))
        self.logger.info('Model %s validation error rate: %.3f' %(model_dir, val_result.error_rate))

        self.logger.info('Model %s training loss: %.3f' %(model_dir, train_result.cross_entropy_loss))
        self.logger.info('Model %s validation loss: %.3f' %(model_dir, val_result.cross_entropy_loss))

        # save images
        vis2d.figure()
        example_dir = os.path.join(model_output_dir, 'examples')
        if not os.path.exists(example_dir):
            os.mkdir(example_dir)

        # train
        self.logger.info('Saving training examples')
        train_example_dir = os.path.join(example_dir, 'train')
        if not os.path.exists(train_example_dir):
            os.mkdir(train_example_dir)
            
        # train TP
        true_positive_indices = train_result.true_positive_indices
        np.random.shuffle(true_positive_indices)
        true_positive_indices = true_positive_indices[:self.num_vis]
        for i, j in enumerate(true_positive_indices):
            k = train_indices[j]
            datapoint = dataset.datapoint(k, field_names=[image_field_name,
                                                          pose_field_name])
            vis2d.clf()
            if angular_bins > 0:
                self._plot_grasp(datapoint, image_field_name, pose_field_name, gripper_mode, angular_preds=train_predictions_raw[j])
            else:
                self._plot_grasp(datapoint, image_field_name, pose_field_name, gripper_mode)
            vis2d.title('Datapoint %d: Pred: %.3f Label: %.3f' %(k,
                                                                 train_result.pred_probs[j],
                                                                 train_result.labels[j]),
                        fontsize=self.font_size)
            vis2d.savefig(os.path.join(train_example_dir, 'true_positive_%03d.png' %(i)))

        # train FP
        false_positive_indices = train_result.false_positive_indices
        np.random.shuffle(false_positive_indices)
        false_positive_indices = false_positive_indices[:self.num_vis]
        for i, j in enumerate(false_positive_indices):
            k = train_indices[j]
            datapoint = dataset.datapoint(k, field_names=[image_field_name,
                                                          pose_field_name])
            vis2d.clf()
            if angular_bins > 0:
                self._plot_grasp(datapoint, image_field_name, pose_field_name, gripper_mode, angular_preds=train_predictions_raw[j])
            else: 
                self._plot_grasp(datapoint, image_field_name, pose_field_name, gripper_mode)
            vis2d.title('Datapoint %d: Pred: %.3f Label: %.3f' %(k,
                                                                 train_result.pred_probs[j],
                                                                 train_result.labels[j]),
                        fontsize=self.font_size)
            vis2d.savefig(os.path.join(train_example_dir, 'false_positive_%03d.png' %(i)))

        # train TN
        true_negative_indices = train_result.true_negative_indices
        np.random.shuffle(true_negative_indices)
        true_negative_indices = true_negative_indices[:self.num_vis]
        for i, j in enumerate(true_negative_indices):
            k = train_indices[j]
            datapoint = dataset.datapoint(k, field_names=[image_field_name,
                                                          pose_field_name])
            vis2d.clf()
            if angular_bins > 0:
                self._plot_grasp(datapoint, image_field_name, pose_field_name, gripper_mode, angular_preds=train_predictions_raw[j])
            else: 
                self._plot_grasp(datapoint, image_field_name, pose_field_name, gripper_mode)
            vis2d.title('Datapoint %d: Pred: %.3f Label: %.3f' %(k,
                                                                 train_result.pred_probs[j],
                                                                 train_result.labels[j]),
                        fontsize=self.font_size)
            vis2d.savefig(os.path.join(train_example_dir, 'true_negative_%03d.png' %(i)))

        # train TP
        false_negative_indices = train_result.false_negative_indices
        np.random.shuffle(false_negative_indices)
        false_negative_indices = false_negative_indices[:self.num_vis]
        for i, j in enumerate(false_negative_indices):
            k = train_indices[j]
            datapoint = dataset.datapoint(k, field_names=[image_field_name,
                                                          pose_field_name])
            vis2d.clf()
            if angular_bins > 0:
                self._plot_grasp(datapoint, image_field_name, pose_field_name, gripper_mode, angular_preds=train_predictions_raw[j])
            else: 
                self._plot_grasp(datapoint, image_field_name, pose_field_name, gripper_mode)
            vis2d.title('Datapoint %d: Pred: %.3f Label: %.3f' %(k,
                                                                 train_result.pred_probs[j],
                                                                 train_result.labels[j]),
                        fontsize=self.font_size)
            vis2d.savefig(os.path.join(train_example_dir, 'false_negative_%03d.png' %(i)))

        # val
        self.logger.info('Saving validation examples')
        val_example_dir = os.path.join(example_dir, 'val')
        if not os.path.exists(val_example_dir):
            os.mkdir(val_example_dir)

        # val TP
        true_positive_indices = val_result.true_positive_indices
        np.random.shuffle(true_positive_indices)
        true_positive_indices = true_positive_indices[:self.num_vis]
        for i, j in enumerate(true_positive_indices):
            k = val_indices[j]
            datapoint = dataset.datapoint(k, field_names=[image_field_name,
                                                          pose_field_name])
            vis2d.clf()
            if angular_bins > 0:
                self._plot_grasp(datapoint, image_field_name, pose_field_name, gripper_mode, angular_preds=val_predictions_raw[j])
            else: 
                self._plot_grasp(datapoint, image_field_name, pose_field_name, gripper_mode)
            vis2d.title('Datapoint %d: Pred: %.3f Label: %.3f' %(k,
                                                                 val_result.pred_probs[j],
                                                                 val_result.labels[j]),
                        fontsize=self.font_size)
            vis2d.savefig(os.path.join(val_example_dir, 'true_positive_%03d.png' %(i)))

        # val FP
        false_positive_indices = val_result.false_positive_indices
        np.random.shuffle(false_positive_indices)
        false_positive_indices = false_positive_indices[:self.num_vis]
        for i, j in enumerate(false_positive_indices):
            k = val_indices[j]
            datapoint = dataset.datapoint(k, field_names=[image_field_name,
                                                          pose_field_name])
            vis2d.clf()
            if angular_bins > 0:
                self._plot_grasp(datapoint, image_field_name, pose_field_name, gripper_mode, angular_preds=val_predictions_raw[j])
            else: 
                self._plot_grasp(datapoint, image_field_name, pose_field_name, gripper_mode)
            vis2d.title('Datapoint %d: Pred: %.3f Label: %.3f' %(k,
                                                                 val_result.pred_probs[j],
                                                                 val_result.labels[j]),
                        fontsize=self.font_size)
            vis2d.savefig(os.path.join(val_example_dir, 'false_positive_%03d.png' %(i)))

        # val TN
        true_negative_indices = val_result.true_negative_indices
        np.random.shuffle(true_negative_indices)
        true_negative_indices = true_negative_indices[:self.num_vis]
        for i, j in enumerate(true_negative_indices):
            k = val_indices[j]
            datapoint = dataset.datapoint(k, field_names=[image_field_name,
                                                          pose_field_name])
            vis2d.clf()
            if angular_bins > 0:
                self._plot_grasp(datapoint, image_field_name, pose_field_name, gripper_mode, angular_preds=val_predictions_raw[j])
            else: 
                self._plot_grasp(datapoint, image_field_name, pose_field_name, gripper_mode)
            vis2d.title('Datapoint %d: Pred: %.3f Label: %.3f' %(k,
                                                                 val_result.pred_probs[j],
                                                                 val_result.labels[j]),
                        fontsize=self.font_size)
            vis2d.savefig(os.path.join(val_example_dir, 'true_negative_%03d.png' %(i)))

        # val TP
        false_negative_indices = val_result.false_negative_indices
        np.random.shuffle(false_negative_indices)
        false_negative_indices = false_negative_indices[:self.num_vis]
        for i, j in enumerate(false_negative_indices):
            k = val_indices[j]
            datapoint = dataset.datapoint(k, field_names=[image_field_name,
                                                          pose_field_name])
            vis2d.clf()
            if angular_bins > 0:
                self._plot_grasp(datapoint, image_field_name, pose_field_name, gripper_mode, angular_preds=val_predictions_raw[j])
            else: 
                self._plot_grasp(datapoint, image_field_name, pose_field_name, gripper_mode)
            vis2d.title('Datapoint %d: Pred: %.3f Label: %.3f' %(k,
                                                                 val_result.pred_probs[j],
                                                                 val_result.labels[j]),
                        fontsize=self.font_size)
            vis2d.savefig(os.path.join(val_example_dir, 'false_negative_%03d.png' %(i)))
            
        # save summary stats
        train_summary_stats = {
            'error_rate': train_result.error_rate,
            'ap_score': train_result.ap_score,
            'auc_score': train_result.auc_score,
            'loss': train_result.cross_entropy_loss
        }
        train_stats_filename = os.path.join(model_output_dir, 'train_stats.json')
        json.dump(train_summary_stats, open(train_stats_filename, 'w'),
                  indent=JSON_INDENT,
                  sort_keys=True)

        val_summary_stats = {
            'error_rate': val_result.error_rate,
            'ap_score': val_result.ap_score,
            'auc_score': val_result.auc_score,
            'loss': val_result.cross_entropy_loss            
        }
        val_stats_filename = os.path.join(model_output_dir, 'val_stats.json')
        json.dump(val_summary_stats, open(val_stats_filename, 'w'),
                  indent=JSON_INDENT,
                  sort_keys=True)        
        
        return train_result, val_result
Example #2
0
    if seed is not None:
        train_config["seed"] = seed
        train_config["gqcnn"]["seed"] = seed
    if tensorboard_port is not None:
        train_config["tensorboard_port"] = tensorboard_port
    gqcnn_params = train_config["gqcnn"]

    # Create a unique output folder based on the date and time.
    if save_datetime:
        # Create output dir.
        unique_name = time.strftime("%Y%m%d-%H%M%S")
        output_dir = os.path.join(output_dir, unique_name)
        utils.mkdir_safe(output_dir)

    # Set visible devices.
    if "gpu_list" in train_config:
        gqcnn_utils.set_cuda_visible_devices(train_config["gpu_list"])

    # Fine-tune the network.
    start_time = time.time()
    gqcnn = get_gqcnn_model(backend)(gqcnn_params)
    trainer = get_gqcnn_trainer(backend)(gqcnn,
                                         dataset_dir,
                                         split_name,
                                         output_dir,
                                         train_config,
                                         name=name)
    trainer.finetune(model_dir)
    logger.info("Total Fine-tuning Time: " +
                str(utils.get_elapsed_time(time.time() - start_time)))
Example #3
0
from gqcnn import get_gqcnn_model, get_gqcnn_trainer, utils as gqcnn_utils

#%%
logger = Logger.get_logger('tools/train.py')

dataset_dir = '/home/ai/git/gqcnn/data/training/Dexnet-2.0_testTraining'
train_config = YamlConfig('cfg/train_dex-net_2.0.yaml')
gqcnn_params = train_config['gqcnn']

#%%
print(os.path.join(dataset_dir, 'config.json'))
config_filename = os.path.join(dataset_dir, 'config.json')
print(config_filename)
# print(os.getcwd())
#%%
open(config_filename, 'r')
config = json.load(open(config_filename, 'r'))
#%%
start_time = time.time()
gqcnn = get_gqcnn_model('tf')(gqcnn_params)
#%%
trainer = get_gqcnn_trainer('tf')(gqcnn,
                                  'data/training/Dexnet-2.0_testTraining',
                                  'image_wise', 'models/', train_config,
                                  'GQCNN-2.0_Training_from_Scratch')
#%%
trainer.train()

#%%
logger.info('Total Training Time: ' +
            str(utils.get_elapsed_time(time.time() - start_time)))