def diceLoss(y_true, y_pred):
    # Loss function for dice coefficient. Negative so we can minimize
    return -dice(y_true, y_pred)
Exemple #2
0
 def metric(self, logit, truth):
     """Define metrics for evaluation especially for early stoppping."""
     #return iou_pytorch(logit, truth)
     return dice(logit, truth)
Exemple #3
0
def validation(trained_net,
               val_set,
               criterion,
               device,
               batch_size,
               ignore_idx,
               name=None,
               epoch=None):
    n_val = len(val_set)
    val_loader = val_set.load()

    tot = 0
    acc = 0
    dice_score_bg = 0
    dice_score_wm = 0
    dice_score_gm = 0
    dice_score_csf = 0
    dice_score_tm = 0

    val_info = {'bg': [], 'wm': [], 'gm': [], 'csf': [], 'tm': []}

    with tqdm(total=n_val, desc='Validation round', unit='patch',
              leave=False) as pbar:
        with torch.no_grad():
            for i, sample in enumerate(val_loader):
                images, segs = sample['image'].to(
                    device=device), sample['seg'].to(device=device)

                outputs = trained_net(images)
                val_loss = criterion(outputs, segs)

                if i == 0:
                    in_images = images.detach().cpu().numpy()[0]
                    in_segs = segs.detach().cpu().numpy()[0]
                    in_pred = outputs.detach().cpu().numpy()[0]
                    heatmap_plot(image=in_images,
                                 mask=in_segs,
                                 pred=in_pred,
                                 name=name,
                                 epoch=epoch,
                                 is_train=False)

                outputs = outputs.view(-1, outputs.shape[-4],
                                       outputs.shape[-3], outputs.shape[-2],
                                       outputs.shape[-1])
                segs = segs.view(-1, segs.shape[-3], segs.shape[-2],
                                 segs.shape[-1])

                _, preds = torch.max(outputs.data, 1)
                dice_score = dice(preds.data.cpu(),
                                  segs.data.cpu(),
                                  ignore_idx=ignore_idx)

                dice_score_bg += dice_score['bg']
                dice_score_wm += dice_score['wm']
                dice_score_gm += dice_score['gm']
                dice_score_csf += dice_score['csf']
                dice_score_tm += dice_score['tm']

                tot += val_loss.detach().item()
                acc += dice_score['avg']

                pbar.set_postfix(
                    **{
                        'validation loss (images)': val_loss.detach().item(),
                        'val_acc_avg': dice_score['avg']
                    })
                pbar.update(images.shape[0])

        val_info['bg'] = dice_score_bg / (np.ceil(n_val / batch_size))
        val_info['wm'] = dice_score_wm / (np.ceil(n_val / batch_size))
        val_info['gm'] = dice_score_gm / (np.ceil(n_val / batch_size))
        val_info['csf'] = dice_score_csf / (np.ceil(n_val / batch_size))
        val_info['tm'] = dice_score_tm / (np.ceil(n_val / batch_size))

    return tot / (np.ceil(n_val / batch_size)), acc / (np.ceil(
        n_val / batch_size)), val_info
Exemple #4
0
def dice_loss(y_pred, y, c_weights=None):
    return 1 - dice(y_pred, y, c_weights)
Exemple #5
0
def train(args):

    torch.cuda.manual_seed(1)
    torch.manual_seed(1)

    # user defined parameters
    model_name = args.model_name
    model_type = args.model_type
    lstm_backbone = args.lstmbase
    unet_backbone = args.unetbase
    layer_num = args.layer_num
    nb_shortcut = args.nb_shortcut
    loss_fn = args.loss_fn
    world_size = args.world_size
    rank = args.rank
    base_channel = args.base_channels
    crop_size = args.crop_size
    ignore_idx = args.ignore_idx
    return_sequence = args.return_sequence
    variant = args.LSTM_variant
    epochs = args.epoch
    is_pretrain = args.is_pretrain

    # system setup parameters
    config_file = 'config.yaml'
    config = load_config(config_file)
    labels = config['PARAMETERS']['labels']
    root_path = config['PATH']['model_root']
    model_dir = config['PATH']['save_ckp']
    best_dir = config['PATH']['save_best_model']

    input_modalites = int(config['PARAMETERS']['input_modalites'])
    output_channels = int(config['PARAMETERS']['output_channels'])
    batch_size = int(config['PARAMETERS']['batch_size'])
    is_best = bool(config['PARAMETERS']['is_best'])
    is_resume = bool(config['PARAMETERS']['resume'])
    patience = int(config['PARAMETERS']['patience'])
    time_step = int(config['PARAMETERS']['time_step'])
    num_workers = int(config['PARAMETERS']['num_workers'])
    early_stop_patience = int(config['PARAMETERS']['early_stop_patience'])
    lr = int(config['PARAMETERS']['lr'])
    optimizer = config['PARAMETERS']['optimizer']
    connect = config['PARAMETERS']['connect']
    conv_type = config['PARAMETERS']['lstm_convtype']

    # build up dirs
    model_path = os.path.join(root_path, model_dir)
    best_path = os.path.join(root_path, best_dir)
    intermidiate_data_save = os.path.join(root_path, 'train_newdata',
                                          model_name)
    train_info_file = os.path.join(intermidiate_data_save,
                                   '{}_train_info.json'.format(model_name))
    log_path = os.path.join(root_path, 'logfiles')

    if not os.path.exists(model_path):
        os.mkdir(model_path)
    if not os.path.exists(best_path):
        os.mkdir(best_path)
    if not os.path.exists(intermidiate_data_save):
        os.makedirs(intermidiate_data_save)
    if not os.path.exists(log_path):
        os.mkdir(log_path)

    log_name = model_name + '_' + config['PATH']['log_file']
    logger = logfile(os.path.join(log_path, log_name))
    logger.info('labels {} are ignored'.format(ignore_idx))
    logger.info('Dataset is loading ...')
    writer = SummaryWriter('ProcessVisu/%s' % model_name)

    # load training set and validation set
    data_class = data_split()
    train, val, test = data_construction(data_class)
    train_dict = time_parser(train, time_patch=time_step)
    val_dict = time_parser(val, time_patch=time_step)

    # LSTM initilization

    if model_type == 'LSTM':
        net = LSTMSegNet(lstm_backbone=lstm_backbone,
                         input_dim=input_modalites,
                         output_dim=output_channels,
                         hidden_dim=base_channel,
                         kernel_size=3,
                         num_layers=layer_num,
                         conv_type=conv_type,
                         return_sequence=return_sequence)
    elif model_type == 'UNet_LSTM':
        if variant == 'back':
            net = BackLSTM(input_dim=input_modalites,
                           hidden_dim=base_channel,
                           output_dim=output_channels,
                           kernel_size=3,
                           num_layers=layer_num,
                           conv_type=conv_type,
                           lstm_backbone=lstm_backbone,
                           unet_module=unet_backbone,
                           base_channel=base_channel,
                           return_sequence=return_sequence,
                           is_pretrain=is_pretrain)
            logger.info(
                'the pretrained status of backbone is {}'.format(is_pretrain))
        elif variant == 'center':
            net = CenterLSTM(input_modalites=input_modalites,
                             output_channels=output_channels,
                             base_channel=base_channel,
                             num_layers=layer_num,
                             conv_type=conv_type,
                             return_sequence=return_sequence,
                             is_pretrain=is_pretrain)
        elif variant == 'bicenter':
            net = BiCenterLSTM(input_modalites=input_modalites,
                               output_channels=output_channels,
                               base_channel=base_channel,
                               num_layers=layer_num,
                               connect=connect,
                               conv_type=conv_type,
                               return_sequence=return_sequence,
                               is_pretrain=is_pretrain)
        elif variant == 'directcenter':
            net = DirectCenterLSTM(input_modalites=input_modalites,
                                   output_channels=output_channels,
                                   base_channel=base_channel,
                                   num_layers=layer_num,
                                   conv_type=conv_type,
                                   return_sequence=return_sequence,
                                   is_pretrain=is_pretrain)
        elif variant == 'bidirectcenter':
            net = BiDirectCenterLSTM(input_modalites=input_modalites,
                                     output_channels=output_channels,
                                     base_channel=base_channel,
                                     num_layers=layer_num,
                                     connect=connect,
                                     conv_type=conv_type,
                                     return_sequence=return_sequence,
                                     is_pretrain=is_pretrain)
        elif variant == 'rescenter':
            net = ResCenterLSTM(input_modalites=input_modalites,
                                output_channels=output_channels,
                                base_channel=base_channel,
                                num_layers=layer_num,
                                conv_type=conv_type,
                                return_sequence=return_sequence,
                                is_pretrain=is_pretrain)
        elif variant == 'birescenter':
            net = BiResCenterLSTM(input_modalites=input_modalites,
                                  output_channels=output_channels,
                                  base_channel=base_channel,
                                  num_layers=layer_num,
                                  connect=connect,
                                  conv_type=conv_type,
                                  return_sequence=return_sequence,
                                  is_pretrain=is_pretrain)
        elif variant == 'shortcut':
            net = ShortcutLSTM(input_modalites=input_modalites,
                               output_channels=output_channels,
                               base_channel=base_channel,
                               num_layers=layer_num,
                               num_connects=nb_shortcut,
                               conv_type=conv_type,
                               return_sequence=return_sequence,
                               is_pretrain=is_pretrain)
    else:
        raise NotImplementedError()

    # loss and optimizer setup
    if loss_fn == 'Dice':
        criterion = DiceLoss(labels=labels, ignore_idx=ignore_idx)
    elif loss_fn == 'GDice':
        criterion = GneralizedDiceLoss(labels=labels)
    elif loss_fn == 'WCE':
        criterion = WeightedCrossEntropyLoss(labels=labels)
    else:
        raise NotImplementedError()

    if optimizer == 'adam':
        optimizer = optim.Adam(net.parameters(), lr=0.001)
        # optimizer = optim.Adam(net.parameters())
    elif optimizer == 'sgd':
        optimizer = optim.SGD(net.parameters(), momentum=0.9, lr=lr)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     verbose=True,
                                                     patience=patience)

    # device setup
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    # net, optimizer = amp.initialize(net, optimizer, opt_level="O1")

    if torch.cuda.device_count() > 1:
        torch.distributed.init_process_group(
            backend='nccl',
            init_method='tcp://127.0.0.1:38366',
            rank=rank,
            world_size=world_size)
    if distributed_is_initialized():
        print('distributed is initialized')
        net.to(device)
        net = nn.parallel.DistributedDataParallel(net,
                                                  find_unused_parameters=True)
    else:
        print('data parallel')
        net = nn.DataParallel(net)
        net.to(device)

    min_loss = float('Inf')
    early_stop_count = 0
    global_step = 0
    start_epoch = 0
    start_loss = 0
    train_info = {
        'train_loss': [],
        'val_loss': [],
        'label_0_acc': [],
        'label_1_acc': [],
        'label_2_acc': [],
        'label_3_acc': [],
        'label_4_acc': []
    }

    if is_resume:
        try:
            # open previous check points
            ckp_path = os.path.join(model_path,
                                    '{}_model_ckp.pth.tar'.format(model_name))
            net, optimizer, scheduler, start_epoch, min_loss, start_loss = load_ckp(
                ckp_path, net, optimizer, scheduler)

            # open previous training records
            with open(train_info_file) as f:
                train_info = json.load(f)

            logger.info(
                'Training loss from last time is {}'.format(start_loss) +
                '\n' +
                'Mininum training loss from last time is {}'.format(min_loss))
            logger.info(
                'Training accuracies from last time are: label 0: {}, label 1: {}, label 2: {}, label 3: {}, label 4: {}'
                .format(train_info['label_0_acc'][-1],
                        train_info['label_1_acc'][-1],
                        train_info['label_2_acc'][-1],
                        train_info['label_3_acc'][-1],
                        train_info['label_4_acc'][-1]))

        except:
            logger.warning(
                'No checkpoint available, strat training from scratch')

    for epoch in range(start_epoch, epochs):

        train_set = data_loader(train_dict,
                                batch_size=batch_size,
                                key='train',
                                num_works=num_workers,
                                time_step=time_step,
                                patch=crop_size,
                                model_type='RNN')
        n_train = len(train_set)

        val_set = data_loader(val_dict,
                              batch_size=batch_size,
                              key='val',
                              num_works=num_workers,
                              time_step=time_step,
                              patch=crop_size,
                              model_type='CNN')
        n_val = len(val_set)

        logger.info('Dataset loading finished!')

        nb_batches = np.ceil(n_train / batch_size)
        n_total = n_train + n_val
        logger.info(
            '{} images will be used in total, {} for trainning and {} for validation'
            .format(n_total, n_train, n_val))

        train_loader = train_set.load()

        # setup to train mode
        net.train()
        running_loss = 0
        dice_score_label_0 = 0
        dice_score_label_1 = 0
        dice_score_label_2 = 0
        dice_score_label_3 = 0
        dice_score_label_4 = 0

        logger.info('Training epoch {} will begin'.format(epoch + 1))

        with tqdm(total=n_train,
                  desc=f'Epoch {epoch+1}/{epochs}',
                  unit='patch') as pbar:

            for i, data in enumerate(train_loader, 0):

                # i : patient
                images, segs = data['image'].to(device), data['seg'].to(device)

                outputs = net(images)
                loss = criterion(outputs, segs)
                loss.backward()
                optimizer.step()

                # if i == 0:
                #     in_images = images.detach().cpu().numpy()[0]
                #     in_segs = segs.detach().cpu().numpy()[0]
                #     in_pred = outputs.detach().cpu().numpy()[0]
                #     heatmap_plot(image=in_images, mask=in_segs, pred=in_pred, name=model_name, epoch=epoch+1, is_train=True)

                running_loss += loss.detach().item()

                outputs = outputs.view(-1, outputs.shape[-4],
                                       outputs.shape[-3], outputs.shape[-2],
                                       outputs.shape[-1])
                segs = segs.view(-1, segs.shape[-3], segs.shape[-2],
                                 segs.shape[-1])
                _, preds = torch.max(outputs.data, 1)
                dice_score = dice(preds.data.cpu(),
                                  segs.data.cpu(),
                                  ignore_idx=None)

                dice_score_label_0 += dice_score['bg']
                dice_score_label_1 += dice_score['csf']
                dice_score_label_2 += dice_score['gm']
                dice_score_label_3 += dice_score['wm']
                dice_score_label_4 += dice_score['tm']

                # show progress bar
                pbar.set_postfix(
                    **{
                        'training loss': loss.detach().item(),
                        'Training accuracy': dice_score['avg']
                    })
                pbar.update(images.shape[0])

                global_step += 1
                if global_step % nb_batches == 0:
                    net.eval()
                    val_loss, val_acc, val_info = validation(net,
                                                             val_set,
                                                             criterion,
                                                             device,
                                                             batch_size,
                                                             ignore_idx=None,
                                                             name=model_name,
                                                             epoch=epoch + 1)
                    net.train()

        train_info['train_loss'].append(running_loss / nb_batches)
        train_info['val_loss'].append(val_loss)
        train_info['label_0_acc'].append(dice_score_label_0 / nb_batches)
        train_info['label_1_acc'].append(dice_score_label_1 / nb_batches)
        train_info['label_2_acc'].append(dice_score_label_2 / nb_batches)
        train_info['label_3_acc'].append(dice_score_label_3 / nb_batches)
        train_info['label_4_acc'].append(dice_score_label_4 / nb_batches)

        # save bast trained model
        scheduler.step(running_loss / nb_batches)
        logger.info('Epoch: {}, LR: {}'.format(
            epoch + 1, optimizer.param_groups[0]['lr']))

        if min_loss > running_loss / nb_batches:
            min_loss = running_loss / nb_batches
            is_best = True
            early_stop_count = 0
        else:
            is_best = False
            early_stop_count += 1

        state = {
            'epoch': epoch + 1,
            'model_state_dict': net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'loss': running_loss / nb_batches,
            'min_loss': min_loss
        }
        verbose = save_ckp(state,
                           is_best,
                           early_stop_count=early_stop_count,
                           early_stop_patience=early_stop_patience,
                           save_model_dir=model_path,
                           best_dir=best_path,
                           name=model_name)

        # summarize the training results of this epoch
        logger.info('The average training loss for this epoch is {}'.format(
            running_loss / nb_batches))
        logger.info('The best training loss till now is {}'.format(min_loss))
        logger.info(
            'Validation dice loss: {}; Validation (avg) accuracy of the last timestep: {}'
            .format(val_loss, val_acc))

        # save the training info every epoch
        logger.info('Writing the training info into file ...')
        val_info_file = os.path.join(intermidiate_data_save,
                                     '{}_val_info.json'.format(model_name))
        with open(train_info_file, 'w') as fp:
            json.dump(train_info, fp)
        with open(val_info_file, 'w') as fp:
            json.dump(val_info, fp)

        for name, layer in net.named_parameters():
            if layer.requires_grad:
                writer.add_histogram(name + '_grad',
                                     layer.grad.cpu().data.numpy(), epoch)
                writer.add_histogram(name + '_data',
                                     layer.cpu().data.numpy(), epoch)
        if verbose:
            logger.info(
                'The validation loss has not improved for {} epochs, training will stop here.'
                .format(early_stop_patience))
            break

    loss_plot(train_info_file, name=model_name)
    logger.info('finish training!')

    return
Exemple #6
0
def evulate_dir_nii_weakly_new():
    '''
    新版本的评估代码
    :return:
    '''
    from metrics import dice, IoU
    from datasets.medicalImage import convertCase2PNGs, image_expand
    nii_dir = '/home/give/Documents/dataset/ISBI2017/Training_Batch_1'
    save_dir = '/home/give/Documents/dataset/ISBI2017/weakly_label_segmentation_V4/Batch_1/DLSC_0/niis'
    # restore_path = '/home/give/PycharmProjects/weakly_label_segmentation/logs/ISBI2017_V2/1s_agumentation_weakly-upsampling-2/model.ckpt-168090'

    restore_path = '/home/give/PycharmProjects/weakly_label_segmentation/logs/ISBI2017_V2/1s_agumentation_weakly_V3-upsampling-1/model.ckpt-167824'

    nii_parent_dir = os.path.dirname(nii_dir)
    with tf.name_scope('test'):
        image = tf.placeholder(dtype=tf.float32, shape=[None, None, 3])
        image_shape_placeholder = tf.placeholder(tf.int32, shape=[2])
        input_shape_placeholder = tf.placeholder(tf.int32, shape=[2])
        processed_image = segmentation_preprocessing.segmentation_preprocessing(
            image,
            None,
            None,
            out_shape=input_shape_placeholder,
            is_training=False)
        b_image = tf.expand_dims(processed_image, axis=0)

        net = UNetBlocksMS.UNet(b_image,
                                None,
                                None,
                                is_training=False,
                                decoder=FLAGS.decoder,
                                update_center_flag=FLAGS.update_center,
                                batch_size=2,
                                init_center_value=None,
                                update_center_strategy=1,
                                num_centers_k=FLAGS.num_centers_k,
                                full_annotation_flag=False,
                                output_shape_tensor=input_shape_placeholder)

        # print slim.get_variables_to_restore()
        global_step = slim.get_or_create_global_step()

    sess_config = tf.ConfigProto(log_device_placement=False,
                                 allow_soft_placement=True)
    if FLAGS.gpu_memory_fraction < 0:
        sess_config.gpu_options.allow_growth = True
    elif FLAGS.gpu_memory_fraction > 0:
        sess_config.gpu_options.per_process_gpu_memory_fraction = FLAGS.gpu_memory_fraction

    # Variables to restore: moving avg. or normal weights.
    if FLAGS.using_moving_average:
        variable_averages = tf.train.ExponentialMovingAverage(
            FLAGS.moving_average_decay)
        variables_to_restore = variable_averages.variables_to_restore()
        variables_to_restore[global_step.op.name] = global_step
    else:
        variables_to_restore = slim.get_variables_to_restore()

    saver = tf.train.Saver()

    nii_pathes = glob(os.path.join(nii_dir, 'volume-*.nii'))

    checkpoint = restore_path
    # pixel_recovery_features = tf.image.resize_images(net.pixel_recovery_features, image_shape_placeholder)
    with tf.Session(config=sess_config) as sess:
        saver.restore(sess, checkpoint)

        global_gt = []
        global_pred = []
        global_pred_grabcut = []

        case_dices = []
        case_IoUs = []
        case_dices_grabcut = []
        case_IoUs_grabcut = []

        for iter, nii_path in enumerate(nii_pathes):
            # if os.path.basename(nii_path) in ['volume-15.nii', 'volume-25.nii']:
            #     continue
            if os.path.basename(nii_path) != 'volume-0.nii':
                continue
            nii_path_basename = os.path.basename(nii_path)

            pred_dir = os.path.join(
                save_dir,
                nii_path_basename.split('.')[0].split('-')[1], 'pred')
            pred_vis_dir = os.path.join(
                save_dir,
                nii_path_basename.split('.')[0].split('-')[1], 'pred_vis')
            recovery_img_dir = os.path.join(
                save_dir,
                nii_path_basename.split('.')[0].split('-')[1], 'recovery_img')
            if not os.path.exists(pred_dir):
                os.makedirs(pred_dir)
            if not os.path.exists(pred_vis_dir):
                os.makedirs(pred_vis_dir)
            if not os.path.exists(recovery_img_dir):
                os.makedirs(recovery_img_dir)

            seg_path = os.path.join(
                nii_dir, 'segmentation-' +
                nii_path.split('.')[0].split('-')[1] + '.nii')

            case_preds = []
            case_gts = []
            case_preds_grabcut = []

            # case_recover_features = []
            print(nii_path, seg_path)
            imgs, tumor_masks, liver_masks, tumor_weak_masks = convertCase2PNGs(
                nii_path, seg_path, save_dir=None)
            print(len(imgs), len(tumor_masks), len(liver_masks),
                  len(tumor_masks))

            for slice_idx, (image_data, liver_mask, whole_mask) in enumerate(
                    zip(imgs, liver_masks, tumor_masks)):
                pixel_cls_scores_ms = []
                pixel_recover_feature_ms = []
                for single_scale in scales:
                    pixel_recover_feature, pixel_cls_scores, b_image_v, global_step_v, net_centers = sess.run(
                        [
                            net.pixel_recovery_features, net.pixel_cls_scores,
                            b_image, global_step, net.centers
                        ],
                        feed_dict={
                            image: image_data,
                            image_shape_placeholder: np.shape(image_data)[:2],
                            input_shape_placeholder: single_scale
                        })
                    pixel_cls_scores_ms.append(
                        cv2.resize(pixel_cls_scores[0, :, :, 1],
                                   tuple(np.shape(image_data)[:2][::-1])))
                    pixel_recover_feature_ms.append(pixel_recover_feature[0])
                    del pixel_recover_feature
                pixel_cls_scores = np.mean(pixel_cls_scores_ms, axis=0)
                pixel_recover_feature = np.mean(pixel_recover_feature_ms,
                                                axis=0)
                # case_recover_features.append(pixel_recover_feature)

                if np.sum(whole_mask) != 0:
                    pred = np.asarray(pixel_cls_scores > 0.6, np.uint8)
                    # 开操作 先腐蚀,后膨胀
                    # 闭操作 先膨胀,后腐蚀
                    # pred = close_operation(pred, kernel_size=3)
                    pred = open_operation(pred, kernel_size=3)
                    pred = fill_region(pred)
                    from grabcut import grabcut
                    pred_grabcut = np.asarray(
                        image_expand(pred, kernel_size=5), np.uint8)
                    xs, ys = np.where(pred_grabcut == 1)
                    print(np.shape(pred_grabcut))
                    if len(xs) == 0:
                        pred_grabcut = np.zeros_like(whole_mask)
                        print(np.min(pred_grabcut), np.max(pred_grabcut),
                              np.sum(pred_grabcut))
                    else:
                        min_xs = np.min(xs)
                        max_xs = np.max(xs)
                        min_ys = np.min(ys)
                        max_ys = np.max(ys)
                        pred_grabcut = grabcut(
                            np.asarray(image_data * 255., np.uint8),
                            [min_xs, min_ys, max_xs, max_ys])
                        pred_grabcut = np.asarray(pred_grabcut == 255,
                                                  np.uint8)
                        print(np.unique(pred_grabcut))
                        cv2.imwrite('./tmp/%d_gt.png' % slice_idx,
                                    np.asarray(whole_mask * 200, np.uint8))
                        cv2.imwrite('./tmp/%d_pred.png' % slice_idx,
                                    np.asarray(pred * 200, np.uint8))
                        cv2.imwrite('./tmp/%d_pred_grabcut.png' % slice_idx,
                                    np.asarray(pred_grabcut * 200, np.uint8))

                else:
                    pred = np.zeros_like(whole_mask)
                    pred_grabcut = np.zeros_like(whole_mask)

                global_gt.append(whole_mask)
                case_gts.append(whole_mask)
                case_preds.append(pred)
                case_preds_grabcut.append(pred_grabcut)
                global_pred.append(pred)
                global_pred_grabcut.append(pred_grabcut)

                print '%d / %d: %s' % (slice_idx + 1, len(
                    imgs), os.path.basename(nii_path)), np.shape(
                        pixel_cls_scores), np.max(pixel_cls_scores), np.min(
                            pixel_cls_scores), np.shape(pixel_recover_feature)
                del pixel_recover_feature, pixel_recover_feature_ms
                gc.collect()

            case_dice = dice(case_gts, case_preds)
            case_IoU = IoU(case_gts, case_preds)
            case_dice_grabcut = dice(case_gts, case_preds_grabcut)
            case_IoU_grabcut = IoU(case_gts, case_preds_grabcut)
            print('case dice: ', case_dice)
            print('case IoU: ', case_IoU)
            print('case dice grabcut: ', case_dice_grabcut)
            print('case IoU grabcut: ', case_IoU_grabcut)
            case_dices.append(case_dice)
            case_IoUs.append(case_IoU)
            case_dices_grabcut.append(case_dice_grabcut)
            case_IoUs_grabcut.append(case_IoU_grabcut)
        print 'global dice is ', dice(global_gt, global_pred)
        print 'global IoU is ', IoU(global_gt, global_pred)

        print('mean of case dice is ', np.mean(case_dices))
        print('mean of case IoU is ', np.mean(case_IoUs))

        print('mean of case dice (grabcut) is ', np.mean(case_dices_grabcut))
        print('mean of case IoU (grabcut) is ', np.mean(case_IoUs_grabcut))

        print('global dice (grabcut) is ', dice(global_gt,
                                                global_pred_grabcut))
        print('global IoU (grabcut) is ', IoU(global_gt, global_pred_grabcut))
Exemple #7
0
def validation(trained_net,
               val_set,
               criterion,
               device,
               batch_size,
               model_type=None,
               softmax=True,
               ignore_idx=None):
    '''
    used for evaluation during training phase

    params trained_net: trained U-net
    params val_set: validation dataset 
    params criterion: loss function
    params device: cpu or gpu
    '''

    n_val = len(val_set)
    val_loader = val_set.load()

    tot = 0
    acc = 0
    dice_score_bg = 0
    dice_score_wm = 0
    dice_score_gm = 0
    dice_score_csf = 0
    dice_score_tm = 0

    val_info = {'bg': [], 'wm': [], 'gm': [], 'csf': [], 'tm': []}

    with tqdm(total=n_val, desc='Validation round', unit='patch',
              leave=False) as pbar:
        with torch.no_grad():
            for i, sample in enumerate(val_loader):
                images, segs = sample['image'].to(
                    device=device), sample['seg'].to(device=device)

                if model_type == 'SkipDenseSeg' and not softmax:
                    segs = segs.long()

                batch, time, channel, z, y, x = images.shape
                images = images.view(-1, channel, z, y, x)
                segs = segs.view(-1, z, y, x)
                preds = trained_net(images)
                val_loss = criterion(preds, segs)

                _, preds = torch.max(preds, 1)
                dice_score = dice(preds.data.cpu(), segs.data.cpu(),
                                  ignore_idx)

                dice_score_bg += dice_score['bg']
                dice_score_wm += dice_score['wm']
                dice_score_gm += dice_score['gm']
                dice_score_csf += dice_score['csf']
                dice_score_tm += dice_score['tm']

                tot += val_loss.detach().item()
                acc += dice_score['avg']

                pbar.set_postfix(
                    **{
                        'validation loss': val_loss.detach().item(),
                        'val_acc_avg': dice_score['avg']
                    })
                pbar.update(images.shape[0])

        val_info['bg'] = dice_score_bg / (np.ceil(n_val / batch_size))
        val_info['wm'] = dice_score_wm / (np.ceil(n_val / batch_size))
        val_info['gm'] = dice_score_gm / (np.ceil(n_val / batch_size))
        val_info['csf'] = dice_score_csf / (np.ceil(n_val / batch_size))
        val_info['tm'] = dice_score_tm / (np.ceil(n_val / batch_size))

    return tot / (np.ceil(n_val / batch_size)), acc / (np.ceil(
        n_val / batch_size)), val_info
Exemple #8
0
def train(args):

    torch.cuda.manual_seed(1)
    torch.manual_seed(1)

    # user defined
    model_name = args.model_name
    model_type = args.model_type
    loss_func = args.loss
    world_size = args.world_size
    rank = args.rank
    base_channel = args.base_channels
    crop_size = args.crop_size
    ignore_idx = args.ignore_idx
    epochs = args.epoch

    # system setup
    config_file = 'config.yaml'
    config = load_config(config_file)
    labels = config['PARAMETERS']['labels']
    root_path = config['PATH']['model_root']
    model_dir = config['PATH']['save_ckp']
    best_dir = config['PATH']['save_best_model']

    output_channels = int(config['PARAMETERS']['output_channels'])
    batch_size = int(config['PARAMETERS']['batch_size'])
    is_best = bool(config['PARAMETERS']['is_best'])
    is_resume = bool(config['PARAMETERS']['resume'])
    patience = int(config['PARAMETERS']['patience'])
    time_step = int(config['PARAMETERS']['time_step'])
    num_workers = int(config['PARAMETERS']['num_workers'])
    early_stop_patience = int(config['PARAMETERS']['early_stop_patience'])
    pad_method = config['PARAMETERS']['pad_method']
    lr = int(config['PARAMETERS']['lr'])
    optimizer = config['PARAMETERS']['optimizer']
    softmax = True
    modalities = ['flair', 't1', 't1gd', 't2']
    input_modalites = len(modalities)

    # build up dirs
    model_path = os.path.join(root_path, model_dir)
    best_path = os.path.join(root_path, best_dir)
    intermidiate_data_save = os.path.join(root_path, 'train_newdata',
                                          model_name)
    train_info_file = os.path.join(intermidiate_data_save,
                                   '{}_train_info.json'.format(model_name))
    log_path = os.path.join(root_path, 'logfiles')

    if not os.path.exists(model_path):
        os.mkdir(model_path)
    if not os.path.exists(best_path):
        os.mkdir(best_path)
    if not os.path.exists(intermidiate_data_save):
        os.makedirs(intermidiate_data_save)
    if not os.path.exists(log_path):
        os.mkdir(log_path)

    log_name = model_name + '_' + config['PATH']['log_file']
    logger = logfile(os.path.join(log_path, log_name))
    logger.info('Dataset is loading ...')
    writer = SummaryWriter('ProcessVisu/%s' % model_name)

    logger.info('patch size: {}'.format(crop_size))

    # load training set and validation set
    data_class = data_split()
    train, val, test = data_construction(data_class)
    train_dict = time_parser(train, time_patch=time_step)
    val_dict = time_parser(val, time_patch=time_step)

    # groups = 4
    if model_type == 'UNet':
        net = init_U_Net(input_modalites, output_channels, base_channel,
                         pad_method, softmax)
    elif model_type == 'ResUNet':
        net = ResUNet(input_modalites, output_channels, base_channel,
                      pad_method, softmax)
    elif model_type == 'DResUNet':
        net = DResUNet(input_modalites, output_channels, base_channel,
                       pad_method, softmax)
    elif model_type == 'direct_concat':
        net = U_Net_direct_concat(input_modalites, output_channels,
                                  base_channel, pad_method, softmax)
    elif model_type == 'Inception':
        net = Inception_UNet(input_modalites, output_channels, base_channel,
                             softmax)
    elif model_type == 'Simple_Inception':
        net = Simplified_Inception_UNet(input_modalites, output_channels,
                                        base_channel, softmax)

    # device setup
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    net.to(device)

    # print model structure
    summary(net, input_size=(input_modalites, crop_size, crop_size, crop_size))
    dummy_input = torch.rand(1, input_modalites, crop_size, crop_size,
                             crop_size).to(device)
    writer.add_graph(net, (dummy_input, ))

    # loss and optimizer setup
    if loss_func == 'Dice' and softmax:
        criterion = DiceLoss(labels=labels, ignore_idx=ignore_idx)
    elif loss_func == 'GDice' and softmax:
        criterion = GneralizedDiceLoss(labels=labels)
    elif loss_func == 'CrossEntropy':
        criterion = WeightedCrossEntropyLoss(labels=labels)
        if not softmax:
            criterion = nn.CrossEntropyLoss().cuda()
    else:
        raise NotImplementedError()

    if optimizer == 'adam':
        optimizer = optim.Adam(net.parameters())
    elif optimizer == 'sgd':
        optimizer = optim.SGD(net.parameters(),
                              momentum=0.9,
                              lr=lr,
                              weight_decay=1e-5)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     verbose=True,
                                                     patience=patience)

    # net, optimizer = amp.initialize(net, optimizer, opt_level='O1')

    if torch.cuda.device_count() > 1:
        logger.info('{} GPUs avaliable'.format(torch.cuda.device_count()))
        torch.distributed.init_process_group(
            backend='nccl',
            init_method='tcp://127.0.0.1:38366',
            rank=rank,
            world_size=world_size)
    if distributed_is_initialized():
        logger.info('distributed is initialized')
        net.to(device)
        net = nn.parallel.DistributedDataParallel(net)
    else:
        logger.info('data parallel')
        net = nn.DataParallel(net)
        net.to(device)

    min_loss = float('Inf')
    early_stop_count = 0
    global_step = 0
    start_epoch = 0
    start_loss = 0
    train_info = {
        'train_loss': [],
        'val_loss': [],
        'label_0_acc': [],
        'label_1_acc': [],
        'label_2_acc': [],
        'label_3_acc': [],
        'label_4_acc': []
    }

    if is_resume:
        try:
            # open previous check points
            ckp_path = os.path.join(model_path,
                                    '{}_model_ckp.pth.tar'.format(model_name))
            net, optimizer, scheduler, start_epoch, min_loss, start_loss = load_ckp(
                ckp_path, net, optimizer, scheduler)
            # open previous training records
            with open(train_info_file) as f:
                train_info = json.load(f)

            logger.info(
                'Training loss from last time is {}'.format(start_loss) +
                '\n' +
                'Mininum training loss from last time is {}'.format(min_loss))
            logger.info(
                'Training accuracies from last time are: label 0: {}, label 1: {}, label 2: {}, label 3: {}, label 4: {}'
                .format(train_info['label_0_acc'][-1],
                        train_info['label_1_acc'][-1],
                        train_info['label_2_acc'][-1],
                        train_info['label_3_acc'][-1],
                        train_info['label_4_acc'][-1]))
            # min_loss = float('Inf')

        except:
            logger.warning(
                'No checkpoint available, strat training from scratch')

    # start training
    for epoch in range(start_epoch, epochs):

        # every epoch generate a new set of images
        train_set = data_loader(train_dict,
                                batch_size=batch_size,
                                key='train',
                                num_works=num_workers,
                                time_step=time_step,
                                patch=crop_size,
                                modalities=modalities,
                                model_type='CNN')
        n_train = len(train_set)
        train_loader = train_set.load()

        val_set = data_loader(val_dict,
                              batch_size=batch_size,
                              key='val',
                              num_works=num_workers,
                              time_step=time_step,
                              patch=crop_size,
                              modalities=modalities,
                              model_type='CNN')
        n_val = len(val_set)

        nb_batches = np.ceil(n_train / batch_size)
        n_total = n_train + n_val
        logger.info(
            '{} images will be used in total, {} for trainning and {} for validation'
            .format(n_total, n_train, n_val))
        logger.info('Dataset loading finished!')

        # setup to train mode
        net.train()
        running_loss = 0
        dice_score_label_0 = 0
        dice_score_label_1 = 0
        dice_score_label_2 = 0
        dice_score_label_3 = 0
        dice_score_label_4 = 0

        logger.info('Training epoch {} will begin'.format(epoch + 1))

        with tqdm(total=n_train,
                  desc=f'Epoch {epoch+1}/{epochs}',
                  unit='patch') as pbar:

            for i, data in enumerate(train_loader, 0):
                images, segs = data['image'].to(device), data['seg'].to(device)

                if model_type == 'SkipDenseSeg' and not softmax:
                    segs = segs.long()

                # combine the batch and time step
                batch, time, channel, z, y, x = images.shape
                images = images.view(-1, channel, z, y, x)
                segs = segs.view(-1, z, y, x)

                # zero the parameter gradients
                optimizer.zero_grad()
                outputs = net(images)

                loss = criterion(outputs, segs)
                loss.backward()
                # with amp.scale_loss(loss, optimizer) as scaled_loss:
                #     scaled_loss.backward()
                optimizer.step()

                running_loss += loss.detach().item()
                _, preds = torch.max(outputs.data, 1)
                dice_score = dice(preds.data.cpu(),
                                  segs.data.cpu(),
                                  ignore_idx=ignore_idx)

                dice_score_label_0 += dice_score['bg']
                dice_score_label_1 += dice_score['csf']
                dice_score_label_2 += dice_score['gm']
                dice_score_label_3 += dice_score['wm']
                dice_score_label_4 += dice_score['tm']

                # show progress bar
                pbar.set_postfix(
                    **{
                        'Training loss': loss.detach().item(),
                        'Training accuracy': dice_score['avg']
                    })
                pbar.update(images.shape[0])

                del images, segs

                global_step += 1
                if global_step % nb_batches == 0:
                    net.eval()
                    val_loss, val_acc, val_info = validation(
                        net,
                        val_set,
                        criterion,
                        device,
                        batch_size,
                        model_type=model_type,
                        softmax=softmax,
                        ignore_idx=ignore_idx)

        train_info['train_loss'].append(running_loss / nb_batches)
        train_info['val_loss'].append(val_loss)
        train_info['label_0_acc'].append(dice_score_label_0 / nb_batches)
        train_info['label_1_acc'].append(dice_score_label_1 / nb_batches)
        train_info['label_2_acc'].append(dice_score_label_2 / nb_batches)
        train_info['label_3_acc'].append(dice_score_label_3 / nb_batches)
        train_info['label_4_acc'].append(dice_score_label_4 / nb_batches)

        # save bast trained model
        if model_type == 'SkipDenseSeg':
            scheduler.step()
        else:
            scheduler.step(val_loss)
        # debug
        for param_group in optimizer.param_groups:
            logger.info('%0.6f | %6d ' % (param_group['lr'], epoch))

        if min_loss > running_loss / nb_batches + 1e-2:
            min_loss = running_loss / nb_batches
            is_best = True
            early_stop_count = 0
        else:
            is_best = False
            early_stop_count += 1

        # save the check point
        state = {
            'epoch': epoch + 1,
            'model_state_dict': net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'loss': running_loss / nb_batches,
            'min_loss': min_loss
        }
        verbose = save_ckp(state,
                           is_best,
                           early_stop_count=early_stop_count,
                           early_stop_patience=early_stop_patience,
                           save_model_dir=model_path,
                           best_dir=best_path,
                           name=model_name)

        # summarize the training results of this epoch
        logger.info('Average training loss of this epoch is {}'.format(
            running_loss / nb_batches))
        logger.info('Best training loss till now is {}'.format(min_loss))
        logger.info('Validation dice loss: {}; Validation accuracy: {}'.format(
            val_loss, val_acc))

        # save the training info every epoch
        logger.info('Writing the training info into file ...')
        val_info_file = os.path.join(intermidiate_data_save,
                                     '{}_val_info.json'.format(model_name))
        with open(train_info_file, 'w') as fp:
            json.dump(train_info, fp)
        with open(val_info_file, 'w') as fp:
            json.dump(val_info, fp)

        loss_plot(train_info_file, name=model_name)
        for name, layer in net.named_parameters():
            writer.add_histogram(name + '_grad',
                                 layer.grad.cpu().data.numpy(), epoch)
            writer.add_histogram(name + '_data',
                                 layer.cpu().data.numpy(), epoch)

        if verbose:
            logger.info(
                'The validation loss has not improved for {} epochs, training will stop here.'
                .format(early_stop_patience))
            break

    writer.close()
    logger.info('finish training!')
                                height: height + crop_size[1],
                                width: width + crop_size[2]] += 1
                    

    whole_pred = whole_pred / count_used
    whole_pred = whole_pred[0, :, :, :, :]
    whole_pred = np.argmax(whole_pred, axis=0)
    print (whole_pred.shape, label.shape)
    label=label.transpose(0,2,1)
    print(whole_pred.shape, label.shape)

    
    dsc = []
    print ('-------------------------')
    for i in range(1, num_classes):
        dsc_i = dice(whole_pred, label, i)
        dsc_i=round(dsc_i*100,2)
        dsc.append(dsc_i)
    
    datetime= time.strftime("%d/%m/%Y")
    print('Data       | Note   | class1| class2|class3| Avg.|')
    print('%s | %s | %2.2f | %2.2f | %2.2f | %2.2f |' % ( \
                datetime,
                note,
                dsc[0],
                dsc[1],
                dsc[2],
                np.mean(dsc)))