def build_model(self): print("Using model: {}".format(self.model_type)) """Build generator and discriminator.""" if self.model_type == 'U_Net': self.unet = U_Net(img_ch=3, output_ch=self.output_ch) elif self.model_type == 'R2U_Net': self.unet = R2U_Net(img_ch=3, output_ch=self.output_ch, t=self.t) elif self.model_type == 'AttU_Net': self.unet = AttU_Net(img_ch=3, output_ch=self.output_ch) elif self.model_type == 'R2AttU_Net': self.unet = R2AttU_Net(img_ch=3, output_ch=self.output_ch, t=self.t) elif self.model_type == 'unet_resnet34': # self.unet = Unet(backbone_name='resnet34', pretrained=True, classes=self.output_ch) self.unet = smp.Unet('resnet34', encoder_weights='imagenet', activation=None) elif self.model_type == 'unet_resnet50': self.unet = smp.Unet('resnet50', encoder_weights='imagenet', activation=None) elif self.model_type == 'unet_se_resnext50_32x4d': self.unet = smp.Unet('se_resnext50_32x4d', encoder_weights='imagenet', activation=None) elif self.model_type == 'unet_densenet121': self.unet = smp.Unet('densenet121', encoder_weights='imagenet', activation=None) elif self.model_type == 'unet_resnet34_t': self.unet = Unet_t('resnet34', encoder_weights='imagenet', activation=None, use_ConvTranspose2d=True) elif self.model_type == 'unet_resnet34_oct': self.unet = OctaveUnet('resnet34', encoder_weights='imagenet', activation=None) elif self.model_type == 'linknet': self.unet = LinkNet34(num_classes=self.output_ch) elif self.model_type == 'deeplabv3plus': self.unet = DeepLabV3Plus(model_backbone='res50_atrous', num_classes=self.output_ch) elif self.model_type == 'pspnet_resnet34': self.unet = smp.PSPNet('resnet34', encoder_weights='imagenet', classes=1, activation=None) if torch.cuda.is_available(): self.unet = torch.nn.DataParallel(self.unet) self.criterion = self.criterion.cuda() self.criterion_stage2 = self.criterion_stage2.cuda() self.criterion_stage3 = self.criterion_stage3.cuda() self.unet.to(self.device)
def build_model(self): """Build generator and discriminator.""" if self.model_type == 'U_Net': self.unet = U_Net(img_ch=3, output_ch=1) elif self.model_type == 'AttU_Net': self.unet = AttU_Net(img_ch=3, output_ch=1) elif self.model_type == 'unet_resnet34': # self.unet = Unet(backbone_name='resnet34', classes=1) self.unet = smp.Unet('resnet34', encoder_weights='imagenet', activation=None) elif self.model_type == 'unet_resnet50': self.unet = smp.Unet('resnet50', encoder_weights='imagenet', activation=None) elif self.model_type == 'unet_se_resnext50_32x4d': self.unet = smp.Unet('se_resnext50_32x4d', encoder_weights='imagenet', activation=None) elif self.model_type == 'unet_densenet121': self.unet = smp.Unet('densenet121', encoder_weights='imagenet', activation=None) elif self.model_type == 'unet_resnet34_t': self.unet = Unet_t('resnet34', encoder_weights='imagenet', activation=None, use_ConvTranspose2d=True) elif self.model_type == 'unet_resnet34_oct': self.unet = OctaveUnet('resnet34', encoder_weights='imagenet', activation=None) elif self.model_type == 'pspnet_resnet34': self.unet = smp.PSPNet('resnet34', encoder_weights='imagenet', classes=1, activation=None) elif self.model_type == 'linknet': self.unet = LinkNet34(num_classes=1) elif self.model_type == 'deeplabv3plus': self.unet = DeepLabV3Plus(model_backbone='res50_atrous', num_classes=1) # self.unet = DeepLabV3Plus(num_classes=1) # print('build model done!') self.unet.to(self.device)
class Test(object): def __init__(self, model_type, image_size, mean, std, t=None): # Models self.unet = None self.image_size = image_size # 模型的输入大小 self.device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu') self.model_type = model_type self.t = t self.mean = mean self.std = std def build_model(self): """Build generator and discriminator.""" if self.model_type == 'U_Net': self.unet = U_Net(img_ch=3, output_ch=1) elif self.model_type == 'AttU_Net': self.unet = AttU_Net(img_ch=3, output_ch=1) elif self.model_type == 'unet_resnet34': # self.unet = Unet(backbone_name='resnet34', classes=1) self.unet = smp.Unet('resnet34', encoder_weights='imagenet', activation=None) elif self.model_type == 'unet_resnet50': self.unet = smp.Unet('resnet50', encoder_weights='imagenet', activation=None) elif self.model_type == 'unet_se_resnext50_32x4d': self.unet = smp.Unet('se_resnext50_32x4d', encoder_weights='imagenet', activation=None) elif self.model_type == 'unet_densenet121': self.unet = smp.Unet('densenet121', encoder_weights='imagenet', activation=None) elif self.model_type == 'unet_resnet34_t': self.unet = Unet_t('resnet34', encoder_weights='imagenet', activation=None, use_ConvTranspose2d=True) elif self.model_type == 'unet_resnet34_oct': self.unet = OctaveUnet('resnet34', encoder_weights='imagenet', activation=None) elif self.model_type == 'pspnet_resnet34': self.unet = smp.PSPNet('resnet34', encoder_weights='imagenet', classes=1, activation=None) elif self.model_type == 'linknet': self.unet = LinkNet34(num_classes=1) elif self.model_type == 'deeplabv3plus': self.unet = DeepLabV3Plus(model_backbone='res50_atrous', num_classes=1) # self.unet = DeepLabV3Plus(num_classes=1) # print('build model done!') self.unet.to(self.device) def test_model(self, thresholds_classify, thresholds_seg, average_threshold, stage_cla, stage_seg, n_splits, test_best_model=True, less_than_sum=2048 * 2, seg_average_vote=True, images_path=None, masks_path=None): """ Args: thresholds_classify: list, 各个分类模型的阈值,高于这个阈值的置为1,否则置为0 thresholds_seg: list,各个分割模型的阈值 average_threshold: 分割后使用平均策略时所使用的平均阈值 stage_cla: 第几阶段的权重作为分类结果 stage_seg: 第几阶段的权重作为分割结果 n_splits: list, 测试哪几折的结果进行平均 test_best_model: 是否要使用最优模型测试,若不是的话,则取最新的模型测试 less_than_sum: list, 预测图片中有预测出的正样本总和小于这个值时,则忽略所有 seg_average_vote: bool,True:平均,False:投票 """ # 对于每一折加载模型,对所有测试集测试,并取平均 with torch.no_grad(): for index, (image_path, mask_path) in enumerate( tqdm(zip(images_path, masks_path), total=len(images_path))): img = Image.open(image_path).convert('RGB') pred_nfolds = 0 for fold in n_splits: # 加载分类模型,进行测试 self.unet = None self.build_model() if test_best_model: unet_path = os.path.join( 'checkpoints', self.model_type, self.model_type + '_{}_{}_best.pth'.format(stage_cla, fold)) else: unet_path = os.path.join( 'checkpoints', self.model_type, self.model_type + '_{}_{}.pth'.format(stage_cla, fold)) # print("Load classify weight from %s" % unet_path) self.unet.load_state_dict( torch.load(unet_path)['state_dict']) self.unet.eval() seg_unet = copy.deepcopy(self.unet) # 加载分割模型,进行测试s if test_best_model: unet_path = os.path.join( 'checkpoints', self.model_type, self.model_type + '_{}_{}_best.pth'.format(stage_seg, fold)) else: unet_path = os.path.join( 'checkpoints', self.model_type, self.model_type + '_{}_{}.pth'.format(stage_seg, fold)) # print('Load segmentation weight from %s.' % unet_path) seg_unet.load_state_dict( torch.load(unet_path)['state_dict']) seg_unet.eval() pred = self.tta(img, self.unet) # 首先经过阈值和像素阈值,判断该图像中是否有掩模 pred = np.where(pred > thresholds_classify[fold], 1, 0) if np.sum(pred) < less_than_sum[fold]: pred[:] = 0 # 如果有掩膜的话,加载分割模型进行测试 if np.sum(pred) > 0: pred = self.tta(img, seg_unet) # 如果不是采用平均策略,即投票策略,则进行阈值处理,变成0或1 if not seg_average_vote: pred = np.where(pred > thresholds_seg[fold], 1, 0) pred_nfolds += pred if not seg_average_vote: vote_model_num = len(n_splits) vote_ticket = round(vote_model_num / 2.0) pred = np.where(pred_nfolds > vote_ticket, 1, 0) # print("Using voting strategy, Ticket / Vote models: %d / %d" % (vote_ticket, vote_model_num)) else: # print('Using average strategy.') pred = pred_nfolds / len(n_splits) pred = np.where(pred > average_threshold, 1, 0) pred = cv2.resize(pred, (1024, 1024)) mask = Image.open(mask_path) mask = np.around(np.array(mask.convert('L')) / 256.) self.combine_display(img, mask, pred, 'demo') def image_transform(self, image): """对样本进行预处理 """ resize = transforms.Resize(self.image_size) to_tensor = transforms.ToTensor() normalize = transforms.Normalize(self.mean, self.std) transform_compose = transforms.Compose([resize, to_tensor, normalize]) return transform_compose(image) def detection(self, image, model): """对输入样本进行检测 Args: image: 待检测样本,Image model: 要使用的网络 Return: pred: 检测结果 """ image = self.image_transform(image) image = torch.unsqueeze(image, dim=0) image = image.float().to(self.device) pred = torch.sigmoid(model(image)) # 预测出的结果 pred = pred.view(self.image_size, self.image_size) pred = pred.detach().cpu().numpy() return pred def tta(self, image, model): """执行TTA预测 Args: image: Image图片 model: 要使用的网络 Return: pred: 最后预测的结果 """ preds = np.zeros([self.image_size, self.image_size]) # 768大小 # image_resize = image.resize((768, 768)) # resize_pred = self.detection(image_resize) # resize_pred_img = Image.fromarray(resize_pred) # resize_pred_img = resize_pred_img.resize((1024, 1024)) # preds += np.asarray(resize_pred_img) # 左右翻转 image_hflip = image.transpose(Image.FLIP_LEFT_RIGHT) hflip_pred = self.detection(image_hflip, model) hflip_pred_img = Image.fromarray(hflip_pred) pred_img = hflip_pred_img.transpose(Image.FLIP_LEFT_RIGHT) preds += np.asarray(pred_img) # CLAHE aug = CLAHE(p=1.0) image_np = np.asarray(image) clahe_image = aug(image=image_np)['image'] clahe_image = Image.fromarray(clahe_image) clahe_pred = self.detection(clahe_image, model) preds += clahe_pred # 原图 original_pred = self.detection(image, model) preds += original_pred # 求平均 pred = preds / 3.0 return pred # dice for threshold selection def dice_overall(self, preds, targs): n = preds.shape[0] # batch size为多少 preds = preds.view(n, -1) targs = targs.view(n, -1) # preds, targs = preds.to(self.device), targs.to(self.device) preds, targs = preds.cpu(), targs.cpu() # tensor之间按位相成,求两个集合的交(只有1×1等于1)后。按照第二个维度求和,得到[batch size]大小的tensor,每一个值代表该输入图片真实类标与预测类标的交集大小 intersect = (preds * targs).sum(-1).float() # tensor之间按位相加,求两个集合的并。然后按照第二个维度求和,得到[batch size]大小的tensor,每一个值代表该输入图片真实类标与预测类标的并集大小 union = (preds + targs).sum(-1).float() ''' 输入图片真实类标与预测类标无并集有两种情况:第一种为预测与真实均没有类标,此时并集之和为0;第二种为真实有类标,但是预测完全错误,此时并集之和不为0; 寻找输入图片真实类标与预测类标并集之和为0的情况,将其交集置为1,并集置为2,最后还有一个2*交集/并集,值为1; 其余情况,直接按照2*交集/并集计算,因为上面的并集并没有减去交集,所以需要拿2*交集,其最大值为1 ''' u0 = union == 0 intersect[u0] = 1 union[u0] = 2 return (2. * intersect / union).mean() def combine_display(self, image_raw, mask, pred, title_diplay): plt.suptitle(title_diplay) plt.subplot(1, 3, 1) plt.title('image_raw') plt.imshow(image_raw) plt.subplot(1, 3, 2) plt.title('mask') plt.imshow(mask) plt.subplot(1, 3, 3) plt.title('pred') plt.imshow(pred) plt.show()
class Train(object): def __init__(self, config, train_loader, valid_loader): # Data loader self.train_loader = train_loader self.valid_loader = valid_loader # Models self.unet = None self.optimizer = None self.img_ch = config.img_ch self.output_ch = config.output_ch self.criterion = SoftBCEDiceLoss(weight=[0.25, 0.75]) # self.criterion = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor(50)) self.criterion_stage2 = SoftBCEDiceLoss(weight=[0.25, 0.75]) self.criterion_stage3 = SoftBCEDiceLoss(weight=[0.25, 0.75]) self.model_type = config.model_type self.t = config.t self.mode = config.mode self.resume = config.resume # Hyper-parameters self.lr = config.lr self.lr_stage2 = config.lr_stage2 self.lr_stage3 = config.lr_stage3 self.start_epoch, self.max_dice = 0, 0 self.weight_decay = config.weight_decay self.weight_decay_stage2 = config.weight_decay self.weight_decay_stage3 = config.weight_decay # save set self.save_path = config.save_path if 'choose_threshold' not in self.mode: TIMESTAMP = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.datetime.now()) self.writer = SummaryWriter(log_dir=self.save_path + '/' + TIMESTAMP) # 配置参数 self.epoch_stage1 = config.epoch_stage1 self.epoch_stage1_freeze = config.epoch_stage1_freeze self.epoch_stage2 = config.epoch_stage2 self.epoch_stage2_accumulation = config.epoch_stage2_accumulation self.accumulation_steps = config.accumulation_steps self.epoch_stage3 = config.epoch_stage3 self.epoch_stage3_accumulation = config.epoch_stage3_accumulation # 模型初始化 self.device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu') self.build_model() def build_model(self): print("Using model: {}".format(self.model_type)) """Build generator and discriminator.""" if self.model_type == 'U_Net': self.unet = U_Net(img_ch=3, output_ch=self.output_ch) elif self.model_type == 'R2U_Net': self.unet = R2U_Net(img_ch=3, output_ch=self.output_ch, t=self.t) elif self.model_type == 'AttU_Net': self.unet = AttU_Net(img_ch=3, output_ch=self.output_ch) elif self.model_type == 'R2AttU_Net': self.unet = R2AttU_Net(img_ch=3, output_ch=self.output_ch, t=self.t) elif self.model_type == 'unet_resnet34': # self.unet = Unet(backbone_name='resnet34', pretrained=True, classes=self.output_ch) self.unet = smp.Unet('resnet34', encoder_weights='imagenet', activation=None) elif self.model_type == 'unet_resnet50': self.unet = smp.Unet('resnet50', encoder_weights='imagenet', activation=None) elif self.model_type == 'unet_se_resnext50_32x4d': self.unet = smp.Unet('se_resnext50_32x4d', encoder_weights='imagenet', activation=None) elif self.model_type == 'unet_densenet121': self.unet = smp.Unet('densenet121', encoder_weights='imagenet', activation=None) elif self.model_type == 'unet_resnet34_t': self.unet = Unet_t('resnet34', encoder_weights='imagenet', activation=None, use_ConvTranspose2d=True) elif self.model_type == 'unet_resnet34_oct': self.unet = OctaveUnet('resnet34', encoder_weights='imagenet', activation=None) elif self.model_type == 'linknet': self.unet = LinkNet34(num_classes=self.output_ch) elif self.model_type == 'deeplabv3plus': self.unet = DeepLabV3Plus(model_backbone='res50_atrous', num_classes=self.output_ch) elif self.model_type == 'pspnet_resnet34': self.unet = smp.PSPNet('resnet34', encoder_weights='imagenet', classes=1, activation=None) if torch.cuda.is_available(): self.unet = torch.nn.DataParallel(self.unet) self.criterion = self.criterion.cuda() self.criterion_stage2 = self.criterion_stage2.cuda() self.criterion_stage3 = self.criterion_stage3.cuda() self.unet.to(self.device) def print_network(self, model, name): """Print out the network information.""" num_params = 0 for p in model.parameters(): num_params += p.numel() print(model) print(name) print("The number of parameters: {}".format(num_params)) def reset_grad(self): """Zero the gradient buffers.""" self.unet.zero_grad() def save_checkpoint(self, state, stage, index, is_best): # 保存权重,每一epoch均保存一次,若为最优,则复制到最优权重;index可以区分不同的交叉验证 pth_path = os.path.join( self.save_path, '%s_%d_%d.pth' % (self.model_type, stage, index)) torch.save(state, pth_path) if is_best: print('Saving Best Model.') write_txt(self.save_path, 'Saving Best Model.') shutil.copyfile( os.path.join(self.save_path, '%s_%d_%d.pth' % (self.model_type, stage, index)), os.path.join( self.save_path, '%s_%d_%d_best.pth' % (self.model_type, stage, index))) def load_checkpoint(self, load_optimizer=True): # Load the pretrained Encoder weight_path = os.path.join(self.save_path, self.resume) if os.path.isfile(weight_path): checkpoint = torch.load(weight_path) # 加载模型的参数,学习率,优化器,开始的epoch,最小误差等 if torch.cuda.is_available: self.unet.module.load_state_dict(checkpoint['state_dict']) else: self.unet.load_state_dict(checkpoint['state_dict']) self.start_epoch = checkpoint['epoch'] self.max_dice = checkpoint['max_dice'] if load_optimizer: self.lr = checkpoint['lr'] self.optimizer.load_state_dict(checkpoint['optimizer']) print('%s is Successfully Loaded from %s' % (self.model_type, weight_path)) write_txt( self.save_path, '%s is Successfully Loaded from %s' % (self.model_type, weight_path)) else: raise FileNotFoundError( "Can not find weight file in {}".format(weight_path)) def train(self, index): # self.optimizer = optim.Adam([{'params': self.unet.decoder.parameters(), 'lr': 1e-4}, {'params': self.unet.encoder.parameters(), 'lr': 1e-6},]) self.optimizer = optim.Adam(self.unet.module.parameters(), self.lr, weight_decay=self.weight_decay) # 若训练到一半暂停了,则需要加载之前训练的参数,并加载之前学习率 TODO:resume学习率没有接上,所以resume暂时无法使用 if self.resume: self.load_checkpoint(load_optimizer=True) ''' CosineAnnealingLR:若存在['initial_lr'],则从initial_lr开始衰减; 若不存在,则执行CosineAnnealingLR会在optimizer.param_groups中添加initial_lr键值,其值等于lr 重置初始学习率,在load_checkpoint中会加载优化器,但其中的initial_lr还是之前的,所以需要覆盖为self.lr,让其从self.lr衰减 ''' self.optimizer.param_groups[0]['initial_lr'] = self.lr stage1_epoches = self.epoch_stage1 - self.start_epoch lr_scheduler = optim.lr_scheduler.CosineAnnealingLR( self.optimizer, stage1_epoches + 10) # 防止训练到一半暂停重新训练,日志被覆盖 global_step_before = self.start_epoch * len(self.train_loader) for epoch in range(self.start_epoch, self.epoch_stage1): epoch += 1 self.unet.train(True) # 学习率重启 # if epoch == 30: # self.optimizer.param_groups[0]['initial_lr'] = 0.0001 # lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(self.optimizer, 25) epoch_loss = 0 tbar = tqdm.tqdm(self.train_loader) for i, (images, masks) in enumerate(tbar): # GT : Ground Truth images = images.to(self.device) masks = masks.to(self.device) # SR : Segmentation Result net_output = self.unet(images) net_output_flat = net_output.view(net_output.size(0), -1) masks_flat = masks.view(masks.size(0), -1) loss_set = self.criterion(net_output_flat, masks_flat) try: loss_num = len(loss_set) except: loss_num = 1 # 依据返回的损失个数分情况处理 if loss_num > 1: for loss_index, loss_item in enumerate(loss_set): if loss_index > 0: loss_name = 'stage1_loss_%d' % loss_index self.writer.add_scalar(loss_name, loss_item.item(), global_step_before + i) loss = loss_set[0] else: loss = loss_set epoch_loss += loss.item() # Backprop + optimize self.reset_grad() loss.backward() self.optimizer.step() params_groups_lr = str() for group_ind, param_group in enumerate( self.optimizer.param_groups): params_groups_lr = params_groups_lr + 'params_group_%d' % ( group_ind) + ': %.12f, ' % (param_group['lr']) # 保存到tensorboard,每一步存储一个 self.writer.add_scalar('Stage1_train_loss', loss.item(), global_step_before + i) descript = "Train Loss: %.7f, lr: %s" % (loss.item(), params_groups_lr) tbar.set_description(desc=descript) # 更新global_step_before为下次迭代做准备 global_step_before += len(tbar) # Print the log info print('Finish Stage1 Epoch [%d/%d], Average Loss: %.7f' % (epoch, self.epoch_stage1, epoch_loss / len(tbar))) write_txt( self.save_path, 'Finish Stage1 Epoch [%d/%d], Average Loss: %.7f' % (epoch, self.epoch_stage1, epoch_loss / len(tbar))) # 验证模型,保存权重,并保存日志 loss_mean, dice_mean = self.validation(stage=1) if dice_mean > self.max_dice: is_best = True self.max_dice = dice_mean else: is_best = False self.lr = lr_scheduler.get_lr() state = { 'epoch': epoch, 'state_dict': self.unet.module.state_dict(), 'max_dice': self.max_dice, 'optimizer': self.optimizer.state_dict(), 'lr': self.lr } self.save_checkpoint(state, 1, index, is_best) self.writer.add_scalar('Stage1_val_loss', loss_mean, epoch) self.writer.add_scalar('Stage1_val_dice', dice_mean, epoch) self.writer.add_scalar('Stage1_lr', self.lr[0], epoch) # 学习率衰减 lr_scheduler.step() def train_stage2(self, index): # # 冻结BN层, see https://zhuanlan.zhihu.com/p/65439075 and https://www.kaggle.com/c/siim-acr-pneumothorax-segmentation/discussion/100736591271 for more information # def set_bn_eval(m): # classname = m.__class__.__name__ # if classname.find('BatchNorm') != -1: # m.eval() # self.unet.apply(set_bn_eval) # self.optimizer = optim.Adam([{'params': self.unet.decoder.parameters(), 'lr': 1e-5}, {'params': self.unet.encoder.parameters(), 'lr': 1e-7},]) self.optimizer = optim.Adam(self.unet.module.parameters(), self.lr_stage2, weight_decay=self.weight_decay_stage2) # 加载的resume分为两种情况:之前没有训练第二个阶段,现在要加载第一个阶段的参数;第二个阶段训练了一半要继续训练 if self.resume: # 若第二个阶段训练一半,要重新加载 TODO if self.resume.split('_')[2] == '2': self.load_checkpoint( load_optimizer=True) # 当load_optimizer为True会重新加载学习率和优化器 ''' CosineAnnealingLR:若存在['initial_lr'],则从initial_lr开始衰减; 若不存在,则执行CosineAnnealingLR会在optimizer.param_groups中添加initial_lr键值,其值等于lr 重置初始学习率,在load_checkpoint中会加载优化器,但其中的initial_lr还是之前的,所以需要覆盖为self.lr,让其从self.lr衰减 ''' self.optimizer.param_groups[0]['initial_lr'] = self.lr # 若第一阶段结束后没有直接进行第二个阶段,中间暂停了 elif self.resume.split('_')[2] == '1': self.load_checkpoint(load_optimizer=False) self.start_epoch = 0 self.max_dice = 0 # 第一阶段结束后直接进行第二个阶段,中间并没有暂停 else: self.start_epoch = 0 self.max_dice = 0 # 防止训练到一半暂停重新训练,日志被覆盖 global_step_before = self.start_epoch * len(self.train_loader) stage2_epoches = self.epoch_stage2 - self.start_epoch lr_scheduler = optim.lr_scheduler.CosineAnnealingLR( self.optimizer, stage2_epoches + 5) for epoch in range(self.start_epoch, self.epoch_stage2): epoch += 1 self.unet.train(True) epoch_loss = 0 self.reset_grad() # 梯度累加的时候需要使用 tbar = tqdm.tqdm(self.train_loader) for i, (images, masks) in enumerate(tbar): # GT : Ground Truth images = images.to(self.device) masks = masks.to(self.device) assert images.size(2) == 1024 # SR : Segmentation Result net_output = self.unet(images) net_output_flat = net_output.view(net_output.size(0), -1) masks_flat = masks.view(masks.size(0), -1) loss_set = self.criterion_stage2(net_output_flat, masks_flat) try: loss_num = len(loss_set) except: loss_num = 1 # 依据返回的损失个数分情况处理 if loss_num > 1: for loss_index, loss_item in enumerate(loss_set): if loss_index > 0: loss_name = 'stage2_loss_%d' % loss_index self.writer.add_scalar(loss_name, loss_item.item(), global_step_before + i) loss = loss_set[0] else: loss = loss_set epoch_loss += loss.item() # Backprop + optimize, see https://discuss.pytorch.org/t/why-do-we-need-to-set-the-gradients-manually-to-zero-in-pytorch/4903/20 for Accumulating Gradients if epoch <= self.epoch_stage2 - self.epoch_stage2_accumulation: self.reset_grad() loss.backward() self.optimizer.step() else: # loss = loss / self.accumulation_steps # Normalize our loss (if averaged) loss.backward() # Backward pass if ( i + 1 ) % self.accumulation_steps == 0: # Wait for several backward steps self.optimizer.step( ) # Now we can do an optimizer step self.reset_grad() params_groups_lr = str() for group_ind, param_group in enumerate( self.optimizer.param_groups): params_groups_lr = params_groups_lr + 'params_group_%d' % ( group_ind) + ': %.12f, ' % (param_group['lr']) # 保存到tensorboard,每一步存储一个 self.writer.add_scalar('Stage2_train_loss', loss.item(), global_step_before + i) descript = "Train Loss: %.7f, lr: %s" % (loss.item(), params_groups_lr) tbar.set_description(desc=descript) # 更新global_step_before为下次迭代做准备 global_step_before += len(tbar) # Print the log info print('Finish Stage2 Epoch [%d/%d], Average Loss: %.7f' % (epoch, self.epoch_stage2, epoch_loss / len(tbar))) write_txt( self.save_path, 'Finish Stage2 Epoch [%d/%d], Average Loss: %.7f' % (epoch, self.epoch_stage2, epoch_loss / len(tbar))) # 验证模型,保存权重,并保存日志 loss_mean, dice_mean = self.validation(stage=2) if dice_mean > self.max_dice: is_best = True self.max_dice = dice_mean else: is_best = False self.lr = lr_scheduler.get_lr() state = { 'epoch': epoch, 'state_dict': self.unet.module.state_dict(), 'max_dice': self.max_dice, 'optimizer': self.optimizer.state_dict(), 'lr': self.lr } self.save_checkpoint(state, 2, index, is_best) self.writer.add_scalar('Stage2_val_loss', loss_mean, epoch) self.writer.add_scalar('Stage2_val_dice', dice_mean, epoch) self.writer.add_scalar('Stage2_lr', self.lr[0], epoch) # 学习率衰减 lr_scheduler.step() # stage3, 接着stage2的训练,只训练有mask的样本 def train_stage3(self, index): # # 冻结BN层, see https://zhuanlan.zhihu.com/p/65439075 and https://www.kaggle.com/c/siim-acr-pneumothorax-segmentation/discussion/100736591271 for more information # def set_bn_eval(m): # classname = m.__class__.__name__ # if classname.find('BatchNorm') != -1: # m.eval() # self.unet.apply(set_bn_eval) # self.optimizer = optim.Adam([{'params': self.unet.decoder.parameters(), 'lr': 1e-5}, {'params': self.unet.encoder.parameters(), 'lr': 1e-7},]) self.optimizer = optim.Adam(self.unet.module.parameters(), self.lr_stage3, weight_decay=self.weight_decay_stage3) # 如果是 train_stage23,则resume只在第二阶段起作用 if self.mode == 'train_stage23': self.resume = None # 加载的resume分为两种情况:之前没有训练第三个阶段,现在要加载第二个阶段的参数;第三个阶段训练了一半要继续训练 if self.resume: # 若第三个阶段训练一半,要重新加载 TODO if self.resume.split('_')[2] == '3': self.load_checkpoint( load_optimizer=True) # 当load_optimizer为True会重新加载学习率和优化器 ''' CosineAnnealingLR:若存在['initial_lr'],则从initial_lr开始衰减; 若不存在,则执行CosineAnnealingLR会在optimizer.param_groups中添加initial_lr键值,其值等于lr 重置初始学习率,在load_checkpoint中会加载优化器,但其中的initial_lr还是之前的,所以需要覆盖为self.lr,让其从self.lr衰减 ''' self.optimizer.param_groups[0]['initial_lr'] = self.lr # 若第二阶段结束后没有直接进行第三个阶段,中间暂停了 elif self.resume.split('_')[2] == '2': self.load_checkpoint(load_optimizer=False) self.start_epoch = 0 self.max_dice = 0 # 第二阶段结束后直接进行第三个阶段,中间并没有暂停 else: print('start stage3 after stage2 directly!') self.start_epoch = 0 self.max_dice = 0 # 防止训练到一半暂停重新训练,日志被覆盖 global_step_before = self.start_epoch * len(self.train_loader) stage3_epoches = self.epoch_stage3 - self.start_epoch lr_scheduler = optim.lr_scheduler.CosineAnnealingLR( self.optimizer, stage3_epoches + 5) for epoch in range(self.start_epoch, self.epoch_stage3): epoch += 1 self.unet.train(True) epoch_loss = 0 self.reset_grad() # 梯度累加的时候需要使用 tbar = tqdm.tqdm(self.train_loader) for i, (images, masks) in enumerate(tbar): # GT : Ground Truth images = images.to(self.device) masks = masks.to(self.device) assert images.size(2) == 1024 # SR : Segmentation Result net_output = self.unet(images) net_output_flat = net_output.view(net_output.size(0), -1) masks_flat = masks.view(masks.size(0), -1) loss_set = self.criterion_stage3(net_output_flat, masks_flat) try: loss_num = len(loss_set) except: loss_num = 1 # 依据返回的损失个数分情况处理 if loss_num > 1: for loss_index, loss_item in enumerate(loss_set): if loss_index > 0: loss_name = 'stage3_loss_%d' % loss_index self.writer.add_scalar(loss_name, loss_item.item(), global_step_before + i) loss = loss_set[0] else: loss = loss_set epoch_loss += loss.item() # Backprop + optimize, see https://discuss.pytorch.org/t/why-do-we-need-to-set-the-gradients-manually-to-zero-in-pytorch/4903/20 for Accumulating Gradients if epoch <= self.epoch_stage3 - self.epoch_stage3_accumulation: self.reset_grad() loss.backward() self.optimizer.step() else: # loss = loss / self.accumulation_steps # Normalize our loss (if averaged) loss.backward() # Backward pass if ( i + 1 ) % self.accumulation_steps == 0: # Wait for several backward steps self.optimizer.step( ) # Now we can do an optimizer step self.reset_grad() params_groups_lr = str() for group_ind, param_group in enumerate( self.optimizer.param_groups): params_groups_lr = params_groups_lr + 'params_group_%d' % ( group_ind) + ': %.12f, ' % (param_group['lr']) # 保存到tensorboard,每一步存储一个 self.writer.add_scalar('Stage3_train_loss', loss.item(), global_step_before + i) descript = "Train Loss: %.7f, lr: %s" % (loss.item(), params_groups_lr) tbar.set_description(desc=descript) # 更新global_step_before为下次迭代做准备 global_step_before += len(tbar) # Print the log info print('Finish Stage3 Epoch [%d/%d], Average Loss: %.7f' % (epoch, self.epoch_stage3, epoch_loss / len(tbar))) write_txt( self.save_path, 'Finish Stage3 Epoch [%d/%d], Average Loss: %.7f' % (epoch, self.epoch_stage3, epoch_loss / len(tbar))) # 验证模型,保存权重,并保存日志 loss_mean, dice_mean = self.validation(stage=3) if dice_mean > self.max_dice: is_best = True self.max_dice = dice_mean else: is_best = False self.lr = lr_scheduler.get_lr() state = { 'epoch': epoch, 'state_dict': self.unet.module.state_dict(), 'max_dice': self.max_dice, 'optimizer': self.optimizer.state_dict(), 'lr': self.lr } self.save_checkpoint(state, 3, index, is_best) self.writer.add_scalar('Stage3_val_loss', loss_mean, epoch) self.writer.add_scalar('Stage3_val_dice', dice_mean, epoch) self.writer.add_scalar('Stage3_lr', self.lr[0], epoch) # 学习率衰减 lr_scheduler.step() def validation(self, stage=1): # 验证的时候,train(False)是必须的0,设置其中的BN层、dropout等为eval模式 # with torch.no_grad(): 可以有,在这个上下文管理器中,不反向传播,会加快速度,可以使用较大batch size self.unet.eval() tbar = tqdm.tqdm(self.valid_loader) loss_sum, dice_sum = 0, 0 if stage == 1: criterion = self.criterion elif stage == 2: criterion = self.criterion_stage2 elif stage == 3: criterion = self.criterion_stage3 with torch.no_grad(): for i, (images, masks) in enumerate(tbar): images = images.to(self.device) masks = masks.to(self.device) net_output = self.unet(images) net_output_flat = net_output.view(net_output.size(0), -1) masks_flat = masks.view(masks.size(0), -1) loss_set = criterion(net_output_flat, masks_flat) try: loss_num = len(loss_set) except: loss_num = 1 # 依据返回的损失个数分情况处理 if loss_num > 1: loss = loss_set[0] else: loss = loss_set loss_sum += loss.item() # 计算dice系数,预测出的矩阵要经过sigmoid含义以及阈值,阈值默认为0.5 net_output_flat_sign = (torch.sigmoid(net_output_flat) > 0.5).float() dice = self.dice_overall(net_output_flat_sign, masks_flat).mean() dice_sum += dice.item() descript = "Val Loss: {:.7f}, dice: {:.7f}".format( loss.item(), dice.item()) tbar.set_description(desc=descript) loss_mean, dice_mean = loss_sum / len(tbar), dice_sum / len(tbar) print("Val Loss: {:.7f}, dice: {:.7f}".format(loss_mean, dice_mean)) write_txt( self.save_path, "Val Loss: {:.7f}, dice: {:.7f}".format(loss_mean, dice_mean)) return loss_mean, dice_mean # dice for threshold selection def dice_overall(self, preds, targs): n = preds.shape[0] # batch size为多少 preds = preds.view(n, -1) targs = targs.view(n, -1) # preds, targs = preds.to(self.device), targs.to(self.device) preds, targs = preds.cpu(), targs.cpu() # tensor之间按位相成,求两个集合的交(只有1×1等于1)后。按照第二个维度求和,得到[batch size]大小的tensor,每一个值代表该输入图片真实类标与预测类标的交集大小 intersect = (preds * targs).sum(-1).float() # tensor之间按位相加,求两个集合的并。然后按照第二个维度求和,得到[batch size]大小的tensor,每一个值代表该输入图片真实类标与预测类标的并集大小 union = (preds + targs).sum(-1).float() ''' 输入图片真实类标与预测类标无并集有两种情况:第一种为预测与真实均没有类标,此时并集之和为0;第二种为真实有类标,但是预测完全错误,此时并集之和不为0; 寻找输入图片真实类标与预测类标并集之和为0的情况,将其交集置为1,并集置为2,最后还有一个2*交集/并集,值为1; 其余情况,直接按照2*交集/并集计算,因为上面的并集并没有减去交集,所以需要拿2*交集,其最大值为1 ''' u0 = union == 0 intersect[u0] = 1 union[u0] = 2 return (2. * intersect / union) def classify_score(self, preds, targs): '''若当前图像中有mask,则为正类,若当前图像中无mask,则为负类。从分类的角度得分当前的准确率 Args: preds: 预测出的mask矩阵 targs: 真实的mask矩阵 Return: 分类准确率 ''' n = preds.shape[0] # batch size为多少 preds = preds.view(n, -1) targs = targs.view(n, -1) # preds, targs = preds.to(self.device), targs.to(self.device) preds_, targs_ = torch.sum(preds, 1), torch.sum(targs, 1) preds_, targs_ = preds_ > 0, targs_ > 0 preds_, targs_ = preds_.cpu(), targs_.cpu() score = torch.sum(preds_ == targs_) return score.item() / n def choose_threshold(self, model_path, index): '''利用线性法搜索当前模型的最优阈值和最优像素阈值;先利用粗略搜索和精细搜索两个过程搜索出最优阈值,然后搜索出最优像素阈值;并保存搜索图 Args: model_path: 当前模型权重的位置 index: 当前为第几个fold Return: 最优阈值,最优像素阈值,最高得分 ''' self.unet.module.load_state_dict(torch.load(model_path)['state_dict']) stage = eval(model_path.split('/')[-1].split('_')[2]) print('Loaded from %s, using choose_threshold!' % model_path) self.unet.eval() with torch.no_grad(): # 先大概选取阈值范围 dices_big = [] thrs_big = np.arange(0.1, 1, 0.1) # 阈值列表 for th in thrs_big: tmp = [] tbar = tqdm.tqdm(self.valid_loader) for i, (images, masks) in enumerate(tbar): # GT : Ground Truth images = images.to(self.device) net_output = torch.sigmoid(self.unet(images)) preds = (net_output > th).to( self.device).float() # 大于阈值的归为1 # preds[preds.view(preds.shape[0],-1).sum(-1) < noise_th,...] = 0.0 # 过滤噪声点 tmp.append(self.dice_overall(preds, masks).mean()) # tmp.append(self.classify_score(preds, masks)) dices_big.append(sum(tmp) / len(tmp)) dices_big = np.array(dices_big) best_thrs_big = thrs_big[dices_big.argmax()] # 精细选取范围 dices_little = [] thrs_little = np.arange(best_thrs_big - 0.05, best_thrs_big + 0.05, 0.01) # 阈值列表 for th in thrs_little: tmp = [] tbar = tqdm.tqdm(self.valid_loader) for i, (images, masks) in enumerate(tbar): # GT : Ground Truth images = images.to(self.device) net_output = torch.sigmoid(self.unet(images)) preds = (net_output > th).to( self.device).float() # 大于阈值的归为1 # preds[preds.view(preds.shape[0],-1).sum(-1) < noise_th,...] = 0.0 # 过滤噪声点 tmp.append(self.dice_overall(preds, masks).mean()) # tmp.append(self.classify_score(preds, masks)) dices_little.append(sum(tmp) / len(tmp)) dices_little = np.array(dices_little) # score = dices.max() best_thr = thrs_little[dices_little.argmax()] # 选最优像素阈值 if stage != 3: dices_pixel = [] pixel_thrs = np.arange(0, 2304, 256) # 阈值列表 for pixel_thr in pixel_thrs: tmp = [] tbar = tqdm.tqdm(self.valid_loader) for i, (images, masks) in enumerate(tbar): # GT : Ground Truth images = images.to(self.device) net_output = torch.sigmoid(self.unet(images)) preds = (net_output > best_thr).to( self.device).float() # 大于阈值的归为1 preds[ preds.view(preds.shape[0], -1).sum(-1) < pixel_thr, ...] = 0.0 # 过滤噪声点 tmp.append(self.dice_overall(preds, masks).mean()) # tmp.append(self.classify_score(preds, masks)) dices_pixel.append(sum(tmp) / len(tmp)) dices_pixel = np.array(dices_pixel) score = dices_pixel.max() best_pixel_thr = pixel_thrs[dices_pixel.argmax()] elif stage == 3: best_pixel_thr, score = 0, dices_little.max() print('best_thr:{}, best_pixel_thr:{}, score:{}'.format( best_thr, best_pixel_thr, score)) plt.figure(figsize=(10.4, 4.8)) plt.subplot(1, 3, 1) plt.title('Large-scale search') plt.plot(thrs_big, dices_big) plt.subplot(1, 3, 2) plt.title('Little-scale search') plt.plot(thrs_little, dices_little) plt.subplot(1, 3, 3) plt.title('pixel thrs search') if stage != 3: plt.plot(pixel_thrs, dices_pixel) plt.savefig( os.path.join(self.save_path, 'stage{}'.format(stage) + '_fold' + str(index))) # plt.show() plt.close() return float(best_thr), float(best_pixel_thr), float(score) def pred_mask_count(self, model_path, masks_bool, val_index, best_thr, best_pixel_thr): '''加载模型,根据最优阈值和最优像素阈值,得到在验证集上的分类准确率。适用于训练的第二阶段使用 dice 选完阈值,查看分类准确率 Args: model_path: 当前模型的权重路径 masks_bool: 全部数据集中的每个是否含有mask val_index: 当前验证集的在全部数据集的下标 best_thr: 选出的最优阈值 best_pixel_thr: 选出的最优像素阈值 Return: None, 打印出有多少个真实情况有多少个正样本,实际预测出了多少个样本。但是不是很严谨,因为这不能代表正确率。 ''' count_true, count_pred = 0, 0 for index1 in val_index: if masks_bool[index1]: count_true += 1 self.unet.module.load_state_dict(torch.load(model_path)['state_dict']) print('Loaded from %s' % model_path) self.unet.eval() with torch.no_grad(): tmp = [] tbar = tqdm.tqdm(self.valid_loader) for i, (images, masks) in enumerate(tbar): # GT : Ground Truth images = images.to(self.device) net_output = torch.sigmoid(self.unet(images)) preds = (net_output > best_thr).to( self.device).float() # 大于阈值的归为1 preds[preds.view(preds.shape[0], -1).sum(-1) < best_pixel_thr, ...] = 0.0 # 过滤噪声点 n = preds.shape[0] # batch size为多少 preds = preds.view(n, -1) for index2 in range(n): pred = preds[index2, ...] if torch.sum(pred) > 0: count_pred += 1 tmp.append(self.dice_overall(preds, masks).mean()) print('score:', sum(tmp) / len(tmp)) print('count_true:{}, count_pred:{}'.format(count_true, count_pred)) def grid_search(self, thrs_big, pixel_thrs): '''利用网格法搜索最优阈值和最优像素阈值 Args: thrs_big: 网格法搜索时的一系列阈值 pixel_thrs: 网格搜索时的一系列像素阈值 Return: 最优阈值,最优像素阈值,最高得分,网络矩阵中每个位置的得分 ''' with torch.no_grad(): # 先大概选取阈值范围和像素阈值范围 dices_big = [] # 存放的是二维矩阵,每一行为每一个阈值下所有像素阈值得到的得分 for th in thrs_big: dices_pixel = [] for pixel_thr in pixel_thrs: tmp = [] tbar = tqdm.tqdm(self.valid_loader) for i, (images, masks) in enumerate(tbar): # GT : Ground Truth images = images.to(self.device) net_output = torch.sigmoid(self.unet(images)) preds = (net_output > th).to( self.device).float() # 大于阈值的归为1 preds[ preds.view(preds.shape[0], -1).sum(-1) < pixel_thr, ...] = 0.0 # 过滤噪声点 tmp.append(self.dice_overall(preds, masks).mean()) # tmp.append(self.classify_score(preds, masks)) dices_pixel.append(sum(tmp) / len(tmp)) dices_big.append(dices_pixel) dices_big = np.array(dices_big) print('粗略挑选最优阈值和最优像素阈值,dices_big_shape:{}'.format( np.shape(dices_big))) re = np.where(dices_big == np.max(dices_big)) # 如果有多个最大值的处理方式 if np.shape(re)[1] != 1: re = re[0] best_thrs_big, best_pixel_thr = thrs_big[int( re[0])], pixel_thrs[int(re[1])] best_thr, score = best_thrs_big, dices_big.max() return best_thr, best_pixel_thr, score, dices_big def choose_threshold_grid(self, model_path, index): '''利用网格法搜索当前模型的最优阈值和最优像素阈值,分为粗略搜索和精细搜索两个过程;并保存热力图 Args: model_path: 当前模型权重的位置 index: 当前为第几个fold Return: 最优阈值,最优像素阈值,最高得分 ''' self.unet.module.load_state_dict(torch.load(model_path)['state_dict']) stage = eval(model_path.split('/')[-1].split('_')[2]) print('Loaded from %s, using choose_threshold_grid!' % model_path) self.unet.eval() thrs_big1 = np.arange(0.60, 0.81, 0.015) # 阈值列表 pixel_thrs1 = np.arange(768, 2305, 256) # 像素阈值列表 best_thr1, best_pixel_thr1, score1, dices_big1 = self.grid_search( thrs_big1, pixel_thrs1) print('best_thr1:{}, best_pixel_thr1:{}, score1:{}'.format( best_thr1, best_pixel_thr1, score1)) thrs_big2 = np.arange(best_thr1 - 0.015, best_thr1 + 0.015, 0.0075) # 阈值列表 pixel_thrs2 = np.arange(best_pixel_thr1 - 256, best_pixel_thr1 + 257, 128) # 像素阈值列表 best_thr2, best_pixel_thr2, score2, dices_big2 = self.grid_search( thrs_big2, pixel_thrs2) print('best_thr2:{}, best_pixel_thr2:{}, score2:{}'.format( best_thr2, best_pixel_thr2, score2)) if score1 < score2: best_thr, best_pixel_thr, score, dices_big = best_thr2, best_pixel_thr2, score2, dices_big2 else: best_thr, best_pixel_thr, score, dices_big = best_thr1, best_pixel_thr1, score1, dices_big1 print('best_thr:{}, best_pixel_thr:{}, score:{}'.format( best_thr, best_pixel_thr, score)) f, (ax1, ax2) = plt.subplots(figsize=(14.4, 4.8), ncols=2) cmap = sns.cubehelix_palette(start=1.5, rot=3, gamma=0.8, as_cmap=True) data1 = pd.DataFrame(data=dices_big1, index=np.around(thrs_big1, 3), columns=pixel_thrs1) sns.heatmap(data1, linewidths=0.05, ax=ax1, vmax=np.max(dices_big1), vmin=np.min(dices_big1), cmap=cmap, annot=True, fmt='.4f') ax1.set_title('Large-scale search') data2 = pd.DataFrame(data=dices_big2, index=np.around(thrs_big2, 3), columns=pixel_thrs2) sns.heatmap(data2, linewidths=0.05, ax=ax2, vmax=np.max(dices_big2), vmin=np.min(dices_big2), cmap=cmap, annot=True, fmt='.4f') ax2.set_title('Little-scale search') f.savefig( os.path.join(self.save_path, 'stage{}'.format(stage) + '_fold' + str(index))) # plt.show() plt.close() return float(best_thr), float(best_pixel_thr), float(score) def get_dice_onval(self, model_path, best_thr, pixel_thr): '''已经训练好模型,并且选完阈值后。根据当前模型,best_thr, pixel_thr得到在验证集的表现 Args: model_path: 要加载的模型路径 best_thr: 选出的最优阈值 pixel_thr: 选出的最优像素阈值 Return: None ''' self.unet.module.load_state_dict(torch.load(model_path)['state_dict']) stage = eval(model_path.split('/')[-1].split('_')[2]) print('Loaded from %s, using get_dice_onval!' % model_path) self.unet.eval() with torch.no_grad(): # 选最优像素阈值 tmp = [] tbar = tqdm.tqdm(self.valid_loader) for i, (images, masks) in enumerate(tbar): # GT : Ground Truth images = images.to(self.device) net_output = torch.sigmoid(self.unet(images)) preds = (net_output > best_thr).to( self.device).float() # 大于阈值的归为1 if stage != 3: preds[preds.view(preds.shape[0], -1).sum(-1) < pixel_thr, ...] = 0.0 # 过滤噪声点 tmp.append(self.dice_overall(preds, masks).mean()) # tmp.append(self.classify_score(preds, masks)) score = sum(tmp) / len(tmp) print('best_thr:{}, best_pixel_thr:{}, score:{}'.format( best_thr, pixel_thr, score))
class Test(object): def __init__(self, model_type, image_size, mean, std, t=None): # Models self.unet = None self.image_size = image_size # 模型的输入大小 self.device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu') self.model_type = model_type self.t = t self.mean = mean self.std = std def build_model(self): """Build generator and discriminator.""" if self.model_type == 'U_Net': self.unet = U_Net(img_ch=3, output_ch=1) elif self.model_type == 'AttU_Net': self.unet = AttU_Net(img_ch=3, output_ch=1) elif self.model_type == 'unet_resnet34': # self.unet = Unet(backbone_name='resnet34', classes=1) self.unet = smp.Unet('resnet34', encoder_weights='imagenet', activation=None) elif self.model_type == 'unet_resnet50': self.unet = smp.Unet('resnet50', encoder_weights='imagenet', activation=None) elif self.model_type == 'unet_se_resnext50_32x4d': self.unet = smp.Unet('se_resnext50_32x4d', encoder_weights='imagenet', activation=None) elif self.model_type == 'unet_densenet121': self.unet = smp.Unet('densenet121', encoder_weights='imagenet', activation=None) elif self.model_type == 'unet_resnet34_t': self.unet = Unet_t('resnet34', encoder_weights='imagenet', activation=None, use_ConvTranspose2d=True) elif self.model_type == 'unet_resnet34_oct': self.unet = OctaveUnet('resnet34', encoder_weights='imagenet', activation=None) elif self.model_type == 'pspnet_resnet34': self.unet = smp.PSPNet('resnet34', encoder_weights='imagenet', classes=1, activation=None) elif self.model_type == 'linknet': self.unet = LinkNet34(num_classes=1) elif self.model_type == 'deeplabv3plus': self.unet = DeepLabV3Plus(model_backbone='res50_atrous', num_classes=1) # self.unet = DeepLabV3Plus(num_classes=1) print('build model done!') self.unet.to(self.device) def test_model(self, thresholds_classify, thresholds_seg, average_threshold, stage_cla, stage_seg, n_splits, test_best_model=True, less_than_sum=2048 * 2, seg_average_vote=True, csv_path=None, test_image_path=None): """ Args: thresholds_classify: list, 各个分类模型的阈值,高于这个阈值的置为1,否则置为0 thresholds_seg: list,各个分割模型的阈值 average_threshold: 分割后使用平均策略时所使用的平均阈值 stage_cla: 第几阶段的权重作为分类结果 stage_seg: 第几阶段的权重作为分割结果 n_splits: list, 测试哪几折的结果进行平均 test_best_model: 是否要使用最优模型测试,若不是的话,则取最新的模型测试 less_than_sum: list, 预测图片中有预测出的正样本总和小于这个值时,则忽略所有 seg_average_vote: bool,True:平均,False:投票 """ # 对于每一折加载模型,对所有测试集测试,并取平均 sample_df = pd.read_csv(csv_path) # preds_cla存放模型的分类结果,而preds存放模型的分割结果,其中分割模型默认为1024的分辨率 preds = np.zeros([len(sample_df), self.image_size, self.image_size]) for fold in n_splits: # 加载分类模型,进行测试 self.unet = None self.build_model() if test_best_model: unet_path = os.path.join( 'checkpoints', self.model_type, self.model_type + '_{}_{}_best.pth'.format(stage_cla, fold)) else: unet_path = os.path.join( 'checkpoints', self.model_type, self.model_type + '_{}_{}.pth'.format(stage_cla, fold)) print("Load classify weight from %s" % unet_path) self.unet.load_state_dict(torch.load(unet_path)['state_dict']) self.unet.eval() seg_unet = copy.deepcopy(self.unet) # 加载分割模型,进行测试s if test_best_model: unet_path = os.path.join( 'checkpoints', self.model_type, self.model_type + '_{}_{}_best.pth'.format(stage_seg, fold)) else: unet_path = os.path.join( 'checkpoints', self.model_type, self.model_type + '_{}_{}.pth'.format(stage_seg, fold)) print('Load segmentation weight from %s.' % unet_path) seg_unet.load_state_dict(torch.load(unet_path)['state_dict']) seg_unet.eval() count_mask_classify = 0 with torch.no_grad(): # sample_df = sample_df.drop_duplicates('ImageId ', keep='last').reset_index(drop=True) for index, row in tqdm(sample_df.iterrows(), total=len(sample_df)): file = row['ImageId'] img_path = os.path.join(test_image_path, file.strip() + '.jpg') img = Image.open(img_path).convert('RGB') pred = self.tta(img, self.unet) # 首先经过阈值和像素阈值,判断该图像中是否有掩模 pred = np.where(pred > thresholds_classify[fold], 1, 0) if np.sum(pred) < less_than_sum[fold]: pred[:] = 0 # 如果有掩膜的话,加载分割模型进行测试 if np.sum(pred) > 0: count_mask_classify += 1 pred = self.tta(img, seg_unet) # 如果不是采用平均策略,即投票策略,则进行阈值处理,变成0或1 if not seg_average_vote: pred = np.where(pred > thresholds_seg[fold], 1, 0) preds[index, ...] += pred print('Fold %d Detect %d mask in classify.' % (fold, count_mask_classify)) if not seg_average_vote: vote_model_num = len(n_splits) vote_ticket = round(vote_model_num / 2.0) print("Using voting strategy, Ticket / Vote models: %d / %d" % (vote_ticket, vote_model_num)) else: print('Using average strategy.') preds = preds / len(n_splits) rle = [] count_has_mask = 0 for index, row in tqdm(sample_df.iterrows(), total=len(sample_df)): file = row['ImageId'] pred = preds[index, ...] if not seg_average_vote: pred = np.where(pred > vote_ticket, 1, 0) else: pred = np.where(pred > average_threshold, 1, 0) # if np.sum(pred) < 512: # TODO # pred[:] = 0 # if np.sum(pred)>0: # count_has_mask += 1 pred = cv2.resize(pred, (1024, 1024)) encoding = mask_to_rle(pred.T, 1024, 1024) if encoding == ' ': rle.append([file.strip(), '-1']) else: count_has_mask += 1 rle.append([file.strip(), encoding[1:]]) print('The number of masked pictures predicted:', count_has_mask) submission_df = pd.DataFrame(rle, columns=['ImageId', 'EncodedPixels']) submission_df.to_csv('submission.csv', index=False) def image_transform(self, image): """对样本进行预处理 """ resize = transforms.Resize(self.image_size) to_tensor = transforms.ToTensor() normalize = transforms.Normalize(self.mean, self.std) transform_compose = transforms.Compose([resize, to_tensor, normalize]) return transform_compose(image) def detection(self, image, model): """对输入样本进行检测 Args: image: 待检测样本,Image model: 要使用的网络 Return: pred: 检测结果 """ image = self.image_transform(image) image = torch.unsqueeze(image, dim=0) image = image.float().to(self.device) pred = torch.sigmoid(model(image)) # 预测出的结果 pred = pred.view(self.image_size, self.image_size) pred = pred.detach().cpu().numpy() return pred def tta(self, image, model): """执行TTA预测 Args: image: Image图片 model: 要使用的网络 Return: pred: 最后预测的结果 """ preds = np.zeros([self.image_size, self.image_size]) # 768大小 # image_resize = image.resize((768, 768)) # resize_pred = self.detection(image_resize) # resize_pred_img = Image.fromarray(resize_pred) # resize_pred_img = resize_pred_img.resize((1024, 1024)) # preds += np.asarray(resize_pred_img) # 左右翻转 image_hflip = image.transpose(Image.FLIP_LEFT_RIGHT) hflip_pred = self.detection(image_hflip, model) hflip_pred_img = Image.fromarray(hflip_pred) pred_img = hflip_pred_img.transpose(Image.FLIP_LEFT_RIGHT) preds += np.asarray(pred_img) # CLAHE aug = CLAHE(p=1.0) image_np = np.asarray(image) clahe_image = aug(image=image_np)['image'] clahe_image = Image.fromarray(clahe_image) clahe_pred = self.detection(clahe_image, model) preds += clahe_pred # 原图 original_pred = self.detection(image, model) preds += original_pred # 求平均 pred = preds / 3.0 return pred