Beispiel #1
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--limit", "-l", default=None, type=int, help="Limits the number of apparition of each class")
    parser.add_argument("--load_data", "-ld", action="store_true", help="Loads all the videos into RAM")
    parser.add_argument("--filters", "-f", nargs='*', default=None, type=str,
                        help="Filters data (for exemple: 'subfolder1'), only usable for images.")
    args = parser.parse_args()

    if not DataConfig.KEEP_TB:
        while DataConfig.TB_DIR.exists():
            shutil.rmtree(DataConfig.TB_DIR, ignore_errors=True)
            time.sleep(0.5)
    DataConfig.TB_DIR.mkdir(parents=True, exist_ok=True)

    if DataConfig.USE_CHECKPOINT:
        if not DataConfig.KEEP_CHECKPOINTS:
            while DataConfig.CHECKPOINT_DIR.exists():
                shutil.rmtree(DataConfig.CHECKPOINT_DIR, ignore_errors=True)
                time.sleep(0.5)
        try:
            DataConfig.CHECKPOINT_DIR.mkdir(parents=True, exist_ok=False)
        except FileExistsError:
            print(f"The checkpoint dir {DataConfig.CHECKPOINT_DIR} already exists")
            return -1

        # Makes a copy of all the code (and config) so that the checkpoints are easy to load and use
        output_folder = DataConfig.CHECKPOINT_DIR / "PyTorch-Video-Classification"
        for filepath in list(Path(".").glob("**/*.py")):
            destination_path = output_folder / filepath
            destination_path.parent.mkdir(parents=True, exist_ok=True)
            shutil.copy(filepath, destination_path)
        misc_files = ["README.md", "requirements.txt", "setup.cfg", ".gitignore"]
        for misc_file in misc_files:
            shutil.copy(misc_file, output_folder / misc_file)
        print("Finished copying files")

    torch.backends.cudnn.benchmark = True   # Makes training a bit faster

    train_dataloader = VideoDataloader(DataConfig.DATA_PATH / "Train", DataConfig.DALI, DataConfig.LOAD_FROM_IMAGES,
                                       DataConfig.LABEL_MAP, drop_last=ModelConfig.MODEL.__name__ == "LRCN",
                                       num_workers=DataConfig.NUM_WORKERS, dali_device_id=DataConfig.DALI_DEVICE_ID,
                                       limit=args.limit, filters=args.filters, load_data=args.load_data,
                                       **get_model_config_dict())

    val_dataloader = VideoDataloader(DataConfig.DATA_PATH / "Validation", DataConfig.DALI, DataConfig.LOAD_FROM_IMAGES,
                                     DataConfig.LABEL_MAP, drop_last=ModelConfig.MODEL.__name__ == "LRCN",
                                     num_workers=DataConfig.NUM_WORKERS, dali_device_id=DataConfig.DALI_DEVICE_ID,
                                     limit=args.limit, filters=args.filters, load_data=args.load_data,
                                     **get_model_config_dict())

    print(f"Loaded {len(train_dataloader)} train data and", f"{len(val_dataloader)} validation data", flush=True)
    print("Building model. . .", end="\r")

    model = build_model(ModelConfig.MODEL, DataConfig.OUTPUT_CLASSES, **get_model_config_dict())
    # The summary does not work with an LSTM for some reason
    if ModelConfig.MODEL.__name__ != "LRCN":
        summary(model, (ModelConfig.SEQUENCE_LENGTH, 1 if ModelConfig.GRAYSCALE else 3,
                        ModelConfig.IMAGE_SIZES[0], ModelConfig.IMAGE_SIZES[1]))

    train(model, train_dataloader, val_dataloader)
def main():
    parser = ArgumentParser()
    parser.add_argument("model_path", type=Path, help="Path to the checkpoint to use")
    parser.add_argument("data_path", type=Path, help="Path to the test dataset")
    parser.add_argument("--show", "--s", action="store_true", help="Show the images where the network failed.")
    args = parser.parse_args()

    inference_start_time = time.perf_counter()

    # Creates and load the model
    model = build_model(ModelConfig.MODEL, DataConfig.NB_CLASSES,
                        model_path=args.model_path, eval=True, **get_config_as_dict(ModelConfig))
    print("Weights loaded", flush=True)

    data, labels, paths = data_loader(args.data_path, DataConfig.LABEL_MAP,
                                      data_preprocessing_fn=default_load_data, return_img_paths=True)
    base_cpu_pipeline = (transforms.resize(ModelConfig.IMAGE_SIZES), )
    base_gpu_pipeline = (transforms.to_tensor(), transforms.normalize(labels_too=True))
    data_transformations = transforms.compose_transformations((*base_cpu_pipeline, *base_gpu_pipeline))
    print("\nData loaded", flush=True)

    results = []  # Variable used to keep track of the classification results
    for img, label, img_path in zip(data, labels, paths):
        clean_print(f"Processing image {img_path}", end="\r")
        img, label = data_transformations([img], [label])
        with torch.no_grad():
            output = model(img)
            output = torch.nn.functional.softmax(output, dim=-1)
            prediction = torch.argmax(output)
            pred_correct = label == prediction
            if pred_correct:
                results.append(1)
            else:
                results.append(0)

            if args.show and not pred_correct:
                out_img = draw_pred_img(img, output, label, DataConfig.LABEL_MAP, size=ModelConfig.IMAGE_SIZES)
                out_img = cv2.cvtColor(out_img[0], cv2.COLOR_RGB2BGR)
                while True:
                    cv2.imshow("Image", out_img)
                    key = cv2.waitKey(10)
                    if key == ord("q"):
                        cv2.destroyAllWindows()
                        break

    results = np.asarray(results)
    total_time = time.perf_counter() - inference_start_time
    print("\nFinished running inference on the test dataset.")
    print(f"Total inference time was {total_time:.3f}s, which averages to {total_time/len(results):.5f}s per image")
    print(f"Precision: {np.mean(results)}")
def main():
    parser = ArgumentParser(description="Segmentation inference")
    parser.add_argument("model_path",
                        type=Path,
                        help="Path to the checkpoint to use")
    parser.add_argument("data_path",
                        type=Path,
                        help="Path to the test dataset")
    parser.add_argument("--json_path",
                        "-j",
                        type=Path,
                        help="Path to the classes.json file")
    parser.add_argument("--show_imgs",
                        "-s",
                        action="store_true",
                        help="Show predicted segmentation masks")
    parser.add_argument(
        "--use_blob_detection",
        "-b",
        action="store_true",
        help=
        "Use blob detection on predicted masks to get a binary classification")
    parser.add_argument("--show_missed",
                        "-sm",
                        action="store_true",
                        help="Show samples where the blob detection failed")
    parser.add_argument("--verbose_level",
                        "-v",
                        choices=["debug", "info", "error"],
                        default="info",
                        type=str,
                        help="Logger level.")
    args = parser.parse_args()

    model_path: Path = args.model_path
    data_path: Path = args.data_path
    json_path: Path = args.json_path
    verbose_level: str = args.verbose_level
    show_imgs: bool = args.show_imgs
    show_missed: bool = args.show_missed
    use_blob_detection: bool = args.use_blob_detection

    model_config = get_model_config()
    logger = create_logger("Inference", verbose_level=verbose_level)

    label_map, color_map = get_label_maps(
        json_path if json_path else data_path.parent / "classes.json", logger)
    denormalize_img_fn = partial(denormalize_tensor,
                                 mean=torch.Tensor(model_config.MEAN),
                                 std=torch.Tensor(model_config.STD))

    data, labels = default_loader(data_path,
                                  get_mask_path_fn=get_mask_path,
                                  verbose=False)
    assert len(labels) > 0, f"Did not find any image in {data_path}, exiting"
    logger.info(f"Data loaded, found {len(labels)} images.")

    common_pipeline = albumentation_wrapper(
        albumentations.Compose([
            albumentations.Normalize(mean=model_config.MEAN,
                                     std=model_config.STD,
                                     max_pixel_value=255.0,
                                     p=1.0),
            albumentations.Resize(*model_config.IMAGE_SIZES,
                                  interpolation=cv2.INTER_LINEAR)
        ]))
    with BatchGenerator(data,
                        labels,
                        1,
                        nb_workers=2,
                        data_preprocessing_fn=default_load_data,
                        labels_preprocessing_fn=default_load_labels,
                        cpu_pipeline=common_pipeline,
                        gpu_pipeline=transforms.to_tensor(),
                        shuffle=False) as dataloader:
        logger.debug("Dataloader created")

        # Creates and load the model
        print("Building model. . .", end="\r")
        model = build_model(model_config.MODEL,
                            len(label_map),
                            model_path=model_path,
                            eval_mode=True,
                            **get_dataclass_as_dict(model_config))
        logger.info("Weights loaded     ")

        if use_blob_detection:
            # Variables used to keep track of the classification results
            true_negs = 0.0
            true_pos = 0.0
            pos_elts = 0
            neg_elts = 0

            # Setup SimpleBlobDetector parameters.
            params = cv2.SimpleBlobDetector_Params()
            params.filterByArea = True
            params.minArea = 1
            params.maxArea = 5000
            params.minThreshold = 0
            params.maxThreshold = 255
            params.filterByCircularity = False
            params.filterByColor = False
            params.filterByConvexity = False
            params.filterByInertia = False

            detector: cv2.SimpleBlobDetector = cv2.SimpleBlobDetector_create(
                params)

        with torch.no_grad():
            # Compute some segmentation metrics
            logger.info("Computing the confusion matrix")
            metrics = ClassificationMetrics(model,
                                            None,
                                            dataloader,
                                            label_map,
                                            max_batches=None,
                                            segmentation=True)
            metrics.compute_confusion_matrix(mode="Validation")
            avg_acc = metrics.get_avg_acc()
            logger.info(f"Average accuracy: {avg_acc}")

            per_class_acc = metrics.get_class_accuracy()
            per_class_acc_msg = [
                "\n\t" + label_map[key] + f": {acc}"
                for key, acc in enumerate(per_class_acc)
            ]
            logger.info("Per Class Accuracy:" + "".join(per_class_acc_msg))

            per_class_iou = metrics.get_class_iou()
            per_class_iou_msg = [
                "\n\t" + label_map[key] + f": {iou}"
                for key, iou in enumerate(per_class_iou)
            ]
            logger.info("Per Class IOU:" + "".join(per_class_iou_msg))

            if show_imgs:
                confusion_matrix = metrics.get_confusion_matrix()
                show_image(confusion_matrix, "Confusion Matrix")

            # Redo a pass over the dataset to get more information if requested
            if show_imgs or use_blob_detection:
                for _step, (inputs, labels) in enumerate(dataloader, start=1):
                    predictions = model(inputs)
                    inputs = denormalize_img_fn(inputs)

                    if use_blob_detection:
                        one_hot_masks_preds = rearrange(
                            predictions, "b c w h -> b w h c")
                        masks_preds: np.ndarray = torch.argmax(
                            one_hot_masks_preds,
                            dim=-1).cpu().detach().numpy()
                        one_hot_masks_labels = rearrange(
                            labels, "b c w h -> b w h c")
                        masks_labels: np.ndarray = torch.argmax(
                            one_hot_masks_labels,
                            dim=-1).cpu().detach().numpy()

                        width, height, _ = one_hot_masks_preds[0].shape
                        for img, pred_mask, label_mask in zip(
                                inputs, masks_preds, masks_labels):
                            # Recreate the segmentation mask from its one hot representation
                            pred_mask_rgb = np.asarray(color_map[pred_mask],
                                                       dtype=np.uint8)
                            label_mask_rgb = np.asarray(color_map[label_mask],
                                                        dtype=np.uint8)

                            # Run the blob detector on the image and store the results
                            keypoints_pred = detector.detect(pred_mask_rgb)
                            keypoints_label = detector.detect(label_mask_rgb)
                            if len(keypoints_label) > 0:
                                if len(keypoints_pred) > 0:
                                    true_negs += 1
                                elif show_missed:
                                    out_img = draw_blobs(
                                        img, pred_mask_rgb, label_mask_rgb,
                                        keypoints_pred)
                                    show_image(out_img,
                                               "Sample with missed defect")
                                neg_elts += 1
                            else:
                                if len(keypoints_pred) == 0:
                                    true_pos += 1
                                elif show_missed:
                                    out_img = draw_blobs(
                                        img, pred_mask_rgb, label_mask_rgb,
                                        keypoints_pred)
                                    show_image(out_img,
                                               "Clean sample misclassified")
                                pos_elts += 1
                            if show_imgs:
                                out_img = draw_blobs(img, pred_mask_rgb,
                                                     label_mask_rgb,
                                                     keypoints_pred)
                                show_image(out_img, "Output image")
                    elif show_imgs:
                        out_imgs = draw_segmentation(inputs,
                                                     predictions,
                                                     labels,
                                                     color_map=color_map)
                        for out_img in out_imgs:
                            show_image(out_img)

    if use_blob_detection:
        precision = true_pos / max(1, (true_pos + (neg_elts - true_negs)))
        recall = true_pos / max(1, pos_elts)
        acc = (true_pos + true_negs) / (neg_elts + pos_elts)
        pos_acc = true_pos / max(1, pos_elts)
        neg_acc = true_negs / max(1, neg_elts)

        stats = (precision, recall, acc, pos_acc, neg_acc)
        stats_names = ("Precision", "Recall", "Accuracy", "Positive accuracy",
                       "Negative accuracy")

        logger.info(
            f"Dataset was composed of {pos_elts} good samples and {neg_elts} bad samples."
        )
        logger.info(
            "Results obtained using blob detection for classification:")
        logger.info(
            f"Bad samples misclassified: {neg_elts-true_negs}, Good samples misclassified: {pos_elts-true_pos}"
        )
        for stat_idx in range(len(stats)):
            logger.info(f"{stats_names[stat_idx]}: {stats[stat_idx]:.2f}")
Beispiel #4
0
def main():
    parser = ArgumentParser(description="Segmentation inference")
    parser.add_argument("model_path",
                        type=Path,
                        help="Path to the checkpoint to use")
    parser.add_argument("data_path",
                        type=Path,
                        help="Path to the test dataset")
    parser.add_argument("--output_path",
                        "-o",
                        type=Path,
                        default=None,
                        help="Save results to that folder if given.")
    parser.add_argument(
        "--json_path",
        "-j",
        type=Path,
        help=
        "Json file with the index mapping, defaults to data_path.parent / 'classes.json'"
    )
    parser.add_argument("--tile_size",
                        "-ts",
                        nargs=2,
                        default=[256, 256],
                        type=int,
                        help="Size of the tiles (w, h)")
    parser.add_argument("--stride",
                        "-s",
                        nargs=2,
                        default=[100, 100],
                        type=int,
                        help="Strides (w, h)")
    parser.add_argument("--display_image",
                        "-d",
                        action="store_true",
                        help="Show result.")
    parser.add_argument("--verbose_level",
                        "-v",
                        choices=["debug", "info", "error"],
                        default="info",
                        type=str,
                        help="Logger level.")
    args = parser.parse_args()

    model_path: Path = args.model_path
    data_path: Path = args.data_path
    output_folder: Path = args.output_path
    json_path: Path = args.json_path
    tile_width: int
    tile_height: int
    tile_width, tile_height = args.tile_size
    stride_width: int
    stride_height: int
    stride_width, stride_height = args.stride
    display_img: bool = args.display_image
    verbose_level: str = args.verbose_level

    model_config = get_model_config()
    logger = create_logger("Inference", verbose_level=verbose_level)
    label_map, color_map = get_label_maps(
        json_path if json_path else data_path.parent / "classes.json", logger)

    imgs_paths, masks_paths = default_loader(data_path,
                                             get_mask_path_fn=get_mask_path,
                                             verbose=False)
    assert len(
        imgs_paths) > 0, f"Did not find any image in {data_path}, exiting"
    logger.info(f"Data loaded, found {(nb_imgs := len(imgs_paths))} images.")

    preprocess_fn = albumentations.Compose([
        albumentations.Normalize(mean=model_config.MEAN,
                                 std=model_config.STD,
                                 max_pixel_value=255.0,
                                 p=1.0),
        albumentations.Resize(*model_config.IMAGE_SIZES,
                              interpolation=cv2.INTER_LINEAR)
    ])
    resize_to_original = albumentations.Resize(tile_height,
                                               tile_width,
                                               interpolation=cv2.INTER_LINEAR)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # Create and load the model
    print("Building model. . .", end="\r")
    model = build_model(model_config.MODEL,
                        len(label_map),
                        model_path=model_path,
                        eval_mode=True,
                        **get_dataclass_as_dict(model_config))
    logger.info(
        "Weights loaded. Starting to process images (this might take a while)."
    )

    for i, (img_path, mask_path) in enumerate(zip(imgs_paths, masks_paths)):
        logger.info(f"Processing image {img_path.name} ({i+1}/{nb_imgs})")

        img = default_load_data(img_path)
        one_hot_mask = default_load_labels(
            mask_path)  # TODO: Make this step optional ?
        height, width, _ = img.shape
        assert one_hot_mask.shape[0] == height and one_hot_mask.shape[
            1] == width, (
                f"\nShape of the image and the mask do not match for image {img_path}"
            )

        nb_tiles_per_img = (1 + (width - tile_width) // stride_width) * (
            1 + (height - tile_height) // stride_height)
        tile_idx = 0
        pred_mask = np.zeros_like(one_hot_mask)
        for x in range(0, width - tile_width, stride_width):
            for y in range(0, height - tile_height, stride_height):
                logger.debug(
                    f"Processing tile {(tile_idx := tile_idx + 1)} / {nb_tiles_per_img}"
                )
                tile = img[y:y + tile_height, x:x + tile_width]

                with torch.no_grad():
                    tile = np.expand_dims(resized_tile :=
                                          preprocess_fn(image=tile)["image"],
                                          axis=0)
                    tile = tile.transpose((0, 3, 1, 2))
                    tile = torch.from_numpy(tile).float().to(device)
                    oh_tile_pred = model(tile)

                oh_tile_pred = rearrange(oh_tile_pred, "b c w h -> b w h c")
                oh_tile_pred = np.squeeze(oh_tile_pred.cpu().detach().numpy(),
                                          axis=0)
                oh_tile_pred = resize_to_original(image=resized_tile,
                                                  mask=oh_tile_pred)["mask"]

                # Effectively averages the predictions from overlapping tiles.
                pred_mask[y:y + tile_height, x:x + tile_width] += oh_tile_pred

        # Skew the results towards over-detection if desired. Needs to be tweaked manually though
        # pred_mask[..., 1:] *= 4
        # Small post processing. Erode to remove small areas.
        pred_mask = cv2.GaussianBlur(pred_mask, (5, 5), 0)
        kernel = np.ones((3, 3), np.uint8)
        pred_mask = cv2.erode(pred_mask, kernel, iterations=1)
        # kernel = np.ones((3, 3), np.uint8)
        # pred_mask = cv2.dilate(pred_mask, kernel, iterations=1)

        # Go from logits to one hot
        pred_mask = np.argmax(pred_mask, axis=-1)
        # Recreate the segmentation mask from its one hot representation
        pred_mask_rgb = cv2.cvtColor(
            np.asarray(color_map[pred_mask], dtype=np.uint8),
            cv2.COLOR_RGB2BGR)
        label_mask = cv2.imread(str(mask_path))

        label_bboxes = get_cc_bboxes(label_mask, logger)
        # Again, remove small predicted areas (tweak value depending on the project).
        pred_bboxes = get_cc_bboxes(pred_mask_rgb, logger, area_threshold=70)

        if output_folder:
            rel_path = img_path.relative_to(data_path)
            output_path = output_folder / rel_path.parent / img_path.name
            output_path.parent.mkdir(parents=True, exist_ok=True)
            drawn_img = draw_blobs_from_bboxes(img, pred_bboxes, (0, 0, 255))
            logger.info(f"Saving result image at {output_path}")
            cv2.imwrite(str(output_path), drawn_img)
        if (output_folder and logger.getEffectiveLevel()
                == logging.DEBUG) or display_img:
            drawn_img = draw_blobs_from_bboxes(img, label_bboxes, (0, 255, 0))
            drawn_img = draw_blobs_from_bboxes(drawn_img, pred_bboxes,
                                               (0, 0, 255))
            result_img = concat_imgs(drawn_img, pred_mask_rgb, label_mask)
            if display_img:
                show_img(result_img)
            if output_folder:
                output_path = output_path.with_stem(img_path.stem + "_debug")
                cv2.imwrite(str(output_path), result_img)

        tp, fp, fn = get_confusion_matrix_from_bboxes(label_bboxes,
                                                      pred_bboxes)
        logger.info(
            f"Results for image {img_path}: TP: {tp}, FP: {fp}, FN: {fn}")

        iou = get_iou_masks(label_mask, pred_mask_rgb, color=(0, 0, 255))
        logger.info(f"\tIoU: {iou:.3f}")
def main():
    parser = ArgumentParser()
    parser.add_argument("model_path",
                        type=Path,
                        help="Path to the checkpoint to use")
    parser.add_argument("data_path",
                        type=Path,
                        help="Path to the test dataset")
    args = parser.parse_args()

    # Creates and load the model
    model = build_model(ModelConfig.MODEL,
                        DataConfig.NB_CLASSES,
                        model_path=args.model_path,
                        eval=True,
                        **get_config_as_dict(ModelConfig))
    print("Weights loaded", flush=True)

    data, labels, paths = data_loader(args.data_path,
                                      DataConfig.LABEL_MAP,
                                      data_preprocessing_fn=default_load_data,
                                      return_img_paths=True)
    base_cpu_pipeline = (transforms.resize(ModelConfig.IMAGE_SIZES), )
    base_gpu_pipeline = (transforms.to_tensor(),
                         transforms.normalize(labels_too=True))
    data_transformations = transforms.compose_transformations(
        (*base_cpu_pipeline, *base_gpu_pipeline))
    print("\nData loaded", flush=True)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    for img, label, img_path in zip(data, labels, paths):
        clean_print(f"Processing image {img_path}", end="\r")
        img, label = data_transformations([img], [label])

        # Feed the image to the model
        output = model(img)
        output = torch.nn.functional.softmax(output, dim=-1)

        # Get top prediction and turn it into a one hot
        prediction = output.argmax(dim=1)
        one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
        one_hot[0][prediction] = 1
        one_hot = torch.from_numpy(one_hot).requires_grad_(True)
        one_hot = torch.sum(one_hot.to(device) * output)

        # Get gradients and activations
        model.zero_grad()
        one_hot.backward(retain_graph=True)
        grads_val = model.get_gradients()[-1].cpu().data.numpy()

        activations = model.get_activations()
        activations = activations.cpu().data.numpy()[0, :]

        # Make gradcam mask
        weights = np.mean(grads_val, axis=(1, 2))
        cam = np.zeros(activations.shape[1:], dtype=np.float32)

        for i, w in enumerate(weights):
            cam += w * activations[i, :, :]

        cam = np.maximum(cam, 0)
        cam = cv2.resize(cam, ModelConfig.IMAGE_SIZES)
        cam = cam - np.min(cam)
        cam = cam / np.max(cam)

        # Draw prediction (logits) on the image
        img = draw_pred_img(img,
                            output,
                            label,
                            DataConfig.LABEL_MAP,
                            size=ModelConfig.IMAGE_SIZES)

        # Fuse input image and gradcam mask
        heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)
        heatmap = np.float32(heatmap)
        cam = heatmap + np.float32(img)
        cam = cam / np.max(cam)

        while True:
            cv2.imshow("Image", cam)
            key = cv2.waitKey(10)
            if key == ord("q"):
                cv2.destroyAllWindows()
                break
def main():
    parser = argparse.ArgumentParser(description="Segmentation training")
    parser.add_argument("--limit",
                        default=None,
                        type=int,
                        help="Limits the number of apparition of each class")
    parser.add_argument("--load_data",
                        action="store_true",
                        help="Loads all the videos into RAM")
    parser.add_argument(
        "--name",
        type=str,
        default="Train",
        help=
        "Use it to know what a train is when using ps. Also name of the logger."
    )
    parser.add_argument("--verbose_level",
                        "-v",
                        choices=["debug", "info", "error"],
                        default="info",
                        type=str,
                        help="Logger level.")
    args = parser.parse_args()

    name: str = args.name
    verbose_level: str = args.verbose_level

    data_config = get_data_config()
    model_config = get_model_config()

    prepare_folders(
        data_config.TB_DIR if data_config.USE_TB else None,
        data_config.CHECKPOINTS_DIR if data_config.USE_CHECKPOINTS else None,
        repo_name="Segmentation-PyTorch")
    log_dir = data_config.CHECKPOINTS_DIR / "print_logs" if data_config.USE_CHECKPOINTS else None
    logger = create_logger(name, log_dir=log_dir, verbose_level=verbose_level)
    logger.info("Finished preparing tensorboard and checkpoints folders.")

    torch.backends.cudnn.benchmark = True  # Makes training quite a bit faster

    train_data, train_labels = default_loader(
        data_config.DATA_PATH / "Train",
        get_mask_path_fn=get_mask_path,
        limit=args.limit,
        load_data=args.load_data,
        data_preprocessing_fn=default_load_data if args.load_data else None,
        labels_preprocessing_fn=default_load_labels
        if args.load_data else None)
    logger.info("Train data loaded")

    val_data, val_labels = default_loader(
        data_config.DATA_PATH / "Validation",
        get_mask_path_fn=get_mask_path,
        limit=args.limit,
        load_data=args.load_data,
        data_preprocessing_fn=default_load_data if args.load_data else None,
        labels_preprocessing_fn=default_load_labels
        if args.load_data else None)
    logger.info("Validation data loaded")

    # Data augmentation done on cpu.
    augmentation_pipeline = albumentation_wrapper(
        albumentations.Compose([
            albumentations.HorizontalFlip(p=0.5),
            albumentations.VerticalFlip(p=0.5),
            # albumentations.RandomRotate90(p=0.2),
            # albumentations.CLAHE(),
            albumentations.RandomBrightnessContrast(brightness_limit=0.1,
                                                    contrast_limit=0.1,
                                                    p=0.5),
            albumentations.HueSaturationValue(hue_shift_limit=10,
                                              sat_shift_limit=15,
                                              val_shift_limit=10,
                                              p=0.5),
            albumentations.ShiftScaleRotate(
                scale_limit=0.05,
                rotate_limit=10,
                shift_limit=0.06,
                p=0.5,
                border_mode=cv2.BORDER_CONSTANT,  # cv2.BORDER_REFLECT_101
                value=0,
                mask_value=[1] + [0] * (data_config.OUTPUT_CLASSES - 1)),
            # albumentations.GridDistortion(p=0.5),
        ]))

    common_pipeline = albumentation_wrapper(
        albumentations.Compose([
            albumentations.Normalize(mean=model_config.MEAN,
                                     std=model_config.STD,
                                     max_pixel_value=255.0,
                                     p=1.0),
            albumentations.Resize(*model_config.IMAGE_SIZES,
                                  interpolation=cv2.INTER_LINEAR)
        ]))
    train_pipeline = transforms.compose_transformations(
        (augmentation_pipeline, common_pipeline))

    with BatchGenerator(train_data,
                        train_labels,
                        model_config.BATCH_SIZE,
                        nb_workers=data_config.NB_WORKERS,
                        data_preprocessing_fn=default_load_data if not args.load_data else None,
                        labels_preprocessing_fn=default_load_labels if not args.load_data else None,
                        cpu_pipeline=train_pipeline,
                        gpu_pipeline=transforms.to_tensor(),
                        shuffle=True) as train_dataloader, \
        BatchGenerator(val_data,
                       val_labels,
                       model_config.BATCH_SIZE,
                       nb_workers=data_config.NB_WORKERS,
                       data_preprocessing_fn=default_load_data if not args.load_data else None,
                       labels_preprocessing_fn=default_load_labels if not args.load_data else None,
                       cpu_pipeline=common_pipeline,
                       gpu_pipeline=transforms.to_tensor(),
                       shuffle=False) as val_dataloader:

        print(f"\nLoaded {len(train_dataloader)} train data and",
              f"{len(val_dataloader)} validation data",
              flush=True)

        print("Building model. . .", end="\r")
        model = build_model(model_config.MODEL, data_config.OUTPUT_CLASSES,
                            **get_dataclass_as_dict(model_config))

        logger.info(f"{'-'*24} Starting train {'-'*24}")
        logger.info("From command : " + ' '.join(sys.argv))
        logger.info(f"Input shape: {train_dataloader.data_shape}")
        logger.info("")
        logger.info("Using model:")
        for line in summary(model, train_dataloader.data_shape):
            logger.info(line)
        logger.info("")

        loss_fn = DiceBCELoss()
        optimizer = torch.optim.AdamW(model.parameters(),
                                      lr=model_config.LR,
                                      weight_decay=model_config.WEIGHT_DECAY)
        trainer = Trainer(model, loss_fn, optimizer, train_dataloader,
                          val_dataloader)
        # scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=model_config.LR_DECAY)
        scheduler = CosineAnnealingLR(optimizer,
                                      model_config.MAX_EPOCHS,
                                      eta_min=5e-6)
        # TODO: Try this https://github.com/rwightman/pytorch-image-models/blob/master/timm/scheduler/cosine_lr.py

        if data_config.USE_TB:
            metrics = ClassificationMetrics(model,
                                            train_dataloader,
                                            val_dataloader,
                                            data_config.LABEL_MAP,
                                            max_batches=10,
                                            segmentation=True)
            tensorboard = TensorBoard(model,
                                      data_config.TB_DIR,
                                      model_config.IMAGE_SIZES,
                                      metrics,
                                      data_config.LABEL_MAP,
                                      color_map=data_config.COLOR_MAP,
                                      denormalize_img_fn=partial(
                                          denormalize_np,
                                          mean=model_config.MEAN,
                                          std=model_config.STD))

        best_loss = 1000
        last_checkpoint_epoch = 0
        train_start_time = time.time()
        try:
            for epoch in range(model_config.MAX_EPOCHS):
                epoch_start_time = time.perf_counter()
                print()  # logger doesn't handle \n super well
                logger.info(f"Epoch {epoch}/{model_config.MAX_EPOCHS}")

                epoch_loss = trainer.train_epoch()
                if data_config.USE_TB:
                    tensorboard.write_loss(epoch, epoch_loss)
                    tensorboard.write_lr(epoch, scheduler.get_last_lr()[0])

                if (epoch_loss < best_loss and data_config.USE_CHECKPOINTS
                        and epoch >= data_config.RECORD_START
                        and (epoch - last_checkpoint_epoch) >=
                        data_config.CHECKPT_SAVE_FREQ):
                    save_path = data_config.CHECKPOINTS_DIR / f"train_{epoch}.pt"
                    logger.info(
                        f"\nLoss improved from {best_loss:.5e} to {epoch_loss:.5e},"
                        f"saving model to {save_path}")
                    best_loss, last_checkpoint_epoch = epoch_loss, epoch
                    torch.save(model.state_dict(), save_path)

                logger.info(
                    f"Epoch loss: {epoch_loss:.5e}  -  Took {time.perf_counter() - epoch_start_time:.5f}s"
                )

                # Validation and other metrics
                if epoch % data_config.VAL_FREQ == 0 and epoch >= data_config.RECORD_START:
                    # if data_config.USE_TB:
                    #     tensorboard.write_weights_grad(epoch)
                    with torch.no_grad():
                        validation_start_time = time.perf_counter()
                        epoch_loss = trainer.val_epoch()

                        if data_config.USE_TB:
                            print("Starting to compute TensorBoard metrics",
                                  end="\r",
                                  flush=True)
                            tensorboard.write_weights_grad(epoch)
                            tensorboard.write_loss(epoch,
                                                   epoch_loss,
                                                   mode="Validation")

                            # Metrics for the Train dataset
                            tensorboard.write_segmentation(
                                epoch, train_dataloader)
                            tensorboard.write_metrics(epoch)
                            train_acc = metrics.get_avg_acc()

                            # Metrics for the Validation dataset
                            tensorboard.write_segmentation(epoch,
                                                           val_dataloader,
                                                           mode="Validation")
                            tensorboard.write_metrics(epoch, mode="Validation")
                            val_acc = metrics.get_avg_acc()

                            logger.info(
                                f"Train accuracy: {train_acc:.3f}  -  Validation accuracy: {val_acc:.3f}"
                            )

                        logger.info(
                            f"Validation loss: {epoch_loss:.5e}  -  "
                            f"Took {time.perf_counter() - validation_start_time:.5f}s"
                        )
                scheduler.step()
        except KeyboardInterrupt:
            print("\n")
        except Exception as error:
            logger.error(''.join(traceback.format_exception(*sys.exc_info())))
            raise error

    if data_config.USE_TB:
        tensorboard.close_writers()

    train_stop_time = time.time()
    end_msg = f"Finished Training\n\tTraining time : {train_stop_time - train_start_time:.03f}s"
    try:
        memory_peak, gpu_memory = resource_usage()
        end_msg += f"\n\tRAM peak : {memory_peak // 1024} MB\n\tVRAM usage : {gpu_memory}"
    except CalledProcessError:
        pass
    logger.info(end_msg)
Beispiel #7
0
def main():
    parser = argparse.ArgumentParser(
        description="Training script",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--limit",
                        "-l",
                        default=None,
                        type=int,
                        help="Limits the number of apparition of each class.")
    parser.add_argument(
        "--name",
        type=str,
        default="Train",
        help=
        "Used to know what a train is when using ps. Also name of the logger.")
    parser.add_argument("--verbose_level",
                        "-v",
                        choices=["debug", "info", "error"],
                        default="info",
                        type=str,
                        help="Logger level.")
    args = parser.parse_args()

    limit: int = args.limit
    name: str = args.name
    verbose_level: str = args.verbose_level

    data_config = get_data_config()
    model_config = get_model_config()

    prepare_folders(
        data_config.TB_DIR if data_config.USE_TB else None,
        data_config.CHECKPOINTS_DIR if data_config.USE_CHECKPOINTS else None,
        repo_name="Classification-PyTorch",
        extra_files=[
            Path("config/data_config.py"),
            Path("config/model_config.py")
        ])
    log_dir = data_config.CHECKPOINTS_DIR / "print_logs" if data_config.USE_CHECKPOINTS else None
    logger = create_logger(name, log_dir=log_dir, verbose_level=verbose_level)
    logger.info("Finished preparing tensorboard and checkpoints folders.")

    torch.backends.cudnn.benchmark = True  # Makes training quite a bit faster

    train_data, train_labels = data_loader(data_config.DATA_PATH / "Train",
                                           data_config.LABEL_MAP,
                                           limit=limit)
    logger.info("Train data loaded")
    val_data, val_labels = data_loader(data_config.DATA_PATH / "Validation",
                                       data_config.LABEL_MAP,
                                       limit=limit)
    logger.info("Validation data loaded")

    # Data augmentation done on cpu.
    augmentation_pipeline = transforms.albumentation_wrapper(
        albumentations.Compose([
            albumentations.HorizontalFlip(p=0.5),
            albumentations.VerticalFlip(p=0.5),
            albumentations.RandomRotate90(),
            albumentations.ShiftScaleRotate(),
            # albumentations.CLAHE(),
            # albumentations.AdvancedBlur(),
            # albumentations.GaussNoise(),
            albumentations.RandomBrightnessContrast(brightness_limit=0.3,
                                                    contrast_limit=0.3,
                                                    p=0.5),
            albumentations.HueSaturationValue(hue_shift_limit=30,
                                              sat_shift_limit=45,
                                              val_shift_limit=30,
                                              p=0.5),
            # albumentations.ImageCompression(),
        ]))
    common_pipeline = transforms.albumentation_wrapper(
        albumentations.Compose([
            albumentations.Normalize(mean=model_config.IMG_MEAN,
                                     std=model_config.IMG_STD,
                                     p=1.0),
            albumentations.Resize(*model_config.IMAGE_SIZES,
                                  interpolation=cv2.INTER_LINEAR)
        ]))

    train_pipeline = transforms.compose_transformations(
        (augmentation_pipeline, common_pipeline))

    denormalize_imgs_fn = transforms.destandardize_img(model_config.IMG_MEAN,
                                                       model_config.IMG_STD)

    with BatchGenerator(train_data, train_labels,
                        model_config.BATCH_SIZE, nb_workers=data_config.NB_WORKERS,
                        data_preprocessing_fn=default_load_data,
                        cpu_pipeline=train_pipeline,
                        gpu_pipeline=transforms.to_tensor(),
                        shuffle=True) as train_dataloader, \
        BatchGenerator(val_data, val_labels, model_config.BATCH_SIZE, nb_workers=data_config.NB_WORKERS,
                       data_preprocessing_fn=default_load_data,
                       cpu_pipeline=common_pipeline,
                       gpu_pipeline=transforms.to_tensor(),
                       shuffle=False) as val_dataloader:

        print(f"\nLoaded {len(train_dataloader)} train data and",
              f"{len(val_dataloader)} validation data",
              flush=True)

        print("Building model. . .", end="\r")
        model = build_model(model_config.MODEL, data_config.NB_CLASSES,
                            **dict(get_dataclass_as_dict(model_config)))

        logger.info(f"{'-'*24} Starting train {'-'*24}")
        logger.info("From command : %s", ' '.join(sys.argv))
        logger.info(f"Input shape: {train_dataloader.data_shape}")
        logger.info("")
        logger.info("Using model:")
        for line in summary(model, train_dataloader.data_shape):
            logger.info(line)
        logger.info("")

        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        weights = torch.Tensor(model_config.LOSS_WEIGTHS).to(
            device) if model_config.LOSS_WEIGTHS else None
        loss_fn = nn.CrossEntropyLoss(weight=weights)
        # loss_fn = SmoothCrossEntropyLoss(model_config.LABEL_SMOOTHING)

        optimizer = torch.optim.AdamW(model.parameters(),
                                      lr=model_config.START_LR,
                                      betas=(0.9, 0.95),
                                      weight_decay=model_config.WEIGHT_DECAY)
        trainer = Trainer(model, loss_fn, optimizer, train_dataloader,
                          val_dataloader)
        scheduler = CosineAnnealingLR(optimizer,
                                      model_config.MAX_EPOCHS,
                                      eta_min=model_config.END_LR)
        # TODO: Try this https://github.com/rwightman/pytorch-image-models/blob/master/timm/scheduler/cosine_lr.py

        if data_config.USE_TB:
            metrics = ClassificationMetrics(model,
                                            train_dataloader,
                                            val_dataloader,
                                            data_config.LABEL_MAP,
                                            max_batches=None)
            tensorboard = ClassificationTensorBoard(model, data_config.TB_DIR,
                                                    train_dataloader,
                                                    val_dataloader, logger,
                                                    metrics,
                                                    denormalize_imgs_fn)

        best_loss = 1000
        last_checkpoint_epoch = 0
        train_start_time = time.time()

        try:
            for epoch in range(model_config.MAX_EPOCHS):
                epoch_start_time = time.perf_counter()
                print()  # logger doesn't handle \n super well
                logger.info(f"Epoch {epoch}/{model_config.MAX_EPOCHS}")

                epoch_loss = trainer.train_epoch()

                if data_config.USE_TB:
                    tensorboard.write_loss(epoch, epoch_loss)
                    tensorboard.write_lr(epoch, scheduler.get_last_lr()[0])

                if (epoch_loss < best_loss and data_config.USE_CHECKPOINTS
                        and epoch >= data_config.RECORD_START
                        and (epoch - last_checkpoint_epoch) >=
                        data_config.CHECKPT_SAVE_FREQ):
                    save_path = data_config.CHECKPOINT_DIR / f"train_{epoch}.pt"
                    logger.info(
                        f"Loss improved from {best_loss:.5e} to {epoch_loss:.5e},"
                        f"saving model to {save_path}")
                    best_loss, last_checkpoint_epoch = epoch_loss, epoch
                    torch.save(model.state_dict(), save_path)

                logger.info(
                    f"Epoch loss: {epoch_loss:.5e}  -  Took {time.perf_counter() - epoch_start_time:.5f}s"
                )

                # Validation and other metrics
                if epoch % data_config.VAL_FREQ == 0 and epoch >= data_config.RECORD_START:
                    if data_config.USE_TB:
                        tensorboard.write_weights_grad(epoch)
                    with torch.no_grad():
                        validation_start_time = time.perf_counter()
                        val_epoch_loss = trainer.val_epoch()

                        if data_config.USE_TB:
                            print("Starting to compute TensorBoard metrics",
                                  end="\r",
                                  flush=True)
                            tensorboard.write_loss(epoch,
                                                   val_epoch_loss,
                                                   mode="Validation")

                            # Metrics for the Train dataset
                            tensorboard.write_images(epoch)
                            tensorboard.write_metrics(epoch)
                            train_acc = metrics.get_avg_acc()

                            # Metrics for the Validation dataset
                            tensorboard.write_images(epoch, mode="Validation")
                            tensorboard.write_metrics(epoch, mode="Validation")
                            val_acc = metrics.get_avg_acc()

                            logger.info(
                                f"Train accuracy: {train_acc:.3f}  -  Validation accuracy: {val_acc:.3f}"
                            )

                        logger.info(
                            f"Validation loss: {val_epoch_loss:.5e}  -  "
                            f"Took {time.perf_counter() - validation_start_time:.5f}s"
                        )
                scheduler.step()
        except KeyboardInterrupt:
            print("\n")
        except Exception as error:
            logger.error(''.join(traceback.format_exception(*sys.exc_info())))
            raise error

    if data_config.USE_TB:
        metrics = {
            "Z - Final Results/Train loss": epoch_loss,
            "Z - Final Results/Validation loss": val_epoch_loss,
            "Z - Final Results/Train accuracy": train_acc,
            "Z - Final Results/Validation accuracy": val_acc
        }
        tensorboard.write_config(get_dataclass_as_dict(model_config), metrics)
        tensorboard.close_writers()

    train_stop_time = time.time()
    end_msg = f"Finished Training\n\tTraining time : {train_stop_time - train_start_time:.03f}s"
    memory_peak, gpu_memory = resource_usage()
    end_msg += f"\n\tRAM peak : {memory_peak // 1024} MB\n\tVRAM usage : {gpu_memory}"
    logger.info(end_msg)