Exemple #1
0
def eval(args, tasks_archive, model, eval_epoch, iterations):
    tasks = args.tasks  # list

    model.eval()
    for task_idx in range(len(tasks)):
        config.task_idx = task_idx  # needed for u2net3d().
        task = tasks[task_idx]
        config_task = config.config_tasks[task]
        st_time = time.time()

        # evaluating. # tensorboard visualization of eval embedded.
        dices = evaluate.evaluate(config_task,
                                  tasks_archive[task]['fold' +
                                                      str(args.fold)]['val'],
                                  model,
                                  epoch_num=eval_epoch,
                                  outdir=config.eval_out_dir)

        fo = open(os.path.join(config.eval_out_dir,
                               '{}_eval_res.csv'.format(args.trainMode)),
                  mode='a+')
        wo = csv.writer(fo, delimiter=',')
        for k, v in dices.items():
            config.writer.add_scalar('data/dices/{}_{}'.format(task, k), v,
                                     iterations)
            wo.writerow([
                args.trainMode, task, eval_epoch, config.step_per_epoch, k, v,
                tinies.datestr()
            ])
        fo.flush()
        logger.info('Eval time elapsed:{}'.format(
            tinies.timer(st_time, time.time())))
Exemple #2
0
        def prep(files, outDir, with_gt=True):
            print("ids[0]:{}, current time:{}".format(
                os.path.basename(files[0]), str(tinies.datestr())))
            for img_path in files:
                # tinies.ForkedPdb().set_trace()
                ID = os.path.basename(img_path).split('.')[0]
                if with_gt:
                    lab_path = os.path.join(config.base_dir, task, 'labelsTr',
                                            ID)
                else:
                    lab_path = None
                volume_list, label, weight, original_shape, [
                    bbmin, bbmax
                ] = utils.preprocess(img_path,
                                     lab_path,
                                     config_task,
                                     with_gt=with_gt)
                volumes = np.asarray(volume_list)
                np.save(os.path.join(outDir, ID + '_volumes.npy'), volumes)
                if with_gt:
                    np.save(os.path.join(outDir, ID + '_label.npy'), label)
                np.save(os.path.join(outDir, ID + '_weight.npy'), weight)

                json_info = dict()
                json_info['original_shape'] = str(
                    original_shape)  # use eval() to unstr
                json_info['bbox'] = str([bbmin, bbmax])  # use eval() to unstr
                with open(os.path.join(outDir, ID + '.json'), 'w') as f:
                    json.dump(json_info, f, indent=4)
Exemple #3
0
 def fuse(files, outDir, with_gt=True):
     print("ids[0]:{}, current time:{}".format(
         os.path.basename(files[0]), str(tinies.datestr())))
     for lab_path in files:
         print('loading:{}'.format(lab_path))
         # tinies.ForkedPdb().set_trace()
         label = np.load(lab_path)
         label[label == 2] = 1  # cancer fused to organ
         np.save(os.path.join(lab_path), label)
Exemple #4
0
    def gen_batch(self, batch_size, patch_size):
        batchImg = np.zeros([
            batch_size, self.config_task.num_modality, patch_size[0],
            patch_size[1], patch_size[2]
        ])  # n,mod,d,h,w
        batchLabel = np.zeros(
            [batch_size, patch_size[0], patch_size[1],
             patch_size[2]])  # n,d,h,w
        batchWeight = np.zeros(
            [batch_size, patch_size[0], patch_size[1],
             patch_size[2]])  # n,d,h,w
        batchAugs = list()

        # import ipdb; ipdb.set_trace()
        for i in range(batch_size):
            temp_prob = np.random.uniform()
            st_time = time.time()

            handler = 0
            while handler == 0:

                t_wait = 0
                if self.trainQueue.qsize() == 0:
                    logger.info(
                        '{} self.trainQueue size = {}, filling....(start time:{})'
                        .format(self.task, self.trainQueue.qsize(),
                                tinies.datestr()))
                while self.trainQueue.qsize() == 0:
                    time.sleep(1)
                    t_wait += 1
                if t_wait > 0:
                    logger.info('{} time to fill self.trainQueue: {}'.format(
                        self.task, t_wait))

                patches = self.trainQueue.get()
                # logger.info('{} trainQueue size:{}'.format(self.task, str(self.trainQueue.qsize())))
                if i <= math.ceil(
                        batch_size / 3
                ):  # nn_unet3d: at least 1/3 samples in a batch contain at least one forground class
                    if temp_prob < self.config_task.small_prob and patches[
                            'small'] is not None:
                        patch = patches['small']
                        handler = 1
                    elif patches['fore'] is not None:
                        patch = patches['fore']
                        handler = 1
                    else:
                        handler = 0
                        logger.warn('handler={}'.format(handler))
                # else for i > math.ceil(batch_size/3)
                else:
                    if temp_prob < self.config_task.small_prob and patches[
                            'small'] is not None:
                        patch = patches['small']
                        handler = 1
                    elif 1 - temp_prob < self.config_task.fore_prob and patches[
                            'fore'] is not None:
                        patch = patches['fore']
                        handler = 1
                    else:
                        patch = patches['any']
                        handler = 1
                if handler == 0:
                    logger.info('handler is 0, going back')
            if handler == 0:
                logger.error('handler is 0')

            # fill in a batch
            batchImg[i, ...] = patch['image']
            batchLabel[i, ...] = patch['label']
            batchWeight[i, ...] = patch['weight']
            batchAugs.append(patch['augs'])

        return (batchImg, batchLabel, batchWeight, batchAugs)
Exemple #5
0
def train(args, tasks_archive, model):
    torch.backends.cudnn.benchmark = True

    if args.resume_ckp != '':
        logger.info('==> loading checkpoint: {}'.format(args.ckp))
        checkpoint = torch.load(args.resume_ckp)

    model = nn.parallel.DataParallel(model)

    logger.info('  + model num_params: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    if config.use_gpu:
        model.cuda()  # required bofore optimizer?
    #     cudnn.benchmark = True

    print(model)  # especially useful for debugging model structure.
    # summary(model, input_size=tuple([config.num_modality]+config.patch_size)) # takes some time. comment during debugging. ouput each layer's out shape.
    # for name, m in model.named_modules():
    #     logger.info('module name:{}'.format(name))
    #     print(m)

    # lr
    lr = config.base_lr
    if args.resume_ckp != '':
        optimizer = checkpoint['optimizer']
    else:
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=lr,
                                     weight_decay=config.weight_decay)  #

    # loss
    dice_loss = MulticlassDiceLoss()
    ce_loss = nn.CrossEntropyLoss()
    focal_loss = FocalLoss(gamma=2)

    # prep data
    tasks = args.tasks  # list
    tb_loaders = list()  # train batch loader
    len_loader = list()
    for task in tasks:
        tb_loader = tb_load(task)
        tb_loader.enQueue(tasks_archive[task]['fold' + str(args.fold)],
                          config.patch_size)
        tb_loaders.append(tb_loader)
        len_loader.append(len(tb_loader))
    min_len_loader = np.min(len_loader)

    # init train values
    if args.resume_ckp != '':
        trLoss_queue = checkpoint['trLoss_queue']
        last_trLoss_ma = checkpoint['last_trLoss_ma']
    else:
        trLoss_queue = deque(
            maxlen=config.trLoss_win
        )  # queue to store exponential moving average of total loss in last N epochs
        last_trLoss_ma = None  # the previous one.
    trLoss_queue_list = [
        deque(maxlen=config.trLoss_win) for i in range(len(tasks))
    ]
    last_trLoss_ma_list = [None] * len(tasks)
    trLoss_ma_list = [None] * len(tasks)

    if args.resume_epoch > 0:
        start_epoch = args.resume_epoch + 1
        iterations = args.resume_epoch * config.step_per_epoch + 1
    else:
        start_epoch = 1
        iterations = 1
    logger.info('start epoch: {}'.format(start_epoch))

    ## run train
    for epoch in range(start_epoch, config.max_epoch + 1):
        logger.info('    ----- training epoch {} -----'.format(epoch))
        epoch_st_time = time.time()
        model.train()
        loss_epoch = 0.0
        loss_epoch_list = [0] * len(tasks)
        num_batch_processed = 0  # growing
        num_batch_processed_list = [0] * len(tasks)

        for step in tqdm(range(config.step_per_epoch),
                         desc='{}: epoch{}'.format(args.trainMode, epoch)):
            config.step = iterations
            config.task_idx = (iterations - 1) % len(tasks)
            config.task = tasks[config.task_idx]
            # import ipdb; ipdb.set_trace()

            # tb show lr
            config.writer.add_scalar('data/lr', lr, iterations - 1)

            st_time = time.time()
            for idx in range(len(tasks)):
                tb_loaders[idx].check_process()
            # import ipdb; ipdb.set_trace()
            (batchImg, batchLabel, batchWeight,
             batchAugs) = tb_loaders[config.task_idx].gen_batch(
                 config.batch_size, config.patch_size)
            # logger.info('idx{}_{}, gen_batch time elapsed:{}'.format(config.task_idx, config.task, tinies.timer(st_time, time.time())))

            st_time = time.time()
            batchImg = torch.from_numpy(batchImg).float(
            )  # change all inputs to same torch tensor type
            batchLabel = torch.from_numpy(batchLabel).float()
            batchWeight = torch.from_numpy(batchWeight).float()

            if config.use_gpu:
                batchImg = batchImg.cuda()
                batchLabel = batchLabel.cuda()
                batchWeight = batchWeight.cuda()
            # logger.info('idx{}_{}, .cuda time elapsed:{}'.format(config.task_idx, config.task, tinies.timer(st_time, time.time())))

            optimizer.zero_grad()

            st_time = time.time()
            if config.trainMode in ["universal"]:
                output, share_map, para_map = model(batchImg)
            else:
                output = model(batchImg)
            # logger.info('idx{}_{}, model() time elapsed:{}'.format(config.task_idx, config.task, tinies.timer(st_time, time.time())))

            st_time = time.time()
            # tensorboard visualization of training
            for i in range(len(tasks)):
                if iterations > 200 and iterations % 1000 == i:
                    tb_images([
                        batchImg[0, 0, ...], batchLabel[0, ...],
                        torch.argmax(output[0, ...], dim=0)
                    ], [False, True, True], ['image', 'GT', 'PS'],
                              iterations,
                              tag='Train_idx{}_{}_batch{}_{}'.format(
                                  config.task_idx, config.task, 0,
                                  '_'.join(batchAugs[0])))

                    tb_images([
                        batchImg[config.batch_size - 1, 0, ...],
                        batchLabel[config.batch_size - 1, ...],
                        torch.argmax(output[config.batch_size - 1, ...], dim=0)
                    ], [False, True, True], ['image', 'GT', 'PS'],
                              iterations,
                              tag='Train_idx{}_{}_batch{}_{}_step{}'.format(
                                  config.task_idx, config.task,
                                  config.batch_size - 1,
                                  '_'.join(batchAugs[config.batch_size - 1]),
                                  iterations - 1))
                    if config.trainMode == "universal":
                        logger.info(
                            'share_map shape:{}, para_map shape:{}'.format(
                                str(share_map.shape), str(para_map.shape)))
                        tb_images([
                            para_map[0, :, 64, ...], share_map[0, :, 64, ...]
                        ], [False, False], ['last_para_map', 'last_share_map'],
                                  iterations,
                                  tag='Train_idx{}_{}_para_share_maps_channels'
                                  .format(config.task_idx, config.task))

            logger.info(
                '----- {}, train epoch {} time elapsed:{} -----'.format(
                    config.task, epoch, tinies.timer(epoch_st_time,
                                                     time.time())))

            st_time = time.time()

            output_softmax = F.softmax(output, dim=1)

            loss = lovasz_softmax(output_softmax, batchLabel,
                                  ignore=10) + focal_loss(output, batchLabel)

            loss.backward()
            optimizer.step()

            # logger.info('idx{}_{}, backward time elapsed:{}'.format(config.task_idx, config.task, tinies.timer(st_time, time.time())))

            # loss.data.item()
            config.writer.add_scalar('data/loss_step', loss.item(), iterations)
            config.writer.add_scalar(
                'data/loss_step_idx{}_{}'.format(config.task_idx, config.task),
                loss.item(), iterations)

            loss_epoch += loss.item()
            num_batch_processed += 1

            loss_epoch_list[config.task_idx] += loss.item()
            num_batch_processed_list[config.task_idx] += 1

            iterations += 1

        # import ipdb; ipdb.set_trace()
        if epoch % config.save_epoch == 0:
            ckp_path = os.path.join(
                config.log_dir,
                '{}_{}_epoch{}_{}.pth.tar'.format(args.trainMode,
                                                  '_'.join(args.tasks), epoch,
                                                  tinies.datestr()))
            torch.save(
                {
                    'epoch': epoch,
                    'model': model,
                    'model_state_dict': model.state_dict(),
                    'optimizer': optimizer,
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': loss,
                    'trLoss_queue': trLoss_queue,
                    'last_trLoss_ma': last_trLoss_ma
                }, ckp_path)

        loss_epoch /= num_batch_processed

        config.writer.add_scalar('data/loss_epoch', loss_epoch, iterations - 1)
        for idx in range(len(tasks)):
            task = tasks[idx]
            loss_epoch_list[idx] /= num_batch_processed_list[idx]
            config.writer.add_scalar(
                'data/loss_epoch_idx{}_{}'.format(idx, task),
                loss_epoch_list[idx], iterations - 1)
        # import ipdb; ipdb.set_trace()

        ### lr decay
        trLoss_queue.append(loss_epoch)
        trLoss_ma = np.asarray(trLoss_queue).mean(
        )  # moving average. What about exponential moving average
        config.writer.add_scalar('data/trLoss_ma', trLoss_ma, iterations - 1)

        for idx in range(len(tasks)):
            task = tasks[idx]
            trLoss_queue_list[idx].append(loss_epoch_list[idx])
            trLoss_ma_list[idx] = np.asarray(trLoss_queue_list[idx]).mean(
            )  # moving average. What about exponential moving average
            config.writer.add_scalar(
                'data/trLoss_ma_idx{}_{}'.format(idx, task),
                trLoss_ma_list[idx], iterations - 1)

        # import ipdb; ipdb.set_trace()
        #### online eval
        Eval_bool = False
        if epoch >= config.start_val_epoch and epoch % config.val_epoch == 0:
            Eval_bool = True
        elif lr < 1e-8:
            Eval_bool = True
            logger.info(
                'lr is reduced to {}. Will do the last evaluation for all samples!'
                .format(lr))

        else:
            pass
        # if epoch >= config.start_val_epoch and epoch % config.val_epoch == 0:
        if Eval_bool:
            eval(args, tasks_archive, model, epoch, iterations - 1)

        ## stop if lr is too low
        if lr < 1e-8:
            logger.info('lr is reduced to {}. Job Done!'.format(lr))
            break

        ###### lr decay based on current task
        if len(trLoss_queue) == trLoss_queue.maxlen:
            if last_trLoss_ma and last_trLoss_ma - trLoss_ma < 1e-4:  # 5e-3
                lr /= 2
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr
            last_trLoss_ma = trLoss_ma

        ## save model when lr < 1e-8
        if lr < 1e-8:
            ckp_path = os.path.join(
                config.log_dir,
                '{}_{}_epoch{}_{}.pth.tar'.format(args.trainMode,
                                                  '_'.join(args.tasks), epoch,
                                                  tinies.datestr()))
            torch.save(
                {
                    'epoch': epoch,
                    'model': model,
                    'model_state_dict': model.state_dict(),
                    'optimizer': optimizer,
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': loss,
                    'trLoss_queue': trLoss_queue,
                    'last_trLoss_ma': last_trLoss_ma
                }, ckp_path)
Exemple #6
0
def evaluate(config_task, ids, model, outdir='eval_out', epoch_num=0):
    """
    evalutation
    """
    files = load_files(ids)
    files = list(files)

    datDir = os.path.join(config.prepData_dir, config_task.task, "Tr")
    dices_list = []

    # files = files[:2] # debugging.
    logger.info('Evaluating epoch{} for {}--- {} cases:\n{}'.format(
        epoch_num, config_task.task, len(files),
        str([obj['id'] for obj in files])))
    for obj in tqdm(files, desc='Eval epoch{}'.format(epoch_num)):
        ID = obj['id']
        # logger.info('evaluating {}:'.format(ID))
        obj['im'] = os.path.join(config.base_dir, config_task.task, "imagesTr",
                                 ID)
        obj['gt'] = os.path.join(config.base_dir, config_task.task, "labelsTr",
                                 ID)
        img_path = os.path.join(config.base_dir, config_task.task, "imagesTr",
                                ID)
        gt_path = os.path.join(config.base_dir, config_task.task, "labelsTr",
                               ID)

        data = get_eval_data(obj, datDir)
        # final_label, probs = segment_one_image(config_task, data, model) # final_label: d, h, w, num_classes

        try:
            final_label = segment_one_image(
                config_task, data, model,
                ID)  # final_label: d, h, w, num_classes
            save_to_nii(final_label,
                        filename=ID + '.nii.gz',
                        refer_file_path=img_path,
                        outdir=outdir,
                        mode="label",
                        prefix='Epoch{}_'.format(epoch_num))

            gt = sitk.GetArrayFromImage(sitk.ReadImage(gt_path))  # d, h, w
            # treat cancer as organ for Task03_Liver and Task07_Pancreas
            if config_task.task in ['Task03_Liver', 'Task07_Pancreas']:
                gt[gt == 2] = 1

            # cal dices
            dices = multiClassDice(gt, final_label, config_task.num_class)
            dices_list.append(dices)

            tinies.sureDir(outdir)
            fo = open(os.path.join(outdir,
                                   '{}_eval_res.csv'.format(config_task.task)),
                      mode='a+')
            wo = csv.writer(fo, delimiter=',')
            wo.writerow([epoch_num, tinies.datestr(), ID] + dices)
            fo.flush()

            ## for tensorboard visualization
            tb_img = sitk.GetArrayFromImage(sitk.ReadImage(img_path))  # d,h,w
            if tb_img.ndim == 4:
                tb_img = tb_img[0, ...]
            train.tb_images([tb_img, gt, final_label], [False, True, True],
                            ['image', 'GT', 'PS'],
                            epoch_num * config.step_per_epoch,
                            tag='Eval_{}_epoch_{}_dices_{}'.format(
                                ID, epoch_num, str(dices)))
        except Exception as e:
            logger.info('{}'.format(str(e)))

    labels = config_task.labels
    dices_all = np.asarray(dices_list)
    dices_mean = dices_all.mean(axis=0)
    logger.info('Eval mean dices:')
    dices_res = {}
    for i in range(config_task.num_class):
        tag = labels[str(i)]
        dices_res[tag] = dices_mean[i]
        logger.info('    {}, {}'.format(tag, dices_mean[i]))

    return dices_res