Exemplo n.º 1
0
def main(img_id):
    model = UNet(n_channels=1, n_classes=1)
    #model_file = max(glob.glob('models/*'), key=os.path.getctime) #detect latest version
    model_file = 'models/intensity_filtering_continued/Checkpoint_e1_d0.9755_l0.0008_2018-11-13_11:06:28.pth'  #best one!!
    model.load_state_dict(torch.load(model_file))
    model = model.double()
    img_path = 'data/testing/slices/img/'
    img_vol_path = 'data/testing/img/'  #this is for getting an accurate header
    data_test = [img_path, img_id]
    test_loader = torch.utils.data.DataLoader(img_loader(data_test))
    hdr = nib.load(img_path + img_id).header
    vol_hdr = nib.load(img_vol_path + img_id[0:7] + '.nii.gz').header
    hdr['pixdim'] = vol_hdr[
        'pixdim']  #explicitly set this to force it to keep the correct pixel dimensions
    prediction = test(model, test_loader).numpy()
    prediction = np.reshape(prediction, (256, 256))
    prediction = upsample(prediction, 2)
    save(prediction, 'data/testing/slices/pred/' + img_id, hdr)
Exemplo n.º 2
0
def main():
    best_dice = np.array([0.0])

    parser = argparse.ArgumentParser(
        description='PyTorch Mutliclass Classification')
    parser.add_argument('--batch-size',
                        type=int,
                        default=1,
                        metavar='N',
                        help='input batch size for training (default: 1)')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=1,
                        metavar='N',
                        help='input batch size for testing (default: 1)')
    parser.add_argument('--epochs',
                        type=int,
                        default=10,
                        metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.01,
                        metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--beta1',
                        type=float,
                        default=0.9,
                        help='Beta1 for Adam (default: 0.9)')
    parser.add_argument('--beta2',
                        type=float,
                        default=0.999,
                        help='Beta2 for Adam (default: 0.999)')
    parser.add_argument('--eps',
                        type=float,
                        default=1e-8,
                        help='Epsilon for Adam (default: 1e-8)')
    #parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
    #                    help='SGD momentum (default: 0.5)')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=10,
        metavar='N',
        help='how many batches to wait before logging training status')
    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()
    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    spleen_probs = []
    with open('spleen_probs.txt', 'r') as f:
        spleen_probs = f.read()
        spleen_probs = spleen_probs.replace('[', '')
        spleen_probs = spleen_probs.replace(']', '')
        spleen_probs = spleen_probs.split(',')
        spleen_probs = list(map(float, spleen_probs))

    num_total_samples = 3779
    training_img_folder = ['data/training/slices/img'] * num_total_samples
    training_label_folder = ['data/training/slices/label'] * num_total_samples
    training_img_files = sorted(os.listdir('data/training/slices/img'))
    training_label_files = sorted(os.listdir('data/training/slices/label'))

    indices = list(range(num_total_samples))
    num_testing_samples = round(.2 * num_total_samples)
    testing_indices = list(
        np.random.choice(indices, size=num_testing_samples, replace=False))
    training_indices = list(set(indices) - set(testing_indices))
    training_sampler = SubsetRandomSampler(training_indices)
    testing_sampler = SubsetRandomSampler(testing_indices)
    #training_sampler = WeightedRandomSampler([spleen_probs[i] for i in training_indices], num_total_samples-num_testing_samples)

    training_data = [
        training_img_folder, training_label_folder, training_img_files,
        training_label_files
    ]
    train_loader = torch.utils.data.DataLoader(img_loader(training_data),
                                               batch_size=args.batch_size,
                                               sampler=training_sampler)
    test_loader = torch.utils.data.DataLoader(img_loader(training_data),
                                              batch_size=args.test_batch_size,
                                              sampler=testing_sampler)

    model = UNet(n_channels=1, n_classes=1).to(device)
    model.double()
    model = model.cuda()
    #    model.load_state_dict(torch.load('models/intensity_filtering_continued/Checkpoint_e1_d0.9755_l0.0008_2018-11-13_11:06:28.pth'))
    optimizer = optim.Adam(model.parameters(),
                           lr=args.lr,
                           betas=(args.beta1, args.beta2),
                           eps=args.eps)

    for epoch in range(1, args.epochs + 1):
        train(args, model, device, train_loader, optimizer, epoch)
        test(args, model, device, test_loader, best_dice, epoch)

    current_daytime = str(datetime.datetime.now()).replace(" ", "_")[:-7]
    model_file = 'models/UNetModel_' + current_daytime + '.pth'
    torch.save(model.state_dict(), model_file)
    loss_file = 'loss_outputs/loss_' + current_daytime
    with open(loss_file, 'w', newline='') as csvfile:
        losswriter = csv.writer(csvfile,
                                dialect='excel',
                                delimiter=' ',
                                quotechar='|',
                                quoting=csv.QUOTE_MINIMAL)
        losswriter.writerow('Batch Size')
        losswriter.writerow(str(args.batch_size))
        losswriter.writerow('Test Batch Size')
        losswriter.writerow(str(args.test_batch_size))
        losswriter.writerow('Epochs')
        losswriter.writerow(str(args.epochs))
        losswriter.writerow('Learning Rate')
        losswriter.writerow(str(args.lr))
        losswriter.writerow('Beta 1')
        losswriter.writerow(str(args.beta1))
        losswriter.writerow('Beta 2')
        losswriter.writerow(str(args.beta2))
        losswriter.writerow('Epsilon')
        losswriter.writerow(str(args.eps))

        losswriter.writerow('DICE')
        for item in dice_loss:
            losswriter.writerow(str(round(float(item), 4)))

        losswriter.writerow('training')
        for item in training_loss:
            losswriter.writerow(str(round(item, 4)))

        losswriter.writerow('testing')
        for item in test_loss:
            losswriter.writerow(str(round(item, 4)))

    end_time = time.time()
    print('\nElapsed Time: {:.02f} seconds\n'.format(end_time - start_time))