def main(config, step):
    devices = ['cpu', 'cuda']
    mask_classes = ['both', 'ggo', 'merge']
    backbones = ['resnet50', 'resnet34', 'resnet18']
    truncation_levels = ['0', '1', '2']
    assert config.device in devices
    assert config.backbone_name in backbones
    assert config.truncation in truncation_levels

    assert config.mask_type in mask_classes
    if config.device == 'cuda' and torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')
    # get the configuration
    # get the thresholds
    confidence_threshold, mask_threshold, save_dir, data_dir, img_dir, gt_dir, mask_type, rpn_nms, roi_nms, backbone_name, truncation \
        = config.confidence_th, config.mask_logits_th, config.save_dir, config.test_data_dir, config.test_imgs_dir, \
        config.gt_dir, config.mask_type, config.rpn_nms_th, config.roi_nms_th, config.backbone_name, config.truncation

    if mask_type == "both":
        n_c = 3
    else:
        n_c = 2
    ckpt = torch.load(config.ckpt, map_location=device)

    model_name = None
    if 'model_name' in ckpt.keys():
        model_name = ckpt['model_name']
    sizes = ckpt['anchor_generator'].sizes
    aspect_ratios = ckpt['anchor_generator'].aspect_ratios
    anchor_generator = AnchorGenerator(sizes, aspect_ratios)
    print("Anchors: ", anchor_generator.sizes, anchor_generator.aspect_ratios)

    # create modules
    # this assumes FPN with 256 channels
    box_head = TwoMLPHead(in_channels=7 * 7 * 256, representation_size=128)
    if backbone_name == 'resnet50':
        maskrcnn_heads = None
        box_predictor = FastRCNNPredictor(in_channels=128, num_classes=n_c)
        mask_predictor = MaskRCNNPredictor(in_channels=256,
                                           dim_reduced=256,
                                           num_classes=n_c)
    else:
        #Backbone->FPN->boxhead->boxpredictor
        box_predictor = FastRCNNPredictor(in_channels=128, num_classes=n_c)
        maskrcnn_heads = MaskRCNNHeads(in_channels=256,
                                       layers=(128, ),
                                       dilation=1)
        mask_predictor = MaskRCNNPredictor(in_channels=128,
                                           dim_reduced=128,
                                           num_classes=n_c)

    # keyword arguments
    maskrcnn_args = {
        'num_classes': None,
        'min_size': 512,
        'max_size': 1024,
        'box_detections_per_img': 100,
        'box_nms_thresh': roi_nms,
        'box_score_thresh': confidence_threshold,
        'rpn_nms_thresh': rpn_nms,
        'box_head': box_head,
        'rpn_anchor_generator': anchor_generator,
        'mask_head': maskrcnn_heads,
        'mask_predictor': mask_predictor,
        'box_predictor': box_predictor
    }

    # Instantiate the segmentation model
    maskrcnn_model = mask_net.maskrcnn_resnet_fpn(backbone_name,
                                                  truncation,
                                                  pretrained_backbone=False,
                                                  **maskrcnn_args)
    # Load weights
    maskrcnn_model.load_state_dict(ckpt['model_weights'])
    # Set to evaluation mode
    print(maskrcnn_model)
    maskrcnn_model.eval().to(device)

    start_time = time.time()
    # get the correct masks and mask colors
    if mask_type == "ggo":
        ct_classes = {0: '__bgr', 1: 'GGO'}
        ct_colors = {1: 'red', 'mask_cols': np.array([[255, 0, 0]])}
    elif mask_type == "merge":
        ct_classes = {0: '__bgr', 1: 'Lesion'}
        ct_colors = {1: 'red', 'mask_cols': np.array([[255, 0, 0]])}
    elif mask_type == "both":
        ct_classes = {0: '__bgr', 1: 'GGO', 2: 'CL'}
        ct_colors = {
            1: 'red',
            2: 'blue',
            'mask_cols': np.array([[255, 0, 0], [0, 0, 255]])
        }

    if not save_dir in os.listdir('.'):
        os.mkdir(save_dir)

    # model name from config, not checkpoint
    if model_name is None:
        model_name = "maskrcnn_segmentation"
    elif model_name is not None and config.model_name != model_name:
        print("Using model name from the config.")
        model_name = config.model_name

    # run the inference with provided hyperparameters
    test_ims = os.listdir(os.path.join(data_dir, img_dir))
    for j, ims in enumerate(test_ims):
        step(os.path.join(os.path.join(data_dir, img_dir), ims), device,
             maskrcnn_model, model_name, confidence_threshold, mask_threshold,
             save_dir, ct_classes, ct_colors, j)
    end_time = time.time()
    print("Inference took {0:.1f} seconds".format(end_time - start_time))
def main(config, main_step):
    devices = ['cpu', 'cuda']
    mask_classes = ['both', 'ggo', 'merge']
    truncation_levels = ['0', '1', '2']
    backbones = ['resnet50', 'resnet34', 'resnet18']
    assert config.backbone_name in backbones
    assert config.mask_type in mask_classes
    assert config.truncation in truncation_levels

    # import arguments from the config file
    start_epoch, model_name, use_pretrained_resnet_backbone, num_epochs, save_dir, train_data_dir, val_data_dir, imgs_dir, gt_dir, batch_size, device, save_every, lrate, rpn_nms, mask_type, backbone_name, truncation = \
        config.start_epoch, config.model_name, config.use_pretrained_resnet_backbone, config.num_epochs, config.save_dir, \
        config.train_data_dir, config.val_data_dir, config.imgs_dir, config.gt_dir, config.batch_size, config.device, config.save_every, config.lrate, config.rpn_nms_th, config.mask_type, config.backbone_name, config.truncation

    assert device in devices
    if not save_dir in os.listdir('.'):
        os.mkdir(save_dir)

    if batch_size > 1:
        print("The model was implemented for batch size of one")
    if device == 'cuda' and torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    print(device)

    # Load the weights if provided
    if config.pretrained_model is not None:
        pretrained_model = torch.load(config.pretrained_model,
                                      map_location=device)
        use_pretrained_resnet_backbone = False
    else:
        pretrained_model = None
    torch.manual_seed(time.time())
    ##############################################################################################
    # DATASETS + DATALOADERS
    # Alex: could be added in the config file in the future
    # parameters for the dataset
    dataset_covid_pars_train = {
        'stage': 'train',
        'gt': os.path.join(train_data_dir, gt_dir),
        'data': os.path.join(train_data_dir, imgs_dir),
        'mask_type': mask_type,
        'ignore_small': True
    }
    datapoint_covid_train = dataset.CovidCTData(**dataset_covid_pars_train)

    dataset_covid_pars_eval = {
        'stage': 'eval',
        'gt': os.path.join(val_data_dir, gt_dir),
        'data': os.path.join(val_data_dir, imgs_dir),
        'mask_type': mask_type,
        'ignore_small': True
    }
    datapoint_covid_eval = dataset.CovidCTData(**dataset_covid_pars_eval)
    ###############################################################################################
    dataloader_covid_pars_train = {'shuffle': True, 'batch_size': batch_size}
    dataloader_covid_train = data.DataLoader(datapoint_covid_train,
                                             **dataloader_covid_pars_train)
    #
    dataloader_covid_pars_eval = {'shuffle': False, 'batch_size': batch_size}
    dataloader_covid_eval = data.DataLoader(datapoint_covid_eval,
                                            **dataloader_covid_pars_eval)
    ###############################################################################################
    # MASK R-CNN model
    # Alex: these settings could also be added to the config
    if mask_type == "both":
        n_c = 3
    else:
        n_c = 2
    maskrcnn_args = {
        'min_size': 512,
        'max_size': 1024,
        'rpn_batch_size_per_image': 256,
        'rpn_positive_fraction': 0.75,
        'box_positive_fraction': 0.75,
        'box_fg_iou_thresh': 0.75,
        'box_bg_iou_thresh': 0.5,
        'num_classes': None,
        'box_batch_size_per_image': 256,
        'rpn_nms_thresh': rpn_nms
    }

    # Alex: for Ground glass opacity and consolidatin segmentation
    # many small anchors
    # use all outputs of FPN
    # IMPORTANT!! For the pretrained weights, this determines the size of the anchor layer in RPN!!!!
    # pretrained model must have anchors
    if pretrained_model is None:
        anchor_generator = AnchorGenerator(sizes=tuple([(2, 4, 8, 16, 32)
                                                        for r in range(5)]),
                                           aspect_ratios=tuple([
                                               (0.1, 0.25, 0.5, 1, 1.5, 2)
                                               for rh in range(5)
                                           ]))
    else:
        print("Loading the anchor generator")
        sizes = pretrained_model['anchor_generator'].sizes
        aspect_ratios = pretrained_model['anchor_generator'].aspect_ratios
        anchor_generator = AnchorGenerator(sizes=sizes,
                                           aspect_ratios=aspect_ratios)
        print(anchor_generator, anchor_generator.num_anchors_per_location())
    # num_classes:3 (1+2)
    # in_channels
    # 256: number if channels from FPN
    # For the ResNet50+FPN: keep the torchvision architecture, but with 128 features
    # For lightweights models: re-implement MaskRCNNHeads with a single layer
    box_head = TwoMLPHead(in_channels=256 * 7 * 7, representation_size=128)
    if backbone_name == 'resnet50':
        maskrcnn_heads = None
        box_predictor = FastRCNNPredictor(in_channels=128, num_classes=n_c)
        mask_predictor = MaskRCNNPredictor(in_channels=256,
                                           dim_reduced=256,
                                           num_classes=n_c)
    else:
        #Backbone->FPN->boxhead->boxpredictor
        box_predictor = FastRCNNPredictor(in_channels=128, num_classes=n_c)
        maskrcnn_heads = MaskRCNNHeads(in_channels=256,
                                       layers=(128, ),
                                       dilation=1)
        mask_predictor = MaskRCNNPredictor(in_channels=128,
                                           dim_reduced=128,
                                           num_classes=n_c)

    maskrcnn_args['box_head'] = box_head
    maskrcnn_args['rpn_anchor_generator'] = anchor_generator
    maskrcnn_args['mask_head'] = maskrcnn_heads
    maskrcnn_args['mask_predictor'] = mask_predictor
    maskrcnn_args['box_predictor'] = box_predictor
    # Instantiate the segmentation model
    maskrcnn_model = mask_net.maskrcnn_resnet_fpn(
        backbone_name,
        truncation,
        pretrained_backbone=use_pretrained_resnet_backbone,
        **maskrcnn_args)
    # pretrained?
    print(maskrcnn_model.backbone.out_channels)
    if pretrained_model is not None:
        print("Loading pretrained weights")
        maskrcnn_model.load_state_dict(pretrained_model['model_weights'])
        if pretrained_model['epoch']:
            start_epoch = int(pretrained_model['epoch']) + 1
        if 'model_name' in pretrained_model.keys():
            model_name = str(pretrained_model['model_name'])

    # Set to training mode
    print(maskrcnn_model)
    maskrcnn_model.train().to(device)

    optimizer_pars = {'lr': lrate, 'weight_decay': 1e-3}
    optimizer = torch.optim.Adam(list(maskrcnn_model.parameters()),
                                 **optimizer_pars)
    if pretrained_model is not None and 'optimizer_state' in pretrained_model.keys(
    ):
        optimizer.load_state_dict(pretrained_model['optimizer_state'])

    start_time = time.time()
    if start_epoch > 0:
        num_epochs += start_epoch
    print("Start training, epoch = {:d}".format(start_epoch))
    for e in range(start_epoch, num_epochs):
        train_loss_epoch = main_step("train", e, dataloader_covid_train,
                                     optimizer, device, maskrcnn_model,
                                     save_every, lrate, model_name, None, None)
        eval_loss_epoch = main_step("eval", e, dataloader_covid_eval,
                                    optimizer, device, maskrcnn_model,
                                    save_every, lrate, model_name,
                                    anchor_generator, save_dir)
        print("Epoch {0:d}: train loss = {1:.3f}, validation loss = {2:.3f}".
              format(e, train_loss_epoch, eval_loss_epoch))
    end_time = time.time()
    print("Training took {0:.1f} seconds".format(end_time - start_time))
Esempio n. 3
0
def main(config, main_step):
    torch.manual_seed(time.time())
    start_time = time.time()
    devices = ['cpu', 'cuda']
    backbones = ['resnet50', 'resnet34', 'resnet18']
    truncation_levels = ['0', '1', '2']
    assert config.device in devices
    assert config.backbone_name in backbones
    assert config.truncation in truncation_levels

    start_epoch, pretrained_classifier, pretrained_segment, model_name, num_epochs, save_dir, train_data_dir, val_data_dir, \
    batch_size, device, save_every, lrate, rpn_nms, roi_nms, backbone_name, truncation, roi_batch_size, n_c, s_features = \
                                            config.start_epoch, config.pretrained_classification_model, \
                                            config.pretrained_segmentation_model, \
                                            config.model_name, config.num_epochs, config.save_dir, \
                                            config.train_data_dir, config.val_data_dir, \
                                            config.batch_size, config.device, config.save_every, \
                                            config.lrate, config.rpn_nms_th, config.roi_nms_th, \
                                            config.backbone_name, config.truncation, \
                                            config.roi_batch_size, config.num_classes, config.s_features

    if pretrained_classifier is not None and pretrained_segment is not None:
        print("Not clear which model to use, switching to the classifier")
        pretrained_model = pretrained_classifier
    elif pretrained_classifier is not None and pretrained_segment is None:
        pretrained_model = pretrained_classifier
    else:
        pretrained_model = pretrained_segment

    if device == 'cuda' and torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')
    ##############################################################################################
    # DATASETS+DATALOADERS
    # Alex: could be added in the config file in the future
    # parameters for the dataset
    # 512x512 is the recommended image size input
    dataset_covid_pars_train_cl = {
        'stage': 'train',
        'data': train_data_dir,
        'img_size': (512, 512)
    }
    datapoint_covid_train_cl = dataset.COVID_CT_DATA(
        **dataset_covid_pars_train_cl)
    #
    dataset_covid_pars_eval_cl = {
        'stage': 'eval',
        'data': val_data_dir,
        'img_size': (512, 512)
    }
    datapoint_covid_eval_cl = dataset.COVID_CT_DATA(
        **dataset_covid_pars_eval_cl)
    #
    dataloader_covid_pars_train_cl = {
        'shuffle': True,
        'batch_size': batch_size
    }
    dataloader_covid_train_cl = data.DataLoader(
        datapoint_covid_train_cl, **dataloader_covid_pars_train_cl)
    #
    dataloader_covid_pars_eval_cl = {'shuffle': True, 'batch_size': batch_size}
    dataloader_covid_eval_cl = data.DataLoader(datapoint_covid_eval_cl,
                                               **dataloader_covid_pars_eval_cl)
    #
    ##### LOAD PRETRAINED WEIGHTS FROM MASK R-CNN MODEL
    # This must be the full path to the checkpoint with the anchor generator and model weights
    # Assumed that the keys in the checkpoint are model_weights and anchor_generator
    ckpt = torch.load(pretrained_model, map_location=device)
    # keyword arguments
    # box_score_threshold:negative!
    # set both NMS thresholds to 0.75 to get adjacent RoIs
    # Box detections/image: batch size for the classifier
    #
    covid_mask_net_args = {
        'num_classes': None,
        'min_size': 512,
        'max_size': 1024,
        'box_detections_per_img': roi_batch_size,
        'box_nms_thresh': roi_nms,
        'box_score_thresh': -0.01,
        'rpn_nms_thresh': rpn_nms
    }

    # copy the anchor generator parameters, create a new one to avoid implementations' clash
    sizes = ckpt['anchor_generator'].sizes
    aspect_ratios = ckpt['anchor_generator'].aspect_ratios
    anchor_generator = AnchorGenerator(sizes, aspect_ratios)
    # out_channels:256, FPN
    # num_classes:3 (1+2)
    box_head = TwoMLPHead(in_channels=256 * 7 * 7, representation_size=128)
    box_predictor = FastRCNNPredictor(in_channels=128, num_classes=n_c)

    covid_mask_net_args['rpn_anchor_generator'] = anchor_generator
    covid_mask_net_args['box_predictor'] = box_predictor
    covid_mask_net_args['box_head'] = box_head
    covid_mask_net_args['s_representation_size'] = s_features
    # Instantiate the model
    covid_mask_net_model = mask_net.fasterrcnn_resnet_fpn(
        backbone_name, truncation, **covid_mask_net_args)
    # which parameters to train?
    trained_pars = []
    # if the weights are loaded from the segmentation model:
    if pretrained_classifier is None:
        for _n, _par in covid_mask_net_model.state_dict().items():
            if _n in ckpt['model_weights']:
                print('Loading parameter', _n)
                _par.copy_(ckpt['model_weights'][_n])
    # if the weights are loaded from the classification model
    else:
        covid_mask_net_model.load_state_dict(ckpt['model_weights'])
        if 'epoch' in ckpt.keys():
            start_epoch = int(ckpt['epoch']) + 1
        if 'model_name' in ckpt.keys():
            model_name = str(ckpt['model_name'])

    # Evaluation mode, no labels!
    covid_mask_net_model.eval()
    # set the model to training mode without triggering the 'training' mode of Mask R-CNN
    # set up the optimizer
    utils.switch_model_on(covid_mask_net_model, ckpt, trained_pars)
    utils.set_to_train_mode(covid_mask_net_model)
    print(covid_mask_net_model)
    covid_mask_net_model = covid_mask_net_model.to(device)
    total_trained_pars = sum([x.numel() for x in trained_pars])
    print("Total trained pars {0:d}".format(total_trained_pars))
    optimizer_pars = {'lr': lrate, 'weight_decay': 1e-3}
    optimizer = torch.optim.Adam(trained_pars, **optimizer_pars)
    if pretrained_classifier is not None and 'optimizer_state' in ckpt.keys():
        optimizer.load_state_dict(ckpt['optimizer_state'])

    if start_epoch > 0:
        num_epochs += start_epoch
    print("Start training, epoch = {:d}".format(start_epoch))
    for e in range(start_epoch, num_epochs):
        train_loss_epoch = main_step("train", e, dataloader_covid_train_cl,
                                     optimizer, device, covid_mask_net_model,
                                     save_every, lrate, model_name, None, None)
        eval_loss_epoch = main_step("eval", e, dataloader_covid_eval_cl,
                                    optimizer, device, covid_mask_net_model,
                                    save_every, lrate, model_name,
                                    anchor_generator, save_dir)
        print("Epoch {0:d}: train loss = {1:.3f}, validation loss = {2:.3f}".
              format(e, train_loss_epoch, eval_loss_epoch))
    end_time = time.time()
    print("Training took {0:.1f} seconds".format(end_time - start_time))
def main(config, step):
    torch.manual_seed(time.time())
    start_time = time.time()
    devices = ['cpu', 'cuda']
    backbones = ['resnet50', 'resnet34', 'resnet18']
    truncation_levels = ['0', '1', '2']

    assert config.device in devices
    assert config.backbone_name in backbones
    assert config.truncation in truncation_levels

    pretrained_model, model_name, test_data_dir, device, rpn_nms, roi_nms, backbone_name, truncation, roi_batch_size, n_c, s_features\
              = config.ckpt, config.model_name, config.test_data_dir, config.device, config.rpn_nms_th, \
                config.roi_nms_th, config.backbone_name, config.truncation, config.roi_batch_size, config.num_classes, config.s_features

    if torch.cuda.is_available() and device == 'cuda':
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')
    # either 2+1 or 1+1 classes
    ckpt = torch.load(pretrained_model, map_location=device)
    # 'box_detections_per_img': batch size input in module S
    # 'box_score_thresh': negative to accept all predictions
    covid_mask_net_args = {
        'num_classes': None,
        'min_size': 512,
        'max_size': 1024,
        'box_detections_per_img': roi_batch_size,
        'box_nms_thresh': roi_nms,
        'box_score_thresh': -0.01,
        'rpn_nms_thresh': rpn_nms
    }

    print(covid_mask_net_args)
    # extract anchor generator from the checkpoint
    sizes = ckpt['anchor_generator'].sizes
    aspect_ratios = ckpt['anchor_generator'].aspect_ratios
    anchor_generator = AnchorGenerator(sizes, aspect_ratios)
    # Faster R-CNN interfaces, masks not implemented at this stage
    box_head = TwoMLPHead(in_channels=256 * 7 * 7, representation_size=128)
    box_predictor = FastRCNNPredictor(in_channels=128, num_classes=n_c)
    # Mask prediction is not necessary, keep it for future extensions
    covid_mask_net_args['rpn_anchor_generator'] = anchor_generator
    covid_mask_net_args['box_predictor'] = box_predictor
    covid_mask_net_args['box_head'] = box_head
    # representation size of the S classification module
    # these should be provided in the config
    covid_mask_net_args['s_representation_size'] = s_features
    # Instance of the model, copy weights
    covid_mask_net_model = mask_net.fasterrcnn_resnet_fpn(
        backbone_name, truncation, **covid_mask_net_args)
    covid_mask_net_model.load_state_dict(ckpt['model_weights'])
    covid_mask_net_model.eval().to(device)
    print(covid_mask_net_model)
    # confusion matrix
    confusion_matrix = torch.zeros(3, 3, dtype=torch.int32).to(device)

    for idx, f in enumerate(os.listdir(test_data_dir)):
        step(f, covid_mask_net_model, test_data_dir, device, confusion_matrix)

    print("------------------------------------------")
    print("Confusion Matrix for 3-class problem:")
    print("0: Control, 1: Normal Pneumonia, 2: COVID")
    print(confusion_matrix)
    print("------------------------------------------")
    # confusion matrix
    cm = confusion_matrix.float()
    cm[0, :].div_(cm[0, :].sum())
    cm[1, :].div_(cm[1, :].sum())
    cm[2, :].div_(cm[2, :].sum())
    print("------------------------------------------")
    print("Class Sensitivity:")
    print(cm)
    print("------------------------------------------")
    print('Overall accuracy:')
    print(confusion_matrix.diag().float().sum().div(confusion_matrix.sum()))
    end_time = time.time()
    print("Evaluation took {0:.1f} seconds".format(end_time - start_time))
Esempio n. 5
0
def main(config, step):
    torch.manual_seed(time.time())
    start_time = time.time()
    devices = ['cpu', 'cuda']
    pretrained_model, model_name, test_data_dir, device = config.ckpt, config.model_name, config.test_data_dir, config.device
    assert device in devices

    if torch.cuda.is_available() and device == 'cuda':
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    ckpt = torch.load(pretrained_model, map_location=device)
    # 'box_detections_per_img': batch size input in module S
    # 'box_score_thresh': negative to accept all predictions
    covid_mask_net_args = {
        'num_classes': None,
        'min_size': 512,
        'max_size': 1024,
        'box_detections_per_img': 256,
        'box_nms_thresh': 0.75,
        'box_score_thresh': -0.01,
        'rpn_nms_thresh': 0.75
    }
    print(covid_mask_net_args)
    # extract anchor generator from the checkpoint
    sizes = ckpt['anchor_generator'].sizes
    aspect_ratios = ckpt['anchor_generator'].aspect_ratios
    anchor_generator = AnchorGenerator(sizes, aspect_ratios)
    box_head_input_size = 256 * 7 * 7
    box_head = TwoMLPHead(in_channels=box_head_input_size,
                          representation_size=128)
    box_predictor = FastRCNNPredictor(in_channels=128, num_classes=3)
    # Mask prediction is not necessary, keep it for future extensions
    mask_predictor = MaskRCNNPredictor(in_channels=256,
                                       dim_reduced=256,
                                       num_classes=3)

    covid_mask_net_args['rpn_anchor_generator'] = anchor_generator
    covid_mask_net_args['mask_predictor'] = mask_predictor
    covid_mask_net_args['box_predictor'] = box_predictor
    covid_mask_net_args['box_head'] = box_head
    # Instance of the model, copy weights
    covid_mask_net_model = mask_net.maskrcnn_resnet50_fpn(
        pretrained=False,
        pretrained_backbone=False,
        progress=False,
        **covid_mask_net_args)
    covid_mask_net_model.load_state_dict(ckpt['model_weights'])
    covid_mask_net_model.eval().to(device)
    # confusion matrix
    confusion_matrix = torch.zeros(3, 3, dtype=torch.int32).to(device)

    for idx, f in enumerate(os.listdir(test_data_dir)):
        step(f, covid_mask_net_model, test_data_dir, device, confusion_matrix)

    print("------------------------------------------")
    print("Confusion Matrix for 3-class problem:")
    print("0: Control, 1: Normal Pneumonia, 2: COVID")
    print(confusion_matrix)
    print("------------------------------------------")
    # confusion matrix
    cm = confusion_matrix.float()
    cm[0, :].div_(cm[0, :].sum())
    cm[1, :].div_(cm[1, :].sum())
    cm[2, :].div_(cm[2, :].sum())
    print("------------------------------------------")
    print("Class Sensitivity:")
    print(cm)
    print("------------------------------------------")
    print('Overall accuracy:')
    print(cm.diag().sum().div(cm.sum()).item())
    end_time = time.time()
    print("Evaluation took {0:.1f} seconds".format(end_time - start_time))
Esempio n. 6
0
def main(config, step):
    devices = ['cpu', 'cuda']
    mask_classes = ['both', 'ggo', 'merge']
    backbones = ['resnet50', 'resnet34', 'resnet18']
    truncation_levels = ['0', '1', '2']
    assert config.device in devices
    assert config.backbone_name in backbones
    assert config.truncation in truncation_levels

    if config.device == 'cuda' and torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')
    #
    model_name = None
    ckpt = torch.load(config.ckpt, map_location=device)
    if 'model_name' in ckpt.keys():
        model_name = ckpt['model_name']

    device = torch.device('cpu')
    if torch.cuda.is_available():
        device = torch.device('cuda')

    # get the thresholds
    confidence_threshold, mask_threshold, save_dir, data_dir, img_dir, gt_dir, mask_type, rpn_nms, roi_nms, backbone_name, truncation \
        = config.confidence_th, config.mask_logits_th, config.save_dir, config.test_data_dir, config.test_imgs_dir, \
        config.gt_dir, config.mask_type, config.rpn_nms_th, config.roi_nms_th, config.backbone_name, config.truncation

    if model_name is None:
        model_name = "maskrcnn_segmentation"
    elif model_name is not None and config.model_name != model_name:
        print("Using model name from the config.")
        model_name = config.model_name

    # either 2+1 or 1+1 classes
    assert mask_type in mask_classes
    if mask_type == "both":
        n_c = 3
    else:
        n_c = 2
    # dataset interface
    dataset_covid_eval_pars = {
        'stage': 'eval',
        'gt': os.path.join(data_dir, gt_dir),
        'data': os.path.join(data_dir, img_dir),
        'mask_type': mask_type,
        'ignore_small': True
    }
    datapoint_eval_covid = dataset.CovidCTData(**dataset_covid_eval_pars)
    dataloader_covid_eval_pars = {'shuffle': False, 'batch_size': 1}
    dataloader_eval_covid = data.DataLoader(datapoint_eval_covid,
                                            **dataloader_covid_eval_pars)
    # MASK R-CNN model
    # Alex: these settings could also be added to the config
    ckpt = torch.load(config.ckpt, map_location=device)
    sizes = ckpt['anchor_generator'].sizes
    aspect_ratios = ckpt['anchor_generator'].aspect_ratios
    anchor_generator = AnchorGenerator(sizes, aspect_ratios)
    print("Anchors: ", anchor_generator.sizes, anchor_generator.aspect_ratios)

    # create modules
    # this assumes FPN with 256 channels
    box_head = TwoMLPHead(in_channels=7 * 7 * 256, representation_size=128)
    if backbone_name == 'resnet50':
        maskrcnn_heads = None
        box_predictor = FastRCNNPredictor(in_channels=128, num_classes=n_c)
        mask_predictor = MaskRCNNPredictor(in_channels=256,
                                           dim_reduced=256,
                                           num_classes=n_c)
    else:
        #Backbone->FPN->boxhead->boxpredictor
        box_predictor = FastRCNNPredictor(in_channels=128, num_classes=n_c)
        maskrcnn_heads = MaskRCNNHeads(in_channels=256,
                                       layers=(128, ),
                                       dilation=1)
        mask_predictor = MaskRCNNPredictor(in_channels=128,
                                           dim_reduced=128,
                                           num_classes=n_c)

    # keyword arguments
    maskrcnn_args = {
        'num_classes': None,
        'min_size': 512,
        'max_size': 1024,
        'box_detections_per_img': 128,
        'box_nms_thresh': roi_nms,
        'box_score_thresh': confidence_threshold,
        'rpn_nms_thresh': rpn_nms,
        'box_head': box_head,
        'rpn_anchor_generator': anchor_generator,
        'mask_head': maskrcnn_heads,
        'mask_predictor': mask_predictor,
        'box_predictor': box_predictor
    }

    # Instantiate the segmentation model
    maskrcnn_model = mask_net.maskrcnn_resnet_fpn(backbone_name,
                                                  truncation,
                                                  pretrained_backbone=False,
                                                  **maskrcnn_args)
    # Load weights
    maskrcnn_model.load_state_dict(ckpt['model_weights'])
    # Set to the evaluation mode
    print(maskrcnn_model)
    maskrcnn_model.eval().to(device)
    # IoU thresholds. By default the model computes AP for each threshold between 0.5 and 0.95 with the step of 0.05
    thresholds = torch.arange(0.5, 1, 0.05).to(device)
    mean_aps_all_th = torch.zeros(thresholds.size()[0]).to(device)
    ap_th = OrderedDict()
    # run the loop for all thresholds
    for t, th in enumerate(thresholds):
        # main method
        ap = step(maskrcnn_model, th, dataloader_eval_covid, device,
                  mask_threshold)
        mean_aps_all_th[t] = ap
        th_name = 'AP@{0:.2f}'.format(th)
        ap_th[th_name] = ap
    print("Done evaluation for {}".format(model_name))
    print("mAP:{0:.4f}".format(mean_aps_all_th.mean().item()))
    for k, aps in ap_th.items():
        print("{0:}:{1:.4f}".format(k, aps))
def main(config, main_step):
    torch.manual_seed(time.time())
    start_time = time.time()
    devices = ['cpu', 'cuda']
    updates = ['heads', 'heads_bn', 'full']
    start_epoch, update_type, pretrained_classifier, pretrained_segmenter, model_name, num_epochs, save_dir, train_data_dir, val_data_dir, \
    batch_size, device, save_every, lrate = config.start_epoch, config.update_type, config.pretrained_classification_model, \
                                            config.pretrained_segmentation_model, \
                                            config.model_name, config.num_epochs, config.save_dir, \
                                            config.train_data_dir, config.val_data_dir, \
                                            config.batch_size, config.device, config.save_every, config.lrate

    if pretrained_classifier is not None and pretrained_segmenter is not None:
        print("Not clear which model to use, switching to the classifier")
        pretrained_model = pretrained_classifier
    elif pretrained_classifier is not None and pretrained_segmenter is None:
        pretrained_model = pretrained_classifier
    else:
        pretrained_model = pretrained_segmenter

    assert device in devices
    assert update_type in updates
    if device == 'cuda' and torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')
    ##############################################################################################
    # DATASETS+DATALOADERS
    # Alex: could be added in the config file in the future
    # parameters for the dataset
    # 512x512 is the recommended image size input
    dataset_covid_pars_train_cl = {
        'stage': 'train',
        'data': train_data_dir,
        'img_size': 512
    }
    datapoint_covid_train_cl = dataset.COVID_CT_DATA(
        **dataset_covid_pars_train_cl)
    #
    dataset_covid_pars_eval_cl = {
        'stage': 'eval',
        'data': val_data_dir,
        'img_size': 512
    }
    datapoint_covid_eval_cl = dataset.COVID_CT_DATA(
        **dataset_covid_pars_eval_cl)
    #
    dataloader_covid_pars_train_cl = {
        'shuffle': True,
        'batch_size': batch_size
    }
    dataloader_covid_train_cl = data.DataLoader(
        datapoint_covid_train_cl, **dataloader_covid_pars_train_cl)
    #
    dataloader_covid_pars_eval_cl = {'shuffle': True, 'batch_size': batch_size}
    dataloader_covid_eval_cl = data.DataLoader(datapoint_covid_eval_cl,
                                               **dataloader_covid_pars_eval_cl)
    #
    ##### LOAD PRETRAINED WEIGHTS FROM MASK R-CNN MODEL
    # This must be the full path to the checkpoint with the anchor generator and model weights
    # Assumed that the keys in the checkpoint are model_weights and anchor_generator
    ckpt = torch.load(pretrained_model, map_location=device)
    # keyword arguments
    # box_score_threshold:negative!
    # set both NMS thresholds to 0.75 to get adjacent RoIs
    # Box detections/image: batch size for the classifier
    #
    covid_mask_net_args = {
        'num_classes': None,
        'min_size': 512,
        'max_size': 1024,
        'box_detections_per_img': 256,
        'box_nms_thresh': 0.75,
        'box_score_thresh': -0.01,
        'rpn_nms_thresh': 0.75
    }

    # copy the anchor generator parameters, create a new one to avoid implementations' clash
    sizes = ckpt['anchor_generator'].sizes
    aspect_ratios = ckpt['anchor_generator'].aspect_ratios
    anchor_generator = AnchorGenerator(sizes, aspect_ratios)
    # out_channels:256
    # num_classes:3 (1+2)
    box_head_input_size = 256 * 7 * 7
    box_head = TwoMLPHead(in_channels=box_head_input_size,
                          representation_size=128)
    box_predictor = FastRCNNPredictor(in_channels=128, num_classes=3)
    # Mask prediction is not necessary, keep it for future extensions
    mask_predictor = MaskRCNNPredictor(in_channels=256,
                                       dim_reduced=256,
                                       num_classes=3)

    covid_mask_net_args['rpn_anchor_generator'] = anchor_generator
    covid_mask_net_args['mask_predictor'] = mask_predictor
    covid_mask_net_args['box_predictor'] = box_predictor
    covid_mask_net_args['box_head'] = box_head

    covid_mask_net_model = mask_net.maskrcnn_resnet50_fpn(
        pretrained=False,
        pretrained_backbone=False,
        progress=False,
        **covid_mask_net_args)

    # which parameters to train?
    trained_pars = []

    if pretrained_classifier is None:
        for _n, _par in covid_mask_net_model.state_dict().items():
            if _n in ckpt['model_weights']:
                print('Loading parameter', _n)
                _par.copy_(ckpt['model_weights'][_n])
    else:
        covid_mask_net_model.load_state_dict(ckpt['model_weights'])
        if ckpt['epoch']:
            start_epoch = int(ckpt['epoch'])
        if ckpt['model_name']:
            model_name = ckpt['model_name']

    # Evaluation mode, no labels!
    covid_mask_net_model.eval()
    # set the model to training mode without triggering the 'training' mode of Mask R-CNN
    utils.switch_model_on(covid_mask_net_model, trained_pars, update_type)
    utils.set_to_train_mode(covid_mask_net_model, update_type)
    print(covid_mask_net_model)
    covid_mask_net_model = covid_mask_net_model.to(device)

    total_trained_pars = sum([x.numel() for x in trained_pars])
    print("Total trained pars {0:d}".format(total_trained_pars))

    optimizer_pars = {'lr': 1e-5, 'weight_decay': 1e-3}
    optimizer = torch.optim.Adam(trained_pars, **optimizer_pars)
    if pretrained_classifier is not None and 'optimizer_state' in ckpt.keys():
        optimizer.load_state_dict(ckpt['optimizer_state'])

    for e in range(start_epoch, num_epochs):
        train_loss_epoch = main_step("train",
                                     e,
                                     dataloader_covid_train_cl,
                                     optimizer,
                                     device,
                                     covid_mask_net_model,
                                     save_every,
                                     lrate,
                                     model_name,
                                     None,
                                     None,
                                     update_type=update_type)
        eval_loss_epoch = main_step("eval",
                                    e,
                                    dataloader_covid_eval_cl,
                                    optimizer,
                                    device,
                                    covid_mask_net_model,
                                    save_every,
                                    lrate,
                                    model_name,
                                    anchor_generator,
                                    save_dir,
                                    update_type=update_type)
        print("Epoch {0:d}: train loss = {1:.3f}, validation loss = {2:.3f}".
              format(e, train_loss_epoch, eval_loss_epoch))
    end_time = time.time()
    print("Training took {0:.1f} seconds".format(end_time - start_time))