예제 #1
0
from torch.nn import DataParallel

model = import_module('models.model_loader')
precise_net, loss = model.get_full_model(
    net_name, loss_name, n_classes=5, alpha=alpha)
c_loss = DependentLoss(alpha)
checkpoint = torch.load(precise_net_path)
precise_net.load_state_dict(checkpoint['state_dict'])
precise_net = precise_net.to(DEVICE)
precise_net = DataParallel(precise_net)
precise_net.eval()

################# first get the threshold #####################
composed_transforms_tr = transforms.Compose([
    tr.Normalize(mean=(0.12, 0.12, 0.12), std=(0.018, 0.018, 0.018)),
    tr.ToTensor2(5)
])
test_files = ["Patient_01"]
'''
eval_dataset = THOR_Data(
       transform=composed_transforms_tr,
    path=data_path,
    file_list=test_files)'''
eval_dataset = THOR_Data(
    path=data_path,
    file_list=test_files)

eval_loader = DataLoader(
    eval_dataset,
    batch_size=1,
    shuffle=False,
예제 #2
0
def evaluation(args, net, loss, epoch, save_dir, test_files, saved_thresholds):
    start_time = time.time()
    net.eval()
    eval_loss = []
    total_precision = []
    total_recall = []

    composed_transforms_tr = transforms.Compose([
        tr.Normalize(mean=(0.12, 0.12, 0.12), std=(0.018, 0.018, 0.018)),
        tr.ToTensor2(args.n_class)
    ])
    eval_dataset = THOR_Data(transform=composed_transforms_tr,
                             path=args.data_path,
                             file_list=test_files)
    evalloader = DataLoader(eval_dataset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=4)
    cur_target = []
    cur_predict = []
    class_predict = []
    class_target = []
    for i, sample in enumerate(evalloader):
        data = sample['image']
        target_c = sample['label_c']
        target_s = sample['label_s']
        data = data.to(DEVICE)
        target_c = target_c.to(DEVICE)
        target_s = target_s.to(DEVICE)
        with torch.no_grad():
            output_s, output_c = net(data)
            cur_loss, _, _, c_p = loss(output_s, output_c, target_s, target_c)
        eval_loss.append(cur_loss.item())
        cur_target.append(torch.argmax(target_s, 1).cpu().numpy())
        cur_predict.append(torch.argmax(output_s, 1).cpu().numpy())
        class_target.append(target_c.cpu().numpy())
        class_predict.append(c_p.cpu().numpy())
        cur_precision, cur_recall = metric(np.concatenate(class_predict, 0),
                                           np.concatenate(class_target, 0),
                                           saved_thresholds)

    total_precision.append(np.array(cur_precision))
    total_recall.append(np.array(cur_recall))

    TPVFs, dices, PPVs, FPVFs = segmentation_metrics(
        np.concatenate(cur_predict, 0), np.concatenate(cur_target, 0))
    logging.info(
        '***************************************************************************'
    )
    logging.info(
        'Esophagus --> Global dice is [%.5f], TPR is [%.5f], Precision is [%.5f] '
        % (dices[0], TPVFs[0], PPVs[0]))
    logging.info(
        'heart    --> Global dice is [%.5f], TPR is [%.5f], Precision is [%.5f] '
        % (dices[1], TPVFs[1], PPVs[1]))
    logging.info(
        'trachea  --> Global dice is [%.5f], TPR is [%.5f], Precision is [%.5f] '
        % (dices[2], TPVFs[2], PPVs[2]))
    logging.info(
        'aorta    --> Global dice is [%.5f], TPR is [%.5f], Precision is [%.5f] '
        % (dices[3], TPVFs[3], PPVs[3]))
    total_precision = np.stack(total_precision, 1)
    total_recall = np.stack(total_recall, 1)
    logging.info(
        'Epoch[%d], [precision=%.4f, -->%.3f, -->%.3f, -->%.3f, -->%.3f], using %.1f s!'
        % (epoch, np.mean(total_precision), np.mean(total_precision[0]),
           np.mean(total_precision[1]), np.mean(total_precision[2]),
           np.mean(total_precision[3]), time.time() - start_time))
    logging.info(
        'Epoch[%d], [recall=%.4f, -->%.3f, -->%.3f, -->%.3f, -->%.3f], using %.1f s!'
        % (epoch, np.mean(total_recall), np.mean(total_recall[0]),
           np.mean(total_recall[1]), np.mean(total_recall[2]),
           np.mean(total_recall[3]), time.time() - start_time))
    logging.info(
        'Epoch[%d], [total loss=%.6f], mean_dice=%.4f, using %.1f s!' %
        (epoch, np.mean(eval_loss), np.mean(dices), time.time() - start_time))
    logging.info(
        '***************************************************************************'
    )
    return np.mean(dices), np.mean(total_precision)
예제 #3
0
def main(args):
    max_precision = 0.
    torch.manual_seed(123)
    cudnn.benchmark = True
    setgpu(args.gpu)
    data_path = args.data_path
    train_files, test_files = get_cross_validation_paths(args.test_flag)

    composed_transforms_tr = transforms.Compose([
        tr.Normalize(mean=(0.12, 0.12, 0.12), std=(0.018, 0.018, 0.018)),
        tr.ToTensor2(args.n_class)
    ])
    eval_dataset = THOR_Data(transform=composed_transforms_tr,
                             path=args.data_path,
                             file_list=test_files,
                             otsu=args.otsu)

    if args.if_dependent == 1:
        alpha = get_global_alpha(train_files, data_path)
        alpha = torch.from_numpy(alpha).float().to(DEVICE)
        alpha.requires_grad = False
    else:
        alpha = None
    model = import_module('models.model_loader')
    net, loss = model.get_full_model(args.model_name,
                                     args.loss_name,
                                     n_classes=args.n_class,
                                     alpha=alpha,
                                     if_closs=args.if_closs,
                                     class_weights=torch.FloatTensor(
                                         [1.0, 4.0, 2.0, 5.0, 3.0]).to(DEVICE))
    start_epoch = args.start_epoch
    save_dir = args.save_dir
    logging.info(args)
    if args.resume:
        checkpoint = torch.load(args.resume)
        start_epoch = checkpoint['epoch'] + 1
        net.load_state_dict(checkpoint['state_dict'])

    net = net.to(DEVICE)
    loss = loss.to(DEVICE)
    if len(args.gpu.split(',')) > 1 or args.gpu == 'all':
        net = DataParallel(net)

    optimizer = torch.optim.SGD(net.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    init_lr = np.copy(args.lr)

    def get_lr(epoch):
        if args.lr < 0.0001:
            return args.lr
        if epoch > 0:
            args.lr = args.lr * 0.95
            logging.info('current learning rate is %f' % args.lr)
        return args.lr

    composed_transforms_tr = transforms.Compose([
        tr.RandomZoom((512, 512)),
        tr.RandomHorizontalFlip(),
        tr.Normalize(mean=(0.12, 0.12, 0.12), std=(0.018, 0.018, 0.018)),
        tr.ToTensor2(args.n_class)
    ])
    train_dataset = THOR_Data(transform=composed_transforms_tr,
                              path=data_path,
                              file_list=train_files,
                              otsu=args.otsu)
    trainloader = DataLoader(train_dataset,
                             batch_size=args.batch_size,
                             shuffle=True,
                             num_workers=4)
    break_flag = 0.
    high_dice = 0.
    selected_thresholds = np.zeros((args.n_class - 1, ))
    run_id = str(uuid.uuid4())
    cur_train_stats_path = train_stats_path.format(run_id)
    cur_eval_stats_path = eval_stats_path.format(run_id)
    with open(cur_train_stats_path, 'w') as f:
        writer = csv.writer(f)
        writer.writerow(stats_fields)

    with open(cur_eval_stats_path, 'w') as f:
        writer = csv.writer(f)
        writer.writerow(stats_fields)

    for epoch in range(start_epoch, args.epochs + 1):
        train_loss, adaptive_thresholds = train(trainloader, net, loss, epoch,
                                                optimizer, get_lr, save_dir,
                                                cur_train_stats_path)
        if epoch < args.untest_epoch:
            continue
        break_flag += 1
        eval_dice, eval_precision = evaluation(args, net, loss, epoch,
                                               eval_dataset,
                                               selected_thresholds,
                                               cur_eval_stats_path)
        if max_precision <= eval_precision:
            selected_thresholds = adaptive_thresholds
            max_precision = eval_precision
            logging.info(
                '************************ dynamic threshold saved successful ************************** !'
            )
        if eval_dice >= high_dice:
            high_dice = eval_dice
            break_flag = 0
            if len(args.gpu.split(',')) > 1 or args.gpu == 'all':
                state_dict = net.module.state_dict()
            else:
                state_dict = net.state_dict()
            torch.save(
                {
                    'epoch': epoch,
                    'save_dir': save_dir,
                    'state_dict': state_dict,
                    'optimizer': optimizer.state_dict(),
                    'args': args
                }, os.path.join(save_dir, '%d.ckpt' % epoch))
            logging.info(
                '************************ model saved successful ************************** !'
            )
        if break_flag > args.patient:
            break
예제 #4
0
def main(args):
    torch.manual_seed(123)
    cudnn.benchmark = True
    setgpu(args.gpu)
    data_path = args.data_path
    train_files, test_files = get_cross_validation_paths(args.test_flag) 

    model = import_module('models.model_loader')

    net_dict, loss = model.get_full_model(
        args.model_name, 
        args.loss_name, 
        n_classes=args.n_class)

    save_dir = args.save_dir
    logging.info(args)

    for net_name, net in net_dict.items():
        net = net.to(DEVICE)
    #summary(net, (3, 512, 512))
    loss = loss.to(DEVICE)

    optimizer_dict = {}
    for net_name, net in net_dict.items():
        optimizer_dict[net_name] = torch.optim.SGD(
                                        net.parameters(),
                                        args.lr,
                                        momentum=args.momentum,
                                        weight_decay=args.weight_decay)
        '''
        for name, param in net.named_parameters():
            if param.requires_grad:
                print(name)
                print(param.shape)
        '''
    init_lr = np.copy(args.lr)
    def get_lr(epoch):
        if args.lr < 0.0001:
            return args.lr
        if epoch > 0:
            args.lr = args.lr * 0.95
            logging.info('current learning rate is %f' % args.lr)
        return args.lr
    
    composed_transforms_tr = transforms.Compose([
    tr.RandomZoom((512, 512)),
    tr.RandomHorizontalFlip(),
    tr.Normalize(mean=(0.12, 0.12, 0.12), std=(0.018, 0.018, 0.018)),
    tr.ToTensor2(args.n_class)])
    train_dataset = THOR_Data(
        transform=composed_transforms_tr, path=data_path, file_list=train_files)
    trainloader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=4)


    # Pretrain ResUNet101
    for epoch_normal in range(args.normal_epochs):
        train_normal(trainloader, net_dict, loss, epoch_normal, optimizer_dict, get_lr, save_dir)
        print('Pretrain ResUNet101 epoch %d done.' % epoch_normal)
        
        # Save state
        net_state_dict = []
        for net_name, net in net_dict.items():
        	net_state_dict.append(net.state_dict())
        optimizer_state_dict = []
        for net_name, optimizer in optimizer_dict.items():
        	optimizer_state_dict.append(optimizer.state_dict())

        torch.save({
            'epoch': epoch_normal,
            'save_dir': save_dir,
            'state_dict': net_state_dict,
            'optimizer': optimizer_state_dict,
            'args': args
        }, os.path.join(save_dir, '%d.ckpt' % epoch_normal))
 
    # Train ResUNet101_lmser based on ResUNet101
    for epoch_lmser in range(args.lmser_epochs):
        train_lmser(trainloader, net_dict, loss, epoch_lmser, optimizer_dict, get_lr, save_dir)
        print('Train ResUNet101_lmser epoch %d done.' % epoch_lmser)