Beispiel #1
0
 def __init__(self, config):
     super(LesionModel, self).__init__()
     if config["A"] == "unet":
         self.model = smp.Unet(
             encoder_name=config[
                 "EN"],  # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
             encoder_weights=config["encoder_weights"],
             # use `imagenet` pre-trained weights for encoder initialization
             in_channels=
             3,  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
             classes=config[
                 "OC"],  # model output channels (number of classes in your dataset)
         )
     elif config["A"] == "DeepLabV3Plus":
         self.model = smp.DeepLabV3Plus(
             encoder_name=config[
                 "EN"],  # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
             encoder_weights=config["encoder_weights"],
             # use `imagenet` pre-trained weights for encoder initialization
             in_channels=
             3,  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
             classes=config[
                 "OC"],  # model output channels (number of classes in your dataset)
         )
     self.model_cfg = config
     self.loss_function = DiceLoss()
     self.save_hyperparameters()
     self.iou_function = IOU()
     self.checkpoint_path = ""
def train(img_path, ori_seg_path, label_path, ckpt_path, xls_path):
    wb = xlwt.Workbook()
    ws = wb.add_sheet('dice loss')

    model = UNet(channels_in=4, channels_out=1)
    model.apply(init_weight)

    train_set = Dataset4Layers(img_path, ori_seg_path, label_path, 'train')
    train_loader = DataLoader(train_set, batch_size=3, shuffle=True)
    val_set = Dataset4Layers(img_path, ori_seg_path, label_path, 'val')
    val_loader = DataLoader(val_set, batch_size=3, shuffle=False)

    opt = Adam(model.parameters(), lr=1e-4, betas=(0.9, 0.999))
    sch = StepLR(opt, step_size=20, gamma=0.7)
    loss = DiceLoss()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.float().to(device)

    max_epoch = 151
    cnt = 0
    stop_count = 15
    min_dice_loss = 1.
    stop_flag = False
    for i in range(max_epoch):
        dice_loss_train = epoch_step(train_loader, model, opt, loss, 'train',
                                     device)
        dice_loss_val = epoch_step(val_loader, model, opt, loss, 'val', device)
        loss_list = [dice_loss_train, dice_loss_val]
        for j in range(len(loss_list)):
            ws.write(i, j, loss_list[j])

        print(
            f'in epoch{i}: train dice loss is {dice_loss_train}, val dice loss is {dice_loss_val}'
        )

        if dice_loss_val < min_dice_loss:
            min_dice_loss = dice_loss_val
            save_ckpt(ckpt_path, i, model.state_dict())
            cnt = 0
        else:
            cnt = cnt + 1
        if cnt == stop_count:
            stop_flag = True
            break
        sch.step()

    if not stop_flag:
        save_ckpt(ckpt_path, max_epoch - 1, model.state_dict())

    if not os.path.exists(xls_path):
        os.mkdir(xls_path)
    wb.save(os.path.join(xls_path, 'seg_of_rectum_unet.xls'))
Beispiel #3
0
def train(img_path, label_path, ckpt_path, xls_path, ws, model, model_name,
          Dataset, vgg_tag):
    train_set = Dataset(img_path, label_path, 'train')
    train_loader = DataLoader(train_set, batch_size=3, shuffle=True)
    val_set = Dataset(img_path, label_path, 'val')
    val_loader = DataLoader(val_set, batch_size=3, shuffle=False)

    opt = Adam(model.parameters(), 1e-4, betas=(0.9, 0.999))
    sch = StepLR(opt, step_size=30, gamma=0.8)
    loss = DiceLoss()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.float().to(device)

    max_epoch = 151
    # max_epoch = 2
    cnt = 0
    stop_count = 15
    min_dice_loss = 1.
    stop_flag = False
    for i in range(max_epoch):
        dice_loss_train = epoch_step(train_loader, model, opt, vgg_tag, loss,
                                     'train', device)
        # dice_loss_train = epoch_step(val_loader, model, opt, vgg_tag, loss, 'train', device)
        dice_loss_val = epoch_step(val_loader, model, opt, vgg_tag, loss,
                                   'val', device)
        loss_list = [dice_loss_train, dice_loss_val]
        for j in range(len(loss_list)):
            ws.write(i, j, loss_list[j])

        print(
            f'in epoch{i}: train dice loss is {dice_loss_train}, val dice loss is {dice_loss_val}'
        )

        if dice_loss_val < min_dice_loss:
            min_dice_loss = dice_loss_val
            save_ckpt(ckpt_path, model_name, i, model.state_dict())
            cnt = 0
        else:
            cnt = cnt + 1
        if cnt == stop_count:
            stop_flag = True
            break
        sch.step()

    if not stop_flag:
        save_ckpt(ckpt_path, model_name, max_epoch - 1, model.state_dict())
    return ws
Beispiel #4
0
 def train(self):
     model = Net_design(self.used_model_list)
     device = torch.cuda.current_device()
     dataset = trainSampler(self.data_path)
     train_loader = DataLoader(dataset,
                               batch_size=self.batch_size,
                               shuffle=True,
                               num_workers=8,
                               pin_memory=True)
     if self.opreater == 'Adam':
         optimizer = optim.Adam(model.parameters(), lr=self.learning_rate)
     elif self.opreater == 'SGD':
         optimizer = optim.SGD(model.parameters(), lr=self.learning_rate)
     else:
         optimizer = optim.SGD(model.parameters(),
                               lr=self.learning_rate,
                               momentum=0.9)
     if self.loss_func == 'CE':
         criterion = nn.CrossEntropyLoss()
     else:
         criterion = DiceLoss()
     with open(os.path.join(self.save_dir, 'train_log.txt'), 'a') as f:
         for ep in range(self.epoch):
             model.train()
             epoch_loss = 0.0
             for batch in train_loader:
                 imgs = batch['image']
                 true_masks = batch['mask']
                 imgs = imgs.to(device=device, dtype=torch.float32)
                 true_masks = true_masks.to(device=device,
                                            dtype=torch.float32)
                 masks_pred = model(imgs)
                 loss = criterion(masks_pred, true_masks)
                 f.write("epoch is {}:loss is {}" % {(ep + 1), loss.item()})
                 logging.info("epoch is {epoch + 1}:loss is {loss.item()}")
                 epoch_loss += loss.item()
                 optimizer.zero_grad()
                 loss.backward()
                 optimizer.step()
             if ep % 5000 == 0:
                 torch.save(model.state_dict(),
                            self.save_dir + f'CP_epoch{ep + 1}.pth')
Beispiel #5
0
def train_unit():
    #model = SeTr(image_size=50, patch_size=10, dim=32, depth=4, heads=8, mlp_dim=400, channels=4).to(device)
    model = UNIT().to(device)
    #model = UNet(4).to(device)
    parameter_count = count_parameters(model)
    print('model parameter count: ' + str(parameter_count))

    #criterion = nn.CrossEntropyLoss()
    criterion = DiceLoss()

    optimizer = optim.Adam(model.parameters(), lr=0.003)
    
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)

    epoch = 0
    while epoch != 10000:
        #if(epoch == 500):
        #    criterion = DiceLoss()
        batch_size = 2
        current_batch_index = (int(epoch/1) % (int(340/batch_size)))
        lower_index = (current_batch_index * batch_size) + 11
        upper_index = ((1+current_batch_index) * batch_size) + 11
        
        if epoch < 100:
            lower_index = 12
            upper_index = 14

        if epoch == 80:
            optimizer = optim.Adam(model.parameters(), lr=0.00003)

        if epoch == 400:
            optimizer = optim.Adam(model.parameters(), lr=0.000003)

        #print('lower_index:' + str(lower_index))
        #print('upper_index:' + str(upper_index))
        chunked_images, chunked_segmented_images = load_data(lower_index, upper_index, chunk_x, chunk_y)

        for batch_epoch in range(0, 1):

            running_loss = 0.0
        

            total_size = chunked_images.shape[0]
            subbatch_size = int(total_size / 1)
            current_subbatch_index = 0#epoch % 1
        
 
            model_input = np.copy(chunked_images[current_subbatch_index*subbatch_size:((1 + current_subbatch_index)*subbatch_size)])
            model_output = chunked_segmented_images[current_subbatch_index*subbatch_size:((1 + current_subbatch_index)*subbatch_size)]
        
            model_input= np.copy(chunked_images)
            model_output = np.copy(chunked_segmented_images)

            model_input = model_input.transpose((0, 3, 1, 2))
            model_input = torch.from_numpy(model_input).to(device)

            model_output = model_output.transpose((0, 1, 2))
            model_output = torch.LongTensor(model_output).to(device)

            outputs = model(model_input)
            
            optimizer.zero_grad()
            loss = criterion(outputs, model_output)
            loss.backward()
            optimizer.step()
            scheduler.step(loss)

            running_loss += loss.item()

            print('[%d] loss: %.3f' % (epoch + 1, running_loss))

            if (epoch%500 == 3):
                evaluate(model, True)
                model.train()
            elif(epoch % 30 == 2):
                evaluate(model, False)
                model.train()
            elif (epoch%500 == 4):
                evaluate_train(model, True)
                model.train()
            elif (epoch%30 == 5):
                evaluate_train(model, False)
                model.train()

            
            epoch = epoch + 1 

            del model_input
            del model_output
            del outputs
            del loss
            torch.cuda.empty_cache()
        
        del chunked_images
        del chunked_segmented_images
        torch.cuda.empty_cache()

    return model
Beispiel #6
0
def train_net(net,
              device,
              output_dir,
              train_date='',
              epochs=20,
              iters=900,
              bs=4,
              lr=0.01,
              save_cp=True,
              only_lastandbest=False,
              eval_freq=5,
              fold_idx=None,
              site='A',
              eval_site=None,
              gpus=None,
              save_folder='',
              aug=False,
              zoom=False,
              whitening=True,
              nonlinear='relu',
              norm_type='BN',
              pretrained=False,
              loaded_model_file_name='model_best',
              spade_seg_mode='soft',
              spade_inferred_mode='mask',
              spade_aux_blocks='',
              freeze_except=None,
              ce_weighted=False,
              spade_reduction=2,
              excluded_classes=None,
              dataset_name=None):
    net.apply(weights_init)
    logging.info('Model Parameters Reset!')
    net.to(device=device)

    if pretrained:
        pretrained_model_dir = pretrained + f'/Fold_{fold_idx}/{loaded_model_file_name}.pth'
        pretrained_dict_load = torch.load(pretrained_model_dir)
        model_dict = net.state_dict()

        pretrained_dict = {}
        for k, v in pretrained_dict_load.items():
            if (k in model_dict) and (model_dict[k].shape
                                      == pretrained_dict_load[k].shape):
                pretrained_dict[k] = v

        model_dict.update(pretrained_dict)
        net.load_state_dict(model_dict)

        print('Freeze Excluded Keywords:', freeze_except)
        if freeze_except is not None:
            for k, v in net.named_parameters():
                v.requires_grad = False
                for except_key in freeze_except:
                    if except_key in k:
                        v.requires_grad = True
                        print(k, ' requires grad')
                        break
        logging.info(f'Model loaded from {pretrained_model_dir}')

    pretrain_suffix = ''
    if pretrained:
        pretrain_suffix = f'_Pretrained'

    block_names = [
        'inc', 'down1', 'down2', 'down3', 'down4', 'mid', 'up1', 'up2', 'up3',
        'up4'
    ]

    spade_blocks_suffix = ''
    if spade_aux_blocks != '':
        if spade_inferred_mode == 'mask':
            spade_blocks_suffix += f'_SPADE_R{spade_reduction}_{spade_seg_mode}_Aux_'
        else:
            spade_blocks_suffix += f'_SPADE_R{spade_reduction}_{spade_inferred_mode}_Aux_'

        for blockname in spade_aux_blocks:
            block_idx = block_names.index(blockname)
            spade_blocks_suffix += str(block_idx)

    excluded_classes_suffix = ''

    if excluded_classes is not None:
        excluded_classes_string = [str(c) for c in excluded_classes]
        excluded_classes_suffix = '_exclude' + ''.join(excluded_classes_string)

    dir_results = output_dir

    tensorboard_logdir = dir_results + 'logs/' + f'{save_folder}/' + f'{train_date}_Site_{site}_GPUs_{gpus}/' + \
                         f'BS_{bs}_Epochs_{epochs}_Aug_{aug}_Zoom_{zoom}_Nonlinear_{nonlinear}_Norm_{norm_type}' + \
                         excluded_classes_suffix + pretrain_suffix + spade_blocks_suffix + f'/Fold_{fold_idx}'
    writer = SummaryWriter(log_dir=tensorboard_logdir)

    dir_checkpoint = dir_results + 'checkpoints/' + f'{save_folder}/' + \
                     f'{train_date}_Site_{site}_GPUs_{gpus}/' + f'Epochs_{epochs}_Aug_{aug}_Zoom_{zoom}_Nonlinear_{nonlinear}_Norm_{norm_type}' + \
                       excluded_classes_suffix + pretrain_suffix + spade_blocks_suffix + f'/Fold_{fold_idx}' + '/'
    print(tensorboard_logdir)
    dir_eval_csv = dir_results + 'eval_csv/' + f'{save_folder}/' + \
                   f'{train_date}_Site_{site}_GPUs_{gpus}/' + f'Epochs_{epochs}_Aug_{aug}_Zoom_{zoom}_Nonlinear_{nonlinear}_Norm_{norm_type}' + \
                   excluded_classes_suffix + pretrain_suffix + spade_blocks_suffix + '/'
    csv_files_prefix = f'{train_date}_Site_{site}_GPUs_{gpus}_'
    print(dir_eval_csv)
    if not os.path.exists(dir_eval_csv):
        os.makedirs(dir_eval_csv)
    train_list = {}
    val_list = {}
    test_list = {}

    train_list['Overall'] = []
    val_list['Overall'] = []
    test_list['Overall'] = []

    if eval_site is None:
        sites_inferred = list(site)
    else:
        sites_inferred = list(set(site + eval_site))
    sites_inferred.sort()
    print(sites_inferred)
    for site_idx in sites_inferred:
        train_list[site_idx] = all_list[site_idx][split_list[site_idx]
                                                  [fold_idx][0]].tolist()
        val_list[site_idx] = all_list[site_idx][split_list[site_idx][fold_idx]
                                                [1]].tolist()
        test_list[site_idx] = all_list[site_idx][split_list[site_idx][fold_idx]
                                                 [1]].tolist()
        if site_idx in site:
            train_list['Overall'].append(train_list[site_idx])
            val_list['Overall'].append(val_list[site_idx])
        test_list['Overall'].append(test_list[site_idx])

    print('-----------------------------------------')
    print('Dataset Info:')
    for site_key in train_list.keys():
        if site_key in ['Overall', 'ABC_mixed']:
            case_total_train = 0
            case_total_test = 0
            for site_list_train, site_list_test in zip(train_list[site_key],
                                                       test_list[site_key]):
                case_total_train += len(site_list_train)
                case_total_test += len(site_list_test)
            print(
                f'{site_key}: {len(train_list[site_key])} sites  '
                f'Train: {case_total_train} cases, Test: {case_total_test} cases'
            )
        else:
            print(
                f'Site {site_key} Train:  {len(train_list[site_key])} cases, Test:  {len(test_list[site_key])} cases'
            )
    print('-----------------------------------------')
    n_train = iters * bs
    if len(site) > 1:
        train_set = SiteSet(train_list['Overall'],
                            iters=n_train,
                            training=True,
                            augmentation=aug,
                            source="Overall",
                            zoom_crop=zoom,
                            whitening=whitening,
                            batchsize=bs // len(site),
                            site_num=len(site),
                            n_classes=net.n_classes,
                            excluded_classes=excluded_classes)
        val_set = SiteSet(val_list['Overall'],
                          iters=n_train,
                          training=True,
                          augmentation=False,
                          source="Overall",
                          zoom_crop=False,
                          whitening=whitening,
                          batchsize=bs // len(site),
                          site_num=len(site),
                          n_classes=net.n_classes,
                          excluded_classes=excluded_classes)
    else:
        train_set = SiteSet(train_list[site],
                            iters=n_train,
                            training=True,
                            augmentation=aug,
                            source=site,
                            zoom_crop=zoom,
                            whitening=whitening,
                            batchsize=bs // len(site),
                            site_num=len(site),
                            n_classes=net.n_classes,
                            excluded_classes=excluded_classes)
        val_set = SiteSet(val_list[site],
                          iters=n_train,
                          training=True,
                          augmentation=False,
                          source=site,
                          zoom_crop=False,
                          whitening=whitening,
                          batchsize=bs // len(site),
                          site_num=len(site),
                          n_classes=net.n_classes,
                          excluded_classes=excluded_classes)

    train_loader = DataLoader(train_set,
                              batch_size=bs,
                              shuffle=False,
                              num_workers=2,
                              pin_memory=True)
    val_loader = DataLoader(val_set,
                            batch_size=bs,
                            shuffle=False,
                            num_workers=2,
                            pin_memory=True)

    global_step = 0

    logging.info(f'''Starting training:
        Output Path:     {output_dir}
        Epochs:          {epochs}
        Iterations:      {iters}
        Batch size:      {bs}
        Learning rate:   {lr}
        Checkpoints:     {save_cp}
        Only_Last_Best:  {only_lastandbest}
        Eval_Frequency:  {eval_freq}
        Device:          {device.type}
        GPU ids:         {gpus}
        Site:            {site}
        Shift+Rotation:  {aug}
        Zoom+Crop:       {zoom}
        Whitening:       {whitening}
        Pretrain:        {pretrained}
        Classes:         {net.n_classes}
        Excluded Clases: {excluded_classes}
    ''')

    optimizer = optim.Adam(net.parameters(), lr=lr)

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     'min',
                                                     factor=0.5,
                                                     patience=5)

    criterion = DiceLoss()

    best_score = 0
    csv_header = True
    for epoch in range(epochs):

        losses = []
        losses_first_forward = []
        epoch_loss = 0
        with tqdm(total=iters,
                  desc=f'Epoch {epoch + 1}/{epochs}',
                  unit='batch') as pbar:
            # for batch in train_loader:
            time.sleep(2)
            for (train_batch, val_batch) in zip(train_loader, val_loader):
                net.train()
                imgs = train_batch[0]
                all_masks = train_batch[1]
                site_label = train_batch[2]

                assert imgs.shape[1] == net.n_channels, \
                    f'Network has been defined with {net.n_channels} input channels, ' \
                    f'but loaded images have {imgs.shape[1]} channels. Please check that ' \
                    'the images are loaded correctly.'

                imgs = imgs.to(device=device, dtype=torch.float32)
                mask_type = torch.float32 if net.n_classes == 1 else torch.long
                all_masks = all_masks.to(device=device, dtype=mask_type)

                masks_pred = net(imgs)
                if spade_aux_blocks == '':
                    loss, loss_hard = criterion(masks_pred,
                                                all_masks,
                                                num_classes=net.n_classes,
                                                return_hard_dice=True,
                                                softmax=True)
                    writer.add_scalar('Backwarded Loss/Dice_Loss', loss.item(),
                                      global_step)

                    if net.n_classes > 2:
                        loss_forward_weighted_ce = Weighted_Cross_Entropy_Loss(
                        )(masks_pred,
                          all_masks,
                          num_classes=net.n_classes,
                          weighted=ce_weighted,
                          softmax=True)
                        loss += loss_forward_weighted_ce
                        writer.add_scalar('Backwarded Loss/Weighted_CE_Loss',
                                          loss_forward_weighted_ce.item(),
                                          global_step)
                    train_loss = loss_hard.item()
                else:
                    # start first forward
                    loss_first_forward, loss_first_forward_hard = criterion(
                        masks_pred,
                        all_masks,
                        num_classes=net.n_classes,
                        return_hard_dice=True,
                        softmax=True)
                    referred_mask_loss = loss_first_forward_hard.item()
                    writer.add_scalar('Backwarded Loss/Dice_Loss_First',
                                      loss_first_forward.item(), global_step)
                    if net.n_classes > 2:
                        loss_first_forward_weighted_ce = Weighted_Cross_Entropy_Loss(
                        )(masks_pred,
                          all_masks,
                          num_classes=net.n_classes,
                          weighted=ce_weighted,
                          softmax=True)
                        loss_first_forward += loss_first_forward_weighted_ce
                        writer.add_scalar(
                            'Backwarded Loss/Weighted_CE_Loss_First',
                            loss_first_forward_weighted_ce.item(), global_step)
                    optimizer.zero_grad()
                    loss_first_forward.backward()
                    optimizer.step()

                    # start second forward
                    mask_pred_first_forward = masks_pred.detach()
                    mask_pred_first_forward = torch.softmax(
                        mask_pred_first_forward, dim=1)
                    masks_pred_second_forward = net(
                        imgs, seg=mask_pred_first_forward)
                    loss, loss_hard = criterion(masks_pred_second_forward,
                                                all_masks,
                                                num_classes=net.n_classes,
                                                return_hard_dice=True,
                                                softmax=True)

                    writer.add_scalar('Backwarded Loss/Dice_Loss_Second',
                                      loss.item(), global_step)
                    if net.n_classes > 2:
                        loss_second_forward_weighted_ce = Weighted_Cross_Entropy_Loss(
                        )(masks_pred_second_forward,
                          all_masks,
                          num_classes=net.n_classes,
                          weighted=ce_weighted,
                          softmax=True)
                        loss += loss_second_forward_weighted_ce
                        writer.add_scalar(
                            'Backwarded Loss/Weighted_CE_Loss_Second',
                            loss_second_forward_weighted_ce.item(),
                            global_step)
                    train_loss = loss_hard.item()

                epoch_loss += train_loss

                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_value_(net.parameters(), 0.1)
                optimizer.step()

                pbar.set_postfix(**{'Dice': train_loss})
                pbar.update(1)
                global_step += 1

                net.eval()
                imgs = val_batch[0]
                all_masks = val_batch[1]
                site_label = val_batch[2]

                imgs = imgs.to(device=device, dtype=torch.float32)
                mask_type = torch.float32 if net.n_classes == 1 else torch.long
                all_masks = all_masks.to(device=device, dtype=mask_type)

                with torch.no_grad():
                    masks_pred = net(imgs)
                    # loss = criterion(masks_pred, all_masks)
                    # val_loss = loss.item()
                    if spade_aux_blocks == '':
                        loss, loss_hard = criterion(masks_pred,
                                                    all_masks,
                                                    num_classes=net.n_classes,
                                                    return_hard_dice=True,
                                                    softmax=True)
                        val_loss = loss_hard.item()
                    else:
                        mask_pred_first_forward = masks_pred.detach()

                        _, loss_hard_first_forward = criterion(
                            mask_pred_first_forward,
                            all_masks,
                            num_classes=net.n_classes,
                            return_hard_dice=True,
                            softmax=True)
                        mask_pred_first_forward = torch.softmax(
                            mask_pred_first_forward, dim=1)
                        masks_pred_second_forward = net(
                            imgs, seg=mask_pred_first_forward)
                        loss, loss_hard = criterion(masks_pred_second_forward,
                                                    all_masks,
                                                    num_classes=net.n_classes,
                                                    return_hard_dice=True,
                                                    softmax=True)
                        val_loss = loss_hard.item()
                        val_loss_first_forward = loss_hard_first_forward.item()

                if spade_aux_blocks == '':
                    writer.add_scalars('Loss/Dice_Loss', {
                        'train': train_loss,
                        'val': val_loss
                    }, global_step)
                else:
                    writer.add_scalars(
                        'Loss/Dice_Loss', {
                            'train': train_loss,
                            'val_first': val_loss_first_forward,
                            'val_second': val_loss
                        }, global_step)
                if global_step % (n_train // (eval_freq * bs)) == 0:
                    if net.n_classes == 2:
                        writer.add_images(
                            'masks/true',
                            (all_masks[:4, ...][:4, ...].cpu().unsqueeze(1)),
                            global_step)
                        writer.add_images(
                            'masks/pred',
                            torch.softmax(masks_pred[:4, ...],
                                          dim=1)[:, 1:2, :, :].cpu(),
                            global_step)

                    elif net.n_classes > 2:
                        writer.add_images(
                            'masks/true',
                            (all_masks[:4, ...].cpu().unsqueeze(1)),
                            global_step)
                        writer.add_images(
                            'masks/pred',
                            torch.argmax(torch.softmax(masks_pred[:4, ...],
                                                       dim=1),
                                         dim=1).unsqueeze(1).cpu(),
                            global_step)
                if spade_aux_blocks != '':
                    losses_first_forward.append(loss_first_forward.item())

                losses.append(train_loss)

                if global_step % 50 == 0 and optimizer.param_groups[0][
                        'lr'] > 1e-5:
                    scheduler.step(np.mean(losses[-50:]))
                if global_step % (n_train // (eval_freq * bs)) == 0:
                    for tag, value in net.named_parameters():
                        tag = tag.replace('.', '/')
                        writer.add_histogram('weights/' + tag,
                                             value.data.cpu().numpy(),
                                             global_step)
                    if eval_site is None:
                        eval_site = site
                    test_scores, test_asds = eval_net(
                        net,
                        test_list,
                        device,
                        fold_idx,
                        global_step,
                        dir_eval_csv,
                        csv_files_prefix=csv_files_prefix,
                        whitening=whitening,
                        eval_site=eval_site,
                        spade_aux=(spade_aux_blocks != ''),
                        excluded_classes=excluded_classes,
                        dataset_name=dataset_name)
                    writer.add_scalar('learning_rate',
                                      optimizer.param_groups[0]['lr'],
                                      global_step)
                    if epoch >= 0:
                        if len(site) == 1:
                            metric_site = site
                        else:
                            metric_site = 'Overall'
                        if net.n_classes == 2:
                            is_best = test_scores[metric_site] > best_score
                            best_score = max(test_scores[metric_site],
                                             best_score)
                            if is_best:
                                try:
                                    os.makedirs(dir_checkpoint)
                                    logging.info(
                                        'Created checkpoint directory')
                                except OSError:
                                    pass
                                torch.save(net.state_dict(),
                                           dir_checkpoint + f'model_best.pth')
                                logging.info(
                                    f'Best model saved ! ( Dice Score:{best_score}) on Site {metric_site})'
                                )

                        elif net.n_classes > 2:
                            scores_all_classes = 0
                            for c in test_scores.keys():
                                scores_all_classes += test_scores[c][
                                    metric_site]
                            scores_all_classes /= len(test_scores.keys())
                            is_best = scores_all_classes > best_score
                            best_score = max(scores_all_classes, best_score)
                            if is_best:
                                try:
                                    os.makedirs(dir_checkpoint)
                                    logging.info(
                                        'Created checkpoint directory')
                                except OSError:
                                    pass

                                torch.save(net.state_dict(),
                                           dir_checkpoint + f'model_best.pth')
                                logging.info(
                                    f'Best model saved ! ( Dice Score:{best_score} on Site {metric_site})'
                                )
                    if net.n_classes == 2:
                        if len(eval_site) > 1:
                            sites_print = list(eval_site) + ['Overall']
                        else:
                            sites_print = list(eval_site)

                        test_performance_dict = {}
                        test_performance_dict['fold'] = fold_idx
                        test_performance_dict['global_step'] = global_step
                        for st in sites_print:
                            print('\nSite: {}'.format(st))
                            print('\nTest Dice Coeff: {}'.format(
                                test_scores[st]))
                            print('\nTest ASD: {}'.format(test_asds[st]))
                        for st in sites_print:
                            test_performance_dict[f'Dice_{st}'] = [
                                format(test_scores[st], '.4f')
                            ]
                        for st in sites_print:
                            test_performance_dict[f'ASD_{st}'] = [
                                format(test_asds[st], '.2f')
                            ]
                        if spade_aux_blocks != '':
                            for st in sites_print:
                                test_performance_dict[
                                    f'Dice_{st}_first_forward'] = [
                                        format(
                                            test_scores[st + '_first_forward'],
                                            '.4f')
                                    ]
                            for st in sites_print:
                                test_performance_dict[
                                    f'ASD_{st}_first_forward'] = [
                                        format(
                                            test_asds[st + '_first_forward'],
                                            '.2f')
                                    ]
                        df = pd.DataFrame.from_dict(test_performance_dict)
                        df.to_csv(dir_eval_csv + csv_files_prefix +
                                  f'site_performance.csv',
                                  mode='a',
                                  header=csv_header,
                                  index=False)
                        csv_header = False
                        print('\n' + tensorboard_logdir)
                        # write with tensorboard
                        writer.add_scalars('Test/Dice_Score', test_scores,
                                           global_step)
                        writer.add_scalars('Test/ASD', test_asds, global_step)

                    elif net.n_classes > 2:
                        if dataset_name == 'ABD-8':
                            abdominal_organ_dict = {
                                1: 'spleen',
                                2: 'r_kidney',
                                3: 'l_kidney',
                                4: 'gallbladder',
                                5: 'pancreas',
                                6: 'liver',
                                7: 'stomach',
                                8: 'aorta'
                            }
                            if excluded_classes is None:
                                organ_dict = abdominal_organ_dict
                            else:
                                print('Original Organ dict')
                                print(abdominal_organ_dict)
                                post_mapping_dict = {}
                                original_classes = list(
                                    range(net.n_classes +
                                          len(excluded_classes)))
                                remain_classes = [
                                    item for item in original_classes
                                    if item not in excluded_classes
                                ]
                                for new_value, value in enumerate(
                                        remain_classes):
                                    post_mapping_dict[value] = new_value
                                organ_dict = {}
                                for c in remain_classes:
                                    if c == 0:
                                        continue
                                    organ_dict[post_mapping_dict[
                                        c]] = abdominal_organ_dict[c]
                                print('Current Organ dict')
                                print(organ_dict)
                        elif dataset_name == 'ABD-6':
                            abdominal_organ_dict = {
                                1: 'spleen',
                                2: 'l_kidney',
                                3: 'gallbladder',
                                4: 'liver',
                                5: 'stomach',
                                6: 'pancreas'
                            }
                            if excluded_classes is None:
                                organ_dict = abdominal_organ_dict
                            else:
                                print('Original Organ dict')
                                print(abdominal_organ_dict)
                                post_mapping_dict = {}
                                original_classes = list(
                                    range(net.n_classes +
                                          len(excluded_classes)))
                                remain_classes = [
                                    item for item in original_classes
                                    if item not in excluded_classes
                                ]
                                for new_value, value in enumerate(
                                        remain_classes):
                                    post_mapping_dict[value] = new_value
                                organ_dict = {}
                                for c in remain_classes:
                                    if c == 0:
                                        continue
                                    organ_dict[post_mapping_dict[
                                        c]] = abdominal_organ_dict[c]
                                print('Current Organ dict')
                                print(organ_dict)
                        print(f'Organ Dict:{organ_dict}')
                        test_performance_dict = {}
                        test_performance_dict['fold'] = fold_idx
                        test_performance_dict['global_step'] = global_step

                        for organ_class in range(1, net.n_classes):
                            if len(eval_site) > 1:
                                sites_print = list(eval_site) + ['Overall']
                            else:
                                sites_print = list(eval_site)

                            for st in sites_print:
                                test_performance_dict[
                                    f'{organ_dict[organ_class]}_Dice_{st}'] = [
                                        format(test_scores[organ_class][st],
                                               '.4f')
                                    ]
                            for st in sites_print:
                                test_performance_dict[
                                    f'{organ_dict[organ_class]}_ASD_{st}'] = [
                                        format(test_asds[organ_class][st],
                                               '.2f')
                                    ]

                            writer.add_scalars(
                                f'Test/Dice_Score_{organ_dict[organ_class]}',
                                test_scores[organ_class], global_step)
                            writer.add_scalars(
                                f'Test/ASD_{organ_dict[organ_class]}',
                                test_asds[organ_class], global_step)

                        test_scores['AVG'] = {}
                        test_asds['AVG'] = {}
                        for st in sites_print:
                            scores_all_classes_avg = 0
                            asds_all_classes_avg = 0
                            for c in range(1, net.n_classes):
                                scores_all_classes_avg += test_scores[c][st]
                                asds_all_classes_avg += test_asds[c][st]
                            scores_all_classes_avg /= (net.n_classes - 1)
                            asds_all_classes_avg /= (net.n_classes - 1)
                            test_scores['AVG'][st] = scores_all_classes_avg
                            test_asds['AVG'][st] = asds_all_classes_avg
                            print(f'\nSite:{st}')
                            print(f'Average:')
                            print('Test Dice Coeff: {}'.format(
                                test_scores['AVG'][st]))
                            print('Test ASD: {}'.format(test_asds['AVG'][st]))
                        for st in sites_print:
                            test_performance_dict[f'Average_Dice_{st}'] = [
                                format(test_scores['AVG'][st], '.4f')
                            ]
                        for st in sites_print:
                            test_performance_dict[f'Average_ASD_{st}'] = [
                                format(test_asds['AVG'][st], '.2f')
                            ]
                        writer.add_scalars(f'Test/Dice_Score_Average',
                                           test_scores['AVG'], global_step)
                        writer.add_scalars(f'Test/ASD_Average',
                                           test_asds['AVG'], global_step)

                        if spade_aux_blocks != '':
                            for organ_class in range(1, net.n_classes):
                                for st in sites_print:
                                    test_performance_dict[
                                        f'{organ_dict[organ_class]}_Dice_{st}_first_forward'] = [
                                            format(
                                                test_scores[organ_class]
                                                [st + '_first_forward'], '.4f')
                                        ]
                                for st in sites_print:
                                    test_performance_dict[
                                        f'{organ_dict[organ_class]}_ASD_{st}_first_forward'] = [
                                            format(
                                                test_asds[organ_class]
                                                [st + '_first_forward'], '.2f')
                                        ]
                        df = pd.DataFrame.from_dict(test_performance_dict)
                        df.to_csv(dir_eval_csv + csv_files_prefix +
                                  f'site_performance.csv',
                                  mode='a',
                                  header=csv_header,
                                  index=False)
                        csv_header = False
                        print('\n' + tensorboard_logdir)

        if save_cp:
            try:
                os.makedirs(dir_checkpoint)
                logging.info('Created checkpoint directory')
            except OSError:
                pass
            if (epoch + 1) < epochs:
                torch.save(net.state_dict(),
                           dir_checkpoint + f'CP_epoch{epoch + 1}.pth')

            else:
                torch.save(net.state_dict(),
                           dir_checkpoint + f'model_last.pth')

            if only_lastandbest:
                if os.path.exists(dir_checkpoint + f'CP_epoch{epoch}.pth'):
                    os.remove(dir_checkpoint + f'CP_epoch{epoch}.pth')
            print(dir_checkpoint)
            logging.info(f'Checkpoint {epoch + 1} saved !')

    writer.close()
    print("GPU is available")
    device = torch.device("cuda")
else:
    print("GPU is not available")
    device = torch.device("cpu")

nets = [
    UNet(1, 1, final_activation=nn.Sigmoid()),
    UNet_BatchNorm(1, 1, final_activation=nn.Sigmoid()),
    UNet_ELU(1, 1, final_activation=nn.Sigmoid()),
    UNet_Five_Layers(1, 1, final_activation=nn.Sigmoid())
]

n_epochs = 15
for net in nets:
    for loss_fn in [torch.nn.MSELoss(), DiceLoss(), DiceCoefficient()]:
        name = net.__class__.__name__
        loss_name = loss_fn.__class__.__name__
        print(f"\n\nTraining network {name} with loss {loss_name}")
        logger = SummaryWriter(f'runs/log_{name}_{loss_name}')
        net.to(device)

        loss_function = loss_fn
        loss_function.to(device)

        # use adam optimizer
        optimizer = torch.optim.Adam(net.parameters(), lr=1.e-3)

        # build the dice coefficient metric
        metric = DiceCoefficient()
        # train for 25 epochs
Beispiel #8
0
def dice_loss(input, target):
    dice = DiceLoss()
    loss = dice(input, target)
    return loss
def evaluate(config, model, data_iter, test=False):
    model.eval()
    loss_total = 0
    predict_all = np.array([], dtype=int)  # 预测小税号
    gather_predict_all = np.array([], dtype=int)  # 预测大税号
    labels_all = np.array([], dtype=int)
    gather_labels_all = np.array([], dtype=int)
    focalLoss = FocalLoss(config.num_classes)
    diceloss = DiceLoss(config.num_classes)  # dice loss
    with torch.no_grad():
        for texts, labels, gather_labels in data_iter:
            outputs, gather_outputs = model(texts)
            loss = F.cross_entropy(outputs, labels)
            # loss = focalLoss(outputs, labels)  # dic loss
            # loss = diceloss(outputs, labels)  # dice loss
            loss_total += loss

            labels = labels.data.cpu().numpy()  # label
            gather_labels = gather_labels.data.cpu().numpy()  # 大分类税号

            predic = torch.max(outputs.data, 1)[1].cpu().numpy()  # predic
            gather_predic = torch.max(gather_outputs.data,
                                      1)[1].cpu().numpy()  # predic_gather

            labels_all = np.append(labels_all, labels)
            gather_labels_all = np.append(gather_labels_all, gather_labels)
            predict_all = np.append(predict_all, predic)
            gather_predict_all = np.append(gather_predict_all, gather_predic)

    haiguan_labels_all = [config.class_list[x] for x in labels_all]
    haiguan_predic_all = [config.class_list[x] for x in predict_all]
    haiguan_gather_labels_all = [
        config.gather_class_list[x] for x in gather_labels_all
    ]
    haiguan_gather_predict_labels_all = [
        config.gather_class_list[x] for x in gather_predict_all
    ]
    acc = metrics.accuracy_score(labels_all, predict_all)
    if test:
        # report = metrics.classification_report(labels_all, predict_all, target_names=config.class_list, digits=4)
        report = metrics.classification_report(labels_all,
                                               predict_all,
                                               digits=4)
        # confusion = metrics.confusion_matrix(labels_all, predict_all)
        # class_map = {}  # 大小税号映射表
        wrong_number = 0  # 小税号预测错误条数
        # '''
        # class_map_list = [x.strip().split('\t') for x in open(config.class_map_path, encoding='utf-8').readlines()]
        # for lin in class_map_list:
        #     class_map[lin[0]] = lin[1]
        list1, list2, list3, list4, index_list = [], [], [], [], []  # 装结果的5列表
        for i, (pre, lab, galab, pre_galab) in enumerate(
                zip(haiguan_predic_all, haiguan_labels_all,
                    haiguan_gather_labels_all,
                    haiguan_gather_predict_labels_all)):
            if pre_galab == galab and pre == lab:  # 全对
                list1.append((pre, lab, pre_galab, galab))
            elif pre_galab == galab and pre != lab:  # 大对小不对
                index_list.append([str(i), pre, lab, galab])
                list2.append((pre, lab, pre_galab, galab))
            elif pre_galab != galab and pre == lab:  # 大不对小对
                list3.append((pre, lab, pre_galab, galab))
            elif pre_galab != galab and pre != lab:  # 大不对小不对
                index_list.append([str(i), pre, lab, galab])
                list4.append((pre, lab, pre_galab, galab))
        print('全对有{}条'.format(len(list1)))
        print('大对小不对有{}条'.format(len(list2)))
        print('大不对小对有{}条'.format(len(list3)))
        print('大不对小不对有{}条'.format(len(list4)))
        """
        with open('./84-85-90/base_wrong.txt', 'w', encoding='utf-8')as f:
            f.write('index, prelab, lab, galab')
            f.write('\n')
            for line in index_list:
                line = '\t'.join(line)
                f.write(line)
                f.write('\n')
        print('index_finish')
        """
        wrong_list = list2 + list4  # 小税号预测错误汇总
        test_number = 0
        for line in wrong_list:
            if line[3] not in line[0]:  # 预测的小税号不在正确的大税号下
                test_number += 1
        print('预测小税号错误中,小税号不在正确大税号属性下有{}条'.format(test_number))
        print('\n')
        print('\n')
        print('\n')
        # '''
        wrong_rate1 = test_number / len(wrong_list)
        # return acc, loss_total / len(data_iter), report, confusion, wrong_list, wrong_number, wrong_rate1
        return acc, loss_total / len(
            data_iter), wrong_list, wrong_number, wrong_rate1, report
    return acc, loss_total / len(data_iter)
    # Model
    if lossf == 'contour':
        model = CleanU_Net(in_channels=1, out_channels=1)
        wo_mask = False
    else:
        model = CleanU_Net(in_channels=1, out_channels=2)
        wo_mask = False
    if torch.cuda.is_available():
        model = torch.nn.DataParallel(
            model, device_ids=list(range(torch.cuda.device_count()))).cuda()

    # Loss function
    if lossf == 'crossentropy':
        criterion = nn.CrossEntropyLoss()
    elif lossf == 'dice':
        criterion = DiceLoss()
    elif lossf == 'contour':
        criterion = ContourLoss()
    else:
        raise ValueError('Undefined loss type')

    # Optimizerd
    if torch.cuda.is_available():
        if opt == 'rmsprop':
            optimizer = torch.optim.RMSprop(model.module.parameters(), lr=lr)
        if opt == 'adam':
            optimizer = torch.optim.Adam(model.module.parameters(), lr=lr)
    else:
        if opt == 'rmsprop':
            optimizer = torch.optim.RMSprop(model.parameters(), lr=lr)
        if opt == 'adam':