Esempio n. 1
0
def main(args):
    # import network architecture
    builder = ModelBuilder()
    model = builder.build_net(arch=args.id,
                              num_input=args.num_input,
                              num_classes=args.num_classes,
                              num_branches=args.num_branches,
                              padding_list=args.padding_list,
                              dilation_list=args.dilation_list)
    model = torch.nn.DataParallel(model, device_ids=list(range(
        args.num_gpus))).cuda()
    cudnn.benchmark = True

    if args.resume:
        if os.path.isfile(args.resume):
            print("=> Loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            state_dict = checkpoint['state_dict']
            model.load_state_dict(state_dict)
            print("=> Loaded checkpoint (epoch {})".format(
                checkpoint['epoch']))
        else:
            raise Exception("=> No checkpoint found at '{}'".format(
                args.resume))

    tf = ValDataset(test_dir, args)
    test_loader = DataLoader(tf,
                             batch_size=args.batch_size,
                             shuffle=args.shuffle,
                             num_workers=args.num_workers,
                             pin_memory=False)
    test(test_loader, model, args)
Esempio n. 2
0
def main(args):
    # import network architecture
    builder = ModelBuilder()
    model = builder.build_net(
            arch=args.id, 
            num_input=args.num_input + 1, 
            num_classes=args.num_classes, 
            num_branches=args.num_branches,
            padding_list=args.padding_list, 
            dilation_list=args.dilation_list)
    
    device_ids = [0,2]
    model = torch.nn.DataParallel(model, device_ids=device_ids).cuda()
    cudnn.benchmark = True
    
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> Loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            state_dict = checkpoint['state_dict']
            model.load_state_dict(state_dict)           
            print("=> Loaded checkpoint (epoch {})".format(checkpoint['epoch']))
        else:
            raise Exception("=> No checkpoint found at '{}'".format(args.resume))         
    
    # initialization      
    num_ignore = 0
    margin = [args.crop_size[k] - args.center_size[k] for k in range(3)]
    num_images = int(len(test_dir)/args.num_input)
    dice_score = np.zeros([num_images, 3]).astype(float)

    for i in tqdm(range(num_images)):
        print("image id: %d" % i)
        # load the images and mask
        im = []
        for j in range(args.num_input):
            direct, _ = test_dir[args.num_input * i + j].split("\n")
#             direct = direct + ".gz"
            name = direct            
            image = nib.load(args.root_path + direct + ".gz").get_data()
            image = np.expand_dims(image, axis=0)
            im.append(image)
            if j == 0:
                mask = nib.load(args.root_path + direct[:-15] + ".nii" + "/mask.nii.gz").get_data()
                   
        images = np.concatenate(im, axis=0).astype(float)

        # divide the input images input small image segments
        # return the padding input images which can be divided exactly
        image_pad, mask_pad, num_segments, padding_index, index = segment(images, mask, args)

        # initialize prediction for the whole image as background
        mask_shape = list(mask.shape)
        mask_shape.append(args.num_classes)
        pred = np.zeros(mask_shape)
        pred[:,:,:,0] = 1
            
        # initialize the prediction for a small segmentation as background
        pad_shape = [int(num_segments[k] * args.center_size[k]) for k in range(3)]
        pad_shape.append(args.num_classes)
        pred_pad = np.zeros(pad_shape)  
        pred_pad[:,:,:,0] = 1 

        
        # iterate over the z dimension
        for idz in range(num_segments[2]):
            tf = TestDataset(image_pad, mask_pad, num_segments, idz, args)
            test_loader = DataLoader(tf, batch_size=args.batch_size, shuffle=args.shuffle, num_workers=args.num_workers, pin_memory=False)
            pred_seg = test(test_loader, model, num_segments, args)
            pred_pad[:, :, idz*args.center_size[2]:(idz+1)*args.center_size[2], :] = pred_seg        
           
                
        # decide the start and end point in the original image
        for k in range(3):
            if index[0][k] == 0:
                index[0][k] = int(margin[k]/2 - padding_index[0][k])
            else:
                index[0][k] = int(margin[k]/2 + index[0][k])

            index[1][k] = int(min(index[0][k] + num_segments[k] * args.center_size[k], mask.shape[k]))

        dist = [index[1][k] - index[0][k] for k in range(3)]
        pred[index[0][0]:index[1][0], index[0][1]:index[1][1], index[0][2]:index[1][2]] = pred_pad[:dist[0], :dist[1], :dist[2]]
            
        if args.visualize:
            vis = np.argmax(pred, axis=3)
            vis = np.swapaxes(vis, 0, 2).astype(dtype=np.uint8)
            visualize_result(name, vis, args)
           
    print('Evalution Done!')
Esempio n. 3
0
def main(args):
    # import network architecture
    builder = ModelBuilder()
    model = builder.build_net(arch=args.id,
                              num_input=args.num_input,
                              num_classes=args.num_classes,
                              num_branches=args.num_branches,
                              padding_list=args.padding_list,
                              dilation_list=args.dilation_list)
    device_ids = [0, 2]
    model = torch.nn.DataParallel(model, device_ids=device_ids).cuda()
    #     model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpus))).cuda()
    cudnn.benchmark = True

    if args.resume:
        if os.path.isfile(args.resume):
            print("=> Loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            state_dict = checkpoint['state_dict']
            model.load_state_dict(state_dict)
            print("=> Loaded checkpoint (epoch {})".format(
                checkpoint['epoch']))
        else:
            raise Exception("=> No checkpoint found at '{}'".format(
                args.resume))

    # initialization
    num_ignore = 0
    margin = [args.crop_size[k] - args.center_size[k] for k in range(3)]
    num_images = int(len(test_dir) / args.num_input)
    dice_score = np.zeros([num_images, 3]).astype(float)

    for i in range(num_images):
        # load the images, label and mask
        im = []
        for j in range(args.num_input):
            direct, _ = test_dir[args.num_input * i + j].split("\n")
            name = direct
            if j < args.num_input - 1:
                image = nib.load(args.root_path + direct + '.gz').get_data()
                image = np.expand_dims(image, axis=0)
                im.append(image)
                if j == 0:
                    mask = nib.load(args.root_path + direct +
                                    "/mask.nii.gz").get_data()
            else:
                labels = nib.load(args.root_path + direct + '.gz').get_data()

        images = np.concatenate(im, axis=0).astype(float)

        # divide the input images input small image segments
        # return the padding input images which can be divided exactly
        image_pad, mask_pad, label_pad, num_segments, padding_index, index = segment(
            images, mask, labels, args)

        # initialize prediction for the whole image as background
        labels_shape = list(labels.shape)
        labels_shape.append(args.num_classes)
        pred = np.zeros(labels_shape)
        pred[:, :, :, 0] = 1

        # initialize the prediction for a small segmentation as background
        pad_shape = [
            int(num_segments[k] * args.center_size[k]) for k in range(3)
        ]
        pad_shape.append(args.num_classes)
        pred_pad = np.zeros(pad_shape)
        pred_pad[:, :, :, 0] = 1

        # score_per_image stores the sum of each image
        score_per_image = np.zeros([3, 3])
        # iterate over the z dimension
        for idz in range(num_segments[2]):
            tf = ValDataset(image_pad, label_pad, mask_pad, num_segments, idz,
                            args)
            test_loader = DataLoader(tf,
                                     batch_size=args.batch_size,
                                     shuffle=args.shuffle,
                                     num_workers=args.num_workers,
                                     pin_memory=False)
            score_seg, pred_seg = test(test_loader, model, num_segments, args)
            pred_pad[:, :, idz * args.center_size[2]:(idz + 1) *
                     args.center_size[2], :] = pred_seg
            score_per_image += score_seg

        # decide the start and end point in the original image
        for k in range(3):
            if index[0][k] == 0:
                index[0][k] = int(margin[k] / 2 - padding_index[0][k])
            else:
                index[0][k] = int(margin[k] / 2 + index[0][k])

            index[1][k] = int(
                min(index[0][k] + num_segments[k] * args.center_size[k],
                    labels.shape[k]))

        dist = [index[1][k] - index[0][k] for k in range(3)]
        pred[index[0][0]:index[1][0], index[0][1]:index[1][1],
             index[0][2]:index[1][2]] = pred_pad[:dist[0], :dist[1], :dist[2]]

        if np.sum(score_per_image[0, :]) == 0 or np.sum(
                score_per_image[1, :]) == 0 or np.sum(
                    score_per_image[2, :]) == 0:
            num_ignore += 1
            continue
        # compute the Enhance, Core and Whole dice score
        dice_score_per = [
            2 * np.sum(score_per_image[k, 2]) /
            (np.sum(score_per_image[k, 0]) + np.sum(score_per_image[k, 1]))
            for k in range(3)
        ]
        print(
            'Image: %d, Enhance score: %.4f, Core score: %.4f, Whole score: %.4f'
            % (i, dice_score_per[0], dice_score_per[1], dice_score_per[2]))

        dice_score[i, :] = dice_score_per

        if args.visualize:
            vis = np.argmax(pred, axis=3)
            vis = np.swapaxes(vis, 0, 2).astype(dtype=np.uint8)
            visualize_result(name, vis, args)

    count_image = num_images - num_ignore
    dice_score = dice_score[:count_image, :]
    mean_dice = np.mean(dice_score, axis=0)
    std_dice = np.std(dice_score, axis=0)
    print('Evalution Done!')
    print(
        'Enhance score: %.4f, Core score: %.4f, Whole score: %.4f, Mean Dice score: %.4f'
        % (mean_dice[0], mean_dice[1], mean_dice[2], np.mean(mean_dice)))
    print(
        'Enhance std: %.4f, Core std: %.4f, Whole std: %.4f, Mean Std: %.4f' %
        (std_dice[0], std_dice[1], std_dice[2], np.mean(std_dice)))
Esempio n. 4
0
def main(args):
    # import network architecture
    builder = ModelBuilder()
    model = builder.build_net(arch=args.id,
                              num_input=args.num_input,
                              num_classes=args.num_classes,
                              num_branches=args.num_branches,
                              padding_list=args.padding_list,
                              dilation_list=args.dilation_list)
    model = torch.nn.DataParallel(model, device_ids=list(range(
        args.num_gpus))).cuda()
    cudnn.benchmark = True

    # collect the number of parameters in the network
    print("------------------------------------------")
    print("Network Architecture of Model %s:" % (args.id))
    num_para = 0
    for name, param in model.named_parameters():
        num_mul = 1
        for x in param.size():
            num_mul *= x
        num_para += num_mul

    print(model)
    print("Number of trainable parameters %d in Model %s" %
          (num_para, args.id))
    print("------------------------------------------")

    # set the optimizer and loss
    optimizer = optim.RMSprop(model.parameters(),
                              args.lr,
                              alpha=args.alpha,
                              eps=args.eps,
                              weight_decay=args.weight_decay,
                              momentum=args.momentum)
    criterion = nn.CrossEntropyLoss()

    if args.resume:
        if os.path.isfile(args.resume):
            print("=> Loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['opt_dict'])
            print("=> Loaded checkpoint (epoch {})".format(
                checkpoint['epoch']))
        else:
            print("=> No checkpoint found at '{}'".format(args.resume))

    # loading data
    tf = TrainDataset(train_dir, args)
    train_loader = DataLoader(tf,
                              batch_size=args.batch_size,
                              shuffle=args.shuffle,
                              num_workers=args.num_workers,
                              pin_memory=True)

    print("Start training ...")
    for epoch in range(args.start_epoch + 1, args.num_epochs + 1):
        train(train_loader, model, criterion, optimizer, epoch, args)

        # save models
        if epoch > args.particular_epoch:
            if epoch % args.save_epochs_steps == 0:
                save_checkpoint(
                    {
                        'epoch': epoch,
                        'state_dict': model.state_dict(),
                        'opt_dict': optimizer.state_dict()
                    }, epoch, args)

    print("Training Done")