Ejemplo n.º 1
0
 def __init__(self,
              encoder_name="resnet34",
              encoder_depth=5,
              encoder_weights="imagenet",
              encoder_output_stride=16,
              decoder_channels=256,
              decoder_atrous_rates=(12, 24, 36),
              in_channels=3,
              classes=10,
              activation=None,
              upsampling=4,
              aux_params=None, ):
     super().__init__()
     self.model = smp.DeepLabV3Plus(
         encoder_name=encoder_name,
         encoder_depth=encoder_depth,
         encoder_weights=encoder_weights,
         encoder_output_stride=encoder_output_stride,
         decoder_channels=decoder_channels,
         decoder_atrous_rates=decoder_atrous_rates,
         in_channels=in_channels,
         classes=classes,
         activation=activation,
         upsampling=upsampling,
         aux_params=aux_params,
     )
Ejemplo n.º 2
0
 def __init__(self, config):
     super(LesionModel, self).__init__()
     if config["A"] == "unet":
         self.model = smp.Unet(
             encoder_name=config[
                 "EN"],  # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
             encoder_weights=config["encoder_weights"],
             # use `imagenet` pre-trained weights for encoder initialization
             in_channels=
             3,  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
             classes=config[
                 "OC"],  # model output channels (number of classes in your dataset)
         )
     elif config["A"] == "DeepLabV3Plus":
         self.model = smp.DeepLabV3Plus(
             encoder_name=config[
                 "EN"],  # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
             encoder_weights=config["encoder_weights"],
             # use `imagenet` pre-trained weights for encoder initialization
             in_channels=
             3,  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
             classes=config[
                 "OC"],  # model output channels (number of classes in your dataset)
         )
     self.model_cfg = config
     self.loss_function = DiceLoss()
     self.save_hyperparameters()
     self.iou_function = IOU()
     self.checkpoint_path = ""
def createDeepLabv3Plus(outputchannels=1):
    """Creates the DeepLabV3+ model with the segmentation models library.

    Args:
        outputchannels (int, optional): Number of output channels.

    Returns:
        model: Returns the DeepLabv3+ model with the MobileNetV2 backbone.
    """
    # aux_params=dict(
    #     pooling='avg',             # one of 'avg', 'max'
    #     dropout=0.5,               # dropout ratio, default is None
    #     activation='sigmoid',      # activation function, default is None
    #     classes=1,                 # define number of output labels
    # )

    model = smp.DeepLabV3Plus(
        encoder_name=
        "mobilenet_v2",  # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
        in_channels=
        3,  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
        classes=1  # model output channels (number of classes in your dataset)
    )

    return model
Ejemplo n.º 4
0
def deeplab(encoder, encoder_weights):
    model = smp.DeepLabV3Plus(\
        encoder_name=encoder,\
        encoder_weights=encoder_weights,\
        classes=1,\
        activation='sigmoid')
    return model
Ejemplo n.º 5
0
    def get_model_factory(self):
        """Get the corresponding model to self.architecture_name and self.encoder_name
      Design pattern: Factory"""
        if self.architecture_name == "Unet":
            return smp.Unet(encoder_name=self.encoder_name,
                            in_channels=self.in_channels,
                            classes=self.num_classes)
        elif self.architecture_name == "DeepLabV3+":
            return smp.DeepLabV3Plus(encoder_name=self.encoder_name,
                                     in_channels=self.in_channels,
                                     classes=self.num_classes)
        elif self.architecture_name == "SwinUnet":
            return models.swin_unet_2d(
                input_size=self.input_size,
                filter_num_begin=self.filter_num_begin,
                n_labels=self.n_labels,
                depth=self.depth,
                stack_num_down=self.stack_num_down,
                stack_num_up=self.stack_num_up,
                patch_size=self.patch_size,
                num_heads=self.num_heads,
                window_size=self.window_size,
                num_mlp=self.num_mlp,
                output_activation=self.output_activation,
                shift_window=self.shift_window)

        else:
            raise Exception(
                "Architecture {0:s} has not been implemented. Please select among [Unet, DeepLabV3+]"
                .format(self.architecture_name))
            return None
Ejemplo n.º 6
0
def get_model():
    model = smp.DeepLabV3Plus(
        encoder_name="resnext101_32x8d",
        encoder_weights='imagenet',
        in_channels=3,
        classes=1
    )
    return model
Ejemplo n.º 7
0
def build_model(configuration):
    model_list = ['UNet', 'LinkNet', 'PSPNet', 'FPN', 'PAN', 'Deeplab_v3', 'Deeplab_v3+']
    if configuration.Model.model_name.lower() == 'unet':
        return smp.Unet(
            encoder_name=configuration.Model.encoder,
            encoder_weights=configuration.Model.encoder_weights,
            activation=None,
            classes=configuration.DataSet.number_of_classes,
            decoder_attention_type=None,
        )
    if configuration.Model.model_name.lower() == 'linknet':
        return smp.Linknet(
            encoder_name=configuration.Model.encoder,
            encoder_weights=configuration.Model.encoder_weights,
            activation=None,
            classes=configuration.DataSet.number_of_classes
        )
    if configuration.Model.model_name.lower() == 'pspnet':
        return smp.PSPNet(
            encoder_name=configuration.Model.encoder,
            encoder_weights=configuration.Model.encoder_weights,
            activation=None,
            classes=configuration.DataSet.number_of_classes
        )
    if configuration.Model.model_name.lower() == 'fpn':
        return smp.FPN(
            encoder_name=configuration.Model.encoder,
            encoder_weights=configuration.Model.encoder_weights,
            activation=None,
            classes=configuration.DataSet.number_of_classes
        )
    if configuration.Model.model_name.lower() == 'pan':
        return smp.PAN(
            encoder_name=configuration.Model.encoder,
            encoder_weights=configuration.Model.encoder_weights,
            activation=None,
            classes=configuration.DataSet.number_of_classes
        )
    if configuration.Model.model_name.lower() == 'deeplab_v3+':
        return smp.DeepLabV3Plus(
            encoder_name=configuration.Model.encoder,
            encoder_weights=configuration.Model.encoder_weights,
            activation=None,
            classes=configuration.DataSet.number_of_classes
        )
    if configuration.Model.model_name.lower() == 'deeplab_v3':
        return smp.DeepLabV3(
            encoder_name=configuration.Model.encoder,
            encoder_weights=configuration.Model.encoder_weights,
            activation=None,
            classes=configuration.DataSet.number_of_classes
        )
    raise KeyError(f'Model should be one of {model_list}')
Ejemplo n.º 8
0
def SkinLesionModel(model, pretrained=True):
    models_zoo = {
        'deeplabv3plus':
        smp.DeepLabV3Plus('resnet101',
                          encoder_weights='imagenet',
                          aux_params=None),
        'deeplabv3plus_resnext':
        smp.DeepLabV3Plus('resnext101_32x8d',
                          encoder_weights='imagenet',
                          aux_params=None),
        'pspnet':
        smp.PSPNet('resnet101', encoder_weights='imagenet', aux_params=None),
        'unetplusplus':
        smp.UnetPlusPlus('resnet101',
                         encoder_weights='imagenet',
                         aux_params=None),
    }
    net = models_zoo.get(model)
    if net is None:
        raise Warning('Wrong Net Name!!')
    return net
Ejemplo n.º 9
0
 def get_model_factory(self):
   """Get the corresponding model to self.architecture_name and self.encoder_name
   Design pattern: Factory"""
   if self.architecture_name == "Unet":
     return smp.Unet(encoder_name = self.encoder_name, in_channels = self.in_channels,
                     classes = self.num_classes)
   elif self.architecture_name == "DeepLabV3+":
     return smp.DeepLabV3Plus(encoder_name = self.encoder_name, in_channels = self.in_channels,
                     classes = self.num_classes)
   
   else:
     raise Exception("Architecture {0:s} has not been implemented. Please select among [Unet, DeepLabV3+]".format(self.architecture_name))
     return None
Ejemplo n.º 10
0
 def __init__(self):
     super().__init__()
     # Resnet config
     aux_params = dict(
         pooling='max',  # one of 'avg', 'max'
         dropout=0.5,  # dropout ratio, default is None
         activation='sigmoid',  # activation function, default is None
         classes=1,  # define number of output labels
     )
     self.model = smp.DeepLabV3Plus(encoder_name="efficientnet-b6",
                                    encoder_weights="imagenet",
                                    in_channels=3,
                                    classes=1,
                                    aux_params=aux_params)
Ejemplo n.º 11
0
    def __init__(self):
        super().__init__()
        pretrained_model = smp.DeepLabV3Plus(encoder_name='resnet152',
                                             encoder_depth=5,
                                             encoder_weights='imagenet')
        self.model = pretrained_model
        self.sigmoid = nn.Sigmoid()

        self.iou_loss = IoUBCELoss()

        self.train_acc = pl.metrics.Accuracy()
        self.train_f1 = pl.metrics.F1()
        self.train_f2 = pl.metrics.FBeta(num_classes=2, beta=2)
        self.val_acc = pl.metrics.Accuracy()
        self.val_f1 = pl.metrics.F1()
        self.val_f2 = pl.metrics.FBeta(num_classes=2, beta=2)
Ejemplo n.º 12
0
def create_smp_model(arch, **kwargs):
    'Create segmentation_models_pytorch model'

    assert arch in ARCHITECTURES, f'Select one of {ARCHITECTURES}'

    if arch == "Unet": model = smp.Unet(**kwargs)
    elif arch == "UnetPlusPlus": model = smp.UnetPlusPlus(**kwargs)
    elif arch == "MAnet": model = smp.MAnet(**kwargs)
    elif arch == "FPN": model = smp.FPN(**kwargs)
    elif arch == "PAN": model = smp.PAN(**kwargs)
    elif arch == "PSPNet": model = smp.PSPNet(**kwargs)
    elif arch == "Linknet": model = smp.Linknet(**kwargs)
    elif arch == "DeepLabV3": model = smp.DeepLabV3(**kwargs)
    elif arch == "DeepLabV3Plus": model = smp.DeepLabV3Plus(**kwargs)
    else: raise NotImplementedError

    setattr(model, 'kwargs', kwargs)
    return model
 def __init__(self,
              num_classes=12,
              encoder="resnext101_32x8d",
              pretrain_weight="imagenet",
              decoder="DeepLabV3Plus"):
     super(smpModel, self).__init__()
     if (decoder == "DeepLabV3Plus"):
         self.backbone = smp.DeepLabV3Plus(encoder_name=encoder,
                                           encoder_weights=pretrain_weight,
                                           in_channels=3,
                                           classes=num_classes)
     elif (decoder == "DeepLabV3"):
         self.backbone = smp.DeepLabV3(encoder_name=encoder,
                                       encoder_weights=pretrain_weight,
                                       in_channels=3,
                                       classes=num_classes)
     elif (decoder == "UnetPlusPlus"):
         self.backbone = smp.UnetPlusPlus(encoder_name=encoder,
                                          encoder_weights=pretrain_weight,
                                          in_channels=3,
                                          classes=num_classes)
Ejemplo n.º 14
0
def get_model(num_classes,
              encoder='resnet50',
              encoder_weight='imagenet',
              activation='sigmoid'):
    ''' return model and optional preprocessing function '''

    # create segmentation model with pretrained encoder
    model = smp.DeepLabV3Plus(
        encoder_name=encoder,
        encoder_weights=None,
        classes=num_classes,
        activation=activation,
    )

    if encoder_weight is None:
        preprocessing_fn = no_pretrain_precessing
    else:
        preprocessing_fn = smp.encoders.get_preprocessing_fn(
            encoder, encoder_weight)

    return model, preprocessing_fn
Ejemplo n.º 15
0
    def choose_network_architecture(self):
        # UNet
        if self.cfg.ARCH == 'UNet':
            print(f"===> Network Architecture: {self.cfg.ARCH}")
            # create segmentation model with pretrained encoder
            Unet = smp.Unet(
                encoder_name=self.cfg.ENCODER,
                encoder_weights=self.cfg.ENCODER_WEIGHTS,
                classes=len(self.cfg.CLASSES),
                activation=self.cfg.ACTIVATION,
            )
            return Unet

        # UNet
        if self.cfg.ARCH == 'DeepLabV3+':
            print(f"===> Network Architecture: {self.cfg.ARCH}")
            # create segmentation model with pretrained encoder
            Unet = smp.DeepLabV3Plus(
                encoder_name=self.cfg.ENCODER,
                encoder_weights=self.cfg.ENCODER_WEIGHTS,
                classes=len(self.cfg.CLASSES),
                activation=self.cfg.ACTIVATION,
                # in_channels = 1
            )
            return Unet

        # FCN
        if self.cfg.ARCH == 'FCN':
            print(f"===> Network Architecture: {self.cfg.ARCH}")
            if self.cfg.ENCODER == 'resnet50':
                fcn = models.segmentation.fcn_resnet50(pretrained=True)
            if self.cfg.ENCODER == 'resnet101':
                fcn = models.segmentation.fcn_resnet101(pretrained=True)

            # FCNHead = models.segmentation.fcn.FCNHead
            UnetSegHead = smp.base.heads.SegmentationHead
            fcn.classifier[4] = UnetSegHead(
                512, 1, kernel_size=1,
                activation='sigmoid')  # FCNHead(2048, 1)
            return fcn
Ejemplo n.º 16
0
    def __init__(self):
        super(model1, self).__init__()
        # model = smp.Unet()
        self.deeplab_model = smp.DeepLabV3Plus(encoder_name='resnet101')

        # encoder
        self.deeplab_encoder = self.deeplab_model.encoder

        # self attention generation
        self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))

        self.mlp1a = nn.Linear(2048, 4096)
        self.mlp2a = nn.Linear(4096, 2048)

        self.mlp1b = nn.Linear(256, 4096)
        self.mlp2b = nn.Linear(4096, 256)

        self.upsample = nn.Upsample(16)  #.cuda(0)

        # decoder
        self.deeplab_decoder = self.deeplab_model.decoder
        self.deeplab_decoder = self.deeplab_decoder  #.cuda(1)
        self.deeplab_seghead = self.deeplab_model.segmentation_head
Ejemplo n.º 17
0
def test_1(model_path, output_dir, test_loader, addNDVI):
    in_channels = 4
    if(addNDVI):
        in_channels += 1
    # model = smp.UnetPlusPlus(
    #         encoder_name="resnet101",
    #         encoder_weights="imagenet",
    #         in_channels=4,
    #         classes=10,
    # )
    model = smp.DeepLabV3Plus(
            encoder_name="resnet101",
            encoder_weights="imagenet",
            in_channels=4,
            classes=10,
    )
    # 如果模型是SWA
    if("swa" in model_path):
        model = AveragedModel(model)
    model.to(DEVICE);
    model.load_state_dict(torch.load(model_path))
    model.eval()

    for image, image_stretch, image_path, ndvi in test_loader:
        with torch.no_grad():
            image = image.cuda()
            image_stretch = image_stretch.cuda()
            output1 = model(image).cpu().data.numpy()
            output2 = model(image_stretch).cpu().data.numpy()
        output = (output1 + output2) / 2.0
        for i in range(output.shape[0]):
            pred = output[i]
            pred = np.argmax(pred, axis = 0) + 1
            pred = np.uint8(pred)
            save_path = os.path.join(output_dir, image_path[i][-10:].replace('.tif', '.png'))
            print(save_path)
            cv2.imwrite(save_path, pred)
Ejemplo n.º 18
0
import torch
import numpy as np
import segmentation_models_pytorch as smp

device = torch.device("cuda")
model = smp.DeepLabV3Plus(
    encoder_name="efficientnet-b7",
    encoder_weights=None,
    encoder_depth=5,
    in_channels=3,
    classes=2,
).to(device)
model.eval()
print(model)
a = np.random.randn(1, 3, 256, 256).astype(np.float32)
a = torch.from_numpy(a).to(device)
b = model(a)
print(b.shape)
Ejemplo n.º 19
0
            torch.save({'model_dict': model.state_dict()},
                       checkpoint_save_path)
            print('model saved')
            m = 0
        else:
            m += 1
            if m >= max_patient:
                torch.save({'model_dict': model.state_dict()},
                           checkpoint_save_path)
                return

    return


# In[ ]:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
lr = 0.001

model = smp.DeepLabV3Plus(encoder_name='resnet101',
                          aux_params={'classes': num_class})
model = model.to(device)

opt = Adam(model.parameters(), lr=lr)

bce_loss = nn.BCEWithLogitsLoss()
ce_loss = nn.CrossEntropyLoss()

train(model, opt, bce_loss, ce_loss, device, a1_class_path, max_patient=5)
print('A1+class training completed.')
Ejemplo n.º 20
0
            data, target = sam[0].to(device=device,dtype=torch.float), sam[1].to(device=device,dtype=torch.float)
            out = model(data)
            loss = bce_loss(out,target) * 100.0
            test_loss.append(loss.item())
            mask.append(target.cpu().detach().numpy())
            pred.append(out.cpu().detach().numpy())
        avg_test_loss = sum(test_loss) / len(test_loss)
        print('Test loss = {:.{prec}f}'.format(avg_test_loss, prec=4))
    return mask,pred


# In[ ]:


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
optimal_model = smp.DeepLabV3Plus(encoder_name='resnet101').to(device)
checkpoint = torch.load(a1_path)
optimal_model.load_state_dict(checkpoint['model_dict'])
optimal_model.eval()
bce_loss = nn.BCEWithLogitsLoss()

mask,pred = test(optimal_model,test_loader)
gt_map = np.concatenate(mask, axis=0)
pred_map = np.concatenate(pred, axis=0)


# In[ ]:


pred_map_sig = nn.Sigmoid()(torch.tensor(pred_map))
threshold = 0.5
folder_path, file_name = os.path.split(save_path)
# if is not ospath.exists(folder_path)
os.makedirs(folder_path, exist_ok=True)

# model = smp.FPN(
#     encoder_name=ENCODER,
#     encoder_weights=ENCODER_WEIGHTS,
#     classes=len(CLASSES),
#     activation=ACTIVATION,
# )

model = smp.DeepLabV3Plus(
    encoder_name=ENCODER,
    encoder_weights=ENCODER_WEIGHTS,
    classes=len(CLASSES),
    activation=ACTIVATION,
    in_channels=4,
)

# model = smp.Unet(
#     encoder_name=ENCODER,
#     encoder_weights=ENCODER_WEIGHTS,
#     classes=len(CLASSES),
#     activation=ACTIVATION,
#     in_channels = 4,
# )

##model = torch.load('/media/xingshi2/data/matting/Unet3/vgg_current_model.pth')
if os.path.exists(init_pre_path):
    model1 = torch.load(init_pre_path)
Ejemplo n.º 22
0
                          batch_size=1,
                          shuffle=False,
                          num_workers=4)

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

loss = sm.utils.losses.DiceLoss()

metrics = [
    sm.utils.metrics.Accuracy(threshold=0.5),
]

model = sm.DeepLabV3Plus('resnet34',
                         classes=1,
                         decoder_channels=512,
                         in_channels=3,
                         activation='sigmoid',
                         encoder_weights='imagenet')
optimizer = torch.optim.Adam([
    dict(params=model.parameters(), lr=0.0001),
])

train_epoch = sm.utils.train.TrainEpoch(
    model,
    loss=loss,
    metrics=metrics,
    optimizer=optimizer,
    device=device,
)

valid_epoch = sm.utils.train.ValidEpoch(
Ejemplo n.º 23
0
valid_dataset = CloudDataset(path = path, df=train, datatype='valid', img_ids=valid_ids, transforms = utils.get_validation_augmentation(), preprocessing = utils.get_preprocessing(preprocessing_fn))

train_loader = DataLoader(train_dataset, batch_size=bs, shuffle=True, num_workers=num_workers)
valid_loader = DataLoader(valid_dataset, batch_size=bs, shuffle=False, num_workers=num_workers)

loaders = {
    "train": train_loader,
    "valid": valid_loader
}

print("setting for training...")

ACTIVATION = None
model = smp.DeepLabV3Plus(
    encoder_name=ENCODER, 
    encoder_weights=ENCODER_WEIGHTS, 
    classes=4, 
    activation=ACTIVATION,
)
wandb.watch(model)

num_epochs = 50
logdir = "./logs/segmentation"

# model, criterion, optimizer
optimizer = torch.optim.Adam([
    {'params': model.decoder.parameters(), 'lr': 1e-2}, 
    {'params': model.encoder.parameters(), 'lr': 1e-3},  
])
scheduler = ReduceLROnPlateau(optimizer, factor=0.15, patience=2)
loss = BCEDiceLoss() # or DiceLoss()
metrics = [
 def __init__(self, num_classes=12):
     super(SMP_DeepLabV3Plus_efficientnet_b1, self).__init__()
     self.seg_model = smp.DeepLabV3Plus(encoder_name="efficientnet-b1",
                                       encoder_weights="imagenet",
                                       in_channels=3,
                                       classes=12)
 def __init__(self, num_classes=12):
     super(SMP_DeepLabV3Plus_timm_resnest101e, self).__init__()
     self.seg_model = smp.DeepLabV3Plus(encoder_name="timm-resnest101e",
                                       encoder_weights="imagenet",
                                       in_channels=3,
                                       classes=12)
Ejemplo n.º 26
0
            print('model saved')
            m=0
        else:
            m+=1
            if m >= max_patient:
                return
                
    return


# In[11]:


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = smp.DeepLabV3Plus(encoder_name='resnet101')


# In[ ]:


lr=0.001
model = model.to(device)

opt = Adam(model.parameters(),lr=lr)
bce_loss = nn.BCEWithLogitsLoss()

train(model,opt,bce_loss,device, a1_path)
print('A1 training completed.')

 def __init__(self, num_classes=12):
     super(SMP_DeepLabV3Plus_se_resnext101_32x4d, self).__init__()
     self.seg_model = smp.DeepLabV3Plus(encoder_name="se_resnext101_32x4d",
                                       encoder_weights="imagenet",
                                       in_channels=3,
                                       classes=12)
Ejemplo n.º 28
0
def test_calculate_metric(iter_nums):
    if args.net == 'unet':
        # timm-efficientnet performs slightly worse.
        if not args.vis_mode:
            backbone_type = re.sub("^eff", "efficientnet", args.backbone_type)
            net = smp.Unet(backbone_type, classes=args.num_classes, encoder_weights='imagenet')
        else:
            net = VanillaUNet(n_channels=3, num_classes=args.num_classes)
    elif args.net == 'unet-scratch':
        # net = UNet(num_classes=args.num_classes)
        net = VanillaUNet(n_channels=3, num_classes=args.num_classes, 
                          use_polyformer=args.polyformer_mode,
                          num_modes=args.num_modes)
    elif args.net == 'nestedunet':
        net = NestedUNet(num_classes=args.num_classes)
    elif args.net == 'unet3plus':
        net = UNet_3Plus(num_classes=args.num_classes)
    elif args.net == 'pranet':
        net = PraNet(num_classes=args.num_classes - 1)
    elif args.net == 'attunet':
        net = AttU_Net(output_ch=args.num_classes)
    elif args.net == 'attr2unet':
        net = R2AttU_Net(output_ch=args.num_classes)
    elif args.net == 'dunet':
        net = DeformableUNet(n_channels=3, n_classes=args.num_classes)
    elif args.net == 'setr':
        # Install mmcv first: 
        # pip install mmcv-full==1.2.2 -f https://download.openmmlab.com/mmcv/dist/cu{Your CUDA Version}/torch{Your Pytorch Version}/index.html
        # E.g.: pip install mmcv-full==1.2.2 -f https://download.openmmlab.com/mmcv/dist/cu102/torch1.7.1/index.html
        from mmcv.utils import Config
        sys.path.append("networks/setr")
        from mmseg.models import build_segmentor
        task2config = { 'refuge': 'SETR_PUP_288x288_10k_refuge_context_bs_4.py', 
                        'polyp':  'SETR_PUP_320x320_10k_polyp_context_bs_4.py' }
        setr_cfg = Config.fromfile("networks/setr/configs/SETR/{}".format(task2config[args.task_name]))
        net = build_segmentor(setr_cfg.model, train_cfg=setr_cfg.train_cfg, test_cfg=setr_cfg.test_cfg)        
        # By default, net() calls forward_train(), which receives extra parameters, and returns losses.
        # net.forward_dummy() receives/returns the traditional input/output.
        # Relevant file: mmseg/models/segmentors/encoder_decoder.py
        net.forward = net.forward_dummy
    elif args.net == 'transunet':
        transunet_config = TransUNet_CONFIGS[args.backbone_type]
        transunet_config.n_classes = args.num_classes
        if args.backbone_type.find('R50') != -1:
            # The "patch" in TransUNet means grid-like patches of the input image.
            # The "patch" in our code means the whole input image after cropping/resizing (part of the augmentation)
            transunet_config.patches.grid = (int(args.patch_size[0] / transunet_config.patches.size[0]), 
                                             int(args.patch_size[1] / transunet_config.patches.size[1]))
        net = TransUNet(transunet_config, img_size=args.patch_size, num_classes=args.num_classes)
        
    elif args.net.startswith('deeplab'):
        use_smp_deeplab = args.net.endswith('smp')
        if use_smp_deeplab:
            backbone_type = re.sub("^eff", "efficientnet", args.backbone_type)
            net = smp.DeepLabV3Plus(backbone_type, classes=args.num_classes, encoder_weights='imagenet')
        else:
            model_name = args.net + "_" + args.backbone_type
            model_map = {
                'deeplabv3_resnet50':       deeplab.deeplabv3_resnet50,
                'deeplabv3plus_resnet50':   deeplab.deeplabv3plus_resnet50,
                'deeplabv3_resnet101':      deeplab.deeplabv3_resnet101,
                'deeplabv3plus_resnet101':  deeplab.deeplabv3plus_resnet101,
                'deeplabv3_mobilenet':      deeplab.deeplabv3_mobilenet,
                'deeplabv3plus_mobilenet':  deeplab.deeplabv3plus_mobilenet
            }
            net = model_map[model_name](num_classes=args.num_classes, output_stride=8)
            
    elif args.net == 'nnunet':
        from nnunet.network_architecture.initialization import InitWeights_He
        from nnunet.network_architecture.generic_UNet import Generic_UNet
        net = Generic_UNet(
                            input_channels=3,
                            base_num_features=32,
                            num_classes=args.num_classes,
                            num_pool=7,
                            num_conv_per_stage=2,
                            feat_map_mul_on_downscale=2,
                            norm_op=nn.InstanceNorm2d,
                            norm_op_kwargs={'eps': 1e-05, 'affine': True},
                            dropout_op_kwargs={'p': 0, 'inplace': True},
                            nonlin_kwargs={'negative_slope': 0.01, 'inplace': True},
                            final_nonlin=(lambda x: x),
                            weightInitializer=InitWeights_He(1e-2),
                            pool_op_kernel_sizes=[[2, 2], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2]],
                            conv_kernel_sizes=([[3, 3], [3, 3], [3, 3], [3, 3], [3, 3], [3, 3], [3, 3], [3, 3]]),
                            upscale_logits=False,
                            convolutional_pooling=True,
                            convolutional_upsampling=True,
                           )
        net.inference_apply_nonlin = (lambda x: F.softmax(x, 1))
                        
    elif args.net == 'segtran':
        get_default(args, 'num_modes',  default_settings, -1,   [args.net, 'num_modes', args.in_fpn_layers])
        set_segtran2d_config(args)
        print(args)
        net = Segtran2d(config)
    else:
        breakpoint()

    net.cuda()
    net.eval()

    if args.robust_ref_cp_path:
        refnet = copy.deepcopy(net)
        print("Reference network created")
        load_model(refnet, args, args.robust_ref_cp_path)
    else:
        refnet = None
        
    # Currently colormap is used only for OCT task.
    colormap = get_seg_colormap(args.num_classes, return_torch=True).cuda()

    # prepred: pre-prediction. postpred: post-prediction.
    task2mask_prepred   = { 'refuge': refuge_map_mask,      'polyp': polyp_map_mask,
                            'oct': partial(index_to_onehot, num_classes=args.num_classes) }
    task2mask_postpred  = { 'refuge': refuge_inv_map_mask,  'polyp': polyp_inv_map_mask,
                            'oct': partial(onehot_inv_map, colormap=colormap) }
    mask_prepred_mapping_func   =   task2mask_prepred[args.task_name]
    mask_postpred_mapping_funcs = [ task2mask_postpred[args.task_name] ]
    if args.do_remove_frag:
        remove_frag = lambda segmap: remove_fragmentary_segs(segmap, 255)
        mask_postpred_mapping_funcs.append(remove_frag)

    if not args.checkpoint_dir:
        if args.vis_mode is not None:
            visualize_model(net, args.vis_mode, args.vis_layers, args.patch_size, db_test)
            return

        if args.eval_robustness:
            eval_robustness(args, net, refnet, testloader, mask_prepred_mapping_func)
            return

    allcls_avg_metric = None
    all_results = np.zeros((args.num_classes, len(iter_nums)))
    
    for iter_idx, iter_num in enumerate(iter_nums):
        if args.checkpoint_dir:
            checkpoint_path = os.path.join(args.checkpoint_dir, 'iter_' + str(iter_num) + '.pth')
            load_model(net, args, checkpoint_path)

            if args.vis_mode is not None:
                visualize_model(net, args.vis_mode, args.vis_layers, args.patch_size, db_test)
                continue

            if args.eval_robustness:
                eval_robustness(args, net, refnet, testloader, mask_prepred_mapping_func)
                continue

        save_results = args.save_results and (not args.test_interp)
        if save_results:
            test_save_paths = []
            test_save_dirs  = []
            test_save_dir_tmpl  = "%s-%s-%s-%d" %(args.net, args.job_name, timestamp, iter_num)
            for suffix in ("-soft", "-%.1f" %args.mask_thres):
                test_save_dir = test_save_dir_tmpl + suffix
                test_save_path = "../prediction/%s" %(test_save_dir)
                if not os.path.exists(test_save_path):
                    os.makedirs(test_save_path)
                test_save_dirs.append(test_save_dir)
                test_save_paths.append(test_save_path)
        else:
            test_save_paths = None
            test_save_dirs  = None

        if args.save_features_img_count > 0:
            args.save_features_file_path = "%s-%s-feat-%s.pth" %(args.net, args.job_name, timestamp)
        else:
            args.save_features_file_path = None
            
        allcls_avg_metric, allcls_metric_count = \
                test_all_cases(net, testloader, task_name=args.task_name,
                               num_classes=args.num_classes,
                               mask_thres=args.mask_thres,
                               model_type=args.net,
                               orig_input_size=args.orig_input_size,
                               patch_size=args.patch_size,
                               stride=(args.orig_input_size[0] // 2, args.orig_input_size[1] // 2),
                               test_save_paths=test_save_paths,
                               out_origsize=args.out_origsize,
                               mask_prepred_mapping_func=mask_prepred_mapping_func,
                               mask_postpred_mapping_funcs=mask_postpred_mapping_funcs,
                               reload_mask=args.reload_mask,
                               test_interp=args.test_interp,
                               save_features_img_count=args.save_features_img_count,
                               save_features_file_path=args.save_features_file_path,
                               verbose=args.verbose_output)

        print("Iter-%d scores on %d images:" %(iter_num, allcls_metric_count[0]))
        dice_sum = 0
        for cls in range(1, args.num_classes):
            dice = allcls_avg_metric[cls-1]
            print('class %d: dice = %.3f' %(cls, dice))
            dice_sum += dice
            all_results[cls, iter_idx] = dice
        avg_dice = dice_sum / (args.num_classes - 1)
        print("Average dice: %.3f" %avg_dice)

        if args.net == 'segtran':
            max_attn, avg_attn, clamp_count, call_count = \
                [ segtran_shared.__dict__[v] for v in ('max_attn', 'avg_attn', 'clamp_count', 'call_count') ]
            print("max_attn={:.2f}, avg_attn={:.2f}, clamp_count={}, call_count={}".format(
                  max_attn, avg_attn, clamp_count, call_count))

        if save_results:
            FNULL = open(os.devnull, 'w')
            for pred_type, test_save_dir, test_save_path in zip(('soft', 'hard'), test_save_dirs, test_save_paths):
                do_tar = subprocess.run(["tar", "cvf", "%s.tar" %test_save_dir, test_save_dir], cwd="../prediction",
                                        stdout=FNULL, stderr=subprocess.STDOUT)
                # print(do_tar)
                print("{} tarball:\n{}.tar".format(pred_type, os.path.abspath(test_save_path)))

    np.set_printoptions(precision=3, suppress=True)
    print(all_results[1:])
    return allcls_avg_metric
 def __init__(self, num_classes=12):
     super(SMP_DeepLabV3Plus_xception, self).__init__()
     self.seg_model = smp.DeepLabV3Plus(encoder_name="xception",
                                       encoder_weights="imagenet",
                                       in_channels=3,
                                       classes=12)
def main(args, config):
    logger.info("-------------------- Hyperparameters --------------------")
    logger.info(f"Seed: {args.seed}")
    logger.info(f"Loss: {config.loss}")
    logger.info(f"Optimizer: {config.optimizer}")
    logger.info(f"Weight decay: {config.weight_decay}")
    logger.info(f"Epochs: {config.epochs}")
    logger.info(f"Batch size: {config.batch_size}")
    logger.info(f"Learning rate: {config.learning_rate}")
    logger.info("--------------------------------------------------\n")

    logger.info("--------------------------------------------------")
    if torch.cuda.is_available():
        device = torch.device("cuda")
        logger.info(f'There are {torch.cuda.device_count()} GPU(s) available.')
        logger.info(f'We will use the GPU:{torch.cuda.get_device_name(0)}')
    else:
        device = torch.device("cpu")
        logger.info('No GPU available, using the CPU instead.')
    logger.info("--------------------------------------------------\n")
    """ Dataset """
    # Augmentation
    train_transform, val_transform, _ = get_transforms(config.aug)

    # Train dataset
    train_dataset = CustomDataLoader(data_dir=train_data_path,
                                     mode='train',
                                     transform=train_transform)
    # Validation dataset
    val_dataset = CustomDataLoader(data_dir=val_data_path,
                                   mode='val',
                                   transform=val_transform)

    # DataLoader
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=config.batch_size,
                                               shuffle=True,
                                               num_workers=1,
                                               collate_fn=collate_fn)

    val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
                                             batch_size=8,
                                             shuffle=False,
                                             num_workers=1,
                                             collate_fn=collate_fn)
    """ Model """
    if args.api == 'smp':
        if config.model == 'DeepLabV3Plus':
            model = smp.DeepLabV3Plus(encoder_name=config.enc,
                                      encoder_weights=config.enc_weights,
                                      classes=12)
    else:
        if config.model == 'DeepLabV3EffiB7Timm':
            model = DeepLabV3EffiB7Timm(n_classes=12,
                                        n_blocks=[3, 4, 23, 3],
                                        atrous_rates=[6, 12, 18, 24])

    logger.info("-------------------- Model Test --------------------")
    x = torch.randn([2, 3, 512, 512])
    logger.info(f"Input shape : {x.shape}")
    out = model(x).to(device)
    logger.info(f"Output shape: {out.shape}")
    logger.info("--------------------------------------------------\n")

    model = model.to(device)

    # Loss function
    if config.loss == 'CE':
        criterion = nn.CrossEntropyLoss()
    elif config.loss == 'SoftCE':
        criterion = SoftCrossEntropyLoss(smooth_factor=config.smooth_factor)
    elif config.loss == 'Focal':
        criterion = FocalLoss('multiclass', gamma=config.focal_gamma)
    elif config.loss == 'DiceCE':
        pass
    elif config.loss == 'RMI':
        # criterion = RMILoss(num_classes=12, loss_weight_lambda=config.RMI_weight)
        criterion = [
            SoftCrossEntropyLoss(smooth_factor=config.smooth_factor),
            FocalLoss('multiclass', gamma=config.focal_gamma),
            # RMILoss(num_classes=12, rmi_only=True)
            RMILoss(num_classes=12)
        ]
    # TODO: Split 및 getattr() 적용
    elif config.loss == 'SoftCE+Focal+RMI':
        criterion = [
            SoftCrossEntropyLoss(smooth_factor=config.smooth_factor),
            FocalLoss('multiclass', gamma=config.focal_gamma),
            # RMILoss(num_classes=12, rmi_only=True)
            RMILoss(num_classes=12)
        ]
    else:
        raise Exception('[ERROR] Invalid loss')

    # Optimizer
    learning_rate = config.lr_min if config.lr_scheduler == 'SGDR' else config.learning_rate
    if config.optimizer == 'Adam':
        optimizer = torch.optim.Adam(params=model.parameters(),
                                     lr=learning_rate,
                                     weight_decay=config.weight_decay)
    elif config.optimizer == 'AdamP':
        optimizer = AdamP(params=model.parameters(),
                          lr=learning_rate,
                          weight_decay=config.weight_decay)
    else:
        raise Exception('[ERROR] Invalid optimizer')

    # Learning rate scheduler
    if config.lr_scheduler == 'no':
        scheduler = None
    elif config.lr_scheduler == 'SGDR':
        scheduler = CosineAnnealingWarmUpRestarts(optimizer,
                                                  T_0=config.T,
                                                  T_up=config.T_warmup,
                                                  T_mult=config.T_mult,
                                                  eta_max=config.lr_max,
                                                  gamma=config.lr_max_decay)
    else:
        raise Exception('[ERROR] Invalid learning rate scheduler')
    """ Train the model """
    train(epochs=config.epochs,
          model=model,
          data_loader=train_loader,
          val_loader=val_loader,
          criterion=criterion,
          optimizer=optimizer,
          scheduler=scheduler,
          val_every=1,
          device=device)