Exemple #1
0
 def get_transforms(stage: str = None, mode: str = None):
     if mode == 'train':
         return train_aug()
     elif mode == 'valid':
         return valid_aug()
     elif mode == 'infer':
         return infer_tta_aug()
Exemple #2
0
    def get_datasets(self, stage: str, **kwargs):
        datasets = OrderedDict()
        """
        image_key: 'id'
        label_key: 'attribute_ids'
        """

        image_size = kwargs.get("image_size", 224)
        train_csv = kwargs.get('train_csv', None)
        valid_csv = kwargs.get('valid_csv', None)
        root = kwargs.get('root', None)
        data_csv = kwargs.get('data_csv', None)
        data = kwargs.get('data', '2D')

        if train_csv:
            transform = train_aug(image_size)
            train_set = StructSegTrain2D(
                csv_file=train_csv,
                transform=transform,
            )
            datasets["train"] = train_set

        if valid_csv:
            transform = valid_aug(image_size)
            valid_set = StructSegTrain2D(
                csv_file=valid_csv,
                transform=transform,
            )
            datasets["valid"] = valid_set

        return datasets
Exemple #3
0
    def get_datasets(
            self,
            stage: str,
            train_ct_dir: str,
            train_pred_dir: str,
            valid_ct_dir: str,
            valid_pred_dir: str,
            all_data: bool = False,
            image_size = [224, 224],
            **kwargs
    ):
        datasets = OrderedDict()
        train_set = CHAOSDataset(
            ct_dir=train_ct_dir,
            pred_dir=train_pred_dir,
            transform=train_aug(image_size),
        )

        valid_set = CHAOSDataset(
            ct_dir=valid_ct_dir,
            pred_dir=valid_pred_dir,
            transform=valid_aug(image_size),
        )

        if all_data:
            concat_dataset = ConcatDataset(
                [train_set, valid_set]
            )
            datasets["train"] = concat_dataset
            datasets["valid"] = concat_dataset
        else:
            datasets["train"] = train_set
            datasets["valid"] = valid_set

        return datasets
Exemple #4
0
    def get_datasets(self, stage: str, **kwargs):
        datasets = OrderedDict()
        """
        image_key: 'id'
        label_key: 'attribute_ids'
        """

        image_size = kwargs.get("image_size", 320)
        train_csv = kwargs.get('train_csv', None)
        valid_csv = kwargs.get('valid_csv', None)
        root = kwargs.get('root', None)

        if train_csv:
            transform = train_aug(image_size)
            train_set = SIIMDataset(
                csv_file=train_csv,
                root=root,
                transform=transform,
                mode='train',
            )
            datasets["train"] = train_set

        if valid_csv:
            transform = valid_aug(image_size)
            valid_set = SIIMDataset(
                csv_file=valid_csv,
                root=root,
                transform=transform,
                mode='train',
            )
            datasets["valid"] = valid_set

        return datasets
Exemple #5
0
    def get_datasets(self, stage: str, **kwargs):
        datasets = OrderedDict()

        image_size = kwargs.get("image_size", 320)
        train_data_txt = kwargs.get('train_data_txt', None)
        valid_data_txt = kwargs.get('valid_data_txt', None)
        root = kwargs.get('root', None)

        if train_data_txt:
            transform = train_aug(image_size)
            train_set = IP102Dataset(data_txt=train_data_txt,
                                     transform=transform,
                                     root=root)
            datasets["train"] = train_set

        if valid_data_txt:
            transform = valid_aug(image_size)
            valid_set = IP102Dataset(data_txt=valid_data_txt,
                                     transform=transform,
                                     root=root)
            datasets["valid"] = valid_set

        flower_train = kwargs.get('flower_train', None)
        flower_valid = kwargs.get('flower_valid', None)
        flower_root = kwargs.get('flower_root', None)

        if flower_train:
            transform = train_aug(image_size)
            train_set = FlowerDataset(csv_file=flower_train,
                                      transform=transform,
                                      root=flower_root)
            datasets["train"] = train_set

        if flower_valid:
            transform = valid_aug(image_size)
            valid_set = FlowerDataset(csv_file=flower_valid,
                                      transform=transform,
                                      root=flower_root)
            datasets["valid"] = valid_set

        return datasets
def predict_test():
    inputdir = "../Lung_GTV/"
    outdir = "../Lung_GTV_pred/"

    transform = valid_aug(image_size=512)

    nii_files = glob.glob(inputdir + "/*/data.nii.gz")

    threshold = 0.1
    for nii_file in nii_files:
        # if not '22' in nii_file:
        #     continue
        # print(nii_file)

        image_slices, spacing = extract_slice(nii_file)
        dataset = TestDataset(image_slices, transform)
        dataloader = DataLoader(dataset=dataset,
                                num_workers=4,
                                batch_size=16,
                                drop_last=False)

        pred_mask_all = 0
        folds = [0]
        for fold in folds:
            log_dir = f"../logs/Unet-resnet34-fold-{fold}/"

            model = Unet(encoder_name='resnet34',
                         classes=1,
                         activation='sigmoid')

            ckp = os.path.join(log_dir, "checkpoints/best.pth")
            checkpoint = torch.load(ckp)
            model.load_state_dict(checkpoint['model_state_dict'])
            model = nn.DataParallel(model)
            model = model.to(device)

            pred_mask = predict(model, dataloader)
            pred_mask = (pred_mask > threshold).astype(np.uint8)
            pred_mask = np.transpose(pred_mask, (1, 0, 2, 3))
            pred_mask = pred_mask[0]

            pred_mask_all += pred_mask

        pred_mask_all = pred_mask_all / len(folds)
        pred_mask_all = (pred_mask_all > threshold).astype(np.uint8)
        pred_mask_all = SimpleITK.GetImageFromArray(pred_mask_all)
        pred_mask_all.SetSpacing(spacing)

        patient_id = nii_file.split("/")[-2]
        patient_dir = f"{outdir}/{patient_id}"
        os.makedirs(patient_dir, exist_ok=True)
        patient_pred = f"{patient_dir}/predict.nii.gz"
        SimpleITK.WriteImage(pred_mask_all, patient_pred)
    def get_datasets(self, stage: str, **kwargs):
        datasets = OrderedDict()
        """
        image_key: 'id'
        label_key: 'attribute_ids'
        """

        image_size = kwargs.get("image_size", 320)
        train_csv = kwargs.get('train_csv', None)
        valid_csv = kwargs.get('valid_csv', None)
        sites = kwargs.get('sites', [1])
        channels = kwargs.get('channels', [1, 2, 3, 4, 5, 6])
        site_mode = kwargs.get('site_mode', 'random')
        root = kwargs.get('root', None)
        dataset = kwargs.get('dataset', "non_pseudo")
        if dataset == 'pseudo':
            dataset_function = RecursionCellularPseudo
            print("Using pseudo dataset")
        elif dataset == 'non_pseudo':
            dataset_function = RecursionCellularSite
            print("Using non pseudo dataset")
        elif dataset == 'control':
            dataset_function = RecursionCellularControl
            print("Using Control dataset")
        else:
            raise ("Invalid")

        if train_csv:
            transform = train_aug(image_size)
            train_set = dataset_function(csv_file=train_csv,
                                         root=root,
                                         transform=transform,
                                         mode='train',
                                         sites=sites,
                                         channels=channels,
                                         site_mode=site_mode)
            datasets["train"] = train_set

        if valid_csv:
            transform = valid_aug(image_size)
            valid_set = dataset_function(csv_file=valid_csv,
                                         root=root,
                                         transform=transform,
                                         mode='train',
                                         sites=sites,
                                         channels=channels,
                                         site_mode=site_mode)
            datasets["valid"] = valid_set

        return datasets
Exemple #8
0
 def prepare_transforms(*, mode, stage=None, use_tta=False, **kwargs):
     image_size = kwargs.get("image_size", 256)
     # print(image_size)
     if mode == "train":
         if stage in ["debug", "stage1"]:
             return train_aug(image_size=image_size)
         elif stage == "stage2":
             return train_aug(image_size=image_size)
         else:
             return train_aug(image_size=image_size)
     elif mode == "valid":
         return valid_aug(image_size=image_size)
     elif mode == "infer":
         if use_tta:
             return test_tta(image_size=image_size)
         else:
             return test_aug(image_size=image_size)
    def get_datasets(
            self,
            stage: str,
            train_file: str,
            valid_file: str,
            image_size = [224, 224],
    ):
        datasets = OrderedDict()

        train_set = TemporalMixDataset(
            csv_file=train_file,
            transform=train_aug(image_size),
        )

        valid_set = TemporalMixDataset(
            csv_file=valid_file,
            transform=valid_aug(image_size),
        )
        # from torch.utils.data import ConcatDataset
        # concat_set = ConcatDataset([train_set, valid_set])
        datasets["train"] = train_set
        datasets["valid"] = valid_set

        return datasets
Exemple #10
0
def predict_valid():
    inputdir = "../Lung_GTV/"
    outdir = "../Lung_GTV_val_pred/190917/Unet3D-bs4-0/"

    transform = valid_aug(image_size=512)

    # nii_files = glob.glob(inputdir + "/*/data.nii.gz")
    threshold = 0.5

    folds = [0]

    for fold in folds:
        log_dir = f"../logs/190918/Unet3D-bs4-fold-{fold}"
        model = UNet3D(in_channels=1, out_channels=1, f_maps=64)

        ckp = os.path.join(log_dir, "checkpoints/best.pth")
        checkpoint = torch.load(ckp)
        model.load_state_dict(checkpoint['model_state_dict'])
        model = nn.DataParallel(model)
        model = model.to(device)

        df = pd.read_csv(f'./csv/5folds/valid_{fold}.csv')
        patient_ids = df.patient_id.values
        for patient_id in patient_ids:
            print(patient_id)
            nii_file = f"{inputdir}/{patient_id}/data.nii.gz"

            image_slices, n_slices, ct_image = extract_slice(nii_file)

            # import pdb
            # pdb.set_trace()

            dataset = TestDataset(image_slices, None)
            dataloader = DataLoader(dataset=dataset,
                                    num_workers=4,
                                    batch_size=2,
                                    drop_last=False)

            pred_mask = predict(model, dataloader)

            # pred_mask = torch.FloatTensor(pred_mask)
            # pred_mask = F.upsample(pred_mask, (size, 512, 512), mode='trilinear').detach().cpu().numpy()
            pred_mask = (pred_mask > threshold).astype(np.int16)
            # pred_mask = pred_mask.reshpae(-1, 512, 512)
            pred_mask = np.transpose(pred_mask, (1, 0, 2, 3, 4))
            pred_mask = pred_mask[0]
            pred_mask = pred_mask.reshape(-1, 256, 256)
            count = n_slices - pred_mask.shape[0]
            if count > 0:
                pred_mask = np.concatenate(
                    [pred_mask, pred_mask[-count:, :, :]], axis=0)

            pred_mask = ndimage.zoom(
                pred_mask, (slice_thickness / ct_image.GetSpacing()[-1],
                            1 / down_scale, 1 / down_scale),
                order=3)

            pred_mask = SimpleITK.GetImageFromArray(pred_mask)
            pred_mask.SetDirection(ct_image.GetDirection())
            pred_mask.SetOrigin(ct_image.GetOrigin())
            pred_mask.SetSpacing(ct_image.GetSpacing())

            # patient_id = nii_file.split("/")[-2]
            patient_dir = f"{outdir}/{patient_id}"
            os.makedirs(patient_dir, exist_ok=True)
            patient_pred = f"{patient_dir}/predict.nii.gz"
            SimpleITK.WriteImage(pred_mask, patient_pred)
Exemple #11
0
    for i, f in enumerate(files[1:]):
        # net2 = model.load(f)
        net2 = models.Unet(
            encoder_name="resnet34",
            activation='sigmoid',
            classes=1,
            # center=True
        )
        checkpoint = torch.load(f)
        net2.load_state_dict(checkpoint['model_state_dict'])
        moving_average(net, net2, 1. / (i + 2))

    test_csv = './csv/train_0.csv'
    root = "/raid/data/kaggle/siim/siim256/"
    # img_size = 128
    batch_size = 16
    train_transform = valid_aug()
    train_dataset = SIIMDataset(csv_file=test_csv,
                                root=root,
                                transform=train_transform,
                                mode='train')
    train_dataloader = DataLoader(train_dataset,
                                  shuffle=True,
                                  batch_size=batch_size,
                                  drop_last=True)
    net.cuda()
    bn_update(train_dataloader, net)

    # models.save(net, args.output)
    torch.save({'model_state_dict': net.state_dict()}, args.output)
Exemple #12
0
def predict_valid():
    inputdir = "/data/Thoracic_OAR/"

    transform = valid_aug(image_size=512)

    # nii_files = glob.glob(inputdir + "/*/data.nii.gz")

    folds = [0, 1, 2, 3, 4]

    for fold in folds:
        print(fold)
        outdir = f"/data/Thoracic_OAR_predict/FPN-seresnext50/"
        log_dir = f"/logs/ss_miccai/FPN-se_resnext50_32x4d-fold-{fold}"
        # model = VNet(
        #     encoder_name='se_resnext50_32x4d',
        #     encoder_weights=None,
        #     classes=7,
        #     # activation='sigmoid',
        #     group_norm=False,
        #     center='none',
        #     attention_type='scse',
        #     reslink=True,
        #     multi_task=False
        # )

        model = FPN(encoder_name='se_resnext50_32x4d',
                    encoder_weights=None,
                    classes=7)

        ckp = os.path.join(log_dir, "checkpoints/best.pth")
        checkpoint = torch.load(ckp)
        model.load_state_dict(checkpoint['model_state_dict'])
        model = nn.DataParallel(model)
        model = model.to(device)

        df = pd.read_csv(f'./csv/5folds/valid_{fold}.csv')
        patient_ids = df.patient_id.unique()
        for patient_id in patient_ids:
            print(patient_id)
            nii_file = f"{inputdir}/{patient_id}/data.nii.gz"

            image_slices, ct_image = extract_slice(nii_file)
            dataset = TestDataset(image_slices, transform)
            dataloader = DataLoader(dataset=dataset,
                                    num_workers=4,
                                    batch_size=8,
                                    drop_last=False)

            pred_mask, pred_logits = predict(model, dataloader)
            # import pdb
            # pdb.set_trace()
            pred_mask = np.argmax(pred_mask, axis=1).astype(np.uint8)
            pred_mask = SimpleITK.GetImageFromArray(pred_mask)

            pred_mask.SetDirection(ct_image.GetDirection())
            pred_mask.SetOrigin(ct_image.GetOrigin())
            pred_mask.SetSpacing(ct_image.GetSpacing())

            # patient_id = nii_file.split("/")[-2]
            patient_dir = f"{outdir}/{patient_id}"
            os.makedirs(patient_dir, exist_ok=True)
            patient_pred = f"{patient_dir}/predict.nii.gz"
            SimpleITK.WriteImage(pred_mask, patient_pred)
Exemple #13
0
def predict_valid():
    inputdir = "/data/HaN_OAR/"

    transform = valid_aug(image_size=512)

    # nii_files = glob.glob(inputdir + "/*/data.nii.gz")

    folds = [0]

    for fold in folds:
        outdir = f"/data/predict_task1/Vnet-se_resnext50_32x4d-weighted2-cedice19-cbam-fold-{fold}"
        log_dir = f"/logs/ss_task1/Vnet-se_resnext50_32x4d-weighted2-cedice19-cbam-fold-{fold}"
        model = VNet(
            encoder_name='se_resnext50_32x4d',
            classes=23,
            # activation='sigmoid',
            group_norm=False,
            center='none',
            attention_type='cbam',
            reslink=True,
            multi_task=False)

        ckp = os.path.join(log_dir, "checkpoints/best.pth")
        checkpoint = torch.load(ckp)
        model.load_state_dict(checkpoint['model_state_dict'])
        model = nn.DataParallel(model)
        model = model.to(device)

        df = pd.read_csv(f'./csv/task1_5folds/valid_{fold}.csv')
        patient_ids = df.patient_id.unique()
        for patient_id in patient_ids:
            print(patient_id)
            nii_file = f"{inputdir}/{patient_id}/data.nii.gz"

            # threshold = 0.7

            image_slices, ct_image = extract_slice(nii_file)
            # import pdb
            # pdb.set_trace()
            dataset = TestDataset(image_slices, transform)
            dataloader = DataLoader(dataset=dataset,
                                    num_workers=4,
                                    batch_size=8,
                                    drop_last=False)

            pred_mask, pred_logits = predict(model, dataloader)
            # import pdb
            # pdb.set_trace()
            pred_mask = np.argmax(pred_mask, axis=1).astype(np.uint8)
            pred_mask = SimpleITK.GetImageFromArray(pred_mask)

            pred_mask.SetDirection(ct_image.GetDirection())
            pred_mask.SetOrigin(ct_image.GetOrigin())
            pred_mask.SetSpacing(ct_image.GetSpacing())

            # patient_id = nii_file.split("/")[-2]
            patient_dir = f"{outdir}/{patient_id}"
            os.makedirs(patient_dir, exist_ok=True)
            patient_pred = f"{patient_dir}/predict.nii.gz"
            SimpleITK.WriteImage(pred_mask, patient_pred)
            np.save(f"{patient_dir}/predic_logits.npy", pred_logits)
Exemple #14
0
def predict_valid():
    inputdir = "/data/Thoracic_OAR/"

    transform = valid_aug(image_size=512)

    # nii_files = glob.glob(inputdir + "/*/data.nii.gz")

    folds = [0]

    crop_size = (32, 256, 256)
    xstep = 1
    ystep = 256
    zstep = 256
    num_classes = 7

    for fold in folds:
        print(fold)
        outdir = f"/data/Thoracic_OAR_predict/Unet3D/"
        log_dir = f"/logs/ss_miccai/Unet3D-fold-{fold}"
        model = ResidualUNet3D(in_channels=1, out_channels=num_classes)

        ckp = os.path.join(log_dir, "checkpoints/best.pth")
        checkpoint = torch.load(ckp)
        model.load_state_dict(checkpoint['model_state_dict'])
        model = nn.DataParallel(model)
        model = model.to(device)

        df = pd.read_csv(f'./csv/5folds/valid_{fold}.csv')
        patient_ids = df.patient_id.unique()
        for patient_id in patient_ids:
            print(patient_id)
            nii_file = f"{inputdir}/{patient_id}/data.nii.gz"

            image, ct_image = load_ct_images(nii_file)

            image = (image - LOWER_BOUND) / (UPPER_BOUND - LOWER_BOUND)
            image[image > 1] = 1.
            image[image < 0] = 0.
            image = image.astype(np.float32)
            C, H, W = image.shape

            deep_slices = np.arange(0, C - crop_size[0] + xstep, xstep)
            height_slices = np.arange(0, H - crop_size[1] + ystep, ystep)
            width_slices = np.arange(0, W - crop_size[2] + zstep, zstep)

            whole_pred = np.zeros((num_classes, C, H, W))
            count_used = np.zeros((C, H, W)) + 1e-5

            # no update parameter gradients during testing
            with torch.no_grad():
                for i in tqdm(range(len(deep_slices))):
                    for j in range(len(height_slices)):
                        for k in range(len(width_slices)):
                            deep = deep_slices[i]
                            height = height_slices[j]
                            width = width_slices[k]
                            image_crop = image[deep:deep + crop_size[0],
                                               height:height + crop_size[1],
                                               width:width + crop_size[2]]
                            image_crop = np.expand_dims(image_crop, axis=0)
                            image_crop = np.expand_dims(image_crop, axis=0)
                            image_crop = torch.from_numpy(image_crop).to(
                                device)
                            # import pdb
                            # pdb.set_trace()
                            outputs = model(image_crop)
                            outputs = F.softmax(outputs, dim=1)
                            # ----------------Average-------------------------------
                            whole_pred[:, deep:deep + crop_size[0],
                                       height:height + crop_size[1],
                                       width:width +
                                       crop_size[2]] += outputs.data.cpu(
                                       ).numpy()[0]

                            count_used[deep:deep + crop_size[0],
                                       height:height + crop_size[1],
                                       width:width + crop_size[2]] += 1

            whole_pred = whole_pred / count_used
            pred_mask = np.argmax(whole_pred, axis=0).astype(np.uint8)

            # pred_mask, pred_logits = predict(model, dataloader)
            # # import pdb
            # # pdb.set_trace()
            pred_mask = SimpleITK.GetImageFromArray(pred_mask)

            pred_mask.SetDirection(ct_image.GetDirection())
            pred_mask.SetOrigin(ct_image.GetOrigin())
            pred_mask.SetSpacing(ct_image.GetSpacing())

            # patient_id = nii_file.split("/")[-2]
            patient_dir = f"{outdir}/{patient_id}"
            os.makedirs(patient_dir, exist_ok=True)
            patient_pred = f"{patient_dir}/predict.nii.gz"
            SimpleITK.WriteImage(pred_mask, patient_pred)
Exemple #15
0
    def get_datasets(self, stage: str, **kwargs):
        datasets = OrderedDict()
        """
        image_key: 'id'
        label_key: 'attribute_ids'
        """

        image_size = kwargs.get("image_size", 320)
        train_csv = kwargs.get('train_csv', None)
        valid_csv = kwargs.get('valid_csv', None)
        root = kwargs.get('root', None)

        if train_csv:
            transform = train_aug(image_size)
            train_set = KERCDataset(csv_file=train_csv,
                                    transform=transform,
                                    mode='train',
                                    root=root)
            datasets["train"] = train_set

        if valid_csv:
            transform = valid_aug(image_size)
            valid_set = KERCDataset(csv_file=valid_csv,
                                    transform=transform,
                                    mode='train',
                                    root=root)
            datasets["valid"] = valid_set

        affectnet_train_csv = kwargs.get("affectnet_train_csv", None)
        affectnet_valid_csv = kwargs.get("affectnet_valid_csv", None)
        affectnet_root = kwargs.get("affectnet_root", None)

        if affectnet_train_csv is not None:

            train_dataset = AffectNetDataset(root=affectnet_root,
                                             df_path=affectnet_train_csv,
                                             transform=train_aug(image_size),
                                             mode="train")
            datasets["train"] = train_dataset

        if affectnet_valid_csv is not None:
            valid_dataset = AffectNetDataset(root=affectnet_root,
                                             df_path=affectnet_valid_csv,
                                             transform=valid_aug(image_size),
                                             mode="train")
            datasets["valid"] = valid_dataset
        """
        RAF Database
        """
        raf_train_csv = kwargs.get("raf_train_csv", None)
        raf_valid_csv = kwargs.get("raf_valid_csv", None)

        if raf_train_csv is not None:
            train_dataset = RAFDataset(raf_train_csv,
                                       transform=train_aug(image_size),
                                       mode="train")
            datasets["train"] = train_dataset

        if raf_valid_csv is not None:
            valid_dataset = RAFDataset(
                raf_valid_csv,
                # transform=Experiment.get_transforms(stage=stage, mode='valid'),
                transform=valid_aug(image_size),
                mode="train")
            datasets["valid"] = valid_dataset
        """
        SFEW Database
        """
        sfew_train_csv = kwargs.get("sfew_train_csv", None)
        sfew_valid_csv = kwargs.get("sfew_valid_csv", None)
        sfew_train_root_image = kwargs.get("sfew_train_root_image", None)
        sfew_train_root_mask = kwargs.get("sfew_train_root_mask", None)

        sfew_valid_root_image = kwargs.get("sfew_valid_root_image", None)
        sfew_valid_root_mask = kwargs.get("sfew_valid_root_mask", None)

        if sfew_train_csv is not None:
            train_dataset = SFEWDataset(sfew_train_csv,
                                        root=sfew_train_root_image,
                                        root_mask=sfew_train_root_mask,
                                        transform=train_sfew_aug(image_size),
                                        mode="train")
            datasets["train"] = train_dataset

        if sfew_valid_csv is not None:
            valid_dataset = SFEWDataset(sfew_valid_csv,
                                        root=sfew_valid_root_image,
                                        root_mask=sfew_valid_root_mask,
                                        transform=valid_sfew_aug(image_size),
                                        mode="train")
            datasets["valid"] = valid_dataset
        """
        Temporal dataset
        """
        train_pool = kwargs.get("train_pool", None)
        valid_pool = kwargs.get("valid_pool", None)

        if train_pool is not None:
            train_dataset = EmotiwPoolingFeature(feature_pkl=train_pool,
                                                 mode="train")
            datasets["train"] = train_dataset

        if valid_pool is not None:
            valid_dataset = EmotiwPoolingFeature(feature_pkl=valid_pool,
                                                 mode="train")
            datasets["valid"] = valid_dataset

        return datasets