def predict_and_save_stage1_masks(self, h5_data_path, h5_result_saved_path, fold_k=0, batch_size=4): """ 从h5data中读取images进行预测,并把预测mask保存进h5data中。 Args: h5_data_path: str, 存放有训练数据的h5文件路径。 batch_size: int, 批大小。 Returns: None. """ f_result = h5py.File(h5_result_saved_path, "a") try: stage1_predict_masks_grp = f_result.create_group("stage1_fold{}_predict_masks".format(fold_k)) except: stage1_predict_masks_grp = f_result["stage1_fold{}_predict_masks".format(fold_k)] dataset = DataSet(h5_data_path, fold_k) images_train = dataset.get_images(is_train=True) images_val = dataset.get_images(is_train=False) keys_train = dataset.get_keys(is_train=True) keys_val = dataset.get_keys(is_train=False) images = np.concatenate([images_train, images_val], axis=0) keys = np.concatenate([keys_train, keys_val], axis=0) print("predicting ...") images = dataset.preprocess(images, mode="image") y_pred = self.predict(images, batch_size, use_channels=1) print(y_pred.shape) print("Saving predicted masks ...") for i, key in enumerate(keys): stage1_predict_masks_grp.create_dataset(key, dtype=np.float32, data=y_pred[i]) print("Done.")
def predict_from_h5data_old(self, h5_data_path, val_fold_nb, is_train=False, save_dir=None, color_lst=None): dataset = DataSet(h5_data_path, val_fold_nb) images = dataset.get_images(is_train=is_train) imgs_src = np.concatenate([images for i in range(3)], axis=-1) masks = dataset.get_masks(is_train=is_train, mask_nb=0) masks = np.squeeze(masks, axis=-1) print("predicting ...") y_pred = self.predict(dataset.preprocess(images, mode="image"), batch_size=4, use_channels=1) y_pred = self.postprocess(y_pred) y_pred = DataSet.de_preprocess(y_pred, mode="mask") print(y_pred.shape) if save_dir: keys = dataset.get_keys(is_train) if color_lst is None: color_gt = [255, 106, 106] color_pred = [0, 191, 255] # color_pred = [255, 255, 0] else: color_gt = color_lst[0] color_pred = color_lst[1] # BGR to RGB imgs_src = imgs_src[..., ::-1] image_masks = [apply_mask(image, mask, color_gt, alpha=0.5) for image, mask in zip(imgs_src, masks)] image_preds = [apply_mask(image, mask, color_pred, alpha=0.5) for image, mask in zip(imgs_src, y_pred)] dst_image_path_lst = [os.path.join(save_dir, "{:03}.tif".format(int(key))) for key in keys] if not os.path.isdir(save_dir): os.makedirs(save_dir) image_mask_preds = np.concatenate([imgs_src, image_masks, image_preds], axis=2) for i in range(len(image_masks)): cv2.imwrite(dst_image_path_lst[i], image_mask_preds[i]) print("Done.") else: return y_pred