Esempio n. 1
0
def main():
    # Load configuration
    args = parser.parse_args()

    # Torch stuff
    torch.cuda.set_device(args.rank)
    cudnn.benchmark = True

    # Create model by loading a snapshot
    body, head, cls_state = load_snapshot(args.snapshot)
    model = SegmentationModule(body, head, 256, 65, args.fusion_mode)
    model.cls.load_state_dict(cls_state)
    model = model.cuda().eval()
    print(model)

    # Create data loader
    transformation = SegmentationTransform(
        2048,
        (0.41738699, 0.45732192, 0.46886091),
        (0.25685097, 0.26509955, 0.29067996),
    )
    dataset = SegmentationDataset(args.data, transformation)
    data_loader = DataLoader(
        dataset,
        batch_size=1,
        pin_memory=True,
        sampler=DistributedSampler(dataset, args.world_size, args.rank),
        num_workers=2,
        collate_fn=segmentation_collate,
        shuffle=False,
    )

    # Run testing
    scales = eval(args.scales)
    with torch.no_grad():
        for batch_i, rec in enumerate(data_loader):
            print("Testing batch [{:3d}/{:3d}]".format(batch_i + 1,
                                                       len(data_loader)))

            img = rec["img"].cuda(non_blocking=True)
            probs, preds = model(img, scales, args.flip)

            for i, (prob, pred) in enumerate(
                    zip(torch.unbind(probs, dim=0), torch.unbind(preds,
                                                                 dim=0))):
                out_size = rec["meta"][i]["size"]
                img_name = rec["meta"][i]["idx"]

                # Save prediction
                prob = prob.cpu()
                pred = pred.cpu()
                pred_img = get_pred_image(pred, out_size,
                                          args.output_mode == "palette")
                pred_img.save(path.join(args.output, img_name + ".png"))

                # Optionally save probabilities
                if args.output_mode == "prob":
                    prob_img = get_prob_image(prob, out_size)
                    prob_img.save(
                        path.join(args.output, img_name + "_prob.png"))
def train(train_config_file):
    """ Medical image segmentation training engine
    :param train_config_file: the input configuration file
    :return: None
    """
    assert os.path.isfile(train_config_file), 'Config not found: {}'.format(
        train_config_file)

    # load config file
    train_cfg = load_config(train_config_file)

    # clean the existing folder if training from scratch
    model_folder = os.path.join(train_cfg.general.save_dir,
                                train_cfg.general.model_scale)
    if os.path.isdir(model_folder):
        if train_cfg.general.resume_epoch < 0:
            shutil.rmtree(model_folder)
            os.makedirs(model_folder)
    else:
        os.makedirs(model_folder)

    # copy training and inference config files to the model folder
    shutil.copy(train_config_file, os.path.join(model_folder,
                                                'train_config.py'))
    infer_config_file = os.path.join(
        os.path.join(os.path.dirname(__file__), 'config', 'infer_config.py'))
    shutil.copy(infer_config_file,
                os.path.join(train_cfg.general.save_dir, 'infer_config.py'))

    # enable logging
    log_file = os.path.join(model_folder, 'train_log.txt')
    logger = setup_logger(log_file, 'seg3d')

    # control randomness during training
    np.random.seed(train_cfg.general.seed)
    torch.manual_seed(train_cfg.general.seed)
    if train_cfg.general.num_gpus > 0:
        torch.cuda.manual_seed(train_cfg.general.seed)

    # dataset
    train_dataset = SegmentationDataset(
        mode='train',
        im_list=train_cfg.general.train_im_list,
        num_classes=train_cfg.dataset.num_classes,
        spacing=train_cfg.dataset.spacing,
        crop_size=train_cfg.dataset.crop_size,
        sampling_method=train_cfg.dataset.sampling_method,
        random_translation=train_cfg.dataset.random_translation,
        random_scale=train_cfg.dataset.random_scale,
        interpolation=train_cfg.dataset.interpolation,
        crop_normalizers=train_cfg.dataset.crop_normalizers)
    train_data_loader = DataLoader(train_dataset,
                                   batch_size=train_cfg.train.batchsize,
                                   num_workers=train_cfg.train.num_threads,
                                   pin_memory=True,
                                   shuffle=True)

    val_dataset = SegmentationDataset(
        mode='val',
        im_list=train_cfg.general.val_im_list,
        num_classes=train_cfg.dataset.num_classes,
        spacing=train_cfg.dataset.spacing,
        crop_size=train_cfg.dataset.crop_size,
        sampling_method=train_cfg.dataset.sampling_method,
        random_translation=train_cfg.dataset.random_translation,
        random_scale=train_cfg.dataset.random_scale,
        interpolation=train_cfg.dataset.interpolation,
        crop_normalizers=train_cfg.dataset.crop_normalizers)
    val_data_loader = DataLoader(val_dataset,
                                 batch_size=1,
                                 num_workers=1,
                                 shuffle=False)

    # define network
    net = GlobalLocalNetwork(train_dataset.num_modality(),
                             train_cfg.dataset.num_classes)
    net.apply(kaiming_weight_init)
    max_stride = net.max_stride()

    if train_cfg.general.num_gpus > 0:
        net = nn.parallel.DataParallel(net,
                                       device_ids=list(
                                           range(train_cfg.general.num_gpus)))
        net = net.cuda()

    assert np.all(np.array(train_cfg.dataset.crop_size) %
                  max_stride == 0), 'crop size not divisible by max stride'

    # training optimizer
    opt = optim.Adam(net.parameters(),
                     lr=train_cfg.train.lr,
                     betas=train_cfg.train.betas)

    # load checkpoint if resume epoch > 0
    if train_cfg.general.resume_epoch >= 0:
        last_save_epoch = load_checkpoint(train_cfg.general.resume_epoch, net,
                                          opt, model_folder)
    else:
        last_save_epoch = 0

    if train_cfg.loss.name == 'Focal':
        # reuse focal loss if exists
        loss_func = FocalLoss(class_num=train_cfg.dataset.num_classes,
                              alpha=train_cfg.loss.obj_weight,
                              gamma=train_cfg.loss.focal_gamma,
                              use_gpu=train_cfg.general.num_gpus > 0)
    else:
        raise ValueError('Unknown loss function')

    writer = SummaryWriter(os.path.join(model_folder, 'tensorboard'))

    max_avg_dice = 0
    for epoch_idx in range(1, train_cfg.train.epochs + 1):
        train_one_epoch(net, train_cfg.loss.branch_weight, opt,
                        train_data_loader, train_cfg.dataset.down_sample_ratio,
                        loss_func, train_cfg.general.num_gpus,
                        epoch_idx + last_save_epoch, logger, writer,
                        train_cfg.train.print_freq,
                        train_cfg.debug.save_inputs,
                        os.path.join(model_folder, 'debug'))

        # evaluation
        if epoch_idx % train_cfg.train.save_epochs == 0:
            avg_dice = evaluate_one_epoch(
                net, val_data_loader, train_cfg.dataset.crop_size,
                train_cfg.dataset.down_sample_ratio,
                train_cfg.dataset.crop_normalizers[0], Metrics(),
                [idx for idx in range(1, train_cfg.dataset.num_classes)],
                train_cfg.loss.branch_type)

            if max_avg_dice < avg_dice:
                max_avg_dice = avg_dice
                save_checkpoint(net, opt, epoch_idx, train_cfg, max_stride, 1)
                msg = 'epoch: {}, best dice ratio: {}'

            else:
                msg = 'epoch: {},  dice ratio: {}'

            msg = msg.format(epoch_idx, avg_dice)
            logger.info(msg)
Esempio n. 3
0
def main():
    # Load configuration
    args = parser.parse_args()

    # Torch stuff
    device = torch.device("cuda:0")
    cudnn.benchmark = True

    # Create model by loading a snapshot
    body, head, cls_state = load_snapshot(args.snapshot)
    model = SegmentationModule(body, head, 256, 65, args.fusion_mode)
    model.cls.load_state_dict(cls_state)
    model = nn.DataParallel(model, output_device=-1).cuda(device).eval()
    print(model)

    # Create data loader
    transformation = SegmentationTransform(
        2048,
        (1024, 2048),
        (0.41738699, 0.45732192, 0.46886091),
        (0.25685097, 0.26509955, 0.29067996),
    )
    dataset = SegmentationDataset(args.data, transformation)
    data_loader = DataLoader(dataset,
                             batch_size=torch.cuda.device_count(),
                             pin_memory=True,
                             num_workers=torch.cuda.device_count(),
                             collate_fn=segmentation_collate,
                             shuffle=False)

    # Run testing
    scales = eval(args.scales)
    with torch.no_grad():
        for batch_i, rec in enumerate(data_loader):
            print("Testing batch [{:3d}/{:3d}]".format(batch_i + 1,
                                                       len(data_loader)))

            img = rec["img"].cuda(device, True)
            probs, preds = model(img, scales, args.flip)

            for i, (prob, pred) in enumerate(
                    zip(torch.unbind(probs, dim=0), torch.unbind(preds,
                                                                 dim=0))):
                crop_bbx = rec["meta"][i]["valid_bbx"]
                out_size = rec["meta"][i]["size"]
                img_name = rec["meta"][i]["idx"]

                # Crop to valid area
                pred = pred[crop_bbx[0]:crop_bbx[2], crop_bbx[1]:crop_bbx[3]]

                # Save prediction
                pred_img = get_pred_image(pred, out_size,
                                          args.output_mode == "palette")
                pred_img.save(path.join(args.output, img_name + ".png"))

                # Optionally save probabilities
                if args.output_mode == "prob":
                    prob = prob[crop_bbx[0]:crop_bbx[2],
                                crop_bbx[1]:crop_bbx[3]]
                    prob_img = get_prob_image(prob, out_size)
                    prob_img.save(
                        path.join(args.output, img_name + "_prob.png"))