Example #1
0
def get_model(config):

    if config.MODEL.NAME == 'hrnetv2':
        model = get_hrnetv2()
        print('model: hrnetv2')

    elif config.MODEL.NAME == 'resnet50_upernet':
        model = get_resnet50_upernet()
        print('model: resnet50_upernet')

    elif config.MODEL.NAME == 'resnet101_upernet':
        model = get_resnet101_upernet()
        print('model: resnet101_upernet')

    elif config.MODEL.NAME == 'acnet':
        model = ACNet(num_class=4, pretrained=True)
        print('model: acnet')

    elif config.MODEL.NAME == 'deeplabv3':
        model = get_deeplabv3()
        print('model: deeplabv3')

    elif config.MODEL.NAME == 'deeplab_xception':
        model = DeepLab(backbone='xception',
                        output_stride=16,
                        num_classes=4,
                        sync_bn=False,
                        freeze_bn=False)

    else:
        model_architecture = config.MODEL.ARCHITECTURE
        model_encoder = config.MODEL.ENCODER
        model_pretrained = config.MODEL.PRETRAINED

        if model_architecture == 'Unet':
            model = Unet(model_encoder,
                         encoder_weights=model_pretrained,
                         classes=4,
                         attention_type='scse')
        elif model_architecture == 'Linknet':
            model = Linknet(model_encoder,
                            encoder_weights=model_pretrained,
                            classes=4)
        elif model_architecture == 'FPN' or model_architecture == 'PSPNet':
            model = FPN(model_encoder,
                        encoder_weights=model_pretrained,
                        classes=4)

        print('architecture:', model_architecture, 'encoder:', model_encoder,
              'pretrained on:', model_pretrained)

    if config.PARALLEL:
        model = nn.DataParallel(model)

    print('[*] num parameters:', count_parameters(model))

    return model
Example #2
0
def get_model(config):
    model_architecture = config.ARCHITECTURE
    model_encoder = config.ENCODER

    # activation은 eval 모드일 때 적용해 주는 거라 train 때에는 직접 sigmoid 쳐야한다.
    if model_architecture == 'Unet':
        model = Unet(model_encoder, encoder_weights='imagenet', classes=4, attention_type='scse')
    elif model_architecture == 'FPN':
        model = FPN(model_encoder, encoder_weights='imagenet', classes=4)

    print('architecture:', model_architecture, 'encoder:', model_encoder)

    return model
Example #3
0
def fpn(backbone, pretrained_weights=None, classes=1, activation='sigmoid'):
    device = torch.device("cuda")
    model = FPN(encoder_name=backbone,
                encoder_weights=pretrained_weights,
                classes=classes,
                activation=activation)
    model.to(device)
    model.eval()  # 위치 확인해볼것

    return model
Example #4
0
    limit_val_num_samples=100 if debug else None)

accumulation_steps = 2

prepare_batch = prepare_batch_fp32


# Image denormalization function to plot predictions with images
def img_denormalize(nimg):
    img = denormalize(nimg, mean=mean, std=std)
    return img[(0, 1, 2), :, :]


#################### Model ####################

model = FPN(in_channels=5, encoder_name='se_resnext50_32x4d', classes=2)

#################### Solver ####################

num_epochs = 75

criterion = nn.CrossEntropyLoss(weight=torch.tensor([0.62, 1.45]))

lr = 0.001
weight_decay = 1e-4
optimizer = optim.Adam(model.parameters(), lr=1.0, weight_decay=weight_decay)

le = len(train_loader)


def lambda_lr_scheduler(iteration, lr0, n, a):
    [A.Normalize(mean=mean, std=std, max_pixel_value=max_value),
     ToTensorV2()])

_, data_loader, _ = get_train_val_loaders(
    train_ds,
    val_ds,
    train_transforms=transforms,
    val_transforms=transforms,
    batch_size=batch_size,
    num_workers=num_workers,
    val_batch_size=batch_size,
    pin_memory=True,
)

prepare_batch = inference_prepare_batch_f32

# Image denormalization function to plot predictions with images
img_denormalize = partial(denormalize, mean=mean, std=std)

#################### Model ####################

model = FPN(encoder_name='se_resnext50_32x4d', classes=2, encoder_weights=None)
run_uuid = "5230c20f609646cb9870a211036ea5cb"
weights_filename = "best_model_67_val_miou_bg=0.7574240313552584.pth"

has_targets = True

tta_transforms = tta.Compose([
    tta.Rotate90(angles=[90, -90, 180]),
])
    val_batch_size=val_batch_size,
    pin_memory=True,
    train_sampler=train_sampler,
    limit_train_num_samples=100 if debug else None,
    limit_val_num_samples=100 if debug else None)

accumulation_steps = 3

prepare_batch = prepare_batch_fp32

# Image denormalization function to plot predictions with images
img_denormalize = partial(denormalize, mean=mean, std=std)

#################### Model ####################

model = FPN(encoder_name='efficientnet-b4', classes=num_classes)

#################### Solver ####################

num_epochs = 75

criterion = nn.CrossEntropyLoss(weight=torch.tensor([0.75, 1.5]))

lr = 0.001
weight_decay = 1e-4
optimizer = optim.Adam(model.parameters(), lr=1.0, weight_decay=weight_decay)

le = len(train_loader)


def lambda_lr_scheduler(iteration, lr0, n, a):
    val_batch_size=val_batch_size,
    pin_memory=True,
    train_sampler=train_sampler,
    limit_train_num_samples=100 if debug else None,
    limit_val_num_samples=100 if debug else None)

accumulation_steps = 2

prepare_batch = prepare_batch_fp32

# Image denormalization function to plot predictions with images
img_denormalize = partial(denormalize, mean=mean, std=std)

#################### Model ####################

model = FPN(encoder_name='se_resnext50_32x4d', classes=num_classes)

#################### Solver ####################

num_epochs = 75

criterion = nn.CrossEntropyLoss(weight=torch.tensor([0.75, 1.5]))

lr = 0.001
weight_decay = 1e-4
optimizer = optim.Adam(model.parameters(), lr=1.0, weight_decay=weight_decay)

le = len(train_loader)


def lambda_lr_scheduler(iteration, lr0, n, a):
Example #8
0
def load_model_fpn(_model_weights, is_inference=False):
    print("Using weights {}".format(_model_weights))
    if _model_weights == "imagenet":
        model = FPN(unet_encoder,
                    encoder_weights="imagenet",
                    classes=4,
                    activation=None)
        if is_inference:
            model.eval()
        return model
    else:
        model = FPN(unet_encoder,
                    encoder_weights=None,
                    classes=4,
                    activation=None)
        if is_inference:
            model.eval()
    if _model_weights is not None:
        device = torch.device("cuda")
        model.to(device)
        state = torch.load(
            _model_weights)  # , map_location=lambda storage, loc: storage)
        model.load_state_dict(state["state_dict"])
        # new_state_dict = OrderedDict()
        #
        # for k, v in state['state_dict'].items():
        #     if k in model.state_dict():
        #         new_state_dict[k] = v
        # model = model.load_state_dict(new_state_dict)
    return model
 def __init__(self):
     super(HuBMAPModel, self).__init__()
     self.model = FPN(encoder_name = ENCODER_NAME, 
                       encoder_weights = 'imagenet',
                       classes = 1,
                       activation = None)
unet_resnet34 = unet_resnet34.eval()

device = torch.device("cuda")
model_senet = Unet('se_resnext50_32x4d',
                   encoder_weights=None,
                   classes=4,
                   activation=None)
model_senet.to(device)
model_senet.eval()
state = torch.load(
    '../input/senetmodels/senext50_30_epochs_high_threshold.pth',
    map_location=lambda storage, loc: storage)
model_senet.load_state_dict(state["state_dict"])

model_fpn91lb = FPN(encoder_name="se_resnext50_32x4d",
                    classes=4,
                    activation=None,
                    encoder_weights=None)
model_fpn91lb.to(device)
model_fpn91lb.eval()
#state = torch.load('../input/fpnseresnext/model_se_resnext50_32x4d_fold_0_epoch_7_dice_0.935771107673645.pth', map_location=lambda storage, loc: storage)
state = torch.load(
    '../input/fpnse50dice944/model_se_resnext50_32x4d_fold_0_epoch_26_dice_0.94392.pth',
    map_location=lambda storage, loc: storage)
model_fpn91lb.load_state_dict(state["state_dict"])

model_fpn91lb_pseudo = FPN(encoder_name="se_resnext50_32x4d",
                           classes=4,
                           activation=None,
                           encoder_weights=None)
model_fpn91lb_pseudo.to(device)
model_fpn91lb_pseudo.eval()
Example #11
0
def predict_valid():
    inputdir = "/data/Thoracic_OAR/"

    transform = valid_aug(image_size=512)

    # nii_files = glob.glob(inputdir + "/*/data.nii.gz")

    folds = [0, 1, 2, 3, 4]

    for fold in folds:
        print(fold)
        outdir = f"/data/Thoracic_OAR_predict/FPN-seresnext50/"
        log_dir = f"/logs/ss_miccai/FPN-se_resnext50_32x4d-fold-{fold}"
        # model = VNet(
        #     encoder_name='se_resnext50_32x4d',
        #     encoder_weights=None,
        #     classes=7,
        #     # activation='sigmoid',
        #     group_norm=False,
        #     center='none',
        #     attention_type='scse',
        #     reslink=True,
        #     multi_task=False
        # )

        model = FPN(encoder_name='se_resnext50_32x4d',
                    encoder_weights=None,
                    classes=7)

        ckp = os.path.join(log_dir, "checkpoints/best.pth")
        checkpoint = torch.load(ckp)
        model.load_state_dict(checkpoint['model_state_dict'])
        model = nn.DataParallel(model)
        model = model.to(device)

        df = pd.read_csv(f'./csv/5folds/valid_{fold}.csv')
        patient_ids = df.patient_id.unique()
        for patient_id in patient_ids:
            print(patient_id)
            nii_file = f"{inputdir}/{patient_id}/data.nii.gz"

            image_slices, ct_image = extract_slice(nii_file)
            dataset = TestDataset(image_slices, transform)
            dataloader = DataLoader(dataset=dataset,
                                    num_workers=4,
                                    batch_size=8,
                                    drop_last=False)

            pred_mask, pred_logits = predict(model, dataloader)
            # import pdb
            # pdb.set_trace()
            pred_mask = np.argmax(pred_mask, axis=1).astype(np.uint8)
            pred_mask = SimpleITK.GetImageFromArray(pred_mask)

            pred_mask.SetDirection(ct_image.GetDirection())
            pred_mask.SetOrigin(ct_image.GetOrigin())
            pred_mask.SetSpacing(ct_image.GetSpacing())

            # patient_id = nii_file.split("/")[-2]
            patient_dir = f"{outdir}/{patient_id}"
            os.makedirs(patient_dir, exist_ok=True)
            patient_pred = f"{patient_dir}/predict.nii.gz"
            SimpleITK.WriteImage(pred_mask, patient_pred)
Example #12
0
    val_batch_size=val_batch_size,
    pin_memory=True,
    train_sampler=train_sampler,
    limit_train_num_samples=100 if debug else None,
    limit_val_num_samples=100 if debug else None)

accumulation_steps = 2

prepare_batch = prepare_batch_fp32

# Image denormalization function to plot predictions with images
img_denormalize = partial(denormalize, mean=mean, std=std)

#################### Model ####################

model = FPN(encoder_name='dpn92', classes=num_classes)

#################### Solver ####################

num_epochs = 75

criterion = nn.CrossEntropyLoss(weight=torch.tensor([0.75, 1.5]))

lr = 0.001
weight_decay = 1e-4
optimizer = optim.Adam(model.parameters(), lr=1.0, weight_decay=weight_decay)

le = len(train_loader)


def lambda_lr_scheduler(iteration, lr0, n, a):