Пример #1
0
def build_model(args):
    if args.residual:
        model = ResidualUNet3D(f_maps=8).to(args.device)
    else:
        model = UNet3D().to(args.device)

    return model
Пример #2
0
 def initGraph(self, node):
     ## Please feel free to add an example for a simple usage in /home/trs/sofa/build/unstable//home/trs/sofa/src/sofa/applications/plugins/SofaPython/scn2python.py
     if use_network:
         self.net = UNet3D(in_channels=self.in_channels, out_channels=self.out_channels)
         # Load previous model if requested
         if self.network_path.exists():
             state = torch.load(str(self.network_path))
             self.net.load_state_dict(state['model'])
             self.net = self.net.to(self.device)
             print('Restored model')
         else:
             print('Failed to restore model')
             exit()
         self.net.eval()
     return 0
Пример #3
0
def get_model(config):
    model_name = config['training']['model_name']
    device = config['training']['device']
    if model_name == 'UNet':
        model = UNet3D(in_channels=1,
                       out_channels=2,
                       layer_order='crg',
                       f_maps=32,
                       num_groups=8,
                       final_sigmoid=False,
                       device=device)
    elif model_name == "UNetAtt":
        model = UNet3D_attention(in_channels=1,
                                 out_channels=2,
                                 layer_order='crg',
                                 f_maps=32,
                                 num_groups=8,
                                 final_sigmoid=False,
                                 device=device)
    else:  # model_name == VNet
        model = VNet()
    return model
Пример #4
0
if __name__ == '__main__':
    # Get input arguments ----------------------------------#
    T1_input_path = sys.argv[1]
    b0_input_path = sys.argv[2]
    b0_output_path = sys.argv[3]
    model_path = sys.argv[4]

    print('T1 input path: ' + T1_input_path)
    print('b0 input path: ' + b0_input_path)
    print('b0 output path: ' + b0_output_path)
    print('Model path: ' + model_path)

    # Run code ---------------------------------------------#

    # Get device
    device = torch.device("cuda")

    # Get model
    model = UNet3D(2, 1).to(device)
    model.load_state_dict(torch.load(model_path))

    # Inference
    img_model = inference(T1_input_path, b0_input_path, model, device)

    # Save
    nii_template = nib.load(b0_input_path)
    nii = nib.Nifti1Image(util.torch2nii(img_model.detach().cpu()),
                          nii_template.affine, nii_template.header)
    nib.save(nii, b0_output_path)
Пример #5
0
    os.makedirs(FLAGS.log_dir)

start_time = datetime.datetime.now()
print(start_time)

if FLAGS.run_seg:
    run_config = tf.ConfigProto()
    with tf.Session(config=run_config) as sess:
        #with tf.Session() as sess:

        unet = UNet3D(sess,
                      checkpoint_dir=FLAGS.checkpoint_dir,
                      log_dir=FLAGS.log_dir,
                      training_paths=training_paths,
                      testing_paths=testing_paths,
                      batch_size=FLAGS.batch_size,
                      layers=FLAGS.layers,
                      features_root=FLAGS.seg_features_root,
                      conv_size=FLAGS.conv_size,
                      dropout=FLAGS.dropout,
                      loss_type=FLAGS.loss_type)

        if FLAGS.train:
            model_vars = tf.trainable_variables()
            slim.model_analyzer.analyze_vars(model_vars, print_info=True)

            train_config = {}
            train_config['epoch'] = FLAGS.epoch

            unet.train(train_config)
Пример #6
0
def main():
    parser = argparse.ArgumentParser(description='SVLS Brats Training')
    parser.add_argument('--batch_size',
                        default=2,
                        type=int,
                        help='mini-batch size')
    parser.add_argument('--num_classes',
                        default=4,
                        type=int,
                        help="num of classes")
    parser.add_argument('--in_channels',
                        default=4,
                        type=int,
                        help="num of input channels")
    parser.add_argument('--train_option',
                        default='SVLS',
                        help="options:[SVLS, LS, OH]")
    parser.add_argument('--epochs',
                        default=200,
                        type=int,
                        help='number of total epochs to run')
    parser.add_argument('--data_root',
                        default='MICCAI_BraTS_2019_Data_Training/HGG_LGG',
                        help='data directory')
    parser.add_argument('--ckpt_dir',
                        default='ckpt_brats19',
                        help='ckpt directory')
    args = parser.parse_args()

    _, val_dataset = get_datasets_brats(data_root=args.data_root)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             pin_memory=False,
                                             num_workers=2)

    print('valid sample:', len(val_dataset), 'valid minibatch:',
          len(val_loader))

    model = UNet3D(inplanes=args.in_channels,
                   num_classes=args.num_classes).cuda()

    model = torch.nn.DataParallel(model)
    criterion_dice = EDiceLoss().cuda()

    legends = ['OH', 'LS(0.1)', 'LS(0.2)', 'LS(0.3)', 'SVLS']
    model_list = [
        'best_oh.pth.tar', 'best_ls0.1.pth.tar', 'best_ls0.2.pth.tar',
        'best_ls0.3.pth.tar', 'best_svls.pth.tar'
    ]
    for model_name, legend in zip(model_list, legends):
        model.load_state_dict(
            torch.load(os.path.join(args.ckpt_dir, model_name)))
        model.eval()
        with torch.no_grad():
            dice_metrics, metrics_sd, ece_avg, acc_avg, conf_avg, tace_avg = step_valid(
                val_loader, model, criterion_dice)
        if legend != 'LS(0.3)':
            reliability_diagram(conf_avg, acc_avg, legend=legend)
        dice_metrics = list(zip(*dice_metrics))
        dice_metrics = [
            torch.tensor(dice, device="cpu").numpy() for dice in dice_metrics
        ]
        avg_dices = np.mean(dice_metrics, 1)
        avg_std = np.std(dice_metrics, 1)

        metrics_sd = list(zip(*metrics_sd))
        metrics_sd = [
            torch.tensor(dice, device="cpu").numpy() for dice in metrics_sd
        ]
        avg_sd = np.mean(metrics_sd, 1)
        avg_std_sd = np.std(metrics_sd, 1)

        print(
            'model:%s , dice[ET:%.3f ± %.3f, TC:%.3f ± %.3f, WT:%.3f ± %.3f], ECE:%.4f, TACE:%.4f'
            % (model_name, avg_dices[0], avg_std[0], avg_dices[1], avg_std[1],
               avg_dices[2], avg_std[2], ece_avg, tace_avg))

        print(
            'model:%s , Surface dice[ET:%.3f ± %.3f, TC:%.3f ± %.3f, WT:%.3f ± %.3f]'
            % (model_name, avg_sd[0], avg_std_sd[0], avg_sd[1], avg_std_sd[1],
               avg_sd[2], avg_std_sd[2]))
Пример #7
0
                    default=1,
                    help='relative weight of positive samples for bce loss')
parser.add_argument('--epochs', type=int, default=300,
                    help="the total number of training epochs")
parser.add_argument('--restart', type=int, default=50,
                    help='restart learning rate every <restart> epochs')
parser.add_argument('--resume_model',
                    type=str,
                    default=None,
                    help='path to load previously saved model')
args = parser.parse_args(argv)
print(args)

is_cuda = torch.cuda.is_available()

net = UNet3D(1, 1, use_bias=True, inplanes=16)
if args.resume_model is not None:
    transfer_weights(net, args.resume_model)
bce_crit = nn.BCELoss()
dice_crit = DiceLoss()
last_bce_loss = 0
last_dice_loss = 0


def criterion(pred, labels, weights=[0.1, 0.9]):
    _bce_loss = bce_crit(pred, labels)
    _dice_loss = dice_crit(pred, labels)
    global last_bce_loss, last_dice_loss
    last_bce_loss = _bce_loss.item()
    last_dice_loss = _dice_loss.item()
    return weights[0] * _bce_loss + weights[1] * _dice_loss
Пример #8
0
def main():
    os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

    if not torch.cuda.is_available():
        raise Exception("No GPU found, please run without --cuda")

    device = torch.device("cuda")
    #device = torch.device("cpu")

    torch.backends.cudnn.benchmark = True

    # get the data files
    data_files = fetch_data_files(config['data_path'],
                                  config['data_filenames'])
    print("num of datasets %d" % (len(data_files)))

    # -------------------------------
    # create data generator for training and validatation, it can load the data from memory pretty fast using multiple workers and buffers if you need to load your data batch by batch
    training_list, validation_list = get_validation_split(data_files,
                                                          'training.pkl',
                                                          'val.pkl',
                                                          data_split=0.8,
                                                          overwrite=True)
    # To make sure the num of training and validation cases is dividable by num_gpus when doing multi-GPUs training
    training_set = DataGenerator(data_files,
                                 training_list,
                                 patch_size=config["train_patch_size"],
                                 voxel_size=config["voxel_size"],
                                 batch_size=config["batch_size"],
                                 shuffle=True)
    validation_set = DataGeneratorValid(data_files,
                                        validation_list,
                                        patch_size=config["test_patch_size"],
                                        voxel_size=config["voxel_size"],
                                        batch_size=1,
                                        shuffle=False)

    params_training = {
        'batch_size': config["batch_size"],
        'shuffle': True,
        'num_workers': 4
    }
    params_valid = {'batch_size': 1, 'shuffle': True, 'num_workers': 1}
    training_generator = torch.utils.data.DataLoader(training_set,
                                                     **params_training)
    validation_generator = torch.utils.data.DataLoader(validation_set,
                                                       **params_valid)

    model = UNet3D(in_channels=1,
                   out_channels=1,
                   final_sigmoid=False,
                   f_maps=config["n_base_filters"],
                   layer_order='cl',
                   num_groups=1,
                   num_levels=config["layer_depth"],
                   is_segmentation=False)
    if torch.cuda.is_available():
        model.cuda()

    print(model)

    optimizer = optim.Adam(model.parameters(),
                           lr=config["initial_learning_rate"])
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.8)

    for epoch in range(config["epochs"]):
        epoch_loss = 0
        iteration = 0

        for batch in training_generator:
            optimizer.zero_grad()

            rdf, mask, w, dipole_kernel = batch[0][0].to(
                device, dtype=torch.float), batch[1][0].to(
                    device, dtype=torch.float), batch[2][0].to(
                        device,
                        dtype=torch.float), batch[3][0].to(device,
                                                           dtype=torch.float)

            chi = model(rdf)
            chi = DoMask()([chi, mask])

            fm_chi = CalFMLayer()([chi, dipole_kernel])
            fm_chi = DoMask()([fm_chi, mask])

            ndi = NDIErr()([rdf, fm_chi, w])

            loss_ndi = torch.mean(torch.square(ndi))
            loss_tv = tv_loss(chi)

            loss = loss_ndi + 0.001 * loss_tv

            epoch_loss += loss.item()

            loss.backward()
            optimizer.step()

            print(
                "===> Epoch[{}]({}/{}): Loss: {:.4f} NDI Loss: {:.4f} TV Loss: {:.4f} "
                .format(epoch, iteration, len(training_generator), loss.item(),
                        loss_ndi.item(), loss_tv.item()))
            iteration += 1

        scheduler.step()
        for param_group in optimizer.param_groups:
            print("Current learning rate is: {}".format(param_group['lr']))

        print("===> Epoch {} Complete: Avg. Loss: {:.4f}".format(
            epoch, epoch_loss / len(training_generator)))

        model_out_path = "model_epoch_{}.pth".format(epoch)
        torch.save(model, model_out_path)
        model_out_path = "model.pth"
        torch.save(model, model_out_path)
        print("Checkpoint saved to {}".format(model_out_path))

        avg_loss = 0
        with torch.no_grad():
            idx = 0
            for batch in validation_generator:
                rdf, mask, w, dipole_kernel = batch[0][0].to(
                    device, dtype=torch.float), batch[1][0].to(
                        device, dtype=torch.float), batch[2][0].to(
                            device, dtype=torch.float), batch[3][0].to(
                                device, dtype=torch.float)

                chi = model(rdf)
                chi = DoMask()([chi, mask])

                chi_pred = (chi.cpu().detach().numpy()[0, 0]) / 100 * 3.0
                saveNifti(
                    chi_pred, 'chi_pred_' + 'epoch' + str(epoch) + '_data' +
                    str(idx) + '_.nii.gz')

                fm_chi = CalFMLayer()([chi, dipole_kernel])
                fm_chi = DoMask()([fm_chi, mask])

                ndi = NDIErr()([rdf, fm_chi, w])

                loss_ndi = torch.mean(torch.square(ndi))
                loss_tv = tv_loss(chi)

                loss = loss_ndi + 0.001 * loss_tv

                avg_loss += loss.item()
                idx += 1

        print("===> Avg. Loss: {:.4f}".format(avg_loss /
                                              len(validation_generator)))
Пример #9
0
def main_test(model=None, args=None, val_mode=False):
    work_dir = os.path.join(args.work_dir, args.exp)
    file_name = args.file_name
    if not val_mode:
        result_dir = os.path.join(work_dir, file_name)
        if not os.path.exists(result_dir):
            os.makedirs(result_dir)

        # load model and input stats
        # Note: here, the model should be given manually
        # TODO: try to import model configuration later
        if model is None:

            model = UNet3D(1, 1, f_maps=args.f_maps, depth_stride=args.depth_stride, conv_layer_order='cbr',
                           num_groups=1)
            model = nn.DataParallel(model).cuda()

        # load model
        checkpoint_path = os.path.join(work_dir, 'model_best.pth')
        state = torch.load(checkpoint_path)

        model.load_state_dict(state['state_dict'])
        cudnn.benchmark = True

    input_stats = np.load(os.path.join(work_dir, 'input_stats.npy'), allow_pickle=True).tolist()


    if not val_mode:

        allData_dic = dict_for_datas()

        # list exam ids
        collated_performance = {}
        for i in range(len(args.test_root)):
            exam_ids = os.listdir(os.path.join(args.test_root[i], 'images'))
            for exam_id in exam_ids:
                print('Processing {}'.format(exam_id))
                exam_path = os.path.join(args.test_root[i], 'images', exam_id)  # '/data2/test_3d/images/403'
                prediction_list, org_input_list, org_target_list = predict(model, exam_path, input_stats, args=args)

                # measure performance
                performance = performance_by_slice(prediction_list, org_target_list)

                # find folder
                find_folder = ''
                count = 0
                for data_no, level_no in allData_dic.items():
                    for level_key, level_val in level_no.items():
                        if exam_id in level_val:
                            if 'overall' in level_key.split('_'):  # prevent duplicate data save
                                continue
                            find_folder = level_key
                            count += 1
                assert count == 1, 'duplicate folder'

                result_dir_sep = os.path.join(result_dir, find_folder)
                save_fig(exam_id, org_input_list, org_target_list, prediction_list, performance, result_dir_sep)

                collated_performance[exam_id] = performance

        for data_no, level_no in allData_dic.items():
            for level_key, level_val in level_no.items():
                sep_dict = seperate_dict(collated_performance, level_val)
                if len(sep_dict) == 0 or len(sep_dict)!= len(level_val):
                    continue

                sep_performance = compute_overall_performance(sep_dict)

                with open(os.path.join(result_dir, '{}_performance.json'.format(level_key)), 'w') as f:
                    json.dump(sep_performance, f)
Пример #10
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--gpu',
                        '-g',
                        type=int,
                        default=-1,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--base',
                        '-B',
                        default=os.path.dirname(os.path.abspath(__file__)),
                        help='base directory path of program files')
    parser.add_argument('--config_path',
                        type=str,
                        default='configs/base.yml',
                        help='path to config file')
    parser.add_argument('--out',
                        '-o',
                        default='results/prediction',
                        help='Directory to output the result')

    parser.add_argument('--model',
                        '-m',
                        default='',
                        help='Load model data(snapshot)')

    parser.add_argument('--root',
                        '-R',
                        default=os.path.dirname(os.path.abspath(__file__)),
                        help='Root directory path of input image')
    parser.add_argument('--test_list',
                        default='configs/test_list.txt',
                        help='Path to test image list file')
    args = parser.parse_args()

    config = yaml_utils.Config(
        yaml.load(open(os.path.join(args.base, args.config_path))))
    print('GPU: {}'.format(args.gpu))
    print('')

    unet = UNet3D(config.unet['number_of_label'])
    chainer.serializers.load_npz(args.model, unet)
    if args.gpu >= 0:
        chainer.cuda.get_device_from_id(args.gpu).use()
        unet.to_gpu()
    xp = unet.xp

    # Read test list
    path_pairs = []
    with open(os.path.join(args.base, args.test_list)) as paths_file:
        for line in paths_file:
            line = line.split()
            if not line: continue
            path_pairs.append(line[:])

    for i in path_pairs:
        print('   Org   from: {}'.format(i[0]))
        print('   label from: {}'.format(i[1]))
        sitkOrg = sitk.ReadImage(os.path.join(args.root, 'data', i[0]))
        org = sitk.GetArrayFromImage(sitkOrg).astype("float32")

        # Calculate maximum of number of patch at each side
        ze, ye, xe = org.shape
        xm = int(math.ceil((float(xe) / float(config.patch['patchside']))))
        ym = int(math.ceil((float(ye) / float(config.patch['patchside']))))
        zm = int(math.ceil((float(ze) / float(config.patch['patchside']))))

        margin = ((0, config.patch['patchside']),
                  (0, config.patch['patchside']), (0,
                                                   config.patch['patchside']))
        org = np.pad(org, margin, 'edge')
        org = chainer.Variable(
            xp.array(org[np.newaxis, np.newaxis, :], dtype=xp.float32))

        prediction_map = np.zeros(
            (ze + config.patch['patchside'], ye + config.patch['patchside'],
             xe + config.patch['patchside']))
        probability_map = np.zeros(
            (config.unet['number_of_label'], ze + config.patch['patchside'],
             ye + config.patch['patchside'], xe + config.patch['patchside']))

        # Patch loop
        for s in range(xm * ym * zm):
            xi = int(s % xm) * config.patch['patchside']
            yi = int((s % (ym * xm)) / xm) * config.patch['patchside']
            zi = int(s / (ym * xm)) * config.patch['patchside']
            # Extract patch from original image
            patch = org[:, :, zi:zi + config.patch['patchside'],
                        yi:yi + config.patch['patchside'],
                        xi:xi + config.patch['patchside']]
            with chainer.using_config('train', False), chainer.using_config(
                    'enable_backprop', False):
                probability_patch = unet(patch)

            # Generate probability map
            probability_patch = probability_patch.data
            if args.gpu >= 0:
                probability_patch = chainer.cuda.to_cpu(probability_patch)
            for ch in range(probability_patch.shape[1]):
                probability_map[ch, zi:zi + config.patch['patchside'],
                                yi:yi + config.patch['patchside'], xi:xi +
                                config.patch['patchside']] = probability_patch[
                                    0, ch, :, :, :]

            prediction_patch = np.argmax(probability_patch, axis=1)

            prediction_map[zi:zi + config.patch['patchside'],
                           yi:yi + config.patch['patchside'], xi:xi +
                           config.patch['patchside']] = prediction_patch[
                               0, :, :, :]

        print('Save image')
        probability_map = probability_map[:, :ze, :ye, :xe]
        prediction_map = prediction_map[:ze, :ye, :xe]

        # Save prediction map
        imagePrediction = sitk.GetImageFromArray(prediction_map)
        imagePrediction.SetSpacing(sitkOrg.GetSpacing())
        imagePrediction.SetOrigin(sitkOrg.GetOrigin())
        result_dir = os.path.join(args.base, args.out, os.path.dirname(i[0]))
        if not os.path.exists(result_dir):
            os.makedirs(result_dir)
        fn = os.path.splitext(os.path.basename(i[0]))[0]
        sitk.WriteImage(imagePrediction, '{}/{}.mhd'.format(result_dir, fn))

        # Save probability map
        for ch in range(probability_map.shape[0]):
            imageProbability = sitk.GetImageFromArray(probability_map[ch, :])
            imageProbability.SetSpacing(sitkOrg.GetSpacing())
            imageProbability.SetOrigin(sitkOrg.GetOrigin())
            sitk.WriteImage(
                imageProbability,
                '{}/{}_probability_{}.mhd'.format(result_dir, fn, ch))
Пример #11
0
def main():
    parser = argparse.ArgumentParser(description='SVLS Brats Training')
    parser.add_argument('--lr',
                        default=1e-4,
                        type=float,
                        help='initial learning rate')
    parser.add_argument('--weight_decay',
                        '--weight-decay',
                        default=0.,
                        type=float,
                        help='weight decay')
    parser.add_argument('--batch_size',
                        default=2,
                        type=int,
                        help='mini-batch size')
    parser.add_argument('--num_classes',
                        default=4,
                        type=int,
                        help="num of classes")
    parser.add_argument('--in_channels',
                        default=4,
                        type=int,
                        help="num of input channels")
    parser.add_argument('--svls_smoothing',
                        default=1.0,
                        type=float,
                        help='SVLS smoothing factor')
    parser.add_argument('--ls_smoothing',
                        default=0.1,
                        type=float,
                        help='LS smoothing factor')
    parser.add_argument('--train_option',
                        default='SVLS',
                        help="options:[SVLS, LS, OH]")
    parser.add_argument('--epochs',
                        default=200,
                        type=int,
                        help='number of total epochs to run')
    parser.add_argument('--data_root',
                        default='MICCAI_BraTS_2019_Data_Training/HGG_LGG',
                        help='data directory')
    parser.add_argument('--ckpt_dir',
                        default='ckpt_brats19',
                        help='ckpt directory')
    args = parser.parse_args()
    args.save_folder = pathlib.Path(args.ckpt_dir)
    args.save_folder.mkdir(parents=True, exist_ok=True)
    train_dataset, val_dataset = get_datasets_brats(data_root=args.data_root)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=2,
                                               pin_memory=False)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             pin_memory=False,
                                             num_workers=2)

    print('train sample:',len(train_dataset), 'train minibatch:',len(train_loader),\
          'valid sample:',len(val_dataset), 'valid minibatch:',len(val_loader))

    model = UNet3D(inplanes=args.in_channels,
                   num_classes=args.num_classes).cuda()

    model = torch.nn.DataParallel(model)
    criterion_dice = EDiceLoss().cuda()

    print('train_option', args.train_option)
    if args.train_option == 'SVLS':
        criterion = CELossWithSVLS(classes=args.num_classes,
                                   sigma=args.svls_smoothing).cuda()
        best_ckpt_name = 'model_best_svls.pth.tar'
    elif args.train_option == 'LS':
        criterion = CELossWithLS(classes=args.num_classes,
                                 smoothing=args.ls_smoothing).cuda()
        best_ckpt_name = 'model_best_ls{}'.format(args.ls_smoothing)
    elif args.train_option == 'OH':
        args.ls_smoothing = 0.0
        criterion = CELossWithLS(classes=args.num_classes,
                                 smoothing=args.ls_smoothing).cuda()
        best_ckpt_name = 'model_best_oh.pth.tar'
    else:
        raise ValueError(args.train_option)

    print('ckpt name:', best_ckpt_name)
    best_ckpt_dir = os.path.join(str(args.save_folder), best_ckpt_name)
    metric = criterion_dice.metric_brats
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay,
                                 eps=1e-4)
    best_loss, best_epoch, best_dices = np.inf, 0, [0, 0, 0]
    args.start_epoch = 0
    for epoch in range(args.start_epoch, args.epochs):
        step_train(train_loader, model, criterion, metric, optimizer)
        with torch.no_grad():
            validation_loss, dice_metrics = step_valid(val_loader, model,
                                                       criterion, metric)
            dice_metrics = list(zip(*dice_metrics))
            dice_metrics = [
                torch.tensor(dice, device="cpu").numpy()
                for dice in dice_metrics
            ]
            avg_dices = np.mean(dice_metrics, 1)
        if validation_loss < best_loss:
            best_loss = validation_loss
            best_epoch = epoch
            best_dices = avg_dices
            torch.save(
                dict(epoch=epoch, arhi='unet', state_dict=model.state_dict()),
                best_ckpt_dir)
        print('epoch:%d/%d, loss:%.4f, best epoch:%d, best loss:%.4f, dice[ET:%.4f, TC:%.4f, WT:%.4f], best dice[ET:%.4f, TC:%.4f, WT:%.4f]' \
                %(epoch, args.epochs, validation_loss, best_epoch, best_loss, avg_dices[0], avg_dices[1], avg_dices[2], best_dices[0], best_dices[1], best_dices[2]))
Пример #12
0
def main(_):
    pp = pprint.PrettyPrinter()
    pp.pprint(flags.FLAGS.__flags)

    # Train
    all_train_paths = []
    for dirpath, dirnames, files in os.walk(FLAGS.train_data_dir):
        if os.path.basename(dirpath)[0:7] == 'Brats17':
            all_train_paths.append(dirpath)

    if FLAGS.split_train:
        if os.path.exists(os.path.join(FLAGS.train_patch_dir, 'files.log')):
            with open(os.path.join(FLAGS.train_patch_dir, 'files.log'),
                      'r') as f:
                training_paths, testing_paths = pickle.load(f)
        else:
            all_paths = [
                os.path.join(FLAGS.train_patch_dir, p)
                for p in sorted(os.listdir(FLAGS.train_data_dir))
            ]
            np.random.shuffle(all_paths)
            n_training = int(len(all_paths) * 4 / 5)
            training_paths = all_paths[:n_training]
            testing_paths = all_paths[n_training:]
            # Save the training paths and testing paths
            with open(os.path.join(FLAGS.train_data_dir, 'files.log'),
                      'w') as f:
                pickle.dump([training_paths, testing_paths], f)

        training_ids = [os.path.basename(i) for i in training_paths]
        testing_ids = [os.path.basename(i) for i in testing_paths]

        training_survival_data = {}
        testing_survival_data = {}
        with open(FLAGS.train_csv, 'r') as csvfile:
            reader = csv.reader(csvfile)
            for row in reader:
                if row[0] in training_ids:
                    training_survival_data[row[0]] = (row[1], row[2])
                elif row[0] in testing_ids:
                    testing_survival_data[row[0]] = (row[1], row[2])

        training_survival_paths = [
            p for p in all_train_paths
            if os.path.basename(p) in training_survival_data.keys()
        ]
        testing_survival_paths = [
            p for p in all_train_paths
            if os.path.basename(p) in testing_survival_data.keys()
        ]
    else:
        training_paths = [
            os.path.join(FLAGS.train_patch_dir, name)
            for name in os.listdir(FLAGS.train_patch_dir) if '.log' not in name
        ]
        testing_paths = None

        training_ids = [os.path.basename(i) for i in training_paths]
        training_survival_paths = []
        testing_survival_paths = None
        training_survival_data = {}
        testing_survival_data = None

        with open(FLAGS.train_csv, 'r') as csvfile:
            reader = csv.reader(csvfile)
            for row in reader:
                if row[0] in training_ids:
                    training_survival_data[row[0]] = (row[1], row[2])
        training_survival_paths = [
            p for p in all_train_paths
            if os.path.basename(p) in training_survival_data.keys()
        ]

    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)

    if not os.path.exists(FLAGS.log_dir):
        os.makedirs(FLAGS.log_dir)

    # Segmentation net
    if FLAGS.run_seg:
        run_config = tf.ConfigProto()
        with tf.Session(config=run_config) as sess:
            unet = UNet3D(sess,
                          checkpoint_dir=FLAGS.checkpoint_dir,
                          log_dir=FLAGS.log_dir,
                          training_paths=training_paths,
                          testing_paths=testing_paths,
                          batch_size=FLAGS.batch_size,
                          layers=FLAGS.layers,
                          features_root=FLAGS.seg_features_root,
                          conv_size=FLAGS.conv_size,
                          dropout=FLAGS.dropout,
                          loss_type=FLAGS.loss_type)

            if FLAGS.train:
                model_vars = tf.trainable_variables()
                slim.model_analyzer.analyze_vars(model_vars, print_info=True)

                train_config = {}
                train_config['epoch'] = FLAGS.epoch

                unet.train(train_config)
            else:
                # Deploy
                if not os.path.exists(FLAGS.deploy_output_dir):
                    os.makedirs(FLAGS.deploy_output_dir)
                unet.deploy(FLAGS.deploy_data_dir, FLAGS.deploy_output_dir)

        tf.reset_default_graph()

    # Survival net
    if FLAGS.run_survival:
        run_config = tf.ConfigProto()
        with tf.Session(config=run_config) as sess:
            survivalvae = SurvivalVAE(
                sess,
                checkpoint_dir=FLAGS.checkpoint_dir,
                log_dir=FLAGS.log_dir,
                training_paths=training_survival_paths,
                testing_paths=testing_survival_paths,
                training_survival_data=training_survival_data,
                testing_survival_data=testing_survival_data)

            if FLAGS.train:
                model_vars = tf.trainable_variables()
                slim.model_analyzer.analyze_vars(model_vars, print_info=True)

                train_config = {}
                train_config['epoch'] = FLAGS.epoch * 100

                survivalvae.train(train_config)
            else:
                all_deploy_paths = []
                for dirpath, dirnames, files in os.walk(FLAGS.deploy_data_dir):
                    if os.path.basename(dirpath)[0:7] == 'Brats17':
                        all_deploy_paths.append(dirpath)
                deploy_survival_data = {}
                with open(FLAGS.deploy_csv, 'r') as csvfile:
                    reader = csv.reader(csvfile)
                    for row in reader:
                        if row[0] != 'Brats17ID':
                            deploy_survival_data[row[0]] = row[1]
                deploy_survival_paths = [
                    p for p in all_deploy_paths
                    if os.path.basename(p) in deploy_survival_data.keys()
                ]
                survivalnet.deploy(FLAGS.deploy_survival_paths,
                                   FLAGS.deploy_survival_data)
Пример #13
0
def main():
    parser = argparse.ArgumentParser(description='Train 3D-Unet')
    parser.add_argument('--gpu',
                        '-g',
                        type=int,
                        default=-1,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--base',
                        '-B',
                        default=os.path.dirname(os.path.abspath(__file__)),
                        help='base directory path of program files')
    parser.add_argument('--config_path',
                        type=str,
                        default='configs/base.yml',
                        help='path to config file')
    parser.add_argument('--out',
                        '-o',
                        default='results/training',
                        help='Directory to output the result')

    parser.add_argument('--model', '-m', default='', help='Load model data')
    parser.add_argument('--resume',
                        '-res',
                        default='',
                        help='Resume the training from snapshot')

    parser.add_argument('--root',
                        '-R',
                        default=os.path.dirname(os.path.abspath(__file__)),
                        help='Root directory path of input image')
    parser.add_argument('--training_list',
                        default='configs/training_list.txt',
                        help='Path to training image list file')
    parser.add_argument('--validation_list',
                        default='configs/validation_list.txt',
                        help='Path to validation image list file')

    args = parser.parse_args()
    '''
    'https://stackoverflow.com/questions/21005822/what-does-os-path-abspathos-path-joinos-path-dirname-file-os-path-pardir'
    '''
    config = yaml_utils.Config(
        yaml.load(open(os.path.join(args.base, args.config_path))))
    print('GPU: {}'.format(args.gpu))
    print('# Minibatch-size: {}'.format(config.batchsize))
    print('# iteration: {}'.format(config.iteration))
    print('Learning Rate: {}'.format(config.adam['alpha']))
    print('')

    # Load the datasets
    train = UnetDataset(args.root, os.path.join(args.base, args.training_list),
                        config.patch['patchside'],
                        config.unet['number_of_label'])
    train_iter = chainer.iterators.SerialIterator(train,
                                                  batch_size=config.batchsize)

    val = UnetDataset(args.root, os.path.join(args.base, args.validation_list),
                      config.patch['patchside'],
                      config.unet['number_of_label'])
    val_iter = chainer.iterators.SerialIterator(val,
                                                batch_size=config.batchsize,
                                                repeat=False,
                                                shuffle=False)

    # Set up a neural network to train
    print('Set up a neural network to train')
    unet = UNet3D(config.unet['number_of_label'])
    if args.model:
        chainer.serializers.load_npz(args.model, gen)

    if args.gpu >= 0:
        chainer.cuda.get_device_from_id(args.gpu).use()
        unet.to_gpu()

    #Set up an optimizer
    def make_optimizer(model, alpha=0.00001, beta1=0.9, beta2=0.999):
        optimizer = chainer.optimizers.Adam(alpha=alpha,
                                            beta1=beta1,
                                            beta2=beta2)
        optimizer.setup(model)
        return optimizer

    opt_unet = make_optimizer(model=unet,
                              alpha=config.adam['alpha'],
                              beta1=config.adam['beta1'],
                              beta2=config.adam['beta2'])
    #Set up a trainer
    updater = Unet3DUpdater(models=(unet),
                            iterator=train_iter,
                            optimizer={'unet': opt_unet},
                            device=args.gpu)

    def create_result_dir(base_dir, output_dir, config_path, config):
        """https://github.com/pfnet-research/sngan_projection/blob/master/train.py"""
        result_dir = os.path.join(args.base, output_dir)
        if not os.path.exists(result_dir):
            os.makedirs(result_dir)

        def copy_to_result_dir(fn, result_dir):
            bfn = os.path.basename(fn)
            shutil.copy(fn, '{}/{}'.format(result_dir, bfn))

        copy_to_result_dir(os.path.join(base_dir, config_path), result_dir)
        copy_to_result_dir(os.path.join(base_dir, config.unet['fn']),
                           result_dir)
        copy_to_result_dir(os.path.join(base_dir, config.updater['fn']),
                           result_dir)

    create_result_dir(args.base, args.out, args.config_path, config)

    trainer = training.Trainer(updater, (config.iteration, 'iteration'),
                               out=os.path.join(args.base, args.out))

    # Set up logging
    snapshot_interval = (config.snapshot_interval, 'iteration')
    display_interval = (config.display_interval, 'iteration')
    evaluation_interval = (config.evaluation_interval, "iteration")
    trainer.extend(
        extensions.snapshot(filename='snapshot_iter_{.updater.iteration}.npz'),
        trigger=snapshot_interval)
    trainer.extend(extensions.snapshot_object(
        unet, filename=unet.__class__.__name__ + '_{.updater.iteration}.npz'),
                   trigger=snapshot_interval)

    # Write a log of evaluation statistics for each epoch
    trainer.extend(extensions.LogReport(trigger=display_interval))

    # Print a progress bar to stdout
    trainer.extend(extensions.ProgressBar(update_interval=10))

    # Print selected entries of the log to stdout
    report_keys = [
        'epoch', 'iteration', 'unet/loss', 'unet/val/loss', 'unet/val/dice'
    ]
    trainer.extend(extensions.PrintReport(report_keys),
                   trigger=display_interval)

    trainer.extend(Unet3DEvaluator(val_iter,
                                   unet,
                                   config.unet['number_of_label'],
                                   device=args.gpu),
                   trigger=evaluation_interval)

    # Use linear shift
    ext_opt_unet = extensions.LinearShift(
        'alpha', (config.adam['alpha'], 0.),
        (config.iteration_decay_start, config.iteration), opt_unet)
    trainer.extend(ext_opt_unet)

    # Save two plot images to the result dir
    if extensions.PlotReport.available():
        trainer.extend(
            extensions.PlotReport(['unet/loss', 'unet/val/loss'],
                                  'iteration',
                                  file_name='unet_loss.png',
                                  trigger=display_interval))
        trainer.extend(
            extensions.PlotReport(['unet/val/dice'],
                                  'iteration',
                                  file_name='unet_dice_score.png',
                                  trigger=display_interval))

    if args.resume:
        # Resume from a snapshot
        print("Resume training with snapshot:{}".format(args.resume))
        chainer.serializers.load_npz(args.resume, trainer)

    # Run the training
    chainer.config.autotune = True
    print('Start training')
    trainer.run()
Пример #14
0
def main(_):
    pp = pprint.PrettyPrinter()
    pp.pprint(flags.FLAGS.__flags)

    # Train
    if FLAGS.split_train:
        if os.path.exists(os.path.join(FLAGS.train_patch_dir, 'files.log')):
            with open(os.path.join(FLAGS.train_patch_dir, 'files.log'),
                      'r') as f:
                training_paths, testing_paths = pickle.load(f)
        else:
            all_paths = [
                os.path.join(FLAGS.train_patch_dir, p)
                for p in sorted(os.listdir(FLAGS.train_data_dir))
            ]
            np.random.shuffle(all_paths)
            n_training = int(len(all_paths) * 4 / 5)
            training_paths = all_paths[:n_training]
            testing_paths = all_paths[n_training:]
            # Save the training paths and testing paths
            with open(os.path.join(FLAGS.train_data_dir, 'files.log'),
                      'w') as f:
                pickle.dump([training_paths, testing_paths], f)

        training_ids = [os.path.basename(i) for i in training_paths]
        testing_ids = [os.path.basename(i) for i in testing_paths]

    else:
        # train_patch_dir = data/ATLAS_R1.1/train/
        training_paths = []
        for dirpath, dirnames, files in os.walk(FLAGS.train_patch_dir):
            if os.path.basename(dirpath)[0:7] == 'patches':
                training_paths.append(dirpath)

        testing_paths = []
        for dirpath, dirnames, files in os.walk(FLAGS.testing_data_dir):
            if os.path.basename(dirpath)[0:7] == 'patches':
                testing_paths.append(dirpath)

    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)

    if not os.path.exists(FLAGS.log_dir):
        os.makedirs(FLAGS.log_dir)

    # Segmentation net
    if FLAGS.run_seg:
        run_config = tf.ConfigProto()
        with tf.Session(config=run_config) as sess:
            unet = UNet3D(sess,
                          checkpoint_dir=FLAGS.checkpoint_dir,
                          log_dir=FLAGS.log_dir,
                          training_paths=training_paths,
                          testing_paths=testing_paths,
                          batch_size=FLAGS.batch_size,
                          layers=FLAGS.layers,
                          features_root=FLAGS.seg_features_root,
                          conv_size=FLAGS.conv_size,
                          dropout=FLAGS.dropout,
                          loss_type=FLAGS.loss_type)

            if FLAGS.train:
                model_vars = tf.trainable_variables()
                slim.model_analyzer.analyze_vars(model_vars, print_info=True)

                train_config = {}
                train_config['epoch'] = FLAGS.epoch

                unet.train(train_config)
            else:
                # Deploy
                unet.deploy(FLAGS.deploy_data_dir)

        tf.reset_default_graph()
Пример #15
0
    parser.add_argument('--file',
                        '-f',
                        metavar='FILE',
                        dest='file',
                        default='16.mat',
                        help='filename')

    return parser.parse_args()


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO,
                        format='%(levelname)s: %(message)s')
    args = args_setting()

    net = UNet3D(residual='conv')

    logging.info("Loading model {}".format(args.model))

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.info(f'Using device {device}')
    net.to(device=device)

    net = torch.nn.DataParallel(net)
    net.load_state_dict(torch.load(args.model, map_location=device))

    logging.info("Model loaded !")

    print(args.input + args.file)
    with h5py.File(args.input + args.file, 'r') as f:
        dset = f['RawImage']
Пример #16
0
    im = np.expand_dims(im, axis=2)
    img_shape = im.shape
    H = img_shape[0]
    W = img_shape[1]

    out = np.concatenate((im[0:H:2, 0:W:2, :], im[0:H:2, 1:W:2, :],
                          im[1:H:2, 1:W:2, :], im[1:H:2, 0:W:2, :]),
                         axis=2)
    return out


sess = tf.Session()
in_image = tf.placeholder(tf.float32, [None, None, None, 4])
gt_image = tf.placeholder(tf.float32, [None, None, None, 3])
net = UNet3D()
out_image = net.construct_model(in_image)

saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt:
    print('loaded ' + ckpt.model_checkpoint_path)
    saver.restore(sess, ckpt.model_checkpoint_path)

if not os.path.isdir(result_dir + 'final/'):
    os.makedirs(result_dir + 'final/')

psnr = {}
ssim = {}
for test_id in test_ids:
Пример #17
0
def main():

    work_dir = os.path.join(args.work_dir, args.exp)
    if not os.path.exists(work_dir):
        os.makedirs(work_dir)

    # copy this file to work dir to keep training configuration
    shutil.copy(__file__, os.path.join(work_dir, 'main.py'))
    with open(os.path.join(work_dir, 'args.pkl'), 'wb') as f:
        pickle.dump(args, f)

    #train
    trn_image_root = os.path.join(args.trn_root, 'images')
    exam_ids = os.listdir(trn_image_root)
    random.shuffle(exam_ids)
    train_exam_ids = exam_ids

    #train_exam_ids = exam_ids[:int(len(exam_ids)*0.8)]
    #val_exam_ids = exam_ids[int(len(exam_ids) * 0.8):]

    # train_dataset
    trn_dataset = DatasetTrain(args.trn_root,
                               train_exam_ids,
                               options=args,
                               input_stats=[0.5, 0.5])
    trn_loader = torch.utils.data.DataLoader(trn_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=True,
                                             num_workers=args.num_workers)

    # save input stats for later use
    np.save(os.path.join(work_dir, 'input_stats.npy'), trn_dataset.input_stats)

    #val
    val_image_root = os.path.join(args.val_root, 'images')
    val_exam = os.listdir(val_image_root)
    random.shuffle(val_exam)
    val_exam_ids = val_exam

    # val_dataset
    val_dataset = DatasetVal(args.val_root,
                             val_exam_ids,
                             options=args,
                             input_stats=trn_dataset.input_stats)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=True,
                                             num_workers=args.num_workers)

    # make logger
    trn_logger = Logger(os.path.join(work_dir, 'train.log'))
    trn_raw_logger = Logger(os.path.join(work_dir, 'train_raw.log'))
    val_logger = Logger(os.path.join(work_dir, 'validation.log'))

    # model_select
    if args.model == 'unet':
        net = UNet3D(1,
                     1,
                     f_maps=args.f_maps,
                     depth_stride=args.depth_stride,
                     conv_layer_order=args.conv_layer_order,
                     num_groups=args.num_groups)

    else:
        raise ValueError('Not supported network.')

    # loss_select
    if args.loss_function == 'bce':
        criterion = nn.BCEWithLogitsLoss(
            pos_weight=torch.Tensor([args.bce_weight])).cuda()
    elif args.loss_function == 'dice':
        criterion = DiceLoss().cuda()
    elif args.loss_function == 'weight_bce':
        criterion = nn.BCEWithLogitsLoss(
            pos_weight=torch.FloatTensor([5])).cuda()
    else:
        raise ValueError('{} loss is not supported yet.'.format(
            args.loss_function))

    # optim_select
    if args.optim == 'sgd':
        optimizer = optim.SGD(net.parameters(),
                              lr=args.lr,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay,
                              nesterov=False)

    elif args.optim == 'adam':
        optimizer = optim.Adam(net.parameters(),
                               lr=args.lr,
                               weight_decay=args.weight_decay)
    else:
        raise ValueError('{} optim is not supported yet.'.format(args.optim))

    net = nn.DataParallel(net).cuda()
    cudnn.benchmark = True

    lr_schedule = args.lr_schedule
    lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                                  milestones=lr_schedule[:-1],
                                                  gamma=0.1)

    best_iou = 0
    for epoch in range(lr_schedule[-1]):

        train(trn_loader, net, criterion, optimizer, epoch, trn_logger,
              trn_raw_logger)
        iou = validate(val_loader, net, criterion, epoch, val_logger)

        lr_scheduler.step()

        # save model parameter
        is_best = iou > best_iou
        best_iou = max(iou, best_iou)
        checkpoint_filename = 'model_checkpoint_{:0>3}.pth'.format(epoch + 1)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': net.state_dict(),
                'optimizer': optimizer.state_dict()
            }, is_best, work_dir, checkpoint_filename)

    # visualize curve
    draw_curve(work_dir, trn_logger, val_logger)

    if args.inplace_test:
        # calc overall performance and save figures
        print('Test mode ...')
        main_test(model=net, args=args)
Пример #18
0
def main(_):
    pp = pprint.PrettyPrinter()
    pp.pprint(flags.FLAGS.__flags)
    
    if FLAGS.test_data_dir == FLAGS.train_data_dir:
        testing_gt_available = True
        if os.path.exists(os.path.join(FLAGS.train_data_dir, 'files.log')):
            with open(os.path.join(FLAGS.train_data_dir, 'files.log'), 'r') as f:
                training_paths, testing_paths = pickle.load(f)
        else:
            # Phase 0
            all_subjects = [os.path.join(FLAGS.train_data_dir, name) for name in os.listdir(FLAGS.train_data_dir)]
            n_training = int(np.rint(len(all_subjects) * 2 / 3))
            training_paths = all_subjects[:n_training]
            testing_paths = all_subjects[n_training:]
            # Save the training paths and testing paths
            with open(os.path.join(FLAGS.train_data_dir, 'files.log'), 'w') as f:
                pickle.dump([training_paths, testing_paths], f)
    else:
        testing_gt_available = False
        training_paths = [os.path.join(FLAGS.train_data_dir, name)
                          for name in os.listdir(FLAGS.train_data_dir) if '.hdf5' in name]
        testing_paths = [os.path.join(FLAGS.test_data_dir, name)
                         for name in os.listdir(FLAGS.test_data_dir) if '.hdf5' in name]
        
    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    
    if not os.path.exists(FLAGS.log_dir):
        os.makedirs(FLAGS.log_dir)
    
    run_config = tf.ConfigProto()
    with tf.Session(config=run_config) as sess:
        unet_all = UNet3D(sess, checkpoint_dir=FLAGS.checkpoint_dir, log_dir=FLAGS.log_dir, training_paths=training_paths,
                          testing_paths=testing_paths, nclass=N_CLASSES + 1, layers=FLAGS.layers,
                          features_root=FLAGS.step1_features_root, conv_size=FLAGS.conv_size, dropout=FLAGS.dropout_ratio,
                          loss_type=FLAGS.loss_type, roi=(-1, 'All'), im_size=ALL_IM_SIZE,
                          testing_gt_available=testing_gt_available, class_weights=(1.0, 2.0, 1.0, 1.0, 1.0, 3.0))
        if FLAGS.train:
            train_config = {}
            train_config['epoch'] = FLAGS.epoch
            unet_all.train(train_config)
        else:
            if not os.path.exists(FLAGS.output_dir):
                os.makedirs(FLAGS.output_dir)
                    
            unet_all.test(testing_paths, FLAGS.output_dir)

    tf.reset_default_graph()
    
    # Second step training
    rois = ['SpinalCord', 'Lung_R', 'Lung_L', 'Heart', 'Esophagus']
    im_sizes = [(160, 128, 64), (72, 192, 120), (72, 192, 120), (32, 160, 192), (80, 80, 64)]
    weights = [(1.0, 2.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (1.0, 3.0)]
        
    for roi in range(5):
        run_config = tf.ConfigProto()
        # Build model
        with tf.Session(config=run_config) as sess:
            unet = UNet3D(sess, checkpoint_dir=FLAGS.checkpoint_dir, log_dir=FLAGS.log_dir, training_paths=training_paths,
                          testing_paths=testing_paths, nclass=2, layers=FLAGS.layers, features_root=FLAGS.step2_features_root,
                          conv_size=FLAGS.conv_size, dropout=FLAGS.dropout_ratio, loss_type=FLAGS.loss_type,
                          roi=(roi, rois[roi]), im_size=im_sizes[roi], testing_gt_available=testing_gt_available,
                          class_weights=weights[roi])
            
            if FLAGS.train:
                train_config = {}
                train_config['epoch'] = FLAGS.epoch
                unet.train(train_config)
            else:
                if not os.path.exists(FLAGS.output_dir):
                    os.makedirs(FLAGS.output_dir)
                    
                # Get result for single ROI
                unet.test(testing_paths, FLAGS.output_dir)
                
        tf.reset_default_graph()
Пример #19
0
def main():
    parser = argparse.ArgumentParser(description='SVLS Brats Training')
    parser.add_argument('--batch_size',
                        default=2,
                        type=int,
                        help='mini-batch size')
    parser.add_argument('--num_classes',
                        default=4,
                        type=int,
                        help="num of classes")
    parser.add_argument('--in_channels',
                        default=4,
                        type=int,
                        help="num of input channels")
    parser.add_argument('--train_option',
                        default='SVLS',
                        help="options:[SVLS, LS, OH]")
    parser.add_argument('--epochs',
                        default=200,
                        type=int,
                        help='number of total epochs to run')
    parser.add_argument('--data_root',
                        default='MICCAI_BraTS_2019_Data_Training/HGG_LGG',
                        help='data directory')
    parser.add_argument('--ckpt_dir',
                        default='ckpt_brats19',
                        help='ckpt directory')
    args = parser.parse_args()

    _, val_dataset = get_datasets_brats(data_root=args.data_root)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             pin_memory=False,
                                             num_workers=2)

    print('valid sample:', len(val_dataset), 'valid minibatch:',
          len(val_loader))

    model = UNet3D(inplanes=args.in_channels,
                   num_classes=args.num_classes).cuda()

    model = torch.nn.DataParallel(model)
    criterion_dice = EDiceLoss().cuda()

    print('train_option', args.train_option)
    if args.train_option == 'SVLS':
        best_ckpt_name = 'best_svls.pth.tar'
    elif args.train_option == 'LS':
        best_ckpt_name = 'best_ls{}'.format(args.ls_smoothing)
    elif args.train_option == 'OH':
        best_ckpt_name = 'best_oh.pth.tar'
    else:
        raise ValueError(args.train_option)

    print('ckpt name:', best_ckpt_name)
    best_ckpt_dir = os.path.join(args.ckpt_dir, best_ckpt_name)
    model.load_state_dict(torch.load(best_ckpt_dir))
    metric = criterion_dice.metric_brats
    with torch.no_grad():
        dice_metrics = step_valid(val_loader, model, metric)
        dice_metrics = list(zip(*dice_metrics))
        dice_metrics = [
            torch.tensor(dice, device="cpu").numpy() for dice in dice_metrics
        ]
        avg_dices = np.mean(dice_metrics, 1)

    print('dice[ET:%.4f, TC:%.4f, WT:%.4f]' %
          (avg_dices[0], avg_dices[1], avg_dices[2]))