示例#1
0
    def build_model(self):
        print("Using model: {}".format(self.model_type))
        """Build generator and discriminator."""
        if self.model_type == 'U_Net':
            self.unet = U_Net(img_ch=3, output_ch=self.output_ch)
        elif self.model_type == 'R2U_Net':
            self.unet = R2U_Net(img_ch=3, output_ch=self.output_ch, t=self.t)
        elif self.model_type == 'AttU_Net':
            self.unet = AttU_Net(img_ch=3, output_ch=self.output_ch)
        elif self.model_type == 'R2AttU_Net':
            self.unet = R2AttU_Net(img_ch=3,
                                   output_ch=self.output_ch,
                                   t=self.t)

        elif self.model_type == 'unet_resnet34':
            # self.unet = Unet(backbone_name='resnet34', pretrained=True, classes=self.output_ch)
            self.unet = smp.Unet('resnet34',
                                 encoder_weights='imagenet',
                                 activation=None)
        elif self.model_type == 'unet_resnet50':
            self.unet = smp.Unet('resnet50',
                                 encoder_weights='imagenet',
                                 activation=None)
        elif self.model_type == 'unet_se_resnext50_32x4d':
            self.unet = smp.Unet('se_resnext50_32x4d',
                                 encoder_weights='imagenet',
                                 activation=None)
        elif self.model_type == 'unet_densenet121':
            self.unet = smp.Unet('densenet121',
                                 encoder_weights='imagenet',
                                 activation=None)
        elif self.model_type == 'unet_resnet34_t':
            self.unet = Unet_t('resnet34',
                               encoder_weights='imagenet',
                               activation=None,
                               use_ConvTranspose2d=True)
        elif self.model_type == 'unet_resnet34_oct':
            self.unet = OctaveUnet('resnet34',
                                   encoder_weights='imagenet',
                                   activation=None)

        elif self.model_type == 'linknet':
            self.unet = LinkNet34(num_classes=self.output_ch)
        elif self.model_type == 'deeplabv3plus':
            self.unet = DeepLabV3Plus(model_backbone='res50_atrous',
                                      num_classes=self.output_ch)
        elif self.model_type == 'pspnet_resnet34':
            self.unet = smp.PSPNet('resnet34',
                                   encoder_weights='imagenet',
                                   classes=1,
                                   activation=None)

        if torch.cuda.is_available():
            self.unet = torch.nn.DataParallel(self.unet)
            self.criterion = self.criterion.cuda()
            self.criterion_stage2 = self.criterion_stage2.cuda()
            self.criterion_stage3 = self.criterion_stage3.cuda()
        self.unet.to(self.device)
示例#2
0
    def build_model(self):
        """Build generator and discriminator."""
        if self.model_type == 'U_Net':
            self.unet = U_Net(img_ch=3, output_ch=1)
        elif self.model_type == 'AttU_Net':
            self.unet = AttU_Net(img_ch=3, output_ch=1)

        elif self.model_type == 'unet_resnet34':
            # self.unet = Unet(backbone_name='resnet34', classes=1)
            self.unet = smp.Unet('resnet34',
                                 encoder_weights='imagenet',
                                 activation=None)
        elif self.model_type == 'unet_resnet50':
            self.unet = smp.Unet('resnet50',
                                 encoder_weights='imagenet',
                                 activation=None)
        elif self.model_type == 'unet_se_resnext50_32x4d':
            self.unet = smp.Unet('se_resnext50_32x4d',
                                 encoder_weights='imagenet',
                                 activation=None)
        elif self.model_type == 'unet_densenet121':
            self.unet = smp.Unet('densenet121',
                                 encoder_weights='imagenet',
                                 activation=None)
        elif self.model_type == 'unet_resnet34_t':
            self.unet = Unet_t('resnet34',
                               encoder_weights='imagenet',
                               activation=None,
                               use_ConvTranspose2d=True)
        elif self.model_type == 'unet_resnet34_oct':
            self.unet = OctaveUnet('resnet34',
                                   encoder_weights='imagenet',
                                   activation=None)

        elif self.model_type == 'pspnet_resnet34':
            self.unet = smp.PSPNet('resnet34',
                                   encoder_weights='imagenet',
                                   classes=1,
                                   activation=None)
        elif self.model_type == 'linknet':
            self.unet = LinkNet34(num_classes=1)
        elif self.model_type == 'deeplabv3plus':
            self.unet = DeepLabV3Plus(model_backbone='res50_atrous',
                                      num_classes=1)
            # self.unet = DeepLabV3Plus(num_classes=1)

        # print('build model done!')

        self.unet.to(self.device)
示例#3
0
class Test(object):
    def __init__(self, model_type, image_size, mean, std, t=None):
        # Models
        self.unet = None
        self.image_size = image_size  # 模型的输入大小

        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        self.model_type = model_type
        self.t = t
        self.mean = mean
        self.std = std

    def build_model(self):
        """Build generator and discriminator."""
        if self.model_type == 'U_Net':
            self.unet = U_Net(img_ch=3, output_ch=1)
        elif self.model_type == 'AttU_Net':
            self.unet = AttU_Net(img_ch=3, output_ch=1)

        elif self.model_type == 'unet_resnet34':
            # self.unet = Unet(backbone_name='resnet34', classes=1)
            self.unet = smp.Unet('resnet34',
                                 encoder_weights='imagenet',
                                 activation=None)
        elif self.model_type == 'unet_resnet50':
            self.unet = smp.Unet('resnet50',
                                 encoder_weights='imagenet',
                                 activation=None)
        elif self.model_type == 'unet_se_resnext50_32x4d':
            self.unet = smp.Unet('se_resnext50_32x4d',
                                 encoder_weights='imagenet',
                                 activation=None)
        elif self.model_type == 'unet_densenet121':
            self.unet = smp.Unet('densenet121',
                                 encoder_weights='imagenet',
                                 activation=None)
        elif self.model_type == 'unet_resnet34_t':
            self.unet = Unet_t('resnet34',
                               encoder_weights='imagenet',
                               activation=None,
                               use_ConvTranspose2d=True)
        elif self.model_type == 'unet_resnet34_oct':
            self.unet = OctaveUnet('resnet34',
                                   encoder_weights='imagenet',
                                   activation=None)

        elif self.model_type == 'pspnet_resnet34':
            self.unet = smp.PSPNet('resnet34',
                                   encoder_weights='imagenet',
                                   classes=1,
                                   activation=None)
        elif self.model_type == 'linknet':
            self.unet = LinkNet34(num_classes=1)
        elif self.model_type == 'deeplabv3plus':
            self.unet = DeepLabV3Plus(model_backbone='res50_atrous',
                                      num_classes=1)
            # self.unet = DeepLabV3Plus(num_classes=1)

        # print('build model done!')

        self.unet.to(self.device)

    def test_model(self,
                   thresholds_classify,
                   thresholds_seg,
                   average_threshold,
                   stage_cla,
                   stage_seg,
                   n_splits,
                   test_best_model=True,
                   less_than_sum=2048 * 2,
                   seg_average_vote=True,
                   images_path=None,
                   masks_path=None):
        """

        Args:
            thresholds_classify: list, 各个分类模型的阈值,高于这个阈值的置为1,否则置为0
            thresholds_seg: list,各个分割模型的阈值
            average_threshold: 分割后使用平均策略时所使用的平均阈值
            stage_cla: 第几阶段的权重作为分类结果
            stage_seg: 第几阶段的权重作为分割结果
            n_splits: list, 测试哪几折的结果进行平均
            test_best_model: 是否要使用最优模型测试,若不是的话,则取最新的模型测试
            less_than_sum: list, 预测图片中有预测出的正样本总和小于这个值时,则忽略所有
            seg_average_vote: bool,True:平均,False:投票
        """

        # 对于每一折加载模型,对所有测试集测试,并取平均

        with torch.no_grad():
            for index, (image_path, mask_path) in enumerate(
                    tqdm(zip(images_path, masks_path),
                         total=len(images_path))):
                img = Image.open(image_path).convert('RGB')
                pred_nfolds = 0
                for fold in n_splits:
                    # 加载分类模型,进行测试
                    self.unet = None
                    self.build_model()
                    if test_best_model:
                        unet_path = os.path.join(
                            'checkpoints', self.model_type, self.model_type +
                            '_{}_{}_best.pth'.format(stage_cla, fold))
                    else:
                        unet_path = os.path.join(
                            'checkpoints', self.model_type, self.model_type +
                            '_{}_{}.pth'.format(stage_cla, fold))
                    # print("Load classify weight from %s" % unet_path)
                    self.unet.load_state_dict(
                        torch.load(unet_path)['state_dict'])
                    self.unet.eval()

                    seg_unet = copy.deepcopy(self.unet)
                    # 加载分割模型,进行测试s
                    if test_best_model:
                        unet_path = os.path.join(
                            'checkpoints', self.model_type, self.model_type +
                            '_{}_{}_best.pth'.format(stage_seg, fold))
                    else:
                        unet_path = os.path.join(
                            'checkpoints', self.model_type, self.model_type +
                            '_{}_{}.pth'.format(stage_seg, fold))
                    # print('Load segmentation weight from %s.' % unet_path)
                    seg_unet.load_state_dict(
                        torch.load(unet_path)['state_dict'])
                    seg_unet.eval()

                    pred = self.tta(img, self.unet)

                    # 首先经过阈值和像素阈值,判断该图像中是否有掩模
                    pred = np.where(pred > thresholds_classify[fold], 1, 0)
                    if np.sum(pred) < less_than_sum[fold]:
                        pred[:] = 0

                    # 如果有掩膜的话,加载分割模型进行测试
                    if np.sum(pred) > 0:
                        pred = self.tta(img, seg_unet)
                        # 如果不是采用平均策略,即投票策略,则进行阈值处理,变成0或1
                        if not seg_average_vote:
                            pred = np.where(pred > thresholds_seg[fold], 1, 0)
                    pred_nfolds += pred

                if not seg_average_vote:
                    vote_model_num = len(n_splits)
                    vote_ticket = round(vote_model_num / 2.0)
                    pred = np.where(pred_nfolds > vote_ticket, 1, 0)
                    # print("Using voting strategy, Ticket / Vote models: %d / %d" % (vote_ticket, vote_model_num))
                else:
                    # print('Using average strategy.')
                    pred = pred_nfolds / len(n_splits)
                    pred = np.where(pred > average_threshold, 1, 0)

                pred = cv2.resize(pred, (1024, 1024))
                mask = Image.open(mask_path)
                mask = np.around(np.array(mask.convert('L')) / 256.)

                self.combine_display(img, mask, pred, 'demo')

    def image_transform(self, image):
        """对样本进行预处理
        """
        resize = transforms.Resize(self.image_size)
        to_tensor = transforms.ToTensor()
        normalize = transforms.Normalize(self.mean, self.std)

        transform_compose = transforms.Compose([resize, to_tensor, normalize])

        return transform_compose(image)

    def detection(self, image, model):
        """对输入样本进行检测
        
        Args:
            image: 待检测样本,Image
            model: 要使用的网络
        Return:
            pred: 检测结果
        """
        image = self.image_transform(image)
        image = torch.unsqueeze(image, dim=0)
        image = image.float().to(self.device)
        pred = torch.sigmoid(model(image))
        # 预测出的结果
        pred = pred.view(self.image_size, self.image_size)
        pred = pred.detach().cpu().numpy()

        return pred

    def tta(self, image, model):
        """执行TTA预测

        Args:
            image: Image图片
            model: 要使用的网络
        Return:
            pred: 最后预测的结果
        """
        preds = np.zeros([self.image_size, self.image_size])
        # 768大小
        # image_resize = image.resize((768, 768))
        # resize_pred = self.detection(image_resize)
        # resize_pred_img = Image.fromarray(resize_pred)
        # resize_pred_img = resize_pred_img.resize((1024, 1024))
        # preds += np.asarray(resize_pred_img)

        # 左右翻转
        image_hflip = image.transpose(Image.FLIP_LEFT_RIGHT)

        hflip_pred = self.detection(image_hflip, model)
        hflip_pred_img = Image.fromarray(hflip_pred)
        pred_img = hflip_pred_img.transpose(Image.FLIP_LEFT_RIGHT)
        preds += np.asarray(pred_img)

        # CLAHE
        aug = CLAHE(p=1.0)
        image_np = np.asarray(image)
        clahe_image = aug(image=image_np)['image']
        clahe_image = Image.fromarray(clahe_image)
        clahe_pred = self.detection(clahe_image, model)
        preds += clahe_pred

        # 原图
        original_pred = self.detection(image, model)
        preds += original_pred

        # 求平均
        pred = preds / 3.0

        return pred

    # dice for threshold selection
    def dice_overall(self, preds, targs):
        n = preds.shape[0]  # batch size为多少
        preds = preds.view(n, -1)
        targs = targs.view(n, -1)
        # preds, targs = preds.to(self.device), targs.to(self.device)
        preds, targs = preds.cpu(), targs.cpu()

        # tensor之间按位相成,求两个集合的交(只有1×1等于1)后。按照第二个维度求和,得到[batch size]大小的tensor,每一个值代表该输入图片真实类标与预测类标的交集大小
        intersect = (preds * targs).sum(-1).float()
        # tensor之间按位相加,求两个集合的并。然后按照第二个维度求和,得到[batch size]大小的tensor,每一个值代表该输入图片真实类标与预测类标的并集大小
        union = (preds + targs).sum(-1).float()
        '''
        输入图片真实类标与预测类标无并集有两种情况:第一种为预测与真实均没有类标,此时并集之和为0;第二种为真实有类标,但是预测完全错误,此时并集之和不为0;

        寻找输入图片真实类标与预测类标并集之和为0的情况,将其交集置为1,并集置为2,最后还有一个2*交集/并集,值为1;
        其余情况,直接按照2*交集/并集计算,因为上面的并集并没有减去交集,所以需要拿2*交集,其最大值为1
        '''
        u0 = union == 0
        intersect[u0] = 1
        union[u0] = 2

        return (2. * intersect / union).mean()

    def combine_display(self, image_raw, mask, pred, title_diplay):
        plt.suptitle(title_diplay)
        plt.subplot(1, 3, 1)
        plt.title('image_raw')
        plt.imshow(image_raw)

        plt.subplot(1, 3, 2)
        plt.title('mask')
        plt.imshow(mask)

        plt.subplot(1, 3, 3)
        plt.title('pred')
        plt.imshow(pred)

        plt.show()
示例#4
0
class Train(object):
    def __init__(self, config, train_loader, valid_loader):
        # Data loader
        self.train_loader = train_loader
        self.valid_loader = valid_loader

        # Models
        self.unet = None
        self.optimizer = None
        self.img_ch = config.img_ch
        self.output_ch = config.output_ch
        self.criterion = SoftBCEDiceLoss(weight=[0.25, 0.75])
        # self.criterion = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor(50))
        self.criterion_stage2 = SoftBCEDiceLoss(weight=[0.25, 0.75])
        self.criterion_stage3 = SoftBCEDiceLoss(weight=[0.25, 0.75])
        self.model_type = config.model_type
        self.t = config.t

        self.mode = config.mode
        self.resume = config.resume

        # Hyper-parameters
        self.lr = config.lr
        self.lr_stage2 = config.lr_stage2
        self.lr_stage3 = config.lr_stage3
        self.start_epoch, self.max_dice = 0, 0
        self.weight_decay = config.weight_decay
        self.weight_decay_stage2 = config.weight_decay
        self.weight_decay_stage3 = config.weight_decay

        # save set
        self.save_path = config.save_path
        if 'choose_threshold' not in self.mode:
            TIMESTAMP = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.datetime.now())
            self.writer = SummaryWriter(log_dir=self.save_path + '/' +
                                        TIMESTAMP)

        # 配置参数
        self.epoch_stage1 = config.epoch_stage1
        self.epoch_stage1_freeze = config.epoch_stage1_freeze
        self.epoch_stage2 = config.epoch_stage2
        self.epoch_stage2_accumulation = config.epoch_stage2_accumulation
        self.accumulation_steps = config.accumulation_steps
        self.epoch_stage3 = config.epoch_stage3
        self.epoch_stage3_accumulation = config.epoch_stage3_accumulation

        # 模型初始化
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        self.build_model()

    def build_model(self):
        print("Using model: {}".format(self.model_type))
        """Build generator and discriminator."""
        if self.model_type == 'U_Net':
            self.unet = U_Net(img_ch=3, output_ch=self.output_ch)
        elif self.model_type == 'R2U_Net':
            self.unet = R2U_Net(img_ch=3, output_ch=self.output_ch, t=self.t)
        elif self.model_type == 'AttU_Net':
            self.unet = AttU_Net(img_ch=3, output_ch=self.output_ch)
        elif self.model_type == 'R2AttU_Net':
            self.unet = R2AttU_Net(img_ch=3,
                                   output_ch=self.output_ch,
                                   t=self.t)

        elif self.model_type == 'unet_resnet34':
            # self.unet = Unet(backbone_name='resnet34', pretrained=True, classes=self.output_ch)
            self.unet = smp.Unet('resnet34',
                                 encoder_weights='imagenet',
                                 activation=None)
        elif self.model_type == 'unet_resnet50':
            self.unet = smp.Unet('resnet50',
                                 encoder_weights='imagenet',
                                 activation=None)
        elif self.model_type == 'unet_se_resnext50_32x4d':
            self.unet = smp.Unet('se_resnext50_32x4d',
                                 encoder_weights='imagenet',
                                 activation=None)
        elif self.model_type == 'unet_densenet121':
            self.unet = smp.Unet('densenet121',
                                 encoder_weights='imagenet',
                                 activation=None)
        elif self.model_type == 'unet_resnet34_t':
            self.unet = Unet_t('resnet34',
                               encoder_weights='imagenet',
                               activation=None,
                               use_ConvTranspose2d=True)
        elif self.model_type == 'unet_resnet34_oct':
            self.unet = OctaveUnet('resnet34',
                                   encoder_weights='imagenet',
                                   activation=None)

        elif self.model_type == 'linknet':
            self.unet = LinkNet34(num_classes=self.output_ch)
        elif self.model_type == 'deeplabv3plus':
            self.unet = DeepLabV3Plus(model_backbone='res50_atrous',
                                      num_classes=self.output_ch)
        elif self.model_type == 'pspnet_resnet34':
            self.unet = smp.PSPNet('resnet34',
                                   encoder_weights='imagenet',
                                   classes=1,
                                   activation=None)

        if torch.cuda.is_available():
            self.unet = torch.nn.DataParallel(self.unet)
            self.criterion = self.criterion.cuda()
            self.criterion_stage2 = self.criterion_stage2.cuda()
            self.criterion_stage3 = self.criterion_stage3.cuda()
        self.unet.to(self.device)

    def print_network(self, model, name):
        """Print out the network information."""
        num_params = 0
        for p in model.parameters():
            num_params += p.numel()
        print(model)
        print(name)
        print("The number of parameters: {}".format(num_params))

    def reset_grad(self):
        """Zero the gradient buffers."""
        self.unet.zero_grad()

    def save_checkpoint(self, state, stage, index, is_best):
        # 保存权重,每一epoch均保存一次,若为最优,则复制到最优权重;index可以区分不同的交叉验证
        pth_path = os.path.join(
            self.save_path, '%s_%d_%d.pth' % (self.model_type, stage, index))
        torch.save(state, pth_path)
        if is_best:
            print('Saving Best Model.')
            write_txt(self.save_path, 'Saving Best Model.')
            shutil.copyfile(
                os.path.join(self.save_path,
                             '%s_%d_%d.pth' % (self.model_type, stage, index)),
                os.path.join(
                    self.save_path,
                    '%s_%d_%d_best.pth' % (self.model_type, stage, index)))

    def load_checkpoint(self, load_optimizer=True):
        # Load the pretrained Encoder
        weight_path = os.path.join(self.save_path, self.resume)
        if os.path.isfile(weight_path):
            checkpoint = torch.load(weight_path)
            # 加载模型的参数,学习率,优化器,开始的epoch,最小误差等
            if torch.cuda.is_available:
                self.unet.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.unet.load_state_dict(checkpoint['state_dict'])
            self.start_epoch = checkpoint['epoch']
            self.max_dice = checkpoint['max_dice']
            if load_optimizer:
                self.lr = checkpoint['lr']
                self.optimizer.load_state_dict(checkpoint['optimizer'])

            print('%s is Successfully Loaded from %s' %
                  (self.model_type, weight_path))
            write_txt(
                self.save_path, '%s is Successfully Loaded from %s' %
                (self.model_type, weight_path))
        else:
            raise FileNotFoundError(
                "Can not find weight file in {}".format(weight_path))

    def train(self, index):
        # self.optimizer = optim.Adam([{'params': self.unet.decoder.parameters(), 'lr': 1e-4}, {'params': self.unet.encoder.parameters(), 'lr': 1e-6},])
        self.optimizer = optim.Adam(self.unet.module.parameters(),
                                    self.lr,
                                    weight_decay=self.weight_decay)

        # 若训练到一半暂停了,则需要加载之前训练的参数,并加载之前学习率 TODO:resume学习率没有接上,所以resume暂时无法使用
        if self.resume:
            self.load_checkpoint(load_optimizer=True)
            '''
            CosineAnnealingLR:若存在['initial_lr'],则从initial_lr开始衰减;
            若不存在,则执行CosineAnnealingLR会在optimizer.param_groups中添加initial_lr键值,其值等于lr

            重置初始学习率,在load_checkpoint中会加载优化器,但其中的initial_lr还是之前的,所以需要覆盖为self.lr,让其从self.lr衰减
            '''
            self.optimizer.param_groups[0]['initial_lr'] = self.lr

        stage1_epoches = self.epoch_stage1 - self.start_epoch
        lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer, stage1_epoches + 10)
        # 防止训练到一半暂停重新训练,日志被覆盖
        global_step_before = self.start_epoch * len(self.train_loader)

        for epoch in range(self.start_epoch, self.epoch_stage1):
            epoch += 1
            self.unet.train(True)

            # 学习率重启
            # if epoch == 30:
            #     self.optimizer.param_groups[0]['initial_lr'] = 0.0001
            #     lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(self.optimizer, 25)

            epoch_loss = 0
            tbar = tqdm.tqdm(self.train_loader)
            for i, (images, masks) in enumerate(tbar):
                # GT : Ground Truth
                images = images.to(self.device)
                masks = masks.to(self.device)

                # SR : Segmentation Result
                net_output = self.unet(images)
                net_output_flat = net_output.view(net_output.size(0), -1)
                masks_flat = masks.view(masks.size(0), -1)
                loss_set = self.criterion(net_output_flat, masks_flat)

                try:
                    loss_num = len(loss_set)
                except:
                    loss_num = 1
                # 依据返回的损失个数分情况处理
                if loss_num > 1:
                    for loss_index, loss_item in enumerate(loss_set):
                        if loss_index > 0:
                            loss_name = 'stage1_loss_%d' % loss_index
                            self.writer.add_scalar(loss_name, loss_item.item(),
                                                   global_step_before + i)
                    loss = loss_set[0]
                else:
                    loss = loss_set
                epoch_loss += loss.item()

                # Backprop + optimize
                self.reset_grad()
                loss.backward()
                self.optimizer.step()

                params_groups_lr = str()
                for group_ind, param_group in enumerate(
                        self.optimizer.param_groups):
                    params_groups_lr = params_groups_lr + 'params_group_%d' % (
                        group_ind) + ': %.12f, ' % (param_group['lr'])

                # 保存到tensorboard,每一步存储一个
                self.writer.add_scalar('Stage1_train_loss', loss.item(),
                                       global_step_before + i)

                descript = "Train Loss: %.7f, lr: %s" % (loss.item(),
                                                         params_groups_lr)
                tbar.set_description(desc=descript)
            # 更新global_step_before为下次迭代做准备
            global_step_before += len(tbar)

            # Print the log info
            print('Finish Stage1 Epoch [%d/%d], Average Loss: %.7f' %
                  (epoch, self.epoch_stage1, epoch_loss / len(tbar)))
            write_txt(
                self.save_path,
                'Finish Stage1 Epoch [%d/%d], Average Loss: %.7f' %
                (epoch, self.epoch_stage1, epoch_loss / len(tbar)))

            # 验证模型,保存权重,并保存日志
            loss_mean, dice_mean = self.validation(stage=1)
            if dice_mean > self.max_dice:
                is_best = True
                self.max_dice = dice_mean
            else:
                is_best = False

            self.lr = lr_scheduler.get_lr()
            state = {
                'epoch': epoch,
                'state_dict': self.unet.module.state_dict(),
                'max_dice': self.max_dice,
                'optimizer': self.optimizer.state_dict(),
                'lr': self.lr
            }

            self.save_checkpoint(state, 1, index, is_best)

            self.writer.add_scalar('Stage1_val_loss', loss_mean, epoch)
            self.writer.add_scalar('Stage1_val_dice', dice_mean, epoch)
            self.writer.add_scalar('Stage1_lr', self.lr[0], epoch)

            # 学习率衰减
            lr_scheduler.step()

    def train_stage2(self, index):
        # # 冻结BN层, see https://zhuanlan.zhihu.com/p/65439075 and https://www.kaggle.com/c/siim-acr-pneumothorax-segmentation/discussion/100736591271 for more information
        # def set_bn_eval(m):
        #     classname = m.__class__.__name__
        #     if classname.find('BatchNorm') != -1:
        #         m.eval()
        # self.unet.apply(set_bn_eval)

        # self.optimizer = optim.Adam([{'params': self.unet.decoder.parameters(), 'lr': 1e-5}, {'params': self.unet.encoder.parameters(), 'lr': 1e-7},])
        self.optimizer = optim.Adam(self.unet.module.parameters(),
                                    self.lr_stage2,
                                    weight_decay=self.weight_decay_stage2)

        # 加载的resume分为两种情况:之前没有训练第二个阶段,现在要加载第一个阶段的参数;第二个阶段训练了一半要继续训练
        if self.resume:
            # 若第二个阶段训练一半,要重新加载 TODO
            if self.resume.split('_')[2] == '2':
                self.load_checkpoint(
                    load_optimizer=True)  # 当load_optimizer为True会重新加载学习率和优化器
                '''
                CosineAnnealingLR:若存在['initial_lr'],则从initial_lr开始衰减;
                若不存在,则执行CosineAnnealingLR会在optimizer.param_groups中添加initial_lr键值,其值等于lr

                重置初始学习率,在load_checkpoint中会加载优化器,但其中的initial_lr还是之前的,所以需要覆盖为self.lr,让其从self.lr衰减
                '''
                self.optimizer.param_groups[0]['initial_lr'] = self.lr

            # 若第一阶段结束后没有直接进行第二个阶段,中间暂停了
            elif self.resume.split('_')[2] == '1':
                self.load_checkpoint(load_optimizer=False)
                self.start_epoch = 0
                self.max_dice = 0

        # 第一阶段结束后直接进行第二个阶段,中间并没有暂停
        else:
            self.start_epoch = 0
            self.max_dice = 0

        # 防止训练到一半暂停重新训练,日志被覆盖
        global_step_before = self.start_epoch * len(self.train_loader)

        stage2_epoches = self.epoch_stage2 - self.start_epoch
        lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer, stage2_epoches + 5)

        for epoch in range(self.start_epoch, self.epoch_stage2):
            epoch += 1
            self.unet.train(True)
            epoch_loss = 0

            self.reset_grad()  # 梯度累加的时候需要使用

            tbar = tqdm.tqdm(self.train_loader)
            for i, (images, masks) in enumerate(tbar):
                # GT : Ground Truth
                images = images.to(self.device)
                masks = masks.to(self.device)
                assert images.size(2) == 1024

                # SR : Segmentation Result
                net_output = self.unet(images)
                net_output_flat = net_output.view(net_output.size(0), -1)
                masks_flat = masks.view(masks.size(0), -1)
                loss_set = self.criterion_stage2(net_output_flat, masks_flat)

                try:
                    loss_num = len(loss_set)
                except:
                    loss_num = 1
                # 依据返回的损失个数分情况处理
                if loss_num > 1:
                    for loss_index, loss_item in enumerate(loss_set):
                        if loss_index > 0:
                            loss_name = 'stage2_loss_%d' % loss_index
                            self.writer.add_scalar(loss_name, loss_item.item(),
                                                   global_step_before + i)
                    loss = loss_set[0]
                else:
                    loss = loss_set
                epoch_loss += loss.item()

                # Backprop + optimize, see https://discuss.pytorch.org/t/why-do-we-need-to-set-the-gradients-manually-to-zero-in-pytorch/4903/20 for Accumulating Gradients
                if epoch <= self.epoch_stage2 - self.epoch_stage2_accumulation:
                    self.reset_grad()
                    loss.backward()
                    self.optimizer.step()
                else:
                    # loss = loss / self.accumulation_steps                # Normalize our loss (if averaged)
                    loss.backward()  # Backward pass
                    if (
                            i + 1
                    ) % self.accumulation_steps == 0:  # Wait for several backward steps
                        self.optimizer.step(
                        )  # Now we can do an optimizer step
                        self.reset_grad()

                params_groups_lr = str()
                for group_ind, param_group in enumerate(
                        self.optimizer.param_groups):
                    params_groups_lr = params_groups_lr + 'params_group_%d' % (
                        group_ind) + ': %.12f, ' % (param_group['lr'])

                # 保存到tensorboard,每一步存储一个
                self.writer.add_scalar('Stage2_train_loss', loss.item(),
                                       global_step_before + i)

                descript = "Train Loss: %.7f, lr: %s" % (loss.item(),
                                                         params_groups_lr)
                tbar.set_description(desc=descript)
            # 更新global_step_before为下次迭代做准备
            global_step_before += len(tbar)

            # Print the log info
            print('Finish Stage2 Epoch [%d/%d], Average Loss: %.7f' %
                  (epoch, self.epoch_stage2, epoch_loss / len(tbar)))
            write_txt(
                self.save_path,
                'Finish Stage2 Epoch [%d/%d], Average Loss: %.7f' %
                (epoch, self.epoch_stage2, epoch_loss / len(tbar)))

            # 验证模型,保存权重,并保存日志
            loss_mean, dice_mean = self.validation(stage=2)
            if dice_mean > self.max_dice:
                is_best = True
                self.max_dice = dice_mean
            else:
                is_best = False

            self.lr = lr_scheduler.get_lr()
            state = {
                'epoch': epoch,
                'state_dict': self.unet.module.state_dict(),
                'max_dice': self.max_dice,
                'optimizer': self.optimizer.state_dict(),
                'lr': self.lr
            }

            self.save_checkpoint(state, 2, index, is_best)

            self.writer.add_scalar('Stage2_val_loss', loss_mean, epoch)
            self.writer.add_scalar('Stage2_val_dice', dice_mean, epoch)
            self.writer.add_scalar('Stage2_lr', self.lr[0], epoch)

            # 学习率衰减
            lr_scheduler.step()

    # stage3, 接着stage2的训练,只训练有mask的样本
    def train_stage3(self, index):
        # # 冻结BN层, see https://zhuanlan.zhihu.com/p/65439075 and https://www.kaggle.com/c/siim-acr-pneumothorax-segmentation/discussion/100736591271 for more information
        # def set_bn_eval(m):
        #     classname = m.__class__.__name__
        #     if classname.find('BatchNorm') != -1:
        #         m.eval()
        # self.unet.apply(set_bn_eval)

        # self.optimizer = optim.Adam([{'params': self.unet.decoder.parameters(), 'lr': 1e-5}, {'params': self.unet.encoder.parameters(), 'lr': 1e-7},])
        self.optimizer = optim.Adam(self.unet.module.parameters(),
                                    self.lr_stage3,
                                    weight_decay=self.weight_decay_stage3)

        # 如果是 train_stage23,则resume只在第二阶段起作用
        if self.mode == 'train_stage23':
            self.resume = None
        # 加载的resume分为两种情况:之前没有训练第三个阶段,现在要加载第二个阶段的参数;第三个阶段训练了一半要继续训练
        if self.resume:
            # 若第三个阶段训练一半,要重新加载 TODO
            if self.resume.split('_')[2] == '3':
                self.load_checkpoint(
                    load_optimizer=True)  # 当load_optimizer为True会重新加载学习率和优化器
                '''
                CosineAnnealingLR:若存在['initial_lr'],则从initial_lr开始衰减;
                若不存在,则执行CosineAnnealingLR会在optimizer.param_groups中添加initial_lr键值,其值等于lr

                重置初始学习率,在load_checkpoint中会加载优化器,但其中的initial_lr还是之前的,所以需要覆盖为self.lr,让其从self.lr衰减
                '''
                self.optimizer.param_groups[0]['initial_lr'] = self.lr

            # 若第二阶段结束后没有直接进行第三个阶段,中间暂停了
            elif self.resume.split('_')[2] == '2':
                self.load_checkpoint(load_optimizer=False)
                self.start_epoch = 0
                self.max_dice = 0

        # 第二阶段结束后直接进行第三个阶段,中间并没有暂停
        else:
            print('start stage3 after stage2 directly!')
            self.start_epoch = 0
            self.max_dice = 0

        # 防止训练到一半暂停重新训练,日志被覆盖
        global_step_before = self.start_epoch * len(self.train_loader)

        stage3_epoches = self.epoch_stage3 - self.start_epoch
        lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer, stage3_epoches + 5)

        for epoch in range(self.start_epoch, self.epoch_stage3):
            epoch += 1
            self.unet.train(True)
            epoch_loss = 0

            self.reset_grad()  # 梯度累加的时候需要使用

            tbar = tqdm.tqdm(self.train_loader)
            for i, (images, masks) in enumerate(tbar):
                # GT : Ground Truth
                images = images.to(self.device)
                masks = masks.to(self.device)
                assert images.size(2) == 1024

                # SR : Segmentation Result
                net_output = self.unet(images)
                net_output_flat = net_output.view(net_output.size(0), -1)
                masks_flat = masks.view(masks.size(0), -1)
                loss_set = self.criterion_stage3(net_output_flat, masks_flat)

                try:
                    loss_num = len(loss_set)
                except:
                    loss_num = 1
                # 依据返回的损失个数分情况处理
                if loss_num > 1:
                    for loss_index, loss_item in enumerate(loss_set):
                        if loss_index > 0:
                            loss_name = 'stage3_loss_%d' % loss_index
                            self.writer.add_scalar(loss_name, loss_item.item(),
                                                   global_step_before + i)
                    loss = loss_set[0]
                else:
                    loss = loss_set
                epoch_loss += loss.item()

                # Backprop + optimize, see https://discuss.pytorch.org/t/why-do-we-need-to-set-the-gradients-manually-to-zero-in-pytorch/4903/20 for Accumulating Gradients
                if epoch <= self.epoch_stage3 - self.epoch_stage3_accumulation:
                    self.reset_grad()
                    loss.backward()
                    self.optimizer.step()
                else:
                    # loss = loss / self.accumulation_steps                # Normalize our loss (if averaged)
                    loss.backward()  # Backward pass
                    if (
                            i + 1
                    ) % self.accumulation_steps == 0:  # Wait for several backward steps
                        self.optimizer.step(
                        )  # Now we can do an optimizer step
                        self.reset_grad()

                params_groups_lr = str()
                for group_ind, param_group in enumerate(
                        self.optimizer.param_groups):
                    params_groups_lr = params_groups_lr + 'params_group_%d' % (
                        group_ind) + ': %.12f, ' % (param_group['lr'])

                # 保存到tensorboard,每一步存储一个
                self.writer.add_scalar('Stage3_train_loss', loss.item(),
                                       global_step_before + i)

                descript = "Train Loss: %.7f, lr: %s" % (loss.item(),
                                                         params_groups_lr)
                tbar.set_description(desc=descript)
            # 更新global_step_before为下次迭代做准备
            global_step_before += len(tbar)

            # Print the log info
            print('Finish Stage3 Epoch [%d/%d], Average Loss: %.7f' %
                  (epoch, self.epoch_stage3, epoch_loss / len(tbar)))
            write_txt(
                self.save_path,
                'Finish Stage3 Epoch [%d/%d], Average Loss: %.7f' %
                (epoch, self.epoch_stage3, epoch_loss / len(tbar)))

            # 验证模型,保存权重,并保存日志
            loss_mean, dice_mean = self.validation(stage=3)
            if dice_mean > self.max_dice:
                is_best = True
                self.max_dice = dice_mean
            else:
                is_best = False

            self.lr = lr_scheduler.get_lr()
            state = {
                'epoch': epoch,
                'state_dict': self.unet.module.state_dict(),
                'max_dice': self.max_dice,
                'optimizer': self.optimizer.state_dict(),
                'lr': self.lr
            }

            self.save_checkpoint(state, 3, index, is_best)

            self.writer.add_scalar('Stage3_val_loss', loss_mean, epoch)
            self.writer.add_scalar('Stage3_val_dice', dice_mean, epoch)
            self.writer.add_scalar('Stage3_lr', self.lr[0], epoch)

            # 学习率衰减
            lr_scheduler.step()

    def validation(self, stage=1):
        # 验证的时候,train(False)是必须的0,设置其中的BN层、dropout等为eval模式
        # with torch.no_grad(): 可以有,在这个上下文管理器中,不反向传播,会加快速度,可以使用较大batch size
        self.unet.eval()
        tbar = tqdm.tqdm(self.valid_loader)
        loss_sum, dice_sum = 0, 0
        if stage == 1:
            criterion = self.criterion
        elif stage == 2:
            criterion = self.criterion_stage2
        elif stage == 3:
            criterion = self.criterion_stage3
        with torch.no_grad():
            for i, (images, masks) in enumerate(tbar):
                images = images.to(self.device)
                masks = masks.to(self.device)

                net_output = self.unet(images)
                net_output_flat = net_output.view(net_output.size(0), -1)
                masks_flat = masks.view(masks.size(0), -1)

                loss_set = criterion(net_output_flat, masks_flat)
                try:
                    loss_num = len(loss_set)
                except:
                    loss_num = 1

                # 依据返回的损失个数分情况处理
                if loss_num > 1:
                    loss = loss_set[0]
                else:
                    loss = loss_set
                loss_sum += loss.item()

                # 计算dice系数,预测出的矩阵要经过sigmoid含义以及阈值,阈值默认为0.5
                net_output_flat_sign = (torch.sigmoid(net_output_flat) >
                                        0.5).float()
                dice = self.dice_overall(net_output_flat_sign,
                                         masks_flat).mean()
                dice_sum += dice.item()

                descript = "Val Loss: {:.7f}, dice: {:.7f}".format(
                    loss.item(), dice.item())
                tbar.set_description(desc=descript)

        loss_mean, dice_mean = loss_sum / len(tbar), dice_sum / len(tbar)
        print("Val Loss: {:.7f}, dice: {:.7f}".format(loss_mean, dice_mean))
        write_txt(
            self.save_path,
            "Val Loss: {:.7f}, dice: {:.7f}".format(loss_mean, dice_mean))
        return loss_mean, dice_mean

    # dice for threshold selection
    def dice_overall(self, preds, targs):
        n = preds.shape[0]  # batch size为多少
        preds = preds.view(n, -1)
        targs = targs.view(n, -1)
        # preds, targs = preds.to(self.device), targs.to(self.device)
        preds, targs = preds.cpu(), targs.cpu()

        # tensor之间按位相成,求两个集合的交(只有1×1等于1)后。按照第二个维度求和,得到[batch size]大小的tensor,每一个值代表该输入图片真实类标与预测类标的交集大小
        intersect = (preds * targs).sum(-1).float()
        # tensor之间按位相加,求两个集合的并。然后按照第二个维度求和,得到[batch size]大小的tensor,每一个值代表该输入图片真实类标与预测类标的并集大小
        union = (preds + targs).sum(-1).float()
        '''
        输入图片真实类标与预测类标无并集有两种情况:第一种为预测与真实均没有类标,此时并集之和为0;第二种为真实有类标,但是预测完全错误,此时并集之和不为0;

        寻找输入图片真实类标与预测类标并集之和为0的情况,将其交集置为1,并集置为2,最后还有一个2*交集/并集,值为1;
        其余情况,直接按照2*交集/并集计算,因为上面的并集并没有减去交集,所以需要拿2*交集,其最大值为1
        '''
        u0 = union == 0
        intersect[u0] = 1
        union[u0] = 2

        return (2. * intersect / union)

    def classify_score(self, preds, targs):
        '''若当前图像中有mask,则为正类,若当前图像中无mask,则为负类。从分类的角度得分当前的准确率
        
        Args:
            preds: 预测出的mask矩阵
            targs: 真实的mask矩阵
        
        Return: 分类准确率
        '''
        n = preds.shape[0]  # batch size为多少
        preds = preds.view(n, -1)
        targs = targs.view(n, -1)
        # preds, targs = preds.to(self.device), targs.to(self.device)
        preds_, targs_ = torch.sum(preds, 1), torch.sum(targs, 1)
        preds_, targs_ = preds_ > 0, targs_ > 0
        preds_, targs_ = preds_.cpu(), targs_.cpu()
        score = torch.sum(preds_ == targs_)
        return score.item() / n

    def choose_threshold(self, model_path, index):
        '''利用线性法搜索当前模型的最优阈值和最优像素阈值;先利用粗略搜索和精细搜索两个过程搜索出最优阈值,然后搜索出最优像素阈值;并保存搜索图
        
        Args:
            model_path: 当前模型权重的位置
            index: 当前为第几个fold
        
        Return: 最优阈值,最优像素阈值,最高得分
        '''
        self.unet.module.load_state_dict(torch.load(model_path)['state_dict'])
        stage = eval(model_path.split('/')[-1].split('_')[2])
        print('Loaded from %s, using choose_threshold!' % model_path)
        self.unet.eval()

        with torch.no_grad():
            # 先大概选取阈值范围
            dices_big = []
            thrs_big = np.arange(0.1, 1, 0.1)  # 阈值列表
            for th in thrs_big:
                tmp = []
                tbar = tqdm.tqdm(self.valid_loader)
                for i, (images, masks) in enumerate(tbar):
                    # GT : Ground Truth
                    images = images.to(self.device)
                    net_output = torch.sigmoid(self.unet(images))
                    preds = (net_output > th).to(
                        self.device).float()  # 大于阈值的归为1
                    # preds[preds.view(preds.shape[0],-1).sum(-1) < noise_th,...] = 0.0 # 过滤噪声点
                    tmp.append(self.dice_overall(preds, masks).mean())
                    # tmp.append(self.classify_score(preds, masks))
                dices_big.append(sum(tmp) / len(tmp))
            dices_big = np.array(dices_big)
            best_thrs_big = thrs_big[dices_big.argmax()]

            # 精细选取范围
            dices_little = []
            thrs_little = np.arange(best_thrs_big - 0.05, best_thrs_big + 0.05,
                                    0.01)  # 阈值列表
            for th in thrs_little:
                tmp = []
                tbar = tqdm.tqdm(self.valid_loader)
                for i, (images, masks) in enumerate(tbar):
                    # GT : Ground Truth
                    images = images.to(self.device)
                    net_output = torch.sigmoid(self.unet(images))
                    preds = (net_output > th).to(
                        self.device).float()  # 大于阈值的归为1
                    # preds[preds.view(preds.shape[0],-1).sum(-1) < noise_th,...] = 0.0 # 过滤噪声点
                    tmp.append(self.dice_overall(preds, masks).mean())
                    # tmp.append(self.classify_score(preds, masks))
                dices_little.append(sum(tmp) / len(tmp))
            dices_little = np.array(dices_little)
            # score = dices.max()
            best_thr = thrs_little[dices_little.argmax()]

            # 选最优像素阈值
            if stage != 3:
                dices_pixel = []
                pixel_thrs = np.arange(0, 2304, 256)  # 阈值列表
                for pixel_thr in pixel_thrs:
                    tmp = []
                    tbar = tqdm.tqdm(self.valid_loader)
                    for i, (images, masks) in enumerate(tbar):
                        # GT : Ground Truth
                        images = images.to(self.device)
                        net_output = torch.sigmoid(self.unet(images))
                        preds = (net_output > best_thr).to(
                            self.device).float()  # 大于阈值的归为1
                        preds[
                            preds.view(preds.shape[0], -1).sum(-1) < pixel_thr,
                            ...] = 0.0  # 过滤噪声点
                        tmp.append(self.dice_overall(preds, masks).mean())
                        # tmp.append(self.classify_score(preds, masks))
                    dices_pixel.append(sum(tmp) / len(tmp))
                dices_pixel = np.array(dices_pixel)
                score = dices_pixel.max()
                best_pixel_thr = pixel_thrs[dices_pixel.argmax()]
            elif stage == 3:
                best_pixel_thr, score = 0, dices_little.max()
            print('best_thr:{}, best_pixel_thr:{}, score:{}'.format(
                best_thr, best_pixel_thr, score))

        plt.figure(figsize=(10.4, 4.8))
        plt.subplot(1, 3, 1)
        plt.title('Large-scale search')
        plt.plot(thrs_big, dices_big)
        plt.subplot(1, 3, 2)
        plt.title('Little-scale search')
        plt.plot(thrs_little, dices_little)
        plt.subplot(1, 3, 3)
        plt.title('pixel thrs search')
        if stage != 3:
            plt.plot(pixel_thrs, dices_pixel)
        plt.savefig(
            os.path.join(self.save_path,
                         'stage{}'.format(stage) + '_fold' + str(index)))
        # plt.show()
        plt.close()
        return float(best_thr), float(best_pixel_thr), float(score)

    def pred_mask_count(self, model_path, masks_bool, val_index, best_thr,
                        best_pixel_thr):
        '''加载模型,根据最优阈值和最优像素阈值,得到在验证集上的分类准确率。适用于训练的第二阶段使用 dice 选完阈值,查看分类准确率
        Args:
            model_path: 当前模型的权重路径
            masks_bool: 全部数据集中的每个是否含有mask
            val_index: 当前验证集的在全部数据集的下标
            best_thr: 选出的最优阈值
            best_pixel_thr: 选出的最优像素阈值
        
        Return: None, 打印出有多少个真实情况有多少个正样本,实际预测出了多少个样本。但是不是很严谨,因为这不能代表正确率。
        '''
        count_true, count_pred = 0, 0
        for index1 in val_index:
            if masks_bool[index1]:
                count_true += 1

        self.unet.module.load_state_dict(torch.load(model_path)['state_dict'])
        print('Loaded from %s' % model_path)
        self.unet.eval()

        with torch.no_grad():
            tmp = []
            tbar = tqdm.tqdm(self.valid_loader)
            for i, (images, masks) in enumerate(tbar):
                # GT : Ground Truth
                images = images.to(self.device)
                net_output = torch.sigmoid(self.unet(images))
                preds = (net_output > best_thr).to(
                    self.device).float()  # 大于阈值的归为1
                preds[preds.view(preds.shape[0], -1).sum(-1) < best_pixel_thr,
                      ...] = 0.0  # 过滤噪声点

                n = preds.shape[0]  # batch size为多少
                preds = preds.view(n, -1)

                for index2 in range(n):
                    pred = preds[index2, ...]
                    if torch.sum(pred) > 0:
                        count_pred += 1

                tmp.append(self.dice_overall(preds, masks).mean())
            print('score:', sum(tmp) / len(tmp))

        print('count_true:{}, count_pred:{}'.format(count_true, count_pred))

    def grid_search(self, thrs_big, pixel_thrs):
        '''利用网格法搜索最优阈值和最优像素阈值
        
        Args:
            thrs_big: 网格法搜索时的一系列阈值
            pixel_thrs: 网格搜索时的一系列像素阈值
        
        Return: 最优阈值,最优像素阈值,最高得分,网络矩阵中每个位置的得分
        '''
        with torch.no_grad():
            # 先大概选取阈值范围和像素阈值范围
            dices_big = []  # 存放的是二维矩阵,每一行为每一个阈值下所有像素阈值得到的得分
            for th in thrs_big:
                dices_pixel = []
                for pixel_thr in pixel_thrs:
                    tmp = []
                    tbar = tqdm.tqdm(self.valid_loader)
                    for i, (images, masks) in enumerate(tbar):
                        # GT : Ground Truth
                        images = images.to(self.device)
                        net_output = torch.sigmoid(self.unet(images))
                        preds = (net_output > th).to(
                            self.device).float()  # 大于阈值的归为1
                        preds[
                            preds.view(preds.shape[0], -1).sum(-1) < pixel_thr,
                            ...] = 0.0  # 过滤噪声点
                        tmp.append(self.dice_overall(preds, masks).mean())
                        # tmp.append(self.classify_score(preds, masks))
                    dices_pixel.append(sum(tmp) / len(tmp))
                dices_big.append(dices_pixel)
            dices_big = np.array(dices_big)
            print('粗略挑选最优阈值和最优像素阈值,dices_big_shape:{}'.format(
                np.shape(dices_big)))
            re = np.where(dices_big == np.max(dices_big))
            # 如果有多个最大值的处理方式
            if np.shape(re)[1] != 1:
                re = re[0]
            best_thrs_big, best_pixel_thr = thrs_big[int(
                re[0])], pixel_thrs[int(re[1])]
            best_thr, score = best_thrs_big, dices_big.max()
        return best_thr, best_pixel_thr, score, dices_big

    def choose_threshold_grid(self, model_path, index):
        '''利用网格法搜索当前模型的最优阈值和最优像素阈值,分为粗略搜索和精细搜索两个过程;并保存热力图
        
        Args:
            model_path: 当前模型权重的位置
            index: 当前为第几个fold
        
        Return: 最优阈值,最优像素阈值,最高得分
        '''
        self.unet.module.load_state_dict(torch.load(model_path)['state_dict'])
        stage = eval(model_path.split('/')[-1].split('_')[2])
        print('Loaded from %s, using choose_threshold_grid!' % model_path)
        self.unet.eval()

        thrs_big1 = np.arange(0.60, 0.81, 0.015)  # 阈值列表
        pixel_thrs1 = np.arange(768, 2305, 256)  # 像素阈值列表
        best_thr1, best_pixel_thr1, score1, dices_big1 = self.grid_search(
            thrs_big1, pixel_thrs1)
        print('best_thr1:{}, best_pixel_thr1:{}, score1:{}'.format(
            best_thr1, best_pixel_thr1, score1))

        thrs_big2 = np.arange(best_thr1 - 0.015, best_thr1 + 0.015,
                              0.0075)  # 阈值列表
        pixel_thrs2 = np.arange(best_pixel_thr1 - 256, best_pixel_thr1 + 257,
                                128)  # 像素阈值列表
        best_thr2, best_pixel_thr2, score2, dices_big2 = self.grid_search(
            thrs_big2, pixel_thrs2)
        print('best_thr2:{}, best_pixel_thr2:{}, score2:{}'.format(
            best_thr2, best_pixel_thr2, score2))

        if score1 < score2:
            best_thr, best_pixel_thr, score, dices_big = best_thr2, best_pixel_thr2, score2, dices_big2
        else:
            best_thr, best_pixel_thr, score, dices_big = best_thr1, best_pixel_thr1, score1, dices_big1

        print('best_thr:{}, best_pixel_thr:{}, score:{}'.format(
            best_thr, best_pixel_thr, score))

        f, (ax1, ax2) = plt.subplots(figsize=(14.4, 4.8), ncols=2)

        cmap = sns.cubehelix_palette(start=1.5, rot=3, gamma=0.8, as_cmap=True)
        data1 = pd.DataFrame(data=dices_big1,
                             index=np.around(thrs_big1, 3),
                             columns=pixel_thrs1)
        sns.heatmap(data1,
                    linewidths=0.05,
                    ax=ax1,
                    vmax=np.max(dices_big1),
                    vmin=np.min(dices_big1),
                    cmap=cmap,
                    annot=True,
                    fmt='.4f')
        ax1.set_title('Large-scale search')

        data2 = pd.DataFrame(data=dices_big2,
                             index=np.around(thrs_big2, 3),
                             columns=pixel_thrs2)
        sns.heatmap(data2,
                    linewidths=0.05,
                    ax=ax2,
                    vmax=np.max(dices_big2),
                    vmin=np.min(dices_big2),
                    cmap=cmap,
                    annot=True,
                    fmt='.4f')
        ax2.set_title('Little-scale search')
        f.savefig(
            os.path.join(self.save_path,
                         'stage{}'.format(stage) + '_fold' + str(index)))
        # plt.show()
        plt.close()
        return float(best_thr), float(best_pixel_thr), float(score)

    def get_dice_onval(self, model_path, best_thr, pixel_thr):
        '''已经训练好模型,并且选完阈值后。根据当前模型,best_thr, pixel_thr得到在验证集的表现
        
        Args:
            model_path: 要加载的模型路径
            best_thr: 选出的最优阈值
            pixel_thr: 选出的最优像素阈值
        
        Return: None
        '''
        self.unet.module.load_state_dict(torch.load(model_path)['state_dict'])
        stage = eval(model_path.split('/')[-1].split('_')[2])
        print('Loaded from %s, using get_dice_onval!' % model_path)
        self.unet.eval()

        with torch.no_grad():
            # 选最优像素阈值
            tmp = []
            tbar = tqdm.tqdm(self.valid_loader)
            for i, (images, masks) in enumerate(tbar):
                # GT : Ground Truth
                images = images.to(self.device)
                net_output = torch.sigmoid(self.unet(images))
                preds = (net_output > best_thr).to(
                    self.device).float()  # 大于阈值的归为1
                if stage != 3:
                    preds[preds.view(preds.shape[0], -1).sum(-1) < pixel_thr,
                          ...] = 0.0  # 过滤噪声点
                tmp.append(self.dice_overall(preds, masks).mean())
                # tmp.append(self.classify_score(preds, masks))
            score = sum(tmp) / len(tmp)
        print('best_thr:{}, best_pixel_thr:{}, score:{}'.format(
            best_thr, pixel_thr, score))
示例#5
0
class Test(object):
    def __init__(self, model_type, image_size, mean, std, t=None):
        # Models
        self.unet = None
        self.image_size = image_size  # 模型的输入大小

        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        self.model_type = model_type
        self.t = t
        self.mean = mean
        self.std = std

    def build_model(self):
        """Build generator and discriminator."""
        if self.model_type == 'U_Net':
            self.unet = U_Net(img_ch=3, output_ch=1)
        elif self.model_type == 'AttU_Net':
            self.unet = AttU_Net(img_ch=3, output_ch=1)

        elif self.model_type == 'unet_resnet34':
            # self.unet = Unet(backbone_name='resnet34', classes=1)
            self.unet = smp.Unet('resnet34',
                                 encoder_weights='imagenet',
                                 activation=None)
        elif self.model_type == 'unet_resnet50':
            self.unet = smp.Unet('resnet50',
                                 encoder_weights='imagenet',
                                 activation=None)
        elif self.model_type == 'unet_se_resnext50_32x4d':
            self.unet = smp.Unet('se_resnext50_32x4d',
                                 encoder_weights='imagenet',
                                 activation=None)
        elif self.model_type == 'unet_densenet121':
            self.unet = smp.Unet('densenet121',
                                 encoder_weights='imagenet',
                                 activation=None)
        elif self.model_type == 'unet_resnet34_t':
            self.unet = Unet_t('resnet34',
                               encoder_weights='imagenet',
                               activation=None,
                               use_ConvTranspose2d=True)
        elif self.model_type == 'unet_resnet34_oct':
            self.unet = OctaveUnet('resnet34',
                                   encoder_weights='imagenet',
                                   activation=None)

        elif self.model_type == 'pspnet_resnet34':
            self.unet = smp.PSPNet('resnet34',
                                   encoder_weights='imagenet',
                                   classes=1,
                                   activation=None)
        elif self.model_type == 'linknet':
            self.unet = LinkNet34(num_classes=1)
        elif self.model_type == 'deeplabv3plus':
            self.unet = DeepLabV3Plus(model_backbone='res50_atrous',
                                      num_classes=1)
            # self.unet = DeepLabV3Plus(num_classes=1)

        print('build model done!')

        self.unet.to(self.device)

    def test_model(self,
                   thresholds_classify,
                   thresholds_seg,
                   average_threshold,
                   stage_cla,
                   stage_seg,
                   n_splits,
                   test_best_model=True,
                   less_than_sum=2048 * 2,
                   seg_average_vote=True,
                   csv_path=None,
                   test_image_path=None):
        """

        Args:
            thresholds_classify: list, 各个分类模型的阈值,高于这个阈值的置为1,否则置为0
            thresholds_seg: list,各个分割模型的阈值
            average_threshold: 分割后使用平均策略时所使用的平均阈值
            stage_cla: 第几阶段的权重作为分类结果
            stage_seg: 第几阶段的权重作为分割结果
            n_splits: list, 测试哪几折的结果进行平均
            test_best_model: 是否要使用最优模型测试,若不是的话,则取最新的模型测试
            less_than_sum: list, 预测图片中有预测出的正样本总和小于这个值时,则忽略所有
            seg_average_vote: bool,True:平均,False:投票
        """

        # 对于每一折加载模型,对所有测试集测试,并取平均
        sample_df = pd.read_csv(csv_path)
        # preds_cla存放模型的分类结果,而preds存放模型的分割结果,其中分割模型默认为1024的分辨率
        preds = np.zeros([len(sample_df), self.image_size, self.image_size])

        for fold in n_splits:
            # 加载分类模型,进行测试
            self.unet = None
            self.build_model()
            if test_best_model:
                unet_path = os.path.join(
                    'checkpoints', self.model_type, self.model_type +
                    '_{}_{}_best.pth'.format(stage_cla, fold))
            else:
                unet_path = os.path.join(
                    'checkpoints', self.model_type,
                    self.model_type + '_{}_{}.pth'.format(stage_cla, fold))
            print("Load classify weight from %s" % unet_path)
            self.unet.load_state_dict(torch.load(unet_path)['state_dict'])
            self.unet.eval()

            seg_unet = copy.deepcopy(self.unet)
            # 加载分割模型,进行测试s
            if test_best_model:
                unet_path = os.path.join(
                    'checkpoints', self.model_type, self.model_type +
                    '_{}_{}_best.pth'.format(stage_seg, fold))
            else:
                unet_path = os.path.join(
                    'checkpoints', self.model_type,
                    self.model_type + '_{}_{}.pth'.format(stage_seg, fold))
            print('Load segmentation weight from %s.' % unet_path)
            seg_unet.load_state_dict(torch.load(unet_path)['state_dict'])
            seg_unet.eval()

            count_mask_classify = 0
            with torch.no_grad():
                # sample_df = sample_df.drop_duplicates('ImageId ', keep='last').reset_index(drop=True)
                for index, row in tqdm(sample_df.iterrows(),
                                       total=len(sample_df)):
                    file = row['ImageId']
                    img_path = os.path.join(test_image_path,
                                            file.strip() + '.jpg')
                    img = Image.open(img_path).convert('RGB')

                    pred = self.tta(img, self.unet)

                    # 首先经过阈值和像素阈值,判断该图像中是否有掩模
                    pred = np.where(pred > thresholds_classify[fold], 1, 0)
                    if np.sum(pred) < less_than_sum[fold]:
                        pred[:] = 0

                    # 如果有掩膜的话,加载分割模型进行测试
                    if np.sum(pred) > 0:
                        count_mask_classify += 1
                        pred = self.tta(img, seg_unet)
                        # 如果不是采用平均策略,即投票策略,则进行阈值处理,变成0或1
                        if not seg_average_vote:
                            pred = np.where(pred > thresholds_seg[fold], 1, 0)
                    preds[index, ...] += pred
                print('Fold %d Detect %d mask in classify.' %
                      (fold, count_mask_classify))

        if not seg_average_vote:
            vote_model_num = len(n_splits)
            vote_ticket = round(vote_model_num / 2.0)
            print("Using voting strategy, Ticket / Vote models: %d / %d" %
                  (vote_ticket, vote_model_num))
        else:
            print('Using average strategy.')
            preds = preds / len(n_splits)

        rle = []
        count_has_mask = 0
        for index, row in tqdm(sample_df.iterrows(), total=len(sample_df)):
            file = row['ImageId']

            pred = preds[index, ...]
            if not seg_average_vote:
                pred = np.where(pred > vote_ticket, 1, 0)
            else:
                pred = np.where(pred > average_threshold, 1, 0)
                # if np.sum(pred) < 512: # TODO
                #     pred[:] = 0

            # if np.sum(pred)>0:
            #     count_has_mask += 1
            pred = cv2.resize(pred, (1024, 1024))
            encoding = mask_to_rle(pred.T, 1024, 1024)
            if encoding == ' ':
                rle.append([file.strip(), '-1'])
            else:
                count_has_mask += 1
                rle.append([file.strip(), encoding[1:]])

        print('The number of masked pictures predicted:', count_has_mask)
        submission_df = pd.DataFrame(rle, columns=['ImageId', 'EncodedPixels'])
        submission_df.to_csv('submission.csv', index=False)

    def image_transform(self, image):
        """对样本进行预处理
        """
        resize = transforms.Resize(self.image_size)
        to_tensor = transforms.ToTensor()
        normalize = transforms.Normalize(self.mean, self.std)

        transform_compose = transforms.Compose([resize, to_tensor, normalize])

        return transform_compose(image)

    def detection(self, image, model):
        """对输入样本进行检测
        
        Args:
            image: 待检测样本,Image
            model: 要使用的网络
        Return:
            pred: 检测结果
        """
        image = self.image_transform(image)
        image = torch.unsqueeze(image, dim=0)
        image = image.float().to(self.device)
        pred = torch.sigmoid(model(image))
        # 预测出的结果
        pred = pred.view(self.image_size, self.image_size)
        pred = pred.detach().cpu().numpy()

        return pred

    def tta(self, image, model):
        """执行TTA预测

        Args:
            image: Image图片
            model: 要使用的网络
        Return:
            pred: 最后预测的结果
        """
        preds = np.zeros([self.image_size, self.image_size])
        # 768大小
        # image_resize = image.resize((768, 768))
        # resize_pred = self.detection(image_resize)
        # resize_pred_img = Image.fromarray(resize_pred)
        # resize_pred_img = resize_pred_img.resize((1024, 1024))
        # preds += np.asarray(resize_pred_img)

        # 左右翻转
        image_hflip = image.transpose(Image.FLIP_LEFT_RIGHT)

        hflip_pred = self.detection(image_hflip, model)
        hflip_pred_img = Image.fromarray(hflip_pred)
        pred_img = hflip_pred_img.transpose(Image.FLIP_LEFT_RIGHT)
        preds += np.asarray(pred_img)

        # CLAHE
        aug = CLAHE(p=1.0)
        image_np = np.asarray(image)
        clahe_image = aug(image=image_np)['image']
        clahe_image = Image.fromarray(clahe_image)
        clahe_pred = self.detection(clahe_image, model)
        preds += clahe_pred

        # 原图
        original_pred = self.detection(image, model)
        preds += original_pred

        # 求平均
        pred = preds / 3.0

        return pred