Exemple #1
0
def test(model_path_effi7, model_path_resnest, output_dir, test_loader, addNDVI):
    in_channels = 4
    if(addNDVI):
        in_channels += 1
    model_resnest = smp.UnetPlusPlus(
        encoder_name="timm-resnest101e",
        encoder_weights="imagenet",
        in_channels=in_channels,
        classes=10,
        )
    model_effi7 = smp.UnetPlusPlus(
        encoder_name="efficientnet-b7",
        encoder_weights="imagenet",
        in_channels=in_channels,
        classes=10,   
        )
    # 如果模型是SWA
    if("swa" in model_path_resnest):
        model_resnest = AveragedModel(model_resnest)
    if("swa" in model_path_effi7):
        model_effi7 = AveragedModel(model_effi7)
    model_resnest.to(DEVICE);
    model_resnest.load_state_dict(torch.load(model_path_resnest))
    model_resnest.eval()
    model_effi7.to(DEVICE);
    model_effi7.load_state_dict(torch.load(model_path_effi7))
    model_effi7.eval()
    for image, image_stretch, image_path, ndvi in test_loader:
        with torch.no_grad():
            # image.shape: 16,4,256,256
            image_flip2 = torch.flip(image,[2])
            image_flip2 = image_flip2.cuda()
            image_flip3 = torch.flip(image,[3])
            image_flip3 = image_flip3.cuda()
            image = image.cuda()
            image_stretch = image_stretch.cuda()
            
            output1 = model_resnest(image).cpu().data.numpy()
            output2 = model_resnest(image_stretch).cpu().data.numpy()
            output3 = model_effi7(image).cpu().data.numpy()
            output4 = model_effi7(image_stretch).cpu().data.numpy()
            
            output5 = torch.flip(model_resnest(image_flip2),[2]).cpu().data.numpy()
            output6 = torch.flip(model_effi7(image_flip2),[2]).cpu().data.numpy()
            output7 = torch.flip(model_resnest(image_flip3),[3]).cpu().data.numpy()
            output8 = torch.flip(model_effi7(image_flip3),[3]).cpu().data.numpy()
            
        output = (output1 + output2 + output3 + output4 + output5 + output6 + output7 + output8) / 8.0
        # output.shape: 16,10,256,256
        for i in range(output.shape[0]):
            pred = output[i]
            # for low_ndvi in range(3,8):
            #     pred[low_ndvi][ndvi[i]>35] = 0
            # for high_ndvi in range(3):
            #     pred[high_ndvi][ndvi[i]<0.02] = 0
            pred = np.argmax(pred, axis = 0) + 1
            pred = np.uint8(pred)
            save_path = os.path.join(output_dir, image_path[i][-10:].replace('.tif', '.png'))
            print(save_path)
            cv2.imwrite(save_path, pred)
 def __init__(self, model_name, n_class):
     super().__init__()  
     self.model = smp.UnetPlusPlus(# UnetPlusPlus 
             encoder_name=model_name,        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
             encoder_weights="imagenet",     # use `imagenet` pretrained weights for encoder initialization
             in_channels=3,                  # model input channels (1 for grayscale images, 3 for RGB, etc.)
             classes=n_class,                      # model output channels (number of classes in your dataset)
         )
Exemple #3
0
    def __init__(self, model_name, in_channels=3, out_channels=1):
        super(SmpModel16, self).__init__()

        aux_params = dict(
            pooling='max',  # one of 'avg', 'max'
            dropout=0.5,  # dropout ratio, default is None
            activation='softmax',  # activation function, default is None
            classes=out_channels,  # define number of output labels
        )

        if 'unet' in model_name:
            self.model = smp.Unet(
                encoder_name=
                "resnet18",  # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
                encoder_depth=3,
                encoder_weights=
                "imagenet",  # use `imagenet` pretrained weights for encoder initialization
                decoder_channels=[128, 64, 32],  # [256, 128, 64, 32]
                in_channels=
                in_channels,  # model input channels (1 for grayscale images, 3 for RGB, etc.)
                classes=
                out_channels,  # model output channels (number of classes in your dataset)
                aux_params=aux_params,
            )

        elif 'uplus' in model_name:

            self.model = smp.UnetPlusPlus(
                encoder_name=
                "resnet18",  # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
                encoder_depth=3,
                encoder_weights=
                "imagenet",  # use `imagenet` pretrained weights for encoder initialization
                decoder_channels=[256, 128, 64],
                in_channels=
                in_channels,  # model input channels (1 for grayscale images, 3 for RGB, etc.)
                classes=
                out_channels,  # model output channels (number of classes in your dataset)
                aux_params=aux_params,
            )

        else:

            self.model = smp.PSPNet(
                encoder_name=
                "resnet18",  # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
                # encoder_depth=4,
                encoder_weights=
                "imagenet",  # use `imagenet` pretrained weights for encoder initialization
                # decoder_channels=[256, 128, 64, 32],
                in_channels=
                in_channels,  # model input channels (1 for grayscale images, 3 for RGB, etc.)
                classes=
                out_channels,  # model output channels (number of classes in your dataset)
                aux_params=aux_params,
            )

        self.down_layer = DownsizeBlock(out_channels, downsize_mode=2)
Exemple #4
0
def get_model(data_channel=3, encoder=None, encoder_weight=None):
    model = smp.UnetPlusPlus(
        encoder_name=encoder,
        encoder_weights=encoder_weight,
        in_channels=data_channel,
        classes=10,
        encoder_depth=5,
    )
    return model
Exemple #5
0
 def __init__(self,
              encoder_name,
              encoder_weights="imagenet",
              in_channels=3,
              n_class=10):
     super().__init__()
     self.model = smp.UnetPlusPlus(  # UnetPlusPlus
         encoder_name=
         encoder_name,  # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
         encoder_weights=
         encoder_weights,  # use `imagenet` pretrained weights for encoder initialization
         in_channels=
         in_channels,  # model input channels (1 for grayscale images, 3 for RGB, etc.)
         classes=
         n_class,  # model output channels (number of classes in your dataset)
         # decoder_attention_type='scse',
         # aux_params={'classes': n_class}
     )
Exemple #6
0
def create_smp_model(arch, **kwargs):
    'Create segmentation_models_pytorch model'

    assert arch in ARCHITECTURES, f'Select one of {ARCHITECTURES}'

    if arch == "Unet": model = smp.Unet(**kwargs)
    elif arch == "UnetPlusPlus": model = smp.UnetPlusPlus(**kwargs)
    elif arch == "MAnet": model = smp.MAnet(**kwargs)
    elif arch == "FPN": model = smp.FPN(**kwargs)
    elif arch == "PAN": model = smp.PAN(**kwargs)
    elif arch == "PSPNet": model = smp.PSPNet(**kwargs)
    elif arch == "Linknet": model = smp.Linknet(**kwargs)
    elif arch == "DeepLabV3": model = smp.DeepLabV3(**kwargs)
    elif arch == "DeepLabV3Plus": model = smp.DeepLabV3Plus(**kwargs)
    else: raise NotImplementedError

    setattr(model, 'kwargs', kwargs)
    return model
Exemple #7
0
def SkinLesionModel(model, pretrained=True):
    models_zoo = {
        'deeplabv3plus':
        smp.DeepLabV3Plus('resnet101',
                          encoder_weights='imagenet',
                          aux_params=None),
        'deeplabv3plus_resnext':
        smp.DeepLabV3Plus('resnext101_32x8d',
                          encoder_weights='imagenet',
                          aux_params=None),
        'pspnet':
        smp.PSPNet('resnet101', encoder_weights='imagenet', aux_params=None),
        'unetplusplus':
        smp.UnetPlusPlus('resnet101',
                         encoder_weights='imagenet',
                         aux_params=None),
    }
    net = models_zoo.get(model)
    if net is None:
        raise Warning('Wrong Net Name!!')
    return net
 def __init__(self,
              num_classes=12,
              encoder="resnext101_32x8d",
              pretrain_weight="imagenet",
              decoder="DeepLabV3Plus"):
     super(smpModel, self).__init__()
     if (decoder == "DeepLabV3Plus"):
         self.backbone = smp.DeepLabV3Plus(encoder_name=encoder,
                                           encoder_weights=pretrain_weight,
                                           in_channels=3,
                                           classes=num_classes)
     elif (decoder == "DeepLabV3"):
         self.backbone = smp.DeepLabV3(encoder_name=encoder,
                                       encoder_weights=pretrain_weight,
                                       in_channels=3,
                                       classes=num_classes)
     elif (decoder == "UnetPlusPlus"):
         self.backbone = smp.UnetPlusPlus(encoder_name=encoder,
                                          encoder_weights=pretrain_weight,
                                          in_channels=3,
                                          classes=num_classes)
Exemple #9
0
def test_1(model_path, output_dir, test_loader, addNDVI):
    in_channels = 4
    if (addNDVI):
        in_channels += 1
    model = smp.UnetPlusPlus(
        encoder_name="resnet101",
        encoder_weights="imagenet",
        in_channels=in_channels,
        classes=10,
    )
    # model = smp.DeepLabV3Plus(
    #         encoder_name="timm-regnety_320", #resnet101
    #         encoder_weights="imagenet",
    #         in_channels=4,
    #         classes=8,
    # )
    # 如果模型是SWA
    if ("swa" in model_path):
        model = AveragedModel(model)
    model.to(DEVICE)
    model.load_state_dict(torch.load(model_path))
    model.eval()

    for image, image_stretch, image_path, ndvi in test_loader:
        with torch.no_grad():
            image = image.cuda()
            image_stretch = image_stretch.cuda()
            output1 = model(image).cpu().data.numpy()
            output2 = model(image_stretch).cpu().data.numpy()
        output = (output1 + output2) / 2.0
        for i in range(output.shape[0]):
            pred = output[i]
            pred = np.argmax(pred, axis=0) + 1
            pred = np.uint8(pred)
            save_path = os.path.join(
                output_dir,
                image_path[i].split('\\')[-1].replace('.tif', '.png'))
            #print(image_path[i][-10:])
            print(save_path)
            cv2.imwrite(save_path, pred)
Exemple #10
0
def get_segmentation_model(
    arch: str,
    encoder_name: str,
    encoder_weights: Optional[str] = "imagenet",
    pretrained_checkpoint_path: Optional[str] = None,
    checkpoint_path: Optional[Union[str, List[str]]] = None,
    convert_bn: Optional[str] = None,
    convert_bottleneck: Tuple[int, int, int] = (0, 0, 0),
    **kwargs: Any,
) -> nn.Module:
    """
    Fetch segmentation model by its name
    :param arch:
    :param encoder_name:
    :param encoder_weights:
    :param checkpoint_path:
    :param pretrained_checkpoint_path:
    :param convert_bn:
    :param convert_bottleneck:
    :param kwargs:
    :return:
    """

    arch = arch.lower()
    if (encoder_name == "en_resnet34" or checkpoint_path is not None
            or pretrained_checkpoint_path is not None):
        encoder_weights = None

    if arch == "unet":
        model = smp.Unet(encoder_name=encoder_name,
                         encoder_weights=encoder_weights,
                         **kwargs)
    elif arch == "unetplusplus" or arch == "unet++":
        model = smp.UnetPlusPlus(encoder_name=encoder_name,
                                 encoder_weights=encoder_weights,
                                 **kwargs)
    elif arch == "linknet":
        model = smp.Linknet(encoder_name=encoder_name,
                            encoder_weights=encoder_weights,
                            **kwargs)
    elif arch == "pspnet":
        model = smp.PSPNet(encoder_name=encoder_name,
                           encoder_weights=encoder_weights,
                           **kwargs)
    elif arch == "pan":
        model = smp.PAN(encoder_name=encoder_name,
                        encoder_weights=encoder_weights,
                        **kwargs)
    elif arch == "deeplabv3":
        model = smp.DeepLabV3(encoder_name=encoder_name,
                              encoder_weights=encoder_weights,
                              **kwargs)
    elif arch == "deeplabv3plus" or arch == "deeplabv3+":
        model = smp.DeepLabV3Plus(encoder_name=encoder_name,
                                  encoder_weights=encoder_weights,
                                  **kwargs)
    elif arch == "manet":
        model = smp.MAnet(encoder_name=encoder_name,
                          encoder_weights=encoder_weights,
                          **kwargs)
    else:
        raise ValueError

    if pretrained_checkpoint_path is not None:
        print(f"Loading pretrained checkpoint {pretrained_checkpoint_path}")
        state_dict = torch.load(pretrained_checkpoint_path,
                                map_location=torch.device("cpu"))
        model.encoder.load_state_dict(state_dict)
        del state_dict

    # TODO fmap_size=16 hardcoded for input 256 (matters for positional encoding)
    botnet.convert_resnet(
        model.encoder,
        replacement=convert_bottleneck,
        fmap_size=16,
        position_encoding=None,
    )

    # TODO parametrize conversion
    print(f"Convert BN to {convert_bn}")
    if convert_bn == "instance":
        print("Converting BatchNorm2d to InstanceNorm2d")
        model = batch_norm2instance(model)
    elif convert_bn == "group":
        print("Converting BatchNorm2d to GroupNorm")
        model = batch_norm2group(model, channels_per_group=1)
    elif convert_bn == "bnet":
        print("Converting BatchNorm2d to BNet2d")
        model = batch_norm2bnet(model)
    elif convert_bn == "gnet":
        print("Converting BatchNorm2d to GNet2d")
        model = batch_norm2gnet(model, channels_per_group=1)
    elif not convert_bn:
        print("Do not convert BatchNorm2d")
    else:
        raise ValueError

    if checkpoint_path is not None:
        if not isinstance(checkpoint_path, list):
            checkpoint_path = [checkpoint_path]
        states = []
        for cp in checkpoint_path:
            # Load checkpoint
            print(f"\nLoading checkpoint {str(cp)}")
            state_dict = torch.load(
                cp, map_location=torch.device("cpu"))["model_state_dict"]
            states.append(state_dict)
        state_dict = average_weights(states)
        model.load_state_dict(state_dict)
        del state_dict

    return model
Exemple #11
0
    def __init__(self,
                 architecture="Unet",
                 encoder="resnet34",
                 depth=5,
                 in_channels=3,
                 classes=2,
                 activation="softmax"):
        super(SegmentationModels, self).__init__()
        self.architecture = architecture
        self.encoder = encoder
        self.depth = depth
        self.in_channels = in_channels
        self.classes = classes
        self.activation = activation

        # define model

        _ARCHITECTURES = [
            "Unet", "UnetPlusPlus", "Linknet", "MAnet", "FPN", "PSPNet", "PAN",
            "DeepLabV3", "DeepLabV3Plus"
        ]
        assert self.architecture in _ARCHITECTURES, "architecture=={0}, actual '{1}'".format(
            _ARCHITECTURES, self.architecture)

        if self.architecture == "Unet":
            self.model = smp.Unet(
                encoder_name=self.encoder,
                encoder_weights=None,
                encoder_depth=self.depth,
                in_channels=self.in_channels,
                classes=self.classes,
                activation=self.activation,
            )
            self.pad_unit = 2**self.depth
        elif self.architecture == "UnetPlusPlus":
            self.model = smp.UnetPlusPlus(
                encoder_name=self.encoder,
                encoder_weights=None,
                encoder_depth=self.depth,
                in_channels=self.in_channels,
                classes=self.classes,
                activation=self.activation,
            )
            self.pad_unit = 2**self.depth
        elif self.architecture == "MAnet":
            self.model = smp.MAnet(
                encoder_name=self.encoder,
                encoder_weights=None,
                encoder_depth=self.depth,
                in_channels=self.in_channels,
                classes=self.classes,
                activation=self.activation,
            )
            self.pad_unit = 2**self.depth
        elif self.architecture == "Linknet":
            self.model = smp.Linknet(
                encoder_name=self.encoder,
                encoder_weights=None,
                encoder_depth=self.depth,
                in_channels=self.in_channels,
                classes=self.classes,
                activation=self.activation,
            )
            self.pad_unit = 2**self.depth
        elif self.architecture == "FPN":
            self.model = smp.FPN(
                encoder_name=self.encoder,
                encoder_weights=None,
                encoder_depth=self.depth,
                in_channels=self.in_channels,
                classes=self.classes,
                activation=self.activation,
            )
            self.pad_unit = 2**self.depth
        elif self.architecture == "PSPNet":
            self.model = smp.PSPNet(
                encoder_name=self.encoder,
                encoder_weights=None,
                encoder_depth=self.depth,
                in_channels=self.in_channels,
                classes=self.classes,
                activation=self.activation,
            )
            self.pad_unit = 2**self.depth
        elif self.architecture == "PAN":
            self.model = smp.PAN(
                encoder_name=self.encoder,
                encoder_weights=None,
                encoder_depth=self.depth,
                in_channels=self.in_channels,
                classes=self.classes,
                activation=self.activation,
            )
            self.pad_unit = 2**self.depth
        elif self.architecture == "DeepLabV3":
            self.model = smp.DeepLabV3(
                encoder_name=self.encoder,
                encoder_weights=None,
                encoder_depth=self.depth,
                in_channels=self.in_channels,
                classes=self.classes,
                activation=self.activation,
            )
            self.pad_unit = 2**self.depth
        elif self.architecture == "DeepLabV3Plus":
            self.model = smp.DeepLabV3Plus(
                encoder_name=self.encoder,
                encoder_weights=None,
                encoder_depth=self.depth,
                in_channels=self.in_channels,
                classes=self.classes,
                activation=self.activation,
            )
            self.pad_unit = 2**self.depth
Exemple #12
0
def fine_tune(EPOCHES, BATCH_SIZE, train_image_paths, train_label_paths,
              val_image_paths, val_label_paths, channels, model_path,
              swa_model_path, addNDVI):

    train_loader = get_dataloader(train_image_paths,
                                  train_label_paths,
                                  "train",
                                  addNDVI,
                                  BATCH_SIZE,
                                  shuffle=True,
                                  num_workers=8)
    valid_loader = get_dataloader(val_image_paths,
                                  val_label_paths,
                                  "val",
                                  addNDVI,
                                  BATCH_SIZE,
                                  shuffle=False,
                                  num_workers=8)

    # 定义模型,优化器,损失函数
    # model = smp.UnetPlusPlus(
    #         encoder_name="efficientnet-b7",
    #         encoder_weights="imagenet",
    #         in_channels=channels,
    #         classes=10,
    # )
    model = smp.UnetPlusPlus(
        encoder_name="resnet101",
        encoder_weights="imagenet",
        in_channels=channels,
        classes=10,
    )
    model.to(DEVICE)
    model.load_state_dict(torch.load(model_path))
    # 采用SGD优化器
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=3e-4,
                                weight_decay=1e-3,
                                momentum=0.9)

    # 随机权重平均SWA,实现更好的泛化
    swa_model = AveragedModel(model).to(DEVICE)
    # SWA调整学习率
    swa_scheduler = SWALR(optimizer, swa_lr=1e-5)

    # LovaszLoss是对基于子模块损失凸Lovasz扩展的mIoU损失的直接优化
    loss_fn = LovaszLoss(mode='multiclass').to(DEVICE)

    header = r'Epoch/EpochNum | TrainLoss | ValidmIoU | Time(m)'
    raw_line = r'{:5d}/{:8d} | {:9.3f} | {:9.3f} | {:9.2f}'
    print(header)

    #    # 在训练最开始之前实例化一个GradScaler对象,使用autocast才需要
    #    scaler = GradScaler()

    # 记录当前验证集最优mIoU,以判定是否保存当前模型
    best_miou = 0
    train_loss_epochs, val_mIoU_epochs, lr_epochs = [], [], []
    # 开始训练
    for epoch in range(1, EPOCHES + 1):
        # print("Start training the {}st epoch...".format(epoch))
        # 存储训练集每个batch的loss
        losses = []
        start_time = time.time()
        model.train()
        model.to(DEVICE)
        for batch_index, (image, target) in enumerate(train_loader):
            image, target = image.to(DEVICE), target.to(DEVICE)
            # 在反向传播前要手动将梯度清零
            optimizer.zero_grad()
            #            # 使用autocast半精度加速训练,前向过程(model + loss)开启autocast
            #            with autocast(): #need pytorch>1.6
            # 模型推理得到输出
            output = model(image)
            # 求解该batch的loss
            loss = loss_fn(output, target)
            #                scaler.scale(loss).backward()
            #                scaler.step(optimizer)
            #                scaler.update()
            # 反向传播求解梯度
            loss.backward()
            # 更新权重参数
            optimizer.step()
            losses.append(loss.item())
        swa_model.update_parameters(model)
        swa_scheduler.step()
        # 计算验证集IoU
        val_iou = cal_val_iou(model, valid_loader)
        # 输出验证集每类IoU
        # print('\t'.join(np.stack(val_iou).mean(0).round(3).astype(str)))
        # 保存当前epoch的train_loss.val_mIoU.lr_epochs
        train_loss_epochs.append(np.array(losses).mean())
        val_mIoU_epochs.append(np.mean(val_iou))
        lr_epochs.append(optimizer.param_groups[0]['lr'])
        # 输出进程
        print(raw_line.format(epoch, EPOCHES,
                              np.array(losses).mean(), np.mean(val_iou),
                              (time.time() - start_time) / 60**1),
              end="")
        if best_miou < np.stack(val_iou).mean(0).mean():
            best_miou = np.stack(val_iou).mean(0).mean()
            torch.save(model.state_dict(), model_path[:-4] + "_finetune.pth")
            print("  valid mIoU is improved. the model is saved.")
        else:
            print("")
    # 最后更新BN层参数
    torch.optim.swa_utils.update_bn(train_loader, swa_model, device=DEVICE)
    # 计算验证集IoU
    val_iou = cal_val_iou(model, valid_loader)
    print("swa_model'mIoU is {}".format(np.mean(val_iou)))
    torch.save(swa_model.state_dict(), swa_model_path)
    return train_loss_epochs, val_mIoU_epochs, lr_epochs
Exemple #13
0
        elif config['model'] == 'resnext101':
            model = ResneXt(5, 'resnext101', shared=True)
        elif config['model'] == 'densenet161':
            model = DensenetUnet(5)
        elif config['model'] == 'dpn92':
            model = DPNUnet(5, shared=True)

        model = torch.nn.DataParallel(model)
        model.load_state_dict(torch.load(config['pretrained'])['state_dict'])
        model = model.module
        model.final = model.make_final_classifier(
            in_filters=config['in_filters'], num_classes=1)
else:
    model = smp.UnetPlusPlus(encoder_name=config["model"],
                             encoder_weights=config["pretrained"],
                             in_channels=1,
                             classes=1,
                             activation=config["activation"])

transform = A.Compose([
    A.RandomCrop(width=config['img_size'], height=config['img_size']),
    A.RandomRotate90(),
    A.Flip()
],
                      additional_targets={
                          'image0': 'image',
                          'image1': 'image'
                      })

df = pd.read_csv(config['sample_submission_path'])
Exemple #14
0
    parallel = args.parallel
    overlap = True

    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_number)
    print('Start data preparation')
    seed_everything(2021)
    data_path = '../data/'
    data = pd.read_csv(data_path + 'train.csv')
    if seg_model_name == 'unet':
        print('Use unet model')
        model = smp.Unet(encoder, encoder_weights="imagenet", in_channels=3, classes=1,
                         decoder_use_batchnorm=False).cuda()
    elif seg_model_name == 'unet++':
        print('Use unet++ model')
        model = smp.UnetPlusPlus(encoder, encoder_weights="imagenet", in_channels=3, classes=1,
                                 decoder_use_batchnorm=False).cuda()
    elif seg_model_name == 'albunet':
        print('Use albunet model')
        from nn.trainer import AlbuNet
        model = AlbuNet(num_classes=1, pretrained=True).cuda()
    else:
        print('Model name is incorrect. Set to unet++')
        model = smp.UnetPlusPlus(encoder, encoder_weights="imagenet", in_channels=3, classes=1,
                                 decoder_use_batchnorm=False).cuda()
    if parallel:
        model = DataParallel(model).cuda()

    sample_sub = pd.read_csv(data_path + 'sample_submission.csv')
    test_paths = sample_sub.id.values
    print('Start test')
Exemple #15
0
    def _get_net(self, net_name):
        if net_name == 'unet':
            if self.encoder_name is None:
                from model import unet
                net = unet.__dict__[net_name](n_channels=self.channels,
                                              n_classes=self.num_classes)
            else:
                import segmentation_models_pytorch as smp
                net = smp.Unet(encoder_name=self.encoder_name,
                               encoder_weights=None
                               if not self.ex_pre_trained else 'imagenet',
                               in_channels=self.channels,
                               classes=self.num_classes,
                               aux_params={"classes": self.num_classes - 1})
        elif net_name == 'unet++':
            if self.encoder_name is None:
                raise ValueError("encoder name must not be 'None'!")
            else:
                import segmentation_models_pytorch as smp
                net = smp.UnetPlusPlus(
                    encoder_name=self.encoder_name,
                    encoder_weights=None
                    if not self.ex_pre_trained else 'imagenet',
                    in_channels=self.channels,
                    classes=self.num_classes,
                    aux_params={"classes": self.num_classes - 1})

        elif net_name == 'FPN':
            if self.encoder_name is None:
                raise ValueError("encoder name must not be 'None'!")
            else:
                import segmentation_models_pytorch as smp
                net = smp.FPN(encoder_name=self.encoder_name,
                              encoder_weights=None
                              if not self.ex_pre_trained else 'imagenet',
                              in_channels=self.channels,
                              classes=self.num_classes,
                              aux_params={"classes": self.num_classes - 1})

        elif net_name == 'deeplabv3+':
            if self.encoder_name is None:
                raise ValueError("encoder name must not be 'None'!")
            else:
                import segmentation_models_pytorch as smp
                net = smp.DeepLabV3Plus(
                    encoder_name=self.encoder_name,
                    encoder_weights=None
                    if not self.ex_pre_trained else 'imagenet',
                    in_channels=self.channels,
                    classes=self.num_classes,
                    aux_params={"classes": self.num_classes - 1})

        elif net_name.startswith('ResUNet'):
            from model import resUnet
            net = resUnet.__dict__[net_name](n_channels=self.channels,
                                             n_classes=self.num_classes)

        elif net_name.startswith('deeplabv3plus'):
            from model import deeplab
            net = deeplab.__dict__[net_name](n_channels=self.channels,
                                             n_classes=self.num_classes)
        return net
Exemple #16
0
def train(EPOCHES, BATCH_SIZE, train_image_paths, train_label_paths,
          val_image_paths, val_label_paths, channels, optimizer_name,
          model_path, swa_model_path, addNDVI, loss, early_stop):

    train_loader = get_dataloader(train_image_paths,
                                  train_label_paths,
                                  "train",
                                  addNDVI,
                                  BATCH_SIZE,
                                  shuffle=True,
                                  num_workers=8)
    valid_loader = get_dataloader(val_image_paths,
                                  val_label_paths,
                                  "val",
                                  addNDVI,
                                  BATCH_SIZE,
                                  shuffle=False,
                                  num_workers=8)

    # 定义模型,优化器,损失函数
    # model = smp.UnetPlusPlus(
    #         encoder_name="efficientnet-b7",
    #         encoder_weights="imagenet",
    #         in_channels=channels,
    #         classes=10,
    # )
    model = smp.UnetPlusPlus(
        encoder_name="timm-resnest101e",
        encoder_weights="imagenet",
        in_channels=channels,
        classes=10,
    )
    # model = seg_hrnet_ocr.get_seg_model()
    model.to(DEVICE)
    # model.load_state_dict(torch.load(model_path))
    # 采用SGD优化器
    if (optimizer_name == "sgd"):
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=1e-4,
                                    weight_decay=1e-3,
                                    momentum=0.9)
    # 采用AdamM优化器
    else:
        optimizer = torch.optim.AdamW(model.parameters(),
                                      lr=1e-4,
                                      weight_decay=1e-3)
    # 余弦退火调整学习率
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer,
        T_0=2,  # T_0就是初始restart的epoch数目
        T_mult=2,  # T_mult就是重启之后因子,即每个restart后,T_0 = T_0 * T_mult
        eta_min=1e-5  # 最低学习率
    )
    # # 使用SWA的初始epoch
    # swa_start = 80
    # # 随机权重平均SWA,以几乎不增加任何成本的方式实现更好的泛化
    # swa_model = AveragedModel(model).to(DEVICE)
    # # SWA调整学习率
    # swa_scheduler = SWALR(optimizer, swa_lr=1e-5)

    if (loss == "SoftCE_dice"):
        # 损失函数采用SoftCrossEntropyLoss+DiceLoss
        # diceloss在一定程度上可以缓解类别不平衡,但是训练容易不稳定
        DiceLoss_fn = DiceLoss(mode='multiclass')
        # 软交叉熵,即使用了标签平滑的交叉熵,会增加泛化性
        SoftCrossEntropy_fn = SoftCrossEntropyLoss(smooth_factor=0.1)
        loss_fn = L.JointLoss(first=DiceLoss_fn,
                              second=SoftCrossEntropy_fn,
                              first_weight=0.5,
                              second_weight=0.5).cuda()
    else:
        # 损失函数采用SoftCrossEntropyLoss+LovaszLoss
        # LovaszLoss是对基于子模块损失凸Lovasz扩展的mIoU损失的直接优化
        LovaszLoss_fn = LovaszLoss(mode='multiclass')
        # 软交叉熵,即使用了标签平滑的交叉熵,会增加泛化性
        SoftCrossEntropy_fn = SoftCrossEntropyLoss(smooth_factor=0.1)
        loss_fn = L.JointLoss(first=LovaszLoss_fn,
                              second=SoftCrossEntropy_fn,
                              first_weight=0.5,
                              second_weight=0.5).cuda()

    header = r'Epoch/EpochNum | TrainLoss | ValidmIoU | Time(m)'
    raw_line = r'{:5d}/{:8d} | {:9.3f} | {:9.3f} | {:9.2f}'
    print(header)

    #    # 在训练最开始之前实例化一个GradScaler对象,使用autocast才需要
    #    scaler = GradScaler()

    # 记录当前验证集最优mIoU,以判定是否保存当前模型
    best_miou = 0
    best_miou_epoch = 0
    train_loss_epochs, val_mIoU_epochs, lr_epochs = [], [], []
    # 开始训练
    for epoch in range(1, EPOCHES + 1):
        # print("Start training the {}st epoch...".format(epoch))
        # 存储训练集每个batch的loss
        losses = []
        start_time = time.time()
        model.train()
        model.to(DEVICE)
        for batch_index, (image, target) in enumerate(train_loader):
            image, target = image.to(DEVICE), target.to(DEVICE)
            # 在反向传播前要手动将梯度清零
            optimizer.zero_grad()
            #            # 使用autocast半精度加速训练,前向过程(model + loss)开启autocast
            #            with autocast(): #need pytorch>1.6
            # 模型推理得到输出
            output = model(image)
            # 求解该batch的loss
            loss = loss_fn(output, target)
            #                scaler.scale(loss).backward()
            #                scaler.step(optimizer)
            #                scaler.update()
            # 反向传播求解梯度
            loss.backward()
            # 更新权重参数
            optimizer.step()
            losses.append(loss.item())
        # if epoch > swa_start:
        #     swa_model.update_parameters(model)
        #     swa_scheduler.step()
        # else:
        # 余弦退火调整学习率
        scheduler.step()
        # 计算验证集IoU
        val_iou = cal_val_iou(model, valid_loader)
        # 输出验证集每类IoU
        # print('\t'.join(np.stack(val_iou).mean(0).round(3).astype(str)))
        # 保存当前epoch的train_loss.val_mIoU.lr_epochs
        train_loss_epochs.append(np.array(losses).mean())
        val_mIoU_epochs.append(np.mean(val_iou))
        lr_epochs.append(optimizer.param_groups[0]['lr'])
        # 输出进程
        print(raw_line.format(epoch, EPOCHES,
                              np.array(losses).mean(), np.mean(val_iou),
                              (time.time() - start_time) / 60**1),
              end="")
        if best_miou < np.stack(val_iou).mean(0).mean():
            best_miou = np.stack(val_iou).mean(0).mean()
            best_miou_epoch = epoch
            torch.save(model.state_dict(), model_path)
            print("  valid mIoU is improved. the model is saved.")
        else:
            print("")
            if (epoch - best_miou_epoch) >= early_stop:
                break
    # # 最后更新BN层参数
    # torch.optim.swa_utils.update_bn(train_loader, swa_model, device= DEVICE)
    # # 计算验证集IoU
    # val_iou = cal_val_iou(model, valid_loader)
    # print("swa_model'mIoU is {}".format(np.mean(val_iou)))
    # torch.save(swa_model.state_dict(), swa_model_path)
    return train_loss_epochs, val_mIoU_epochs, lr_epochs
Exemple #17
0
def inference():
    import segmentation_models_pytorch as smp
    import torch

    DEVICE = 'cpu'
    if torch.cuda.is_available():
        print('using gpu')
        DEVICE = "cuda:0"
    else:
        print('using cpu')

    load_model_path = "unetpp_best_model.pth"

    model = smp.UnetPlusPlus(encoder_weights=None,
                             in_channels=2,
                             classes=3,
                             activation='sigmoid')
    model.load_state_dict(torch.load(load_model_path, map_location=DEVICE))
    if torch.cuda.is_available():
        model.cuda()

    testset_path = "test"

    testset = Dataset(testset_path, )

    x, y = testset[1]

    # convert numpy to tensor
    import torchvision.transforms as transforms

    transform = transforms.Compose([transforms.ToTensor()])

    # pad to three dimension so we can use transform
    _x = np.stack((x[..., 0], x[..., 1],
                   np.zeros((x.shape[0], x.shape[1]), dtype=np.float32)),
                  axis=2)

    # convert to tensor
    image_tensor = transform(_x)

    # take the first two images
    image_tensor = image_tensor[0:2]

    # make batch size == 1
    image_tensor = image_tensor.unsqueeze(0)

    # cast to float
    image_tensor = image_tensor.float()

    # fit model
    model.eval()

    if torch.cuda.is_available():
        image_tensor = image_tensor.cuda()

    pred = model(image_tensor)

    # convert tensor to numpy
    # convert tensor to numpy
    np_pred = pred.detach().cpu().numpy().squeeze()

    # switch axis
    np_pred = np.transpose(np_pred, (1, 2, 0))

    # evalute and save predict results
    import segmentation_models_pytorch.utils.metrics as metrics

    xs = []
    ys = []
    preds = []
    cts = []
    fts = []
    mns = []
    stacks = []

    for i in range(len(testset)):
        x, y = testset[i]
        xs.append(x)
        ys.append(y)
        # pad to three dimension so we can use transform
        _x = np.stack((x[..., 0], x[..., 1],
                       np.zeros((x.shape[0], x.shape[1]), dtype=np.float32)),
                      axis=2)
        # convert to tensor
        image_tensor = transform(_x)
        # take the first two images
        image_tensor = image_tensor[0:2]
        # make batch size == 1
        image_tensor = image_tensor.unsqueeze(0)
        # cast to float
        image_tensor = image_tensor.float()
        if torch.cuda.is_available():
            image_tensor = image_tensor.cuda()
        # fit model
        model.eval()
        pred = model(image_tensor)
        # convert tensor to numpy
        np_pred = pred.detach().cpu().numpy().squeeze()
        # switch axis
        np_pred = np.transpose(np_pred, (1, 2, 0))
        # transfrom groundtruth to tensor
        _y = transform(y)
        _y = _y.unsqueeze(0)
        if torch.cuda.is_available():
            _y = _y.cuda()

        # print each image
        print('{}'.format(i))
        # print('ct iou')
        # print(metrics.IoU()(pred[:, 0, ...], _y[:, 0, ...]))
        # print('ct dice', end='  ')
        # print(metrics.Fscore()(pred[:, 0, ...], _y[:, 0, ...]).item() )
        # print('ft iou')
        # print(metrics.IoU()(pred[:, 1, ...], _y[:, 1, ...]))
        # print('ft dice', end='  ')
        # print(metrics.Fscore()(pred[:, 1, ...], _y[:, 1, ...]).item() )
        # print('mn iou')
        # print(metrics.IoU()(pred[:, 2, ...], _y[:, 2, ...]))
        # print('mn dice', end='  ')
        # print(metrics.Fscore()(pred[:, 2, ...], _y[:, 2, ...]).item() )
        # print('\n')

        # save dice
        cts.append(metrics.Fscore()(pred[:, 0, ...], _y[:, 0, ...]).item())
        fts.append(metrics.Fscore()(pred[:, 1, ...], _y[:, 1, ...]).item())
        mns.append(metrics.Fscore()(pred[:, 2, ...], _y[:, 2, ...]).item())

        # save pred results
        preds.append(np_pred)

        # stack image
        ct = (np_pred[..., 0] * 255).astype(np.uint8)
        ft = (np_pred[..., 1] * 255).astype(np.uint8)
        mn = (np_pred[..., 2] * 255).astype(np.uint8)

        blurred = cv2.GaussianBlur(ct, (11, 11), 0)
        binaryIMG_ct = cv2.Canny(blurred, 20, 160)

        blurred = cv2.GaussianBlur(ft, (11, 11), 0)
        binaryIMG_ft = cv2.Canny(blurred, 20, 160)

        blurred = cv2.GaussianBlur(mn, (11, 11), 0)
        binaryIMG_mn = cv2.Canny(blurred, 20, 160)

        # ct red, mn yellow, ft blue
        stacked = np.stack(
            (binaryIMG_ct + binaryIMG_mn, binaryIMG_mn, binaryIMG_ft), axis=2)

        clone_t1 = (x[..., 0] * 255).astype(np.uint8).copy()
        clone_t1 = np.stack((clone_t1, clone_t1, clone_t1), axis=2)
        draw_t1 = (clone_t1 * 0.5 + stacked * 0.5).astype(np.uint8)

        stacks.append(draw_t1)

    print('count len')
    print(len(xs))
    print(len(ys))
    print(len(preds))
    print(len(cts))
    print(len(fts))
    print(len(mns))
    print(len(stacks))

    print(np.mean(np.array(cts)))
    print(np.mean(np.array(fts)))
    print(np.mean(np.array(mns)))

    # visual input data and mask and predict result
    # visualize(
    #     t1=xs[0][...,0],
    #     t2=xs[0][...,1],
    #     ct=ys[0][...,0],
    #     ft=ys[0][...,1],
    #     mn=ys[0][...,2],
    #     p_ct=preds[0][..., 0],
    #     p_ft=preds[0][..., 1],
    #     p_mn=preds[0][..., 2],
    #     stack=stacks[0]
    # )

    return xs, ys, preds, cts, fts, mns, stacks
Exemple #18
0
 def __init__(self, backbone):
     super(SegModel, self).__init__()
     self.seg = smp.UnetPlusPlus(encoder_name=backbone,
                                 encoder_weights='imagenet',
                                 classes=2,
                                 activation=None)
Exemple #19
0
            img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)

        msk_path = f'{cfg.label_path}/{fname}.png'
        mask = cv2.imread(msk_path, cv2.IMREAD_GRAYSCALE)
        if mask is None:
            msk_path = f'{cfg.label_path}/external_{fname}.png'
            mask = cv2.imread(msk_path, cv2.IMREAD_GRAYSCALE)
        if self.tfms is not None:
            augmented = self.tfms(image=img, mask=mask)
            img, mask = augmented['image'], augmented['mask']
        return img2tensor((img / 255.0 - cfg.mean) / cfg.std), img2tensor(mask)


name = 'UnetPlusPlus'
base_model = smp.UnetPlusPlus(encoder_name=cfg.encoder_name,
                              encoder_weights=cfg.encoder_weights,
                              in_channels=cfg.in_channels,
                              classes=cfg.classes)


class HuBMAPModel(nn.Module):
    def __init__(self):
        super(HuBMAPModel, self).__init__()
        self.cnn_model = base_model

    def forward(self, imgs):
        img_segs = self.cnn_model(imgs)
        return img_segs


for n, (tr, te) in enumerate(kfold):
    fold = n
def create_model(model_name,
                 encoder_name,
                 pretrained=False,
                 num_classes=6,
                 in_chans=3,
                 checkpoint_path='',
                 **kwargs):
    """Create a model
    Args:
        model_name (str): name of model to instantiate
        encoder_name (str): name of encoder to instantiate
        pretrained (bool): load pretrained ImageNet-1k weights if true
        num_classes (int): number of classes for final layer (default 6)
        in_chans (int): number of input channels / colors (default: 3)
        checkpoint_path (str): path of checkpoint to load after model is initialized
    Keyword Args:
        **: other kwargs are model specific
    """
    # I should probably rewrite it
    weights = None
    if pretrained:
        weights = 'imagenet'
        _logger.info('Using pre-trained imagenet weights')

    if model_name == 'unetplusplus':
        model = smp.UnetPlusPlus(encoder_name=encoder_name,
                                 encoder_weights=weights,
                                 classes=num_classes,
                                 in_channels=in_chans,
                                 **kwargs)
    elif model_name == 'unet':
        model = smp.Unet(encoder_name=encoder_name,
                         encoder_weights=weights,
                         classes=num_classes,
                         in_channels=in_chans,
                         **kwargs)
    elif model_name == 'fpn':
        model = smp.FPN(encoder_name=encoder_name,
                        encoder_weights=weights,
                        classes=num_classes,
                        in_channels=in_chans,
                        **kwargs)
    elif model_name == 'linknet':
        model = smp.Linknet(encoder_name=encoder_name,
                            encoder_weights=weights,
                            classes=num_classes,
                            in_channels=in_chans,
                            **kwargs)
    elif model_name == 'pspnet':
        model = smp.PSPNet(encoder_name=encoder_name,
                           encoder_weights=weights,
                           classes=num_classes,
                           in_channels=in_chans,
                           **kwargs)
    else:
        raise NotImplementedError()

    if checkpoint_path:
        load_checkpoint(model, checkpoint_path)

    return model
Exemple #21
0
            val_iou.append(iou)
    return val_iou


# repo: https://github.com/qubvel/segmentation_models.pytorch
# doc: https://smp.readthedocs.io/en/latest/

header = r'''Epoch |  Loss |  Score | Time(min)'''
raw_line = '{:8d}' + '\u2502{:8f}' * 2 + '\u2502{:8f}'
class_name = [
    'farm', 'land', 'forest', 'grass', 'road', 'urban_area', 'countryside',
    'industrial_land', 'construction', 'water', 'bareland'
]

model = smp.UnetPlusPlus(encoder_name="efficientnet-b6",
                         encoder_weights='imagenet',
                         in_channels=3,
                         classes=10)
model.train()
model.to(DEVICE)

# optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0005)

# loss
# DiceLoss, JaccardLoss, SoftBCEWithLogitsLoss, SoftCrossEntropyLoss
DiceLoss_fn = DiceLoss(mode='multiclass')
SoftCrossEntropy_fn = SoftCrossEntropyLoss(smooth_factor=0.1)
loss_fn = L.JointLoss(first=DiceLoss_fn,
                      second=SoftCrossEntropy_fn,
                      first_weight=0.5,
                      second_weight=0.5).to(DEVICE)
Exemple #22
0
def train_fold(val_index, X_images, Masks, image_dims, fold, train=True, predict=True):
    positive_idxs, negative_idxs = get_indexes(val_index, X_images, Masks, image_dims, size, step_size, t=0.01)
    kid_sampler = KidneySampler(positive_idxs, negative_idxs, not_empty_ratio)
    train_dataset = KidneyLoader(X_images, Masks, image_dims, positive_idxs, False, size, step_size=step_size,
                                 val_index=val_index, new_augs=new_augs, augumentations=augumentations,
                                 size_after_reshape=size_after_reshape)
    val_dataset = KidneyLoader(X_images, Masks, image_dims, positive_idxs, True, size, val_index=val_index,
                               new_augs=new_augs, size_after_reshape=size_after_reshape)
    if not use_sampler:
        print("Use full dataset")
        trainloader = DataLoader(train_dataset, batch_size=bs, shuffle=True, num_workers=16)
    else:
        print("Use sample dataset")
        trainloader = DataLoader(train_dataset, batch_size=bs, shuffle=False, num_workers=16, sampler=kid_sampler)
    valloader = DataLoader(val_dataset, batch_size=bs * 2, shuffle=False, num_workers=16)
    x, y, key = train_dataset[10]
    print(len(trainloader), len(valloader))
    if seg_model_name == 'unet':
        print('Use unet model')
        model = smp.Unet(encoder, encoder_weights="imagenet", in_channels=3, classes=1,
                         decoder_use_batchnorm=False).cuda()
    elif seg_model_name == 'unet++':
        print('Use unet++ model')
        model = smp.UnetPlusPlus(encoder, encoder_weights="imagenet", in_channels=3, classes=1,
                                 decoder_use_batchnorm=False).cuda()
    elif seg_model_name == 'albunet':
        print('Use albunet model')
        from nn.trainer import AlbuNet
        model = AlbuNet(num_classes=1, pretrained=True).cuda()
    elif seg_model_name == 'scseunet':
        print('Use SCSEUnet model')
        model = SCSEUnet(seg_classes=1).cuda()
    else:
        print('Model name is incorrect. Set to unet++')
        model = smp.UnetPlusPlus(encoder, encoder_weights="imagenet", in_channels=3, classes=1,
                                 decoder_use_batchnorm=False).cuda()
    if parallel:
        model = DataParallel(model).cuda()

    if loss_name == 'comboloss':
        print("Use combo loss")
        loss = ComboLoss(weights=weights)
    else:
        print("Use BCE Loss")
        loss = BCEWithLogitsLoss()
    optim = AdamW(model.parameters(), lr=max_lr)
    if fp16:
        model, optimizer = apex.amp.initialize(
            model,
            optim,
            opt_level='O1')
    scheduler_params_cos = dict(
        T_max=epochs, eta_min=min_lr
    )
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, **scheduler_params_cos)
    # ---- Start train loop here ----- #
    print('Start train')
    min_val_loss = 100
    max_val_dice = -100
    best_dice_epochs = []
    if train:
        for epoch in range(epochs + epochs_minlr):
            with open(f"../{model_name}/{model_name}.log", 'a+') as logger:
                logger.write(f"Epoch # {epoch}, lr = {optim.param_groups[0]['lr']}\n")
            empty_cache()
            dice = 0
            print(f"epoch number {epoch}, lr = {optim.param_groups[0]['lr']}")
            if use_sampler & use_adaptive_sampler:
                if epoch < 10:
                    kid_sampler = KidneySampler(positive_idxs, negative_idxs, 0.2)
                    trainloader = DataLoader(train_dataset, batch_size=bs,
                                             shuffle=False, num_workers=16, sampler=kid_sampler)
                else:

                    kid_sampler = KidneySampler(positive_idxs, negative_idxs, 0.5)
                    trainloader = DataLoader(train_dataset, batch_size=bs,
                                             shuffle=False, num_workers=16, sampler=kid_sampler)

            model, optim, pred_keys, final_masks, \
            train_loss, train_dice = train_one_epoch(model, optim, trainloader, size, loss, store_train_masks=False)
            print(f"train loss = {train_loss}")
            del pred_keys, final_masks
            gc.collect()
            model, optim, val_keys, val_masks, val_loss = val_one_epoch(model, optim, valloader, size, loss)
            print(f"val loss = {val_loss}")
            if epoch < epochs:
                scheduler.step()
            m = 0
            for img_number in val_index:
                _, val_dice = calculate_dice(Masks, val_keys, val_masks, img_number, size, image_dims)
                with open(f"../{model_name}/{model_name}.log", 'a+') as logger:
                    logger.write(f'dice on image {img_number} = {val_dice} ')
                print(f'dice on image {img_number} = {val_dice}')
                m += val_dice
            with open(f"../{model_name}/{model_name}.log", 'a+') as logger:
                logger.write(f'\n')
            val_dice = m / len(val_index)
            if predict_by_epochs != 'best':
                if len(best_dice_epochs) < predict_by_epochs:
                    best_dice_epochs.append((epoch, val_dice))
                    #save(model.state_dict(), f"../{model_name}/{model_name}_{epoch}_{fold}.h5")
                    print('Best epochs updated. ', best_dice_epochs)
                else:
                    ind = np.argmin([el[1] for el in best_dice_epochs])
                    if best_dice_epochs[ind][1] < val_dice:
                        best_dice_epochs[ind] = (epoch, val_dice)
                        save(model.state_dict(), f"../{model_name}/{model_name}_{epoch}_{fold}.h5")
                        print('Best epochs updated. ', best_dice_epochs)
                    else:
                        print('No change in best epochs: ', best_dice_epochs)
            # val_dice = calc_average_dice(Masks, val_keys, val_masks, val_index, image_dims, size)
            if val_dice > max_val_dice:
                max_val_dice = val_dice
                save(model.state_dict(), f"../{model_name}/{model_name}_{epoch}_{fold}.h5")
                save(model.state_dict(), f"../{model_name}/last_best_model.h5")

            print("Dice on train micro ", train_dice)
            print(f"Dice on val (average) = {val_dice}")
            with open(f"../{model_name}/{model_name}.log", 'a+') as logger:
                logger.write(f'train loss = {train_loss}, train dice (micro) = {train_dice}\n')
                logger.write(f'validation loss = {val_loss}, val dice = {val_dice}\n')
                logger.write('\n')
            del val_keys, val_masks
            gc.collect()
            print("=====================")
        with open(f"../{model_name}/{model_name}.log", 'a+') as logger:
            logger.write(f'Best epochs = {best_dice_epochs}\n')
        model.load_state_dict(load(f"../{model_name}/last_best_model.h5"))
        val_keys, val_masks = predict_data(model, valloader, size, True)
        #val_dice = calc_average_dice(Masks, val_keys, val_masks, val_index, image_dims, size)
        best_t, best_val_dice = search_for_best_threshold(Masks, val_keys, val_masks, val_index, image_dims, size)
        print(f"Dice on val (average) with TTA = {best_val_dice} with t = {best_t}")
        with open(f"../{model_name}/{model_name}.log", 'a+') as logger:
            logger.write(f'dice on val with TTA = {best_val_dice} with t = {best_t}\n')
        del val_keys, val_masks
        gc.collect()
    if predict:
        if not train:
            pass
            # if fold == 0:
            #     best_dice_epochs = [(34, 0.938852186919906), (39, 0.9393114299799262), (36, 0.9388174075553963), (38, 0.9389920761027962)]
            # elif fold == 1:
            #     best_dice_epochs = [(28, 0.9027552584839075), (36, 0.9033416557294215), (37, 0.9031257069334585), (38, 0.9029346479118595)]
            # elif fold == 2:
            #     best_dice_epochs = [(38, 0.9328974798970685), (33, 0.9329123230368631), (37, 0.9331401395093545), (39, 0.932692250874815)]
            # elif fold == 3:
            #     best_dice_epochs = [(31, 0.9388590937356484), (39, 0.9368731018176651), (30, 0.9372210758433651), (38, 0.9375657655754374)]
            # elif fold == 4:
            #     best_dice_epochs = [(37, 0.9372186165197486), (32, 0.9372384178794498), (36, 0.9372328533782275), (33, 0.9368665971919303)]

        sample_sub = pd.read_csv(data_path + 'sample_submission.csv')
        test_paths = sample_sub.id.values
        print('Start test')
        del X_images
        gc.collect()
        X_test_images = []
        img_dims_test = []
        for name in test_paths:
            img = tiff.imread(data_path + f"test/{name}.tiff")
            if img.shape[0] == 3:
                img = np.moveaxis(img, 0, 2)
            X_test_images.append(img)
            img_dims_test.append(img.shape)
            print(img.shape)
        del img
        gc.collect()

        test_dataset = ValLoader(X_test_images, img_dims_test, size, new_augs=new_augs,
                                 size_after_reshape=size_after_reshape, overlap=overlap, step_size=step_size)
        testloader = DataLoader(test_dataset, batch_size=bs * 2, shuffle=False, num_workers=16)
        if predict_by_epochs == 'best':
            model.load_state_dict(load(f'../{model_name}/last_best_model.h5'))
            test_masks, test_keys = predict_test(model, size, testloader, True)
            del X_test_images
            gc.collect()
            masks = []
            for n in range(len(sample_sub)):
                mask = make_masks(test_keys, test_masks, n, img_dims_test, size, overlap=overlap, step_size=step_size)
                masks.append(mask)
            return masks

        else:
            bled_masks = [np.zeros(s[:2]) for s in img_dims_test]
            for epoch in best_dice_epochs:
                print(f'Start predict for epoch {epoch[0]}')
                model.load_state_dict(load(f'../{model_name}/{model_name}_{epoch[0]}_{fold}.h5'))
                test_masks, test_keys = predict_test(model, size, testloader, True)
                for n in range(len(sample_sub)):
                    mask = make_masks(test_keys, test_masks, n, img_dims_test, size, overlap=overlap, step_size=step_size)
                    bled_masks[n] += mask / len(best_dice_epochs)

            del X_test_images
            gc.collect()
            if store_masks:
                for j, mask in enumerate(bled_masks):
                    with h5py.File(f'../{model_name}/{model_name}_mask_{j}_fold_{fold}.txt', "w") as f:
                        dset = f.create_dataset("mask", data=mask, dtype='f')
                    # np.savetxt(f'../{model_name}/{model_name}_mask_{j}.txt', mask)
            for tt in range(2, 7):
                all_enc = []
                t = tt/10
                for mask in bled_masks:
                    mask_c = mask.copy()
                    mask_c[mask_c < t] = 0
                    mask_c[mask_c >= t] = 1
                    enc = mask2enc(mask_c)
                    all_enc.append(enc[0])
                sample_sub.predicted = all_enc
                s = ''.join([str(e[0]) + '_' for e in best_dice_epochs])[:-1]
                sample_sub.to_csv(f'../{model_name}/mean_{model_name}_{s}_fold_{fold}_t_{t}_overlap_{overlap}.csv', index=False)

            return bled_masks
    else:
        return []