Beispiel #1
0
def test(args):
    print("Predicting ...")
    test_paths = os.listdir(os.path.join(args.dataset_dir, args.test_img_dir))
    print(len(test_paths), 'test images found')
    test_df = pd.DataFrame({'ImageId': test_paths, 'EncodedPixels': None})

    from skimage.morphology import binary_opening, disk

    test_df = test_df[:5000]
    test_loader = make_dataloader(test_df,
                                  args,
                                  batch_size=args.batch_size,
                                  shuffle=False,
                                  transform=None,
                                  mode='predict')

    model = UNet()
    if args.gpu and torch.cuda.is_available():
        model = model.cuda()
    run_id = 1
    print("Resuming run #{}...".format(run_id))
    model_path = Path('model_{run_id}.pt'.format(run_id=run_id))
    state = torch.load(str(model_path))
    state = {
        key.replace('module.', ''): value
        for key, value in state['model'].items()
    }
    model.load_state_dict(state)

    out_pred_rows = []

    for batch_id, (inputs,
                   image_paths) in enumerate(tqdm(test_loader,
                                                  desc='Predict')):
        if args.gpu and torch.cuda.is_available():
            inputs = inputs.cuda()
        inputs = torch.tensor(inputs)
        outputs = model(inputs)
        for i, image_name in enumerate(image_paths):
            mask = torch.sigmoid(outputs[i, 0]).data.cpu().numpy()
            cur_seg = binary_opening(mask > 0.5, disk(2))
            cur_rles = multi_rle_encode(cur_seg)
            if len(cur_rles) > 0:
                for c_rle in cur_rles:
                    out_pred_rows += [{
                        'ImageId': image_name,
                        'EncodedPixels': c_rle
                    }]
            else:
                out_pred_rows += [{
                    'ImageId': image_name,
                    'EncodedPixels': None
                }]

    submission_df = pd.DataFrame(out_pred_rows)[['ImageId', 'EncodedPixels']]
    submission_df.to_csv('submission.csv', index=False)
    print("done.")
Beispiel #2
0
def test(args):
    model = UNet(3, 1)
    model.load_state_dict(torch.load(args.weight, map_location='cpu'))
    verse_data = DatasetVerse(dir_img,
                              dir_mask,
                              transform=x_transform,
                              target_transform=y_transform)
    dataloaders = DataLoader(verse_data, batch_size=1)
    model.eval()
    import matplotlib.pyplot as plt
    plt.ion()
    with torch.no_grad():
        for x, _ in dataloaders:
            y = model(x).sigmoid()
            img_y = torch.squeeze(y).numpy()
            plt.imshow(img_y)
            plt.pause(0.01)
        plt.show()
Beispiel #3
0
    def init_model(self, CHANNELS_IN, CHANNELS_OUT, LOAD_MODEL,
                   MODEL_LOAD_PATH, MODEL_NAME, MODEL_SUFFIX,
                   USE_DECONV_LAYERS):
        """
		Initialization and loading model if needed.
			Int: CHANNELS_IN -> Number of input channels in UNet
			Int: CHANNELS_OUT -> Number of output channels in UNet
			Bool: LOAD_MODEL -> If True we need to load existing parameters.
			Str: MODEL_LOAD_PATH -> Path where models are stored
			Str: MODEL_NAME -> Name of loading model

		Returns: Model
		"""

        model = UNet(CHANNELS_IN, CHANNELS_OUT, not USE_DECONV_LAYERS)
        if LOAD_MODEL:
            model_state_dict = torch.load(MODEL_LOAD_PATH + MODEL_NAME +
                                          MODEL_SUFFIX)
            model.load_state_dict(model_state_dict)
        return model
Beispiel #4
0
def predict(args):
    batch_size = 16
    num_workers = 4
    postprocess = True if args.postprocess == "True" else False

    model = UNet(1, 1, first_out_channels=16)
    model.eval()
    if args.model_path is not None:
        model_weights = torch.load(args.model_path)
        model.load_state_dict(model_weights)
    model = nn.DataParallel(model).cuda()

    transforms = [tsfm.Window(-200, 1000), tsfm.MinMaxNorm(-200, 1000)]

    image_path_list = sorted([
        os.path.join(args.image_dir, file)
        for file in os.listdir(args.image_dir) if "nii" in file
    ])
    image_id_list = [
        os.path.basename(path).split("-")[0] for path in image_path_list
    ]

    progress = tqdm(total=len(image_id_list))
    pred_info_list = []
    for image_id, image_path in zip(image_id_list, image_path_list):
        dataset = FracNetInferenceDataset(image_path, transforms=transforms)
        dataloader = FracNetInferenceDataset.get_dataloader(
            dataset, batch_size, num_workers)
        pred_arr = _predict_single_image(model, dataloader, postprocess,
                                         args.prob_thresh, args.bone_thresh,
                                         args.size_thresh)
        pred_image, pred_info = _make_submission_files(pred_arr, image_id,
                                                       dataset.image_affine)
        pred_info_list.append(pred_info)
        pred_path = os.path.join(args.pred_dir, f"{image_id}_pred.nii.gz")
        nib.save(pred_image, pred_path)

        progress.update()

    pred_info = pd.concat(pred_info_list, ignore_index=True)
    pred_info.to_csv(os.path.join(args.pred_dir, "pred_info.csv"), index=False)
Beispiel #5
0
    # We do not expect these to be non-zero for an accurate mask,
    # so this should not harm the score.
    pixels[0] = 0
    pixels[-1] = 0
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 2
    runs[1::2] = runs[1::2] - runs[:-1:2]
    return runs


def submit(net):
    """Used for Kaggle submission: predicts and encode all test images"""
    dir = 'data/test/'
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    N = len(list(os.listdir(dir)))
    with open('SUBMISSION.csv', 'a') as f:
        f.write('img,rle_mask\n')
        for index, i in enumerate(os.listdir(dir)):
            print('{}/{}'.format(index, N))

            img = Image.open(dir + i)

            mask = predict_img(net, img, device)
            enc = rle_encode(mask)
            f.write('{},{}\n'.format(i, ' '.join(map(str, enc))))


if __name__ == '__main__':
    net = UNet(3, 1).cuda()
    net.load_state_dict(torch.load('MODEL.pth'))
    submit(net)
Beispiel #6
0
def main():
    global args, best_prec1
    args = parser.parse_args()
    print(args)

    if args.saveTest == 'True':
        args.saveTest = True
    elif args.saveTest == 'False':
        args.saveTest = False

    # Check if the save directory exists or not
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    cudnn.benchmark = True

    data_transforms = {
        'train': transforms.Compose([
            transforms.Resize((args.imageSize, args.imageSize), interpolation=Image.NEAREST),
            transforms.TenCrop(args.resizedImageSize),
            transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
            #transforms.Lambda(lambda normalized: torch.stack([transforms.Normalize([0.295, 0.204, 0.197], [0.221, 0.188, 0.182])(crop) for crop in normalized]))
            #transforms.RandomResizedCrop(224, interpolation=Image.NEAREST),
            #transforms.RandomHorizontalFlip(),
            #transforms.RandomVerticalFlip(),
            #transforms.ToTensor(),
        ]),
        'test': transforms.Compose([
            transforms.Resize((args.imageSize, args.imageSize), interpolation=Image.NEAREST),
            transforms.ToTensor(),
            #transforms.Normalize([0.295, 0.204, 0.197], [0.221, 0.188, 0.182])
        ]),
    }

    # Data Loading
    data_dir = 'datasets/miccaiSegRefined'
    # json path for class definitions
    json_path = 'datasets/miccaiSegClasses.json'

    image_datasets = {x: miccaiSegDataset(os.path.join(data_dir, x), data_transforms[x],
                        json_path) for x in ['train', 'test']}

    dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x],
                                                  batch_size=args.batchSize,
                                                  shuffle=True,
                                                  num_workers=args.workers)
                  for x in ['train', 'test']}
    dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'test']}

    # Get the dictionary for the id and RGB value pairs for the dataset
    classes = image_datasets['train'].classes
    key = utils.disentangleKey(classes)
    num_classes = len(key)

    # Initialize the model
    model = UNet(num_classes)

    # # Optionally resume from a checkpoint
    # if args.resume:
    #     if os.path.isfile(args.resume):
    #         print("=> loading checkpoint '{}'".format(args.resume))
    #         checkpoint = torch.load(args.resume)
    #         #args.start_epoch = checkpoint['epoch']
    #         pretrained_dict = checkpoint['state_dict']
    #         pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model.state_dict()}
    #         model.state_dict().update(pretrained_dict)
    #         model.load_state_dict(model.state_dict())
    #         print("=> loaded checkpoint (epoch {})".format(checkpoint['epoch']))
    #     else:
    #         print("=> no checkpoint found at '{}'".format(args.resume))
    #
    #     # # Freeze the encoder weights
    #     # for param in model.encoder.parameters():
    #     #     param.requires_grad = False
    #
    #     optimizer = optim.Adam(model.parameters(), lr = args.lr, weight_decay = args.wd)
    # else:
    optimizer = optim.Adam(model.parameters(), lr = args.lr, weight_decay = args.wd)

    # Load the saved model
    if os.path.isfile(args.resume):
        print("=> loading checkpoint '{}'".format(args.resume))
        checkpoint = torch.load(args.resume)
        args.start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint (epoch {})".format(checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(args.resume))

    print(model)

    # Define loss function (criterion)
    criterion = nn.CrossEntropyLoss()

    # Use a learning rate scheduler
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

    if use_gpu:
        model.cuda()
        criterion.cuda()

    # Initialize an evaluation Object
    evaluator = utils.Evaluate(key, use_gpu)

    for epoch in range(args.start_epoch, args.epochs):
        #adjust_learning_rate(optimizer, epoch)

        # Train for one epoch
        print('>>>>>>>>>>>>>>>>>>>>>>>Training<<<<<<<<<<<<<<<<<<<<<<<')
        train(dataloaders['train'], model, criterion, optimizer, scheduler, epoch, key)

        # Evaulate on validation set

        print('>>>>>>>>>>>>>>>>>>>>>>>Testing<<<<<<<<<<<<<<<<<<<<<<<')
        validate(dataloaders['test'], model, criterion, epoch, key, evaluator)

        # Calculate the metrics
        print('>>>>>>>>>>>>>>>>>> Evaluating the Metrics <<<<<<<<<<<<<<<<<')
        IoU = evaluator.getIoU()
        print('Mean IoU: {}, Class-wise IoU: {}'.format(torch.mean(IoU), IoU))
        PRF1 = evaluator.getPRF1()
        precision, recall, F1 = PRF1[0], PRF1[1], PRF1[2]
        print('Mean Precision: {}, Class-wise Precision: {}'.format(torch.mean(precision), precision))
        print('Mean Recall: {}, Class-wise Recall: {}'.format(torch.mean(recall), recall))
        print('Mean F1: {}, Class-wise F1: {}'.format(torch.mean(F1), F1))
        evaluator.reset()

        save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
        }, filename=os.path.join(args.save_dir, 'checkpoint_{}.tar'.format(epoch)))
Beispiel #7
0
def train(args):
    print("Traning")

    print("Prepaing data")
    masks = pd.read_csv(os.path.join(args.dataset_dir, args.train_masks))
    unique_img_ids = get_unique_img_ids(masks, args)
    train_df, valid_df = get_balanced_train_valid(masks, unique_img_ids, args)

    if args.stage == 0:
        train_shape = (256, 256)
        batch_size = args.stage0_batch_size
        extra_epoch = args.stage0_epochs
    elif args.stage == 1:
        train_shape = (384, 384)
        batch_size = args.stage1_batch_size
        extra_epoch = args.stage1_epochs
    elif args.stage == 2:
        train_shape = (512, 512)
        batch_size = args.stage2_batch_size
        extra_epoch = args.stage2_epochs
    elif args.stage == 3:
        train_shape = (768, 768)
        batch_size = args.stage3_batch_size
        extra_epoch = args.stage3_epochs

    print("Stage {}".format(args.stage))

    train_transform = DualCompose([
        Resize(train_shape),
        HorizontalFlip(),
        VerticalFlip(),
        RandomRotate90(),
        Shift(),
        Transpose(),
        # ImageOnly(RandomBrightness()),
        # ImageOnly(RandomContrast()),
    ])
    val_transform = DualCompose([
        Resize(train_shape),
    ])

    train_dataloader = make_dataloader(train_df,
                                       args,
                                       batch_size,
                                       args.shuffle,
                                       transform=train_transform)
    val_dataloader = make_dataloader(valid_df,
                                     args,
                                     batch_size // 2,
                                     args.shuffle,
                                     transform=val_transform)

    # Build model
    model = UNet()
    optimizer = Adam(model.parameters(), lr=args.lr)
    scheduler = StepLR(optimizer, step_size=args.decay_fr, gamma=0.1)
    if args.gpu and torch.cuda.is_available():
        model = model.cuda()

    # Restore model ...
    run_id = 4

    model_path = Path('model_{run_id}.pt'.format(run_id=run_id))
    if not model_path.exists() and args.stage > 0:
        raise ValueError(
            'model_{run_id}.pt does not exist, initial train first.'.format(
                run_id=run_id))
    if model_path.exists():
        state = torch.load(str(model_path))
        last_epoch = state['epoch']
        step = state['step']
        model.load_state_dict(state['model'])
        print('Restore model, epoch {}, step {:,}'.format(last_epoch, step))
    else:
        last_epoch = 1
        step = 0

    log_file = open('train_{run_id}.log'.format(run_id=run_id),
                    'at',
                    encoding='utf8')

    loss_fn = LossBinary(jaccard_weight=args.iou_weight)

    valid_losses = []

    print("Start training ...")
    for _ in range(last_epoch):
        scheduler.step()

    for epoch in range(last_epoch, last_epoch + extra_epoch):
        scheduler.step()
        model.train()
        random.seed()
        tq = tqdm(total=len(train_dataloader) * batch_size)
        tq.set_description('Run Id {}, Epoch {} of {}, lr {}'.format(
            run_id, epoch, last_epoch + extra_epoch,
            args.lr * (0.1**(epoch // args.decay_fr))))
        losses = []
        try:
            mean_loss = 0.
            for i, (inputs, targets) in enumerate(train_dataloader):
                inputs, targets = torch.tensor(inputs), torch.tensor(targets)
                if args.gpu and torch.cuda.is_available():
                    inputs = inputs.cuda()
                    targets = targets.cuda()

                outputs = model(inputs)
                loss = loss_fn(outputs, targets)
                loss.backward()
                optimizer.step()

                step += 1
                tq.update(batch_size)
                losses.append(loss.item())
                mean_loss = np.mean(losses[-args.log_fr:])
                tq.set_postfix(loss="{:.5f}".format(mean_loss))

                if i and (i % args.log_fr) == 0:
                    write_event(log_file, step, loss=mean_loss)
            write_event(log_file, step, loss=mean_loss)
            tq.close()
            save_model(model, epoch, step, model_path)

            valid_metrics = validation(args, model, loss_fn, val_dataloader)
            write_event(log_file, step, **valid_metrics)
            valid_loss = valid_metrics['valid_loss']
            valid_losses.append(valid_loss)

        except KeyboardInterrupt:
            tq.close()
            print('Ctrl+C, saving snapshot')
            save_model(model, epoch, step, model_path)
            print('Terminated.')
    print('Done.')
Beispiel #8
0
def predict(input_images: list = None,
            target_images: list = None,
            config_file: str = None,
            save_file_suffix=None):
    '''
    Compute output masked and its contours graphs given the "list" of input images filenames.
    Args:
       input_images (list[str]): list of input images filenames, if None then input filenames are given by argument list instead
       target_images (list[str]): list of target images mask filenames, if None then target filenames are given by argument list instead
       config_file  (list[str]): path to the configuation file that specify evaluation detail, 
            if None then config file path are given by argument list instead
    Returns:
        out_files (list[str]): list of output maksed filenames
        countors_outs_files (list[str]): list of countours of output maksed filenames
        dc_val_records (list[float]): list of dice coefficient of each target masked image and output(predicted) masked image
    '''
    logging.basicConfig(level=logging.INFO,
                        format='%(levelname)s: %(message)s')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.info(f'Using device {device}')

    # specify configuration file
    if config_file is None:
        parser = argparse.ArgumentParser()
        parser.add_argument("--config",
                            type=str,
                            help="Path to (.yml) config file.")
        parser.add_argument('--input_images',
                            '-i',
                            metavar='INPUT',
                            nargs='+',
                            help='filenames of input images')

        parser.add_argument('--target_images',
                            '-t',
                            metavar='INPUT',
                            nargs='+',
                            help='filenames of target mask images')
        configargs = parser.parse_args()
        config_file = configargs.config

    # Read config file.
    cfg = None
    with open(config_file, "r") as f:
        cfg_dict = yaml.load(f, Loader=yaml.FullLoader)
        print(cfg_dict)
    # set up network in/out channels details
    # n_channels=3 for RGB images
    # n_classes is the number of probabilities you want to get per pixel
    #   - For 1 class and background, use n_classes=1
    #   - For 2 classes, use n_classes=1
    #   - For N > 2 classes, use n_classes=N
    net = UNet(n_channels=1, n_classes=1, bilinear=True)
    #print(net)
    net.to(device=device)
    net.load_state_dict(
        torch.load(cfg_dict.get('model_weights', None), map_location=device))
    # In the case of ignoring parameters, input filenames are given by argument list instead
    if input_images is None:
        input_images = configargs.input_images
    if target_images is None:
        target_images = configargs.target_images
    logging.info("Model loaded !")
    out_files = get_output_filenames(in_files=input_images,
                                     output_dir=cfg_dict.get(
                                         'output_dir', None),
                                     suffix=save_file_suffix)
    countors_outs_files = []
    dc_val_records = []
    # start evaluating
    for i, (filename,
            target_filename) in enumerate(zip(input_images, target_images)):
        logging.info(
            f"\nPredicting image {filename}, Target image {target_filename}")

        img = Image.open(filename)
        target = Image.open(target_filename)

        mask, dc_val = predict_img(net=net,
                                   full_img=img,
                                   target_img=target,
                                   scale_factor=cfg_dict.get('scale', 1),
                                   out_threshold=cfg_dict.get(
                                       'mask_threshold', 0.5),
                                   device=device)
        if cfg_dict.get('save', True):
            out_filename = out_files[i]
            result, contours = mask_to_image(mask, fn=contours_fn)
            result.save(out_files[i])
            out_contour = out_files[i].replace(".jpg", "-contour.jpg")
            contours.save(out_contour)
            countors_outs_files.append(out_contour)
            # Record DC value for evaluation
            dc_val_records.append(dc_val)
            logging.info(
                f"\nMask saved to {out_files[i]}, Countour saved to {out_contour}"
            )
    return out_files, countors_outs_files, dc_val_records
Beispiel #9
0
    return Image.fromarray((mask * 255).astype(np.uint8))


if __name__ == "__main__":
    args = get_args()
    in_files = args.input
    out_files = get_output_filenames(args)

    net = UNet(n_channels=3, n_classes=1)

    logging.info("Loading model {}".format(args.model))

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.info(f'Using device {device}')
    net.to(device=device)
    net.load_state_dict(torch.load(args.model, map_location=device))

    logging.info("Model loaded !")

    for i, fn in enumerate(in_files):
        logging.info("\nPredicting image {} ...".format(fn))

        img = Image.open(fn)

        mask = predict_img(net=net,
                           full_img=img,
                           scale_factor=args.scale,
                           out_threshold=args.mask_threshold,
                           device=device)

        if not args.no_save:
Beispiel #10
0
resume_path = './checkpoint/best_' + 'unet' + '_model.pkl'  #fcn,unet
root_path = './image'
img_path = root_path + '.png'
mask_path = root_path + '_mask.png'

image = cv2.imread(img_path)
mask = cv2.imread(mask_path, 0)
img, mask = Compose([Scale(224)])(image.copy(), mask)
# image_gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
mask = mask.astype(np.uint8)
img, mask = transform(img, mask)
img, mask = torch.unsqueeze(img, 0), torch.unsqueeze(mask, 0)
# resume
if osp.isfile(resume_path):
    checkpoint = torch.load(resume_path)
    model.load_state_dict(checkpoint["model_state"])
    best_iou = checkpoint['best_iou']
    print(
        "=====>",
        "Loaded checkpoint '{}' (iter {})".format(resume_path,
                                                  checkpoint["epoch"]))
    print("=====> best mIoU: %.4f best mean dice: %.4f" %
          (best_iou, (best_iou * 2) / (best_iou + 1)))

else:
    raise ValueError("can't find model")

crf = True

with torch.no_grad():
    img, mask = img.to(device), mask.to(device)
Beispiel #11
0
        msg_p = "Test Class {} Precision".format(i)
        msg_f = "Test Class {} F1-score".format(i)

        print(msg_r + " " + str(r))
        print(msg_s + " " + str(s))
        print(msg_p + " " + str(p))
        print(msg_f + " " + str(f))

        if wandb_track:
            wandb.log({
                msg_r: r,
                msg_s: s,
                msg_p: p,
                msg_f: f,
            })


if __name__ == "__main__":
    n_classes = utils.params.n_classes
    n_channels = utils.params.n_channels
    net = UNet(n_channels=n_channels, n_classes=n_classes)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    net.to(device=device)
    try:
        os.mkdir(save_path)
    except:
        pass
    net.load_state_dict(
        torch.load(model_path + model_name + ".pth", map_location=device))
    predict(net, n_channels, n_classes=n_classes)
Beispiel #12
0

def mask_to_image(mask):
    return Image.fromarray((mask * 255).astype(np.uint8))


if __name__ == "__main__":
    in_files = [r'data/test/butterfly (78).jpg']
    out_files = ['predict2.png']

    net = UNet(n_channels=3, n_classes=1)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    net.to(device=device)
    net.load_state_dict(
        torch.load(r"checkpoints/CP_epoch200.pth", map_location=device))

    for i, file in enumerate(in_files):

        img = Image.open(file)

        # img = cv2.imread(fn)
        # img = cv2.resize(img,(112,112))
        # img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)

        mask = predict_img(net=net,
                           full_img=img,
                           out_threshold=0.5,
                           device=device)

        out_fn = out_files[i]
Beispiel #13
0
class BaseModel:
    losses = {'train': [], 'val': []}
    acces = {'train': [], 'val': []}
    scores = {'train': [], 'val': []}
    pred = {'train': [], 'val': []}
    true = {'train': [], 'val': []}

    def __init__(self, args):
        self.args = args
        self.net = None
        print(args.model_name)
        if args.model_name == 'UNet':
            self.net = UNet(args.in_channels, args.num_classes)
            self.net.apply(weights_init)
        elif args.model_name == 'UNetResNet34':
            self.net = UNetResNet34(args.num_classes, dropout_2d=0.2)
        elif args.model_name == 'UNetResNet152':
            self.net = UNetResNet152(args.num_classes, dropout_2d=0.2)
        elif args.model_name == 'UNet11':
            self.net = UNet11(args.num_classes, pretrained=True)
        elif args.model_name == 'UNetVGG16':
            self.net = UNetVGG16(args.num_classes,
                                 pretrained=True,
                                 dropout_2d=0.0,
                                 is_deconv=True)
        elif args.model_name == 'deeplab50_v2':
            if args.ms:
                raise NotImplemented
            else:
                self.net = deeplab50_v2(args.num_classes,
                                        pretrained=args.pretrained)
        elif args.model_name == 'deeplab_v2':
            if args.ms:
                self.net = ms_deeplab_v2(args.num_classes,
                                         pretrained=args.pretrained,
                                         scales=args.ms_scales)
            else:
                self.net = deeplab_v2(args.num_classes,
                                      pretrained=args.pretrained)
        elif args.model_name == 'deeplab_v3':
            if args.ms:
                self.net = ms_deeplab_v3(args.num_classes,
                                         out_stride=args.out_stride,
                                         pretrained=args.pretrained,
                                         scales=args.ms_scales)
            else:
                self.net = deeplab_v3(args.num_classes,
                                      out_stride=args.out_stride,
                                      pretrained=args.pretrained)
        elif args.model_name == 'deeplab_v3_plus':
            if args.ms:
                self.net = ms_deeplab_v3_plus(args.num_classes,
                                              out_stride=args.out_stride,
                                              pretrained=args.pretrained,
                                              scales=args.ms_scales)
            else:
                self.net = deeplab_v3_plus(args.num_classes,
                                           out_stride=args.out_stride,
                                           pretrained=args.pretrained)

        self.interp = nn.Upsample(size=args.size, mode='bilinear')

        self.iterations = args.epochs
        self.lr_current = args.lr
        self.cuda = args.cuda
        self.phase = args.phase
        self.lr_policy = args.lr_policy
        self.cyclic_m = args.cyclic_m
        if self.lr_policy == 'cyclic':
            print('using cyclic')
            assert self.iterations % self.cyclic_m == 0
        if args.loss == 'CELoss':
            self.criterion = nn.CrossEntropyLoss(size_average=True)
        elif args.loss == 'DiceLoss':
            self.criterion = DiceLoss(num_classes=args.num_classes)
        elif args.loss == 'MixLoss':
            self.criterion = MixLoss(args.num_classes,
                                     weights=args.loss_weights)
        elif args.loss == 'LovaszLoss':
            self.criterion = LovaszSoftmax(per_image=args.loss_per_img)
        elif args.loss == 'FocalLoss':
            self.criterion = FocalLoss(args.num_classes, alpha=None, gamma=2)
        else:
            raise RuntimeError('must define loss')

        if 'deeplab' in args.model_name:
            self.optimizer = optim.SGD(
                [{
                    'params': get_1x_lr_params_NOscale(self.net),
                    'lr': args.lr
                }, {
                    'params': get_10x_lr_params(self.net),
                    'lr': 10 * args.lr
                }],
                lr=args.lr,
                momentum=args.momentum,
                weight_decay=args.weight_decay)
        else:
            self.optimizer = optim.SGD(filter(lambda p: p.requires_grad,
                                              self.net.parameters()),
                                       lr=args.lr,
                                       momentum=args.momentum,
                                       weight_decay=args.weight_decay)
        self.iters = 0
        self.best_val = 0.0
        self.count = 0

    def init_model(self):
        if self.args.resume_model:
            saved_state_dict = torch.load(
                self.args.resume_model,
                map_location=lambda storage, loc: storage)
            if self.args.ms:
                new_params = self.net.Scale.state_dict().copy()
                for i in saved_state_dict:
                    # Scale.layer5.conv2d_list.3.weight
                    i_parts = i.split('.')
                    # print i_parts
                    if not (not i_parts[0] == 'layer5') and (not i_parts[0]
                                                             == 'decoder'):
                        new_params[i] = saved_state_dict[i]
                self.net.Scale.load_state_dict(new_params)
            else:
                new_params = self.net.state_dict().copy()
                for i in saved_state_dict:
                    # Scale.layer5.conv2d_list.3.weight
                    i_parts = i.split('.')
                    # print i_parts
                    if (not i_parts[0] == 'layer5') and (not i_parts[0]
                                                         == 'decoder'):
                        # if not i_parts[0] == 'layer5':
                        new_params[i] = saved_state_dict[i]
                self.net.load_state_dict(new_params)

            print('Resuming training, image net loading {}...'.format(
                self.args.resume_model))
            # self.load_weights(self.net, self.args.resume_model)

        if self.args.mGPUs:
            self.net = nn.DataParallel(self.net)

        if self.args.cuda:
            self.net = self.net.cuda()
            cudnn.benchmark = True

    def _adjust_learning_rate(self, epoch):
        """Sets the learning rate to the initial LR decayed by 10 at every specified step
        # Adapted from PyTorch Imagenet example:
        # https://github.com/pytorch/examples/blob/master/imagenet/main.py
        """
        if epoch < int(self.iterations * 0.5):
            self.lr_current = max(self.lr_current * self.args.gamma, 1e-4)
        elif epoch < int(self.iterations * 0.85):
            self.lr_current = max(self.lr_current * self.args.gamma, 1e-5)
        else:
            self.lr_current = max(self.lr_current * self.args.gamma, 1e-6)
        self.optimizer.param_groups[0]['lr'] = self.lr_current
        self.optimizer.param_groups[1]['lr'] = self.lr_current * 10

    def save_network(self, net, net_name, epoch, label=''):
        save_fname = '%s_%s_%s.pth' % (epoch, net_name, label)
        save_path = os.path.join(self.args.save_folder, self.args.exp_name,
                                 save_fname)
        torch.save(net.state_dict(), save_path)

    def load_weights(self, net, base_file):
        other, ext = os.path.splitext(base_file)
        if ext == '.pkl' or '.pth':
            print('Loading weights into state dict...')
            net.load_state_dict(
                torch.load(base_file,
                           map_location=lambda storage, loc: storage))
            print('Finished!')
        else:
            print('Sorry only .pth and .pkl files supported.')

    def load_trained_model(self):
        path = os.path.join(self.args.save_folder, self.args.exp_name,
                            self.args.trained_model)
        print('eval cls, image net loading {}...'.format(path))
        if self.args.ms:
            self.load_weights(self.net.Scale, path)
        else:
            self.load_weights(self.net, path)

    def eval(self, dataloader):
        assert self.phase == 'test', "Command arg phase should be 'test'. "
        from tqdm import tqdm
        self.net.eval()
        output = []

        for i, image in tqdm(enumerate(dataloader)):
            if self.cuda:
                image = Variable(image.cuda(), volatile=True)
            else:
                image = Variable(image, volatile=True)

            # cls forward
            out = self.net(image)
            if isinstance(out, list):
                out_max = out[-1]
                if out_max.size(2) != image.size(2):
                    out = self.interp(out_max)
            else:
                if out.size(2) != image.size(2):
                    out = self.interp(out)
            # out [bs * num_tta, c, h, w]
            if self.args.use_tta:
                num_tta = len(tta_config)
                # out = F.softmax(out, dim=1)
                out = detta_score(
                    out.view(num_tta, -1, self.args.num_classes, out.size(2),
                             out.size(3)))  # [num_tta, bs, nclass, H, W]
                out = out.mean(dim=0)  # [bs, nclass, H, W]
            out = F.softmax(out)
            output.extend([
                resize(pred[1].data.cpu().numpy(), (101, 101)) for pred in out
            ])
        return np.array(output)

    def tta(self, dataloaders):
        results = np.zeros(shape=(len(dataloaders[0].dataset),
                                  self.args.num_classes))
        for dataloader in dataloaders:
            output = self.eval(dataloader)
            results += output
        return np.argmax(results, 1)

    def tta_output(self, dataloaders):
        results = np.zeros(shape=(len(dataloaders[0].dataset),
                                  self.args.num_classes))
        for dataloader in dataloaders:
            output = self.eval(dataloader)
            results += output
        return results

    def test_val(self, dataloader):
        assert self.phase == 'test', "Command arg phase should be 'test'. "
        from tqdm import tqdm
        self.net.eval()
        predict = []
        true = []
        t1 = time.time()

        for i, (image, mask) in tqdm(enumerate(dataloader)):
            if self.cuda:
                image = Variable(image.cuda(), volatile=True)
                label_image = Variable(mask.cuda(), volatile=True)
            else:
                image = Variable(image, volatile=True)
                label_image = Variable(mask, volatile=True)

            # cls forward
            out = self.net(image)
            if isinstance(out, list):
                out_max = out[-1]
                if out_max.size(2) != label_image.size(2):
                    out = self.interp(out_max)
            else:
                if out.size(2) != image.size(2):
                    out = self.interp(out)
            # out [bs * num_tta, c, h, w]
            if self.args.use_tta:
                num_tta = len(tta_config)
                # out = F.softmax(out, dim=1)
                out = detta_score(
                    out.view(num_tta, -1, self.args.num_classes, out.size(2),
                             out.size(3)))  # [num_tta, bs, nclass, H, W]
                out = out.mean(dim=0)  # [bs, nclass, H, W]
            out = F.softmax(out)
            if self.args.aug == 'heng':
                out = out[:, :, 11:11 + 202, 11:11 + 202]
            predict.extend([
                resize(pred[1].data.cpu().numpy(), (101, 101)) for pred in out
            ])
            # predict.extend([pred[1, :101, :101].data.cpu().numpy() for pred in out])
            # pred.extend(out.data.cpu().numpy())
            true.extend(label_image.data.cpu().numpy())
        # pred_all = np.argmax(np.array(pred), 1)
        for t in np.arange(0.05, 0.51, 0.01):
            pred_all = np.array(predict) > t
            true_all = np.array(true).astype(np.int)
            # new_iou = intersection_over_union(true_all, pred_all)
            # new_iou_t = intersection_over_union_thresholds(true_all, pred_all)
            mean_iou, iou_t = mIoU(true_all, pred_all)
            print('threshold : {:.4f}'.format(t))
            print('mean IoU : {:.4f}, IoU threshold : {:.4f}'.format(
                mean_iou, iou_t))

        return predict, true

    def run_epoch(self, dataloader, writer, epoch, train=True, metrics=True):
        if train:
            self.net.train()
            flag = 'train'
        else:
            self.net.eval()
            flag = 'val'
        t2 = time.time()
        for image, mask in dataloader:
            if train and self.lr_policy != 'step':
                adjust_learning_rate(self.args.lr, self.optimizer, self.iters,
                                     self.iterations * len(dataloader), 0.9,
                                     self.cyclic_m, self.lr_policy)
                self.iters += 1

            if self.cuda:
                image = Variable(image.cuda(), volatile=(not train))
                label_image = Variable(mask.cuda(), volatile=(not train))
            else:
                image = Variable(image, volatile=(not train))
                label_image = Variable(mask, volatile=(not train))
            # cls forward
            out = self.net(image)

            if isinstance(out, list):
                out_max = None
                loss = 0.0
                for i, out_scale in enumerate(out):
                    if out_scale.size(2) != label_image.size(2):
                        out_scale = self.interp(out_scale)
                    if i == (len(out) - 1):
                        out_max = out_scale
                    loss += self.criterion(out_scale, label_image)
                label_image_np = label_image.data.cpu().numpy()
                sig_out_np = out_max.data.cpu().numpy()
                acc = accuracy(label_image_np, np.argmax(sig_out_np, 1))

                self.pred[flag].extend(sig_out_np)
                self.true[flag].extend(label_image_np)

                self.losses[flag].append(loss.data[0])
                self.acces[flag].append(acc)

            else:
                if out.size(-1) != label_image.size(-1):
                    out = self.interp(out)

                loss = self.criterion(out, label_image)
                label_image_np = label_image.data.cpu().numpy()
                sig_out_np = out.data.cpu().numpy()
                acc = accuracy(label_image_np, np.argmax(sig_out_np, 1))

                self.pred[flag].extend(sig_out_np)
                self.true[flag].extend(label_image_np)

                self.losses[flag].append(loss.data[0])
                self.acces[flag].append(acc)

            if train:
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

        if metrics:
            n = len(self.losses[flag])
            loss = sum(self.losses[flag]) / n
            scalars = [
                loss,
            ]
            names = [
                'loss',
            ]
            write_scalars(writer, scalars, names, epoch, tag=flag + '_loss')

            all_acc = sum(self.acces[flag]) / n
            scalars = [
                all_acc,
            ]
            names = [
                'all_acc',
            ]
            write_scalars(writer, scalars, names, epoch, tag=flag + '_acc')

            # all_score = sum(self.scores[flag]) / n
            # scalars = [all_score, ]
            # names = ['all_score', ]
            # write_scalars(writer, scalars, names, epoch, tag=flag + '_score')

            pred_all = np.argmax(np.array(self.pred[flag]), 1)
            true_all = np.array(self.true[flag]).astype(np.int)
            mean_iou, iou_t = mIoU(true_all, pred_all)

            # new_iou = intersection_over_union(true_all, pred_all)
            # new_iou_t = intersection_over_union_thresholds(true_all, pred_all)

            scalars = [
                mean_iou,
                iou_t,
            ]
            names = [
                'mIoU',
                'mIoU_threshold',
            ]
            write_scalars(writer, scalars, names, epoch, tag=flag + '_IoU')

            scalars = [
                self.optimizer.param_groups[0]['lr'],
            ]
            names = [
                'learning_rate',
            ]
            write_scalars(writer, scalars, names, epoch, tag=flag + '_lr')

            print(
                '{} loss: {:.4f} | acc: {:.4f} | mIoU: {:.4f} | mIoU_threshold: {:.4f} |  n_iter: {} |  learning_rate: {} | time: {:.2f}'
                .format(flag, loss, all_acc, mean_iou, iou_t, epoch,
                        self.optimizer.param_groups[0]['lr'],
                        time.time() - t2))

            self.losses[flag] = []
            self.pred[flag] = []
            self.true[flag] = []
            self.acces[flag] = []
            self.scores[flag] = []

            if (not train) and (iou_t >= self.best_val):
                if self.args.ms:
                    if self.args.mGPUs:
                        self.save_network(self.net.module.Scale,
                                          self.args.model_name,
                                          epoch=epoch,
                                          label='best')
                    else:
                        self.save_network(self.net.Scale,
                                          self.args.model_name,
                                          epoch=epoch,
                                          label='best')
                else:
                    if self.args.mGPUs:
                        self.save_network(self.net.module,
                                          self.args.model_name,
                                          epoch=epoch,
                                          label='best')
                    else:
                        self.save_network(self.net,
                                          self.args.model_name,
                                          epoch=epoch,
                                          label='best')
                print(
                    'val improve from {:.4f} to {:.4f} saving in best val_iteration {}'
                    .format(self.best_val, iou_t, epoch))
                self.best_val = iou_t
                self.count = 0

            if (not train) and (self.best_val - iou_t > 0.003) and (
                    self.count < 10) and (self.lr_policy == 'step'):
                self.count += 1
            if (not train) and (self.count >= 10) and (self.lr_policy
                                                       == 'step'):
                self._adjust_learning_rate(epoch)
                self.count = 0

    def train_val(self, dataloader_train, dataloader_val, writer):
        val_epoch = 0
        for epoch in range(self.iterations):
            if (self.lr_policy == 'cyclic') and (
                    epoch % int(self.iterations / self.cyclic_m) == 0):
                print('-------start cycle {}------------'.format(
                    epoch // int(self.iterations / self.cyclic_m)))
                self.best_val = 0.0
            self.run_epoch(dataloader_train,
                           writer,
                           epoch,
                           train=True,
                           metrics=True)
            self.run_epoch(dataloader_val,
                           writer,
                           val_epoch,
                           train=False,
                           metrics=True)
            val_epoch += 1
            if (epoch + 1) % self.args.save_freq == 0:
                if self.args.ms:
                    if self.args.mGPUs:
                        self.save_network(
                            self.net.module.Scale,
                            self.args.model_name,
                            epoch=val_epoch,
                        )
                    else:
                        self.save_network(
                            self.net.Scale,
                            self.args.model_name,
                            epoch=val_epoch,
                        )
                else:
                    if self.args.mGPUs:
                        self.save_network(
                            self.net.module,
                            self.args.model_name,
                            epoch=val_epoch,
                        )
                    else:
                        self.save_network(
                            self.net,
                            self.args.model_name,
                            epoch=val_epoch,
                        )
                print('saving in val_iteration {}'.format(val_epoch))
Beispiel #14
0
from model.train import train
from model.unet import UNet
from model.dataloader import testloader
from model.evaluate import evaluate
import torch

COLAB = True
BATCH_SIZE = 1
PATH = 'unet_augment.pt'
# PATH = '../drive/My Drive/Colab Notebooks/im2height.pt'

test_loader = testloader(colab=COLAB, batch_size=BATCH_SIZE)

net = UNet()
net.load_state_dict(torch.load(PATH))
if torch.cuda.is_available():
    net.cuda()
criterion = torch.nn.L1Loss()
evaluate(net, test_loader, criterion=criterion)
Beispiel #15
0
import utils.metrics as m
from model.unet import UNet
from utils.preprocessing_utils import nii2labels, nii2slices
from utils.surface import Surface
from utils.test_utils import (draw_contours, draw_many_slices, imwrite,
                              remove_fragment)

if __name__ == '__main__':
    LITS_data_path = 'LITS/'
    model_path = 'checkpoints/liver_segmentation_U-Net_on_LITS_datasetiter_300000.pth'

    prediction_path = 'results/'

    device = torch.device('cuda:0')
    model = UNet(1, 2).to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()

    sm = nn.Softmax(dim=1)

    idx_list = []
    dice_list = []
    iou_list = []
    voe_list = []
    rvd_list = []
    assd_list = []
    msd_list = []

    for i in range(31):
        print(i)
        idx_list.append(i)