Beispiel #1
0
    def predict_loader(
        self,
        loader: DataLoader,
        resume: str = None,
        verbose: bool = False,
        state_kwargs: Dict = None,
        fp16: Union[Dict, bool] = None,
        check: bool = False,
    ):
        loaders = OrderedDict([("infer", loader)])

        callbacks = OrderedDict([("inference", InferCallback())])
        if resume is not None:
            callbacks["loader"] = CheckpointCallback(resume=resume)

        self.infer(model=self.model,
                   loaders=loaders,
                   callbacks=callbacks,
                   verbose=verbose,
                   state_kwargs=state_kwargs,
                   fp16=fp16,
                   check=check)

        output = callbacks["inference"].predictions
        if isinstance(self.output_key, str):
            output = output[self.output_key]

        return output
def valid_model(runner, model, valid_loader, valid_dataset, log_dir):
    encoded_pixels = []
    loaders = {"infer": valid_loader}
    runner.infer(
        model=model,
        loaders=loaders,
        callbacks=[CheckpointCallback(resume=log_dir+'/checkpoints/best.pth'),
                   InferCallback()
                   ],
    )
    valid_masks = []
    probabilities = np.zeros((2220, HEIGHT_TRAIN, WIDTH_TRAIN))
    for i, (batch, output) in enumerate(tqdm(zip(valid_dataset, runner.callbacks[0].predictions["logits"]))):
        image, mask = batch
        for m in mask:
            if m.shape != (HEIGHT_TRAIN, WIDTH_TRAIN):
                m = cv2.resize(m, dsize=(WIDTH_TRAIN, HEIGHT_TRAIN),
                               interpolation=cv2.INTER_LINEAR)
            valid_masks.append(m)

        for j, probability in enumerate(output):
            if probability.shape != (HEIGHT_TRAIN, WIDTH_TRAIN):
                probability = cv2.resize(probability, dsize=(
                    WIDTH_TRAIN, HEIGHT_TRAIN), interpolation=cv2.INTER_LINEAR)
            probabilities[i * 4 + j, :, :] = probability
    return probabilities, valid_masks
Beispiel #3
0
def generate_test_preds(class_params):

    preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
    dummy_dataset = CloudDataset(df=sub, datatype='test', img_ids=test_ids[:1],  transforms=get_validation_augmentation(),
                                preprocessing=get_preprocessing(preprocessing_fn))
    dummy_loader = DataLoader(dummy_dataset, batch_size=1, shuffle=False, num_workers=0)

    model = smp.FPN(
        encoder_name=ENCODER,
        encoder_weights=ENCODER_WEIGHTS,
        classes=4,
        activation=ACTIVATION,
    )
    runner = SupervisedRunner(model)

    # HACK: We are loading a few examples from our dummy loader so catalyst will properly load the weights
    # from our checkpoint
    loaders = {"test": dummy_loader}
    runner.infer(
        model=model,
        loaders=loaders,
        callbacks=[
            CheckpointCallback(
                resume=f"{logdir}/checkpoints/best.pth"),
            InferCallback()
        ],
    )

    # Now we do real inference on the full dataset
    test_dataset = CloudDataset(df=sub, datatype='test', img_ids=test_ids,  transforms=get_validation_augmentation(),
                                preprocessing=get_preprocessing(preprocessing_fn))
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0)

    encoded_pixels = []
    image_id = 0
    for i, test_batch in enumerate(tqdm.tqdm(test_loader)):
        runner_out = runner.predict_batch({"features": test_batch[0].cuda()})['logits'].cpu().detach().numpy()
        for i, batch in enumerate(runner_out):
            for probability in batch:

                # probability = probability.cpu().detach().numpy()
                if probability.shape != (350, 525):
                    probability = cv2.resize(probability, dsize=(525, 350), interpolation=cv2.INTER_LINEAR)
                predict, num_predict = post_process(sigmoid(probability), class_params[image_id % 4][0],
                                                    class_params[image_id % 4][1])
                if num_predict == 0:
                    encoded_pixels.append('')
                else:
                    r = mask2rle(predict)
                    encoded_pixels.append(r)
                image_id += 1

    print("Saving submission...")
    sub['EncodedPixels'] = encoded_pixels
    sub.to_csv('submission.csv', columns=['Image_Label', 'EncodedPixels'], index=False)
    print("Saved.")
Beispiel #4
0
def train(args):
    set_random_seed(42)
    model = get_model(args.network)
    print('Loading model')
    model.encoder.conv1 = nn.Conv2d(
        count_channels(args.channels), 64, kernel_size=(7, 7),
        stride=(2, 2), padding=(3, 3), bias=False)
    model, device = UtilsFactory.prepare_model(model)

    train_df = pd.read_csv(args.train_df).to_dict('records')
    val_df = pd.read_csv(args.val_df).to_dict('records')

    ds = Dataset(args.channels, args.dataset_path, args.image_size, args.batch_size, args.num_workers)
    loaders = ds.create_loaders(train_df, val_df)
    print(loaders['train'].dataset.data)

    criterion = BCE_Dice_Loss(bce_weight=0.2)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=[10, 20, 40], gamma=0.3
    )

    save_path = os.path.join(
        args.logdir,
        '_'.join([args.network, *args.channels])
    )

    # model runner
    runner = SupervisedRunner()

    # model training
    runner.train(
        model=model,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        loaders=loaders,
        callbacks=[
            DiceCallback()
        ],
        logdir=save_path,
        num_epochs=args.epochs,
        verbose=True
    )

    infer_loader = collections.OrderedDict([('infer', loaders['valid'])])
    runner.infer(
        model=model,
        loaders=infer_loader,
        callbacks=[
            CheckpointCallback(resume=f'{save_path}/checkpoints/best.pth'),
            InferCallback()
        ],
    )
Beispiel #5
0
def run_validation(data,
                   valid_path,
                   image_size,
                   batch_size,
                   splits,
                   fold_idx,
                   model,
                   exp_name,
                   labels,
                   ttatype=None):
    logdir = 'logs/{}_fold{}/'.format(exp_name, fold_idx)
    valid_data = data.loc[splits['test_idx'][fold_idx], :]
    model.load_state_dict(
        torch.load(os.path.join(logdir,
                                'checkpoints/best.pth'))['model_state_dict'])
    model.eval()
    if ttatype == 'd4':
        model = tta.TTAWrapper(model, tta.d4_image2label)
    elif ttatype == 'fliplr_image2label':
        model = tta.TTAWrapper(model, tta.d4_image2label)
    runner = SupervisedRunner(model=model)
    val_dataset = EyeDataset(dataset_path=valid_path,
                             labels=data.loc[splits['test_idx'][fold_idx],
                                             labels].values,
                             ids=data.loc[splits['test_idx'][fold_idx],
                                          'id'].values,
                             albumentations_tr=aug_val(image_size))
    val_loader = DataLoader(val_dataset,
                            num_workers=8,
                            pin_memory=False,
                            batch_size=batch_size,
                            shuffle=False)
    loaders = collections.OrderedDict()
    loaders["valid"] = val_loader
    #predictions = runner.predict_loader(loaders["valid"], resume=f"{logdir}/checkpoints/best.pth")
    runner.infer(model=model, loaders=loaders, callbacks=[InferCallback()])
    predictions = runner.callbacks[0].predictions['logits']
    probabilities = softmax(torch.from_numpy(predictions), dim=1).numpy()
    for idx in range(probabilities.shape[0]):
        if all(probabilities[idx, :] < 0.5):
            probabilities[idx, 0] = 1.0
    predicted_labels = pd.DataFrame(probabilities, columns=labels)
    predicted_labels['id'] = data.loc[splits['test_idx'][fold_idx],
                                      'id'].values
    predicted_labels.loc[:, 'group'] = predicted_labels.id.apply(
        lambda x: x.split('_')[0])
    valid_data.loc[:, 'group'] = valid_data.id.apply(lambda x: x.split('_')[0])
    valid_data_groupped = valid_data.groupby(['group']).aggregate(
        dict(zip(labels, ['max'] * (len(labels)))))
    predicted_labels_groupped = predicted_labels.groupby(['group']).aggregate(
        dict(zip(labels, ['max'] * (len(labels)))))
    return (valid_data_groupped, predicted_labels_groupped)
Beispiel #6
0
def generate_valid_preds(args):

    train_ids, valid_ids, logdir = args
    preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, 'imagenet')
    valid_dataset = CloudDataset(
        df=train,
        datatype='valid',
        img_ids=valid_ids,
        transforms=get_validation_augmentation(),
        preprocessing=get_preprocessing(preprocessing_fn))
    valid_loader = DataLoader(valid_dataset,
                              batch_size=1,
                              shuffle=False,
                              num_workers=0)

    model = smp.Unet(
        encoder_name=ENCODER,
        encoder_weights=ENCODER_WEIGHTS,
        classes=4,
        activation=ACTIVATION,
    )

    runner = SupervisedRunner()
    # Generate validation predictions
    loaders = {"infer": valid_loader}
    runner.infer(
        model=model,
        loaders=loaders,
        callbacks=[
            CheckpointCallback(resume=f"{logdir}/checkpoints/best.pth"),
            InferCallback()
        ],
    )

    valid_preds = np.load('data/valid_preds.npy')

    for im_id, preds in zip(valid_ids,
                            runner.callbacks[0].predictions["logits"]):

        preds = preds.transpose((1, 2, 0))
        preds = cv2.resize(preds, (525, 350))
        preds = preds.transpose((2, 0, 1))

        indexes = train.index[train['im_id'] == im_id]
        valid_preds[indexes[0]] = preds[0]  # fish
        valid_preds[indexes[1]] = preds[1]  # flower
        valid_preds[indexes[2]] = preds[2]  # gravel
        valid_preds[indexes[3]] = preds[3]  # sugar

    np.save('data/valid_preds.npy', valid_preds)

    return True
def main():
    test = pd.read_csv(
        '/home/yuko/kaggle_understanding_cloud_organization/src/data_process/data/sample_submission.csv'
    )

    test['label'] = test['Image_Label'].apply(lambda x: x.split('_')[-1])
    test['im_id'] = test['Image_Label'].apply(
        lambda x: x.replace('_' + x.split('_')[-1], ''))

    test['img_label'] = test.EncodedPixels.apply(lambda x: 0
                                                 if x is np.nan else 1)

    img_label = test.groupby('im_id')['img_label'].agg(list).reset_index()

    test_id = np.array(img_label.im_id)

    test_dataset = CloudClassDataset(datatype='test',
                                     img_ids=test_id,
                                     transforms=get_validation_augmentation(),
                                     preprocessing=ort_get_preprocessing())

    test_loader = DataLoader(test_dataset,
                             batch_size=8,
                             shuffle=False,
                             num_workers=16)

    loaders = {"infer": test_loader}

    for fold in range(5):
        runner = SupervisedRunner()
        clf_model = ResNet()

        checkpoint = torch.load(
            f'/home/yuko/kaggle_understanding_cloud_organization/src/class/segmentation/fold_{fold}/checkpoints/best.pth'
        )
        clf_model.load_state_dict(checkpoint['model_state_dict'])
        clf_model.eval()
        runner.infer(
            model=clf_model,
            loaders=loaders,
            callbacks=[InferCallback()],
        )
        callbacks_num = 0
        pred = runner.callbacks[callbacks_num].predictions["logits"]

        df_pred = pd.DataFrame(
            [pred[:, 0], pred[:, 1], pred[:, 2], pred[:, 3]]).T
        df_pred.to_csv(
            f'/home/yuko/kaggle_understanding_cloud_organization/src/class/segmentation/pred_{fold}.csv'
        )
Beispiel #8
0
def train(args):
    set_random_seed(42)
    for fold in range(args.folds):
        model = get_model(args.network)

        print("Loading model")
        model, device = UtilsFactory.prepare_model(model)
        train_df = pd.read_csv(
            os.path.join(args.dataset_path,
                         f'train{fold}.csv')).to_dict('records')
        val_df = pd.read_csv(os.path.join(args.dataset_path,
                                          f'val{fold}.csv')).to_dict('records')

        ds = Dataset(args.channels, args.dataset_path, args.image_size,
                     args.batch_size, args.num_workers)
        loaders = ds.create_loaders(train_df, val_df)

        criterion = BCE_Dice_Loss(bce_weight=0.2)

        optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=[10, 20, 40], gamma=0.3)

        # model runner
        runner = SupervisedRunner()

        save_path = os.path.join(args.logdir, f'fold{fold}')

        # model training
        runner.train(model=model,
                     criterion=criterion,
                     optimizer=optimizer,
                     scheduler=scheduler,
                     loaders=loaders,
                     callbacks=[DiceCallback()],
                     logdir=save_path,
                     num_epochs=args.epochs,
                     verbose=True)

        infer_loader = collections.OrderedDict([("infer", loaders["valid"])])
        runner.infer(
            model=model,
            loaders=infer_loader,
            callbacks=[
                CheckpointCallback(resume=f'{save_path}/checkpoints/best.pth'),
                InferCallback()
            ],
        )

        print(f'Fold {fold} ended')
Beispiel #9
0
    def predict_loader(
        self,
        model: Model,
        loader: DataLoader,
        resume: str = None,
        verbose: bool = False,
        state_kwargs: Dict = None,
        fp16: Union[Dict, bool] = None,
        check: bool = False,
    ) -> Any:
        """
        Makes a prediction on the whole loader with the specified model.

        Args:
            model (Model): model to infer
            loader (DataLoader): dictionary containing only one
                ``torch.utils.data.DataLoader`` for inference
            resume (str): path to checkpoint for model
            verbose (bool): ff true, it displays the status of the inference
                to the console.
            state_kwargs (dict): additional state params to ``RunnerState``
            fp16 (Union[Dict, bool]): If not None, then sets inference to FP16.
                See https://nvidia.github.io/apex/amp.html#properties
                if fp16=True, params by default will be ``{"opt_level": "O1"}``
            check (bool): if True, then only checks that pipeline is working
                (3 epochs only)
        """
        loaders = OrderedDict([("infer", loader)])

        callbacks = OrderedDict([("inference", InferCallback())])
        if resume is not None:
            callbacks["loader"] = CheckpointCallback(resume=resume)

        self.infer(
            model=model,
            loaders=loaders,
            callbacks=callbacks,
            verbose=verbose,
            state_kwargs=state_kwargs,
            fp16=fp16,
            check=check
        )

        output = callbacks["inference"].predictions
        if isinstance(self.output_key, str):
            output = output[self.output_key]

        return output
Beispiel #10
0
def make_predictions(runner, model, loader, y_true):
    runner.infer(
        model=model,
        loaders=loader,
        callbacks=[
            CheckpointCallback(resume=f"{LOG_DIR}/checkpoints/best.pth"),
            InferCallback(),
        ],
        verbose=True)

    y_preds = runner.callbacks[0].predictions['logits'].argmax(1)

    acc = calc_accuracy(y_preds, y_true)
    acc1, acc2 = calc_accuracy_per_cls(y_preds, y_true)
    f1 = calc_f1_score(y_preds, y_true)

    return {'acc': acc, 'acc1': acc1, 'acc2': acc2, 'f1': f1}
Beispiel #11
0
def train(args):
    model = Autoencoder_Unet(encoder_name='resnet50')

    print("Loading model")
    model, device = UtilsFactory.prepare_model(model)

    train_df = pd.read_csv(args.train_df).to_dict('records')
    val_df = pd.read_csv(args.val_df).to_dict('records')
    ds = AutoencoderDataset(args.channels, args.dataset_path, args.image_size,
                            args.batch_size, args.num_workers)
    loaders = ds.create_loaders(train_df, val_df)

    criterion = MSELoss()

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=1e-3,
                                 weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=[10, 20, 40],
                                                     gamma=0.3)

    # model runner
    runner = SupervisedRunner()

    # model training
    runner.train(model=model,
                 criterion=criterion,
                 optimizer=optimizer,
                 callbacks=[],
                 scheduler=scheduler,
                 loaders=loaders,
                 logdir=args.logdir,
                 num_epochs=args.epochs,
                 verbose=True)

    infer_loader = collections.OrderedDict([("infer", loaders["valid"])])
    runner.infer(
        model=model,
        loaders=infer_loader,
        callbacks=[
            CheckpointCallback(resume=f"{args.logdir}/checkpoints/best.pth"),
            InferCallback()
        ],
    )
    def prepare_callbacks(*, args, mode, stage=None, **kwargs):
        callbacks = collections.OrderedDict()

        if mode == "train":
            if stage == "debug":
                callbacks["stage"] = StageCallback()
                callbacks["loss"] = LossCallback(
                    emb_l2_reg=kwargs.get("emb_l2_reg", -1))
                callbacks["optimizer"] = OptimizerCallback(
                    grad_clip=kwargs.get("grad_clip", None))
                callbacks["metrics"] = BaseMetrics()
                callbacks["lr-finder"] = LRFinder(
                    final_lr=kwargs.get("final_lr", 0.1),
                    n_steps=kwargs.get("n_steps", None))
                callbacks["logger"] = Logger()
                callbacks["tflogger"] = TensorboardLogger()
            else:
                callbacks["stage"] = StageCallback()
                callbacks["loss"] = LossCallback(
                    emb_l2_reg=kwargs.get("emb_l2_reg", -1))
                callbacks["optimizer"] = OptimizerCallback(
                    grad_clip=kwargs.get("grad_clip", None))
                callbacks["metrics"] = BaseMetrics()
                callbacks["map"] = MapKCallback(
                    map_args=kwargs.get("map_args", [3]))
                callbacks["saver"] = CheckpointCallback(save_n_best=getattr(
                    args, "save_n_best", 7),
                                                        resume=args.resume)

                # Pytorch scheduler callback
                callbacks["scheduler"] = SchedulerCallback(
                    reduce_metric="map03")

                callbacks["logger"] = Logger()
                callbacks["tflogger"] = TensorboardLogger()
        elif mode == "infer":
            callbacks["saver"] = CheckpointCallback(resume=args.resume)
            callbacks["infer"] = InferCallback(out_prefix=args.out_prefix)
        else:
            raise NotImplementedError

        return callbacks
Beispiel #13
0
 def inference(self, model):
     if not os.path.exists((os.path.join(self.__rez_dir, "mask"))):
         os.mkdir(os.path.join(self.__rez_dir, "mask"))
     if not os.path.exists((os.path.join(self.__rez_dir, "overlay"))):
         os.mkdir(os.path.join(self.__rez_dir, "overlay"))
     model = model
     loaders = self.__get_data()
     runner = SupervisedRunner()
     runner.infer(
         model=model,
         loaders=loaders,
         verbose=True,
         callbacks=[
             CheckpointCallback(resume=os.path.join(
                 self.__logs_dir, "checkpoints/best.pth")),
             InferCallback(),
         ],
     )
     sigmoid = lambda x: 1 / (1 + np.exp(-x))
     for i, (input, output) in enumerate(
             zip(self.loader.image_list,
                 runner.callbacks[1].predictions["logits"])):
         threshold = self.threshold
         output = sigmoid(output)
         image_path = input
         file_name = image_path[0].split("/")[-1]
         image = cv2.imread(image_path[0])
         canvas = (output[0] > threshold).astype(np.uint8) * 255
         canvas = np.squeeze(canvas)
         original_height, original_width = image.shape[:2]
         canvas = CenterCrop(p=1,
                             height=original_height,
                             width=original_width)(image=canvas)["image"]
         canvas = np.reshape(canvas, list(canvas.shape) + [1])
         overlay = make_overlay(image, canvas)
         cv2.imwrite(os.path.join(self.__rez_dir, "mask", file_name),
                     canvas)
         cv2.imwrite(os.path.join(self.__rez_dir, "overlay", file_name),
                     overlay)
Beispiel #14
0
 def inference(self, model):
     if not os.path.exists((os.path.join(self.__rez_dir, "mask"))):
         os.mkdir(os.path.join(self.__rez_dir, "mask"))
     if not os.path.exists((os.path.join(self.__rez_dir, "overlay"))):
         os.mkdir(os.path.join(self.__rez_dir, "overlay"))
     model = model
     loaders = self.__get_data()
     runner = SupervisedRunner()
     runner.infer(
         model=model,
         loaders=loaders,
         verbose=True,
         callbacks=[
             CheckpointCallback(resume=os.path.join(
                 self.__logs_dir, "checkpoints/best.pth")),
             InferCallback(),
         ],
     )
     sigmoid = lambda x: 1 / (1 + np.exp(-x))
     for i, (input, output) in enumerate(
             zip(self.loader.image_list,
                 runner.callbacks[1].predictions["logits"])):
         threshold = self.threshold
         classes = np.argmax(output, axis=0)
         image_path = input
         file_name = image_path[0].split("/")[-1]
         image = cv2.imread(image_path[0])
         original_height, original_width = image.shape[:2]
         classes_cropped = CenterCrop(
             p=1, height=original_height,
             width=original_width)(image=classes)["image"]
         overlay = output2final(classes_cropped)
         raise
         # cv2.imwrite(os.path.join(self.__rez_dir, "mask",file_name), canvas)
         # cv2.imwrite(os.path.join(self.__rez_dir, "overlay", file_name), overlay)
         plt.imsave(os.path.join(self.__rez_dir, "mask", file_name),
                    classes_cropped)
         plt.imsave(os.path.join(self.__rez_dir, "overlay", file_name),
                    overlay)
Beispiel #15
0
def generate_class_params(i_dont_know_how_to_return_values_without_map):

    preprocessing_fn = smp.encoders.get_preprocessing_fn(
        ENCODER, ENCODER_WEIGHTS)
    valid_dataset = CloudDataset(
        df=train,
        datatype='valid',
        img_ids=valid_ids,
        transforms=get_validation_augmentation(),
        preprocessing=get_preprocessing(preprocessing_fn))
    valid_loader = DataLoader(valid_dataset,
                              batch_size=1,
                              shuffle=False,
                              num_workers=0)

    model = smp.Unet(
        encoder_name=ENCODER,
        encoder_weights=ENCODER_WEIGHTS,
        classes=4,
        activation=ACTIVATION,
    )

    runner = SupervisedRunner()
    # Generate validation predictions
    encoded_pixels = []
    loaders = {"infer": valid_loader}
    runner.infer(
        model=model,
        loaders=loaders,
        callbacks=[
            CheckpointCallback(resume=f"{logdir}/checkpoints/best.pth"),
            InferCallback()
        ],
    )

    valid_masks = []
    probabilities = np.zeros((2220, 350, 525))
    for i, (batch, output) in enumerate(
            tqdm.tqdm(
                zip(valid_dataset,
                    runner.callbacks[0].predictions["logits"]))):
        image, mask = batch
        for m in mask:
            if m.shape != (350, 525):
                m = cv2.resize(m,
                               dsize=(525, 350),
                               interpolation=cv2.INTER_LINEAR)
            valid_masks.append(m)

        for j, probability in enumerate(output):
            if probability.shape != (350, 525):
                probability = cv2.resize(probability,
                                         dsize=(525, 350),
                                         interpolation=cv2.INTER_LINEAR)
            probabilities[i * 4 + j, :, :] = probability

    class_params = {}
    for class_id in range(4):
        print(class_id)
        attempts = []
        for t in range(30, 100, 5):
            t /= 100
            for ms in [1200, 5000, 10000]:
                masks = []
                for i in range(class_id, len(probabilities), 4):
                    probability = probabilities[i]
                    predict, num_predict = post_process(
                        sigmoid(probability), t, ms)
                    masks.append(predict)

                d = []
                for i, j in zip(masks, valid_masks[class_id::4]):
                    if (i.sum() == 0) & (j.sum() == 0):
                        d.append(1)
                    else:
                        d.append(dice(i, j))

                attempts.append((t, ms, np.mean(d)))

        attempts_df = pd.DataFrame(attempts,
                                   columns=['threshold', 'size', 'dice'])

        attempts_df = attempts_df.sort_values('dice', ascending=False)
        print(attempts_df.head())
        best_threshold = attempts_df['threshold'].values[0]
        best_size = attempts_df['size'].values[0]

        class_params[class_id] = (best_threshold, best_size)

    return class_params
Beispiel #16
0
def main():

    parser = argparse.ArgumentParser()
    arg = parser.add_argument
    arg('--seed', type=int, default=1234, help='Random seed')
    arg('--model-name',
        type=str,
        default=Path('seresnext101'),
        help='String model name used for saving')
    arg('--run-root',
        type=Path,
        default=Path('../results'),
        help='Directory for saving model')
    arg('--data-root', type=Path, default=Path('../data'))
    arg('--image-size', type=int, default=224, help='Image size for training')
    arg('--batch-size',
        type=int,
        default=16,
        help='Batch size during training')
    arg('--fold', type=int, default=0, help='Validation fold')
    arg('--n-epochs', type=int, default=10, help='Epoch to run')
    arg('--learning-rate',
        type=float,
        default=1e-3,
        help='Initial learning rate')
    arg('--step', type=int, default=1, help='Current training step')
    arg('--patience', type=int, default=4)
    arg('--criterion', type=str, default='bce', help='Criterion')
    arg('--optimizer', default='Adam', help='Name of the optimizer')
    arg('--continue_train', type=bool, default=False)
    arg('--checkpoint',
        type=str,
        default=Path('../results'),
        help='Checkpoint file path')
    arg('--workers', type=int, default=2)
    arg('--debug', type=bool, default=True)
    args = parser.parse_args()

    set_seed(args.seed)
    """
    
    SET PARAMS
    
    """
    args.debug = True
    ON_KAGGLE = configs.ON_KAGGLE
    N_CLASSES = configs.NUM_CLASSES
    args.image_size = configs.SIZE
    args.data_root = configs.DATA_ROOT
    use_cuda = cuda.is_available()
    fold = args.fold
    num_workers = args.workers
    num_epochs = args.n_epochs
    batch_size = args.batch_size
    learning_rate = args.learning_rate
    """

    LOAD DATA
    
    """
    print(os.listdir(args.data_root))
    folds = pd.read_csv(args.data_root / 'folds.csv')
    train_root = args.data_root / 'train'

    if args.debug:
        folds = folds.head(50)
    train_fold = folds[folds['fold'] != fold]
    valid_fold = folds[folds['fold'] == fold]
    check_fold(train_fold, valid_fold)

    def get_dataloader(df: pd.DataFrame, image_transform) -> DataLoader:
        """
        Calls dataloader to load Imet Dataset
        """
        return DataLoader(
            ImetDataset(train_root, df, image_transform),
            shuffle=True,
            batch_size=batch_size,
            num_workers=num_workers,
        )

    train_loader = get_dataloader(train_fold, image_transform=albu_transform)
    valid_loader = get_dataloader(valid_fold, image_transform=valid_transform)
    print('{} items in train, {} in valid'.format(len(train_loader.dataset),
                                                  len(valid_loader.dataset)))
    loaders = OrderedDict()
    loaders["train"] = train_loader
    loaders["valid"] = valid_loader
    """
    
    MODEL
    
    """
    model = seresnext101(num_classes=N_CLASSES)
    if use_cuda:
        model = model.cuda()

    criterion = nn.BCEWithLogitsLoss()
    optimizer = Adam(model.parameters(), lr=learning_rate)
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                               factor=0.5,
                                               patience=args.patience)
    """
    
    MODEL RUNNER
    
    """
    # call an instance of the model runner
    runner = SupervisedRunner()
    # logs folder
    current_time = datetime.now().strftime('%b%d_%H_%M')
    prefix = f'{current_time}_{args.model_name}'
    logdir = os.path.join(args.run_root, prefix)
    os.makedirs(logdir, exist_ok=False)

    print('\tTrain session    :', prefix)
    print('\tOn KAGGLE      :', ON_KAGGLE)
    print('\tDebug          :', args.debug)
    print('\tClasses number :', N_CLASSES)
    print('\tModel          :', args.model_name)
    print('\tParameters     :', model.parameters())
    print('\tImage size     :', args.image_size)
    print('\tEpochs         :', num_epochs)
    print('\tWorkers        :', num_workers)
    print('\tLog dir        :', logdir)
    print('\tLearning rate  :', learning_rate)
    print('\tBatch size     :', batch_size)
    print('\tPatience       :', args.patience)

    if args.continue_train:
        state = load_model(model, args.checkpoint)
        epoch = state['epoch']
        step = state['step']
        print('Loaded model weights from {}, epoch {}, step {}'.format(
            args.checkpoint, epoch, step))

    # model training
    runner.train(
        model=model,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        loaders=loaders,
        callbacks=[
            F1ScoreCallback(threshold=0.5),
            #F2ScoreCallback(num_classes=N_CLASSES),
            EarlyStoppingCallback(patience=args.patience, min_delta=0.01)
        ],
        logdir=logdir,
        num_epochs=num_epochs,
        verbose=True)

    # by default it only plots loss, works in IPython Notebooks
    #utils.plot_metrics(logdir=logdir, metrics=["loss", "_base/lr"])
    """
    
    INFERENCE TEST
    
    """
    loaders = OrderedDict([("infer", loaders["train"])])
    runner.infer(
        model=model,
        loaders=loaders,
        callbacks=[
            CheckpointCallback(resume=f"{logdir}/checkpoints/best.pth"),
            InferCallback()
        ],
    )
    print(runner.callbacks[1].predictions["logits"])
Beispiel #17
0
def generate_test_preds(args):

    valid_dice, class_params, = args

    test_preds = np.zeros((len(sub), 350, 525), dtype=np.float32)

    for i in range(NFOLDS):
        logdir = LOG_DIR_BASE + str(i)

        preprocessing_fn = smp.encoders.get_preprocessing_fn(
            ENCODER, ENCODER_WEIGHTS)
        dummy_dataset = CloudDataset(
            df=sub,
            datatype='test',
            img_ids=test_ids[:1],
            transforms=get_validation_augmentation(),
            preprocessing=get_preprocessing(preprocessing_fn))
        dummy_loader = DataLoader(dummy_dataset,
                                  batch_size=1,
                                  shuffle=False,
                                  num_workers=0)

        model = smp.Unet(
            encoder_name=ENCODER,
            encoder_weights=ENCODER_WEIGHTS,
            classes=4,
            activation=ACTIVATION,
        )
        runner = SupervisedRunner(model)

        # HACK: We are loading a few examples from our dummy loader so catalyst will properly load the weights
        # from our checkpoint
        loaders = {"test": dummy_loader}
        runner.infer(
            model=model,
            loaders=loaders,
            callbacks=[
                CheckpointCallback(resume=f"{logdir}/checkpoints/best.pth"),
                InferCallback()
            ],
        )

        # Now we do real inference on the full dataset
        test_dataset = CloudDataset(
            df=sub,
            datatype='test',
            img_ids=test_ids,
            transforms=get_validation_augmentation(),
            preprocessing=get_preprocessing(preprocessing_fn))
        test_loader = DataLoader(test_dataset,
                                 batch_size=1,
                                 shuffle=False,
                                 num_workers=0)

        image_id = 0
        for batch_index, test_batch in enumerate(tqdm.tqdm(test_loader)):
            runner_out = runner.predict_batch(
                {"features":
                 test_batch[0].cuda()})['logits'].cpu().detach().numpy()
            for preds in runner_out:

                preds = preds.transpose((1, 2, 0))
                preds = cv2.resize(
                    preds,
                    (525, 350))  # height and width are backward in cv2...
                preds = preds.transpose((2, 0, 1))

                idx = batch_index * 4
                test_preds[idx + 0] += sigmoid(preds[0]) / NFOLDS  # fish
                test_preds[idx + 1] += sigmoid(preds[1]) / NFOLDS  # flower
                test_preds[idx + 2] += sigmoid(preds[2]) / NFOLDS  # gravel
                test_preds[idx + 3] += sigmoid(preds[3]) / NFOLDS  # sugar

    # Convert ensembled predictions to RLE predictions
    encoded_pixels = []
    for image_id, preds in enumerate(test_preds):

        predict, num_predict = post_process(preds,
                                            class_params[image_id % 4][0],
                                            class_params[image_id % 4][1])
        if num_predict == 0:
            encoded_pixels.append('')
        else:
            r = mask2rle(predict)
            encoded_pixels.append(r)

    print("Saving submission...")
    sub['EncodedPixels'] = encoded_pixels
    sub.to_csv('unet_submission_{}.csv'.format(valid_dice),
               columns=['Image_Label', 'EncodedPixels'],
               index=False)
    print("Saved.")
Beispiel #18
0
             criterion=criterion,
             optimizer=optimizer,
             loaders=loaders,
             logdir=logdir,
             num_epochs=num_epochs,
             check=True)

# # Setup 8 - loader inference

# In[ ]:

from catalyst.dl.callbacks import InferCallback
loaders = collections.OrderedDict([("infer", loaders["train"])])
runner.infer(model=model,
             loaders=loaders,
             callbacks=[InferCallback()],
             check=True)

# In[ ]:

runner.callbacks[0].predictions["logits"].shape

# # Setup 9 - batch inference

# In[ ]:

features, targets = next(iter(loaders["infer"]))

# In[ ]:

features.shape
Beispiel #19
0
def get_optimal_postprocess(loaders=None, runner=None, logdir: str = ''):
    """
    Calculate optimal thresholds for validation data.

    Args:
        loaders: loaders with necessary datasets
        runner: runner
        logdir: directory with model checkpoints

    Returns:

    """
    loaders['infer'] = loaders['valid']

    runner.infer(
        model=runner.model,
        loaders=loaders,
        callbacks=[
            CheckpointCallback(resume=f"{logdir}/checkpoints/best.pth"),
            InferCallback()
        ],
    )
    valid_masks = []
    probabilities = np.zeros((2220, 350, 525))
    for i, (batch, output) in enumerate(
            zip(loaders['infer'].dataset,
                runner.callbacks[0].predictions["logits"])):
        image, mask = batch
        for m in mask:
            if m.shape != (350, 525):
                m = cv2.resize(m,
                               dsize=(525, 350),
                               interpolation=cv2.INTER_LINEAR)
            valid_masks.append(m)

        for j, probability in enumerate(output):
            if probability.shape != (350, 525):
                probability = cv2.resize(probability,
                                         dsize=(525, 350),
                                         interpolation=cv2.INTER_LINEAR)
            probabilities[i * 4 + j, :, :] = probability

    class_params = {}
    for class_id in range(4):
        print(class_id)
        attempts = []
        for t in range(0, 100, 10):
            t /= 100
            for ms in [
                    0, 100, 1000, 5000, 10000, 11000, 14000, 15000, 16000,
                    18000, 19000, 20000, 21000, 23000, 25000, 27000, 30000,
                    50000
            ]:
                masks = []
                for i in range(class_id, len(probabilities), 4):
                    probability = probabilities[i]
                    predict, num_predict = post_process(
                        sigmoid(probability), t, ms)
                    masks.append(predict)

                d = []
                for i, j in zip(masks, valid_masks[class_id::4]):
                    if (i.sum() == 0) & (j.sum() == 0):
                        d.append(1)
                    else:
                        d.append(dice(i, j))

                attempts.append((t, ms, np.mean(d)))

        attempts_df = pd.DataFrame(attempts,
                                   columns=['threshold', 'size', 'dice'])

        attempts_df = attempts_df.sort_values('dice', ascending=False)
        print(attempts_df.head())
        best_threshold = attempts_df['threshold'].values[0]
        best_size = attempts_df['size'].values[0]

        class_params[class_id] = (best_threshold, best_size)

    print(class_params)
    return class_params
def train(args):
    set_random_seed(42)
    if(args.model=='lstm_diff'):
        model = ULSTMNet(count_channels(args.channels), 1, args.image_size)
    elif(args.model=='lstm_decoder'):
        model = Unet_LstmDecoder(count_channels(args.channels), all_masks=args.allmasks)
    else:
        print('Unknown LSTM model. Return to the default model.')
        model = ULSTMNet(count_channels(args.channels), 1, args.image_size)
    
    if torch.cuda.is_available(): model.cuda()
    print('Loading model')

    model, device = UtilsFactory.prepare_model(model)
    print(device)

    optimizer = get_optimizer(args.optimizer, args.lr, model)
    criterion = get_loss(args.loss)    
    
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=[10, 40, 80, 150, 300], gamma=0.2
    )

    save_path = os.path.join(
        args.logdir,
        args.name
    )
    
    os.system(f"mkdir {save_path}")

    train_df = pd.read_csv(args.train_df)
    val_df = pd.read_csv(args.val_df)

    train_dataset = LstmDataset(args.neighbours, train_df, 'train',args.channels, args.dataset_path, args.image_size, args.batch_size, args.allmasks)
    valid_dataset = LstmDataset(args.neighbours, val_df, 'valid',args.channels, args.dataset_path, args.image_size, args.batch_size, args.allmasks)

    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, 
        shuffle=sampler is None, num_workers=args.num_workers, sampler=sampler(train_df))
    valid_loader = DataLoader(valid_dataset, batch_size=1, 
        shuffle=False, num_workers=args.num_workers)

    loaders = collections.OrderedDict()

    loaders['train'] = train_loader
    loaders['valid'] = valid_loader

    runner = SupervisedRunner()
    if args.model_weights_path:
        checkpoint = torch.load(args.model_weights_path, map_location='cpu')
        model.load_state_dict(checkpoint['model_state_dict'])
    
    runner.train(
        model=model,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        loaders=loaders,
        callbacks=[
            DiceCallback()
        ],
        logdir=save_path,
        num_epochs=args.epochs,
        verbose=True
    )

    infer_loader = collections.OrderedDict([('infer', loaders['valid'])])
    runner.infer(
        model=model,
        loaders=infer_loader,
        callbacks=[
            CheckpointCallback(resume=f'{save_path}/checkpoints/best.pth'),
            InferCallback()
        ],
    )

    '''
Beispiel #21
0
    criterion=criterion,
    optimizer=optimizer,
    loaders=loaders,
    logdir=logdir,
    num_epochs=num_epochs,
    check=True
)

# # Setup 8 - loader inference

# In[ ]:

from catalyst.dl.callbacks import InferCallback
loaders = collections.OrderedDict([("infer", loaders["train"])])
runner.infer(
    model=model, loaders=loaders, callbacks=[InferCallback()], check=True
)

# In[ ]:

runner.callbacks[0].predictions["logits"].shape

# # Setup 9 - batch inference

# In[ ]:

features, targets = next(iter(loaders["infer"]))

# In[ ]:

features.shape
Beispiel #22
0
def training(train_ids, valid_ids, num_split, encoder, decoder):
    """
    模型训练
    """
    train = "./data/Clouds_Classify/train.csv"

    # Data overview
    train = pd.read_csv(open(train))
    train.head()

    train['label'] = train['Image_Label'].apply(lambda x: x.split('_')[1])
    train['im_id'] = train['Image_Label'].apply(lambda x: x.split('_')[0])

    ENCODER = encoder
    ENCODER_WEIGHTS = 'imagenet'

    if decoder == 'unet':
        model = smp.Unet(
            encoder_name=ENCODER,
            encoder_weights=ENCODER_WEIGHTS,
            classes=4,
            activation=None,
        )
    else:
        model = smp.FPN(
            encoder_name=ENCODER,
            encoder_weights=ENCODER_WEIGHTS,
            classes=4,
            activation=None,
        )
    preprocessing_fn = smp.encoders.get_preprocessing_fn(
        ENCODER, ENCODER_WEIGHTS)

    num_workers = 4
    bs = 12
    train_dataset = CloudDataset(
        df=train,
        transforms=get_training_augmentation(),
        datatype='train',
        img_ids=train_ids,
        preprocessing=get_preprocessing(preprocessing_fn))
    valid_dataset = CloudDataset(
        df=train,
        transforms=get_validation_augmentation(),
        datatype='valid',
        img_ids=valid_ids,
        preprocessing=get_preprocessing(preprocessing_fn))

    train_loader = DataLoader(train_dataset,
                              batch_size=bs,
                              shuffle=True,
                              num_workers=num_workers)
    valid_loader = DataLoader(valid_dataset,
                              batch_size=bs,
                              shuffle=False,
                              num_workers=num_workers)

    loaders = {"train": train_loader, "valid": valid_loader}

    num_epochs = 50
    logdir = "./logs/log_{}_{}/log_{}".format(encoder, decoder, num_split)

    # model, criterion, optimizer
    optimizer = torch.optim.Adam([
        {
            'params': model.decoder.parameters(),
            'lr': 1e-2
        },
        {
            'params': model.encoder.parameters(),
            'lr': 1e-3
        },
    ])
    scheduler = ReduceLROnPlateau(optimizer, factor=0.35, patience=4)
    criterion = smp.utils.losses.BCEDiceLoss(eps=1.)
    runner = SupervisedRunner()

    runner.train(model=model,
                 criterion=criterion,
                 optimizer=optimizer,
                 scheduler=scheduler,
                 loaders=loaders,
                 callbacks=[DiceCallback()],
                 logdir=logdir,
                 num_epochs=num_epochs,
                 verbose=True)

    # Exploring predictions
    loaders = {"infer": valid_loader}
    runner.infer(
        model=model,
        loaders=loaders,
        callbacks=[
            CheckpointCallback(resume=f"{logdir}/checkpoints/best.pth"),
            InferCallback()
        ],
    )
def main(config):
    opts = config()
    path = opts.path
    train = pd.read_csv(f'{path}/train.csv')
    sub = pd.read_csv(f'{path}/sample_submission.csv')

    n_train = len(os.listdir(f'{path}/train_images'))
    n_test = len(os.listdir(f'{path}/test_images'))

    sub['label'] = sub['Image_Label'].apply(lambda x: x.split('_')[1])
    sub['im_id'] = sub['Image_Label'].apply(lambda x: x.split('_')[0])
    train.loc[train['EncodedPixels'].isnull() == False,
              'Image_Label'].apply(lambda x: x.split('_')[1]).value_counts()
    train.loc[train['EncodedPixels'].isnull() == False, 'Image_Label'].apply(
        lambda x: x.split('_')[0]).value_counts().value_counts()

    train['label'] = train['Image_Label'].apply(lambda x: x.split('_')[1])
    train['im_id'] = train['Image_Label'].apply(lambda x: x.split('_')[0])

    valid_ids = pd.read_csv("csvs/valid_threshold.csv")["img_id"].values
    test_ids = sub['Image_Label'].apply(
        lambda x: x.split('_')[0]).drop_duplicates().values
    #     print(valid_ids)
    ENCODER = opts.backborn
    ENCODER_WEIGHTS = opts.encoder_weights
    DEVICE = 'cuda'

    ACTIVATION = None
    model = get_model(model_type=opts.model_type,
                      encoder=ENCODER,
                      encoder_weights=ENCODER_WEIGHTS,
                      activation=ACTIVATION,
                      n_classes=opts.class_num,
                      task=opts.task,
                      attention_type=opts.attention_type,
                      head='simple',
                      center=opts.center,
                      tta=opts.tta)
    if opts.refine:
        model = get_ref_model(infer_model=model,
                              encoder=opts.ref_backborn,
                              encoder_weights=ENCODER_WEIGHTS,
                              activation=ACTIVATION,
                              n_classes=opts.class_num,
                              preprocess=opts.preprocess,
                              tta=opts.tta)
    model = convert_model(model)
    preprocessing_fn = encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

    encoded_pixels = []
    runner = SupervisedRunner()
    probabilities = np.zeros((2220, 350, 525))

    for i in range(opts.fold_max):
        if opts.refine:
            logdir = f"{opts.logdir}_refine/fold{i}"
        else:
            logdir = f"{opts.logdir}/fold{i}"
        valid_dataset = CloudDataset(
            df=train,
            datatype='valid',
            img_ids=valid_ids,
            transforms=get_validation_augmentation(opts.img_size),
            preprocessing=get_preprocessing(preprocessing_fn))
        valid_loader = DataLoader(valid_dataset,
                                  batch_size=opts.batchsize,
                                  shuffle=False,
                                  num_workers=opts.num_workers)
        loaders = {"infer": valid_loader}
        runner.infer(
            model=model,
            loaders=loaders,
            callbacks=[
                CheckpointCallback(resume=f"{logdir}/checkpoints/best.pth"),
                InferCallback()
            ],
        )
        valid_masks = []
        for i, (batch, output) in enumerate(
                tqdm.tqdm(
                    zip(valid_dataset,
                        runner.callbacks[0].predictions["logits"]))):
            image, mask = batch
            for m in mask:
                if m.shape != (350, 525):
                    m = cv2.resize(m,
                                   dsize=(525, 350),
                                   interpolation=cv2.INTER_LINEAR)
                valid_masks.append(m)

            for j, probability in enumerate(output):
                if probability.shape != (350, 525):
                    probability = cv2.resize(probability,
                                             dsize=(525, 350),
                                             interpolation=cv2.INTER_LINEAR)
                probabilities[i * 4 + j, :, :] += sigmoid(probability)

    probabilities /= opts.fold_max
    if opts.tta:
        np.save(
            f'probabilities/{opts.logdir.split("/")[-1]}_{opts.img_size[0]}x{opts.img_size[1]}_tta_valid.npy',
            probabilities)
    else:
        np.save(
            f'probabilities/{opts.logdir.split("/")[-1]}_{opts.img_size[0]}x{opts.img_size[1]}_valid.npy',
            probabilities)

    torch.cuda.empty_cache()
    gc.collect()

    class_params = {}
    cv_d = []
    for class_id in tqdm.trange(opts.class_num, desc='class_id', leave=False):
        #         print(class_id)
        attempts = []
        for tt in tqdm.trange(0, 100, 10, desc='top_threshold', leave=False):
            tt /= 100
            for bt in tqdm.trange(0,
                                  100,
                                  10,
                                  desc='bot_threshold',
                                  leave=False):
                bt /= 100
                for ms in tqdm.tqdm([
                        0, 100, 1000, 5000, 10000, 11000, 14000, 15000, 16000,
                        18000, 19000, 20000, 21000, 23000, 25000, 27000, 30000,
                        50000
                ],
                                    desc='min_size',
                                    leave=False):
                    masks = []
                    for i in range(class_id, len(probabilities), 4):
                        probability = probabilities[i]
                        predict, num_predict = post_process(
                            probability, tt, ms, bt)

                        masks.append(predict)

                    d = []
                    for i, j in zip(masks, valid_masks[class_id::4]):
                        #                     print(i.shape, j.shape)
                        if (i.sum() == 0) & (j.sum() == 0):
                            d.append(1)
                        else:
                            d.append(dice(i, j))
                    attempts.append((tt, ms, bt, np.mean(d)))

        attempts_df = pd.DataFrame(
            attempts,
            columns=['top_threshold', 'size', 'bottom_threshold', 'dice'])

        attempts_df = attempts_df.sort_values('dice', ascending=False)
        print(attempts_df.head())
        cv_d.append(attempts_df['dice'].values[0])
        best_top_threshold = attempts_df['top_threshold'].values[0]
        best_size = attempts_df['size'].values[0]
        best_bottom_threshold = attempts_df['bottom_threshold'].values[0]

        class_params[class_id] = (best_top_threshold, best_size,
                                  best_bottom_threshold)
    cv_d = np.array(cv_d)
    print("CV Dice:", np.mean(cv_d))
    pathlist = [
        "../input/test_images/" + i.split("_")[0] for i in sub['Image_Label']
    ]

    del masks
    del valid_masks
    del probabilities
    gc.collect()

    ############# predict ###################
    probabilities = np.zeros((n_test, 4, 350, 525))
    for fold in tqdm.trange(opts.fold_max, desc='fold loop'):
        if opts.refine:
            logdir = f"{opts.logdir}_refine/fold{fold}"
        else:
            logdir = f"{opts.logdir}/fold{fold}"


#         loaders = {"test": test_loader}
        test_dataset = CloudDataset(
            df=sub,
            datatype='test',
            img_ids=test_ids,
            transforms=get_validation_augmentation(opts.img_size),
            preprocessing=get_preprocessing(preprocessing_fn))
        test_loader = DataLoader(test_dataset,
                                 batch_size=opts.batchsize,
                                 shuffle=False,
                                 num_workers=opts.num_workers)
        runner_out = runner.predict_loader(
            model,
            test_loader,
            resume=f"{logdir}/checkpoints/best.pth",
            verbose=True)
        for i, batch in enumerate(
                tqdm.tqdm(runner_out, desc='probability loop')):
            for j, probability in enumerate(batch):
                if probability.shape != (350, 525):
                    probability = cv2.resize(probability,
                                             dsize=(525, 350),
                                             interpolation=cv2.INTER_LINEAR)
                probabilities[i, j, :, :] += sigmoid(probability)
        gc.collect()
    probabilities /= opts.fold_max
    if opts.tta:
        np.save(
            f'probabilities/{opts.logdir.split("/")[-1]}_{opts.img_size[0]}x{opts.img_size[1]}_tta_test.npy',
            probabilities)
    else:
        np.save(
            f'probabilities/{opts.logdir.split("/")[-1]}_{opts.img_size[0]}x{opts.img_size[1]}_test.npy',
            probabilities)
    image_id = 0
    print("##################### start post_process #####################")
    for i in tqdm.trange(n_test, desc='post porocess loop'):
        for probability in probabilities[i]:
            predict, num_predict = post_process(probability,
                                                class_params[image_id % 4][0],
                                                class_params[image_id % 4][1],
                                                class_params[image_id % 4][2])
            if num_predict == 0:
                encoded_pixels.append('')
            else:
                black_mask = get_black_mask(pathlist[image_id])
                predict = np.multiply(predict, black_mask)
                r = mask2rle(predict)
                encoded_pixels.append(r)
            image_id += 1
        gc.collect()
    print("##################### Finish post_process #####################")
    #######################################
    sub['EncodedPixels'] = encoded_pixels
    sub.to_csv(
        f'submissions/submission_{opts.logdir.split("/")[-1]}_{opts.img_size[0]}x{opts.img_size[1]}.csv',
        columns=['Image_Label', 'EncodedPixels'],
        index=False)
Beispiel #24
0
def generate_test_preds(ensemble_info):

    test_preds = np.zeros((len(sub), 350, 525), dtype=np.float32)
    num_models = len(ensemble_info)

    for model_info in ensemble_info:

        class_params = model_info['class_params']
        encoder = model_info['encoder']
        model_type = model_info['model_type']
        logdir = model_info['logdir']

        preprocessing_fn = smp.encoders.get_preprocessing_fn(
            encoder, ENCODER_WEIGHTS)

        model = None
        if model_type == 'unet':
            model = smp.Unet(
                encoder_name=encoder,
                encoder_weights=ENCODER_WEIGHTS,
                classes=4,
                activation=ACTIVATION,
            )
        elif model_type == 'fpn':
            model = smp.FPN(
                encoder_name=encoder,
                encoder_weights=ENCODER_WEIGHTS,
                classes=4,
                activation=ACTIVATION,
            )
        else:
            raise NotImplementedError("We only support FPN and UNet")

        runner = SupervisedRunner(model)

        # HACK: We are loading a few examples from our dummy loader so catalyst will properly load the weights
        # from our checkpoint
        dummy_dataset = CloudDataset(
            df=sub,
            datatype='test',
            img_ids=test_ids[:1],
            transforms=get_validation_augmentation(),
            preprocessing=get_preprocessing(preprocessing_fn))
        dummy_loader = DataLoader(dummy_dataset,
                                  batch_size=1,
                                  shuffle=False,
                                  num_workers=0)
        loaders = {"test": dummy_loader}
        runner.infer(
            model=model,
            loaders=loaders,
            callbacks=[
                CheckpointCallback(resume=f"{logdir}/checkpoints/best.pth"),
                InferCallback()
            ],
        )

        # Now we do real inference on the full dataset
        test_dataset = CloudDataset(
            df=sub,
            datatype='test',
            img_ids=test_ids,
            transforms=get_validation_augmentation(),
            preprocessing=get_preprocessing(preprocessing_fn))
        test_loader = DataLoader(test_dataset,
                                 batch_size=1,
                                 shuffle=False,
                                 num_workers=0)

        image_id = 0
        for batch_index, test_batch in enumerate(tqdm.tqdm(test_loader)):
            runner_out = runner.predict_batch(
                {"features":
                 test_batch[0].cuda()})['logits'].cpu().detach().numpy()

            # Applt TTA transforms
            v_flip = test_batch[0].flip(dims=(2, ))
            h_flip = test_batch[0].flip(dims=(3, ))
            v_flip_out = runner.predict_batch({"features": v_flip.cuda()
                                               })['logits'].cpu().detach()
            h_flip_out = runner.predict_batch({"features": h_flip.cuda()
                                               })['logits'].cpu().detach()
            # Undo transforms
            v_flip_out = v_flip_out.flip(dims=(2, )).numpy()
            h_flip_out = h_flip_out.flip(dims=(3, )).numpy()
            # Get average
            tta_avg_out = (v_flip_out + h_flip_out) / 2

            # Combine with original predictions
            beta = 0.4  # From fastai TTA
            runner_out = (beta) * runner_out + (1 - beta) * tta_avg_out

            for preds in runner_out:

                preds = preds.transpose((1, 2, 0))
                preds = cv2.resize(
                    preds,
                    (525, 350))  # height and width are backward in cv2...
                preds = preds.transpose((2, 0, 1))

                idx = batch_index * 4
                test_preds[idx + 0] += sigmoid(preds[0]) / num_models  # fish
                test_preds[idx + 1] += sigmoid(preds[1]) / num_models  # flower
                test_preds[idx + 2] += sigmoid(preds[2]) / num_models  # gravel
                test_preds[idx + 3] += sigmoid(preds[3]) / num_models  # sugar

    # Convert ensembled predictions to RLE predictions
    encoded_pixels = []
    for image_id, preds in enumerate(test_preds):

        predict, num_predict = post_process(preds,
                                            class_params[image_id % 4][0],
                                            class_params[image_id % 4][1])
        if num_predict == 0:
            encoded_pixels.append('')
        else:
            r = mask2rle(predict)
            encoded_pixels.append(r)

    print("Saving submission...")
    sub['EncodedPixels'] = encoded_pixels
    sub.to_csv('ensembled_submission.csv',
               columns=['Image_Label', 'EncodedPixels'],
               index=False)
    print("Saved.")
Beispiel #25
0
def find_class_params(args):
    runner = SupervisedRunner()
    model = create_model(args.encoder_type)
    valid_loader = get_train_val_loaders(args.encoder_type,
                                         batch_size=args.batch_size)['valid']

    encoded_pixels = []
    loaders = {"infer": valid_loader}
    runner.infer(
        model=model,
        loaders=loaders,
        callbacks=[CheckpointCallback(resume=args.ckp),
                   InferCallback()],
    )
    print(runner.callbacks)
    valid_masks = []
    probabilities = np.zeros((2220, 350, 525))
    for i, (batch, output) in enumerate(
            tqdm(
                zip(valid_loader.dataset,
                    runner.callbacks[0].predictions["logits"]))):
        image, mask = batch
        for m in mask:
            if m.shape != (350, 525):
                m = cv2.resize(m,
                               dsize=(525, 350),
                               interpolation=cv2.INTER_LINEAR)
            valid_masks.append(m)

        for j, probability in enumerate(output):
            if probability.shape != (350, 525):
                probability = cv2.resize(probability,
                                         dsize=(525, 350),
                                         interpolation=cv2.INTER_LINEAR)
            probabilities[i * 4 + j, :, :] = probability

    class_params = {}
    for class_id in range(4):
        print(class_id)
        attempts = []
        for t in range(0, 100, 5):
            t /= 100
            #for ms in [0, 100, 1200, 5000, 10000]:
            for ms in [5000, 10000, 15000, 20000, 22500, 25000, 30000]:

                masks = []
                for i in range(class_id, len(probabilities), 4):
                    probability = probabilities[i]
                    predict, num_predict = post_process(
                        sigmoid(probability), t, ms)
                    masks.append(predict)

                d = []
                for i, j in zip(masks, valid_masks[class_id::4]):
                    if (i.sum() == 0) & (j.sum() == 0):
                        d.append(1)
                    else:
                        d.append(dice(i, j))

                attempts.append((t, ms, np.mean(d)))

        attempts_df = pd.DataFrame(attempts,
                                   columns=['threshold', 'size', 'dice'])

        attempts_df = attempts_df.sort_values('dice', ascending=False)
        print(attempts_df.head())
        best_threshold = attempts_df['threshold'].values[0]
        best_size = attempts_df['size'].values[0]

        class_params[class_id] = (best_threshold, best_size)
    print(class_params)
    return class_params, runner
def train(args):
    set_random_seed(42)
    model = get_model(args.network, args.classification_head)
    print('Loading model')

    model.encoder.conv1 = nn.Conv2d(count_channels(args.channels) *
                                    args.neighbours,
                                    64,
                                    kernel_size=(7, 7),
                                    stride=(2, 2),
                                    padding=(3, 3),
                                    bias=False)

    model, device = UtilsFactory.prepare_model(model)

    train_df = pd.read_csv(args.train_df).to_dict('records')
    val_df = pd.read_csv(args.val_df).to_dict('records')

    ds = Dataset(args.channels, args.dataset_path, args.image_size,
                 args.batch_size, args.num_workers, args.neighbours,
                 args.classification_head)
    loaders = ds.create_loaders(train_df, val_df)

    save_path = os.path.join(args.logdir, args.name)

    optimizer = get_optimizer(args.optimizer, args.lr, model)

    if not args.classification_head:
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=[10, 40, 80, 150, 300], gamma=0.1)

        criterion = get_loss(args.loss)

        runner = SupervisedRunner()
        if args.model_weights_path:
            checkpoint = torch.load(args.model_weights_path,
                                    map_location='cpu')
            model.load_state_dict(checkpoint['model_state_dict'])

        runner.train(model=model,
                     criterion=criterion,
                     optimizer=optimizer,
                     scheduler=scheduler,
                     loaders=loaders,
                     callbacks=[DiceCallback()],
                     logdir=save_path,
                     num_epochs=args.epochs,
                     verbose=True)

        infer_loader = collections.OrderedDict([('infer', loaders['valid'])])
        runner.infer(
            model=model,
            loaders=infer_loader,
            callbacks=[
                CheckpointCallback(resume=f'{save_path}/checkpoints/best.pth'),
                InferCallback()
            ],
        )
    else:
        criterion = get_loss('multi')
        net = Model(model,
                    optimizer,
                    criterion,
                    batch_metrics=[
                        classification_head_accuracy, segmentation_head_dice
                    ])
        net = net.to(device)
        net.fit_generator(loaders['train'],
                          loaders['valid'],
                          epochs=args.epochs,
                          callbacks=[
                              ModelCheckpoint(
                                  f'{save_path}/checkpoints/best.pth', ),
                              MultiStepLR(milestones=[10, 40, 80, 150, 300],
                                          gamma=0.1)
                          ])
model = models[args.model.lower()][0](**models[args.model.lower()][1])
encoded_pixels = []
loaders = {"infer": valid_loader}
logdir = f'./logs/{args.model}/fold_{args.fold}'
gc.collect()
runner = SupervisedRunner(model=model,
                          device='cuda',
                          input_key='image',
                          input_target_key='mask')
runner.infer(
    model=model,
    loaders=loaders,
    callbacks=[
        CheckpointCallback(
            resume=f"{logdir}/checkpoints/best.pth"),
        InferCallback()
    ],
    # fp16={"opt_level": "O1"},
)
valid_masks = []
probabilities = np.zeros((2220, 350, 525))
for i, (batch, output) in enumerate(tqdm.tqdm(zip(
        valid_dataset, runner.callbacks[0].predictions["logits"]))):
    gc.collect()
    image, mask = batch
    for m in mask:
        if m.shape != (350, 525):
            m = cv2.resize(m, dsize=(525, 350), interpolation=cv2.INTER_LINEAR)
        valid_masks.append(m)

    for j, probability in enumerate(output):
def main():
    fold_path = args.fold_path
    fold_num = args.fold_num
    model_name = args.model_name
    train_csv = args.train_csv
    sub_csv = args.sub_csv
    encoder = args.encoder
    num_workers = args.num_workers
    batch_size = args.batch_size
    log_path = args.log_path
    is_tta = args.is_tta
    test_batch_size = args.test_batch_size
    attention_type = args.attention_type
    print(log_path)

    train = pd.read_csv(train_csv)
    train['label'] = train['Image_Label'].apply(lambda x: x.split('_')[-1])
    train['im_id'] = train['Image_Label'].apply(lambda x: x.replace('_' + x.split('_')[-1], ''))

    val_fold = pd.read_csv(f'{fold_path}/valid_file_fold_{fold_num}.csv')
    valid_ids = np.array(val_fold.file_name)

    attention_type = None if attention_type == 'None' else attention_type

    encoder_weights = 'imagenet'

    if model_name == 'Unet':
        model = smp.Unet(
            encoder_name=encoder,
            encoder_weights=encoder_weights,
            classes=CLASS,
            activation='softmax',
            attention_type=attention_type,
        )
    if model_name == 'Linknet':
        model = smp.Linknet(
            encoder_name=encoder,
            encoder_weights=encoder_weights,
            classes=CLASS,
            activation='softmax',
        )
    if model_name == 'FPN':
        model = smp.FPN(
            encoder_name=encoder,
            encoder_weights=encoder_weights,
            classes=CLASS,
            activation='softmax',
        )
    if model_name == 'ORG':
        model = Linknet_resnet18_ASPP(
        )


    preprocessing_fn = smp.encoders.get_preprocessing_fn(encoder, encoder_weights)

    valid_dataset = CloudDataset(df=train,
                                 datatype='valid',
                                 img_ids=valid_ids,
                                 transforms=get_validation_augmentation(),
                                 preprocessing=get_preprocessing(preprocessing_fn))

    valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    loaders = {"infer": valid_loader}
    runner = SupervisedRunner()

    checkpoint = torch.load(f"{log_path}/checkpoints/best.pth")
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    transforms = tta.Compose(
        [
            tta.HorizontalFlip(),
            tta.VerticalFlip(),
        ]
    )
    model = tta.SegmentationTTAWrapper(model, transforms)
    runner.infer(
        model=model,
        loaders=loaders,
        callbacks=[InferCallback()],
    )
    callbacks_num = 0

    valid_masks = []
    probabilities = np.zeros((valid_dataset.__len__() * CLASS, IMG_SIZE[0], IMG_SIZE[1]))

    # ========
    # val predict
    #

    for batch in tqdm(valid_dataset):  # クラスごとの予測値
        _, mask = batch
        for m in mask:
            m = resize_img(m)
            valid_masks.append(m)

    for i, output in enumerate(tqdm(runner.callbacks[callbacks_num].predictions["logits"])):
        for j, probability in enumerate(output):
            probability = resize_img(probability)  # 各クラスごとにprobability(予測値)が取り出されている。jは0~3だと思う。
            probabilities[i * CLASS + j, :, :] = probability

    # ========
    # search best size and threshold
    #

    class_params = {}
    for class_id in range(CLASS):
        attempts = []
        for threshold in range(20, 90, 5):
            threshold /= 100
            for min_size in [10000, 15000, 20000]:
                masks = class_masks(class_id, probabilities, threshold, min_size)
                dices = class_dices(class_id, masks, valid_masks)
                attempts.append((threshold, min_size, np.mean(dices)))

        attempts_df = pd.DataFrame(attempts, columns=['threshold', 'size', 'dice'])
        attempts_df = attempts_df.sort_values('dice', ascending=False)

        print(attempts_df.head())

        best_threshold = attempts_df['threshold'].values[0]
        best_size = attempts_df['size'].values[0]

        class_params[class_id] = (best_threshold, best_size)

    # ========
    # gc
    #
    torch.cuda.empty_cache()
    gc.collect()

    # ========
    # predict
    #
    sub = pd.read_csv(sub_csv)
    sub['label'] = sub['Image_Label'].apply(lambda x: x.split('_')[-1])
    sub['im_id'] = sub['Image_Label'].apply(lambda x: x.replace('_' + x.split('_')[-1], ''))

    test_ids = sub['Image_Label'].apply(lambda x: x.split('_')[0]).drop_duplicates().values

    test_dataset = CloudDataset(df=sub,
                                datatype='test',
                                img_ids=test_ids,
                                transforms=get_validation_augmentation(),
                                preprocessing=get_preprocessing(preprocessing_fn))

    encoded_pixels = get_test_encoded_pixels(test_dataset, runner, class_params, test_batch_size)
    sub['EncodedPixels'] = encoded_pixels

    # ========
    # val dice
    #

    val_Image_Label = []
    for i, row in val_fold.iterrows():
        val_Image_Label.append(row.file_name + '_Fish')
        val_Image_Label.append(row.file_name + '_Flower')
        val_Image_Label.append(row.file_name + '_Gravel')
        val_Image_Label.append(row.file_name + '_Sugar')

    val_encoded_pixels = get_test_encoded_pixels(valid_dataset, runner, class_params, test_batch_size)
    val = pd.DataFrame(val_encoded_pixels, columns=['EncodedPixels'])
    val['Image_Label'] = val_Image_Label

    sub.to_csv(f'./sub/sub_{model_name}_fold_{fold_num}_{encoder}.csv', columns=['Image_Label', 'EncodedPixels'], index=False)
    val.to_csv(f'./val/val_{model_name}_fold_{fold_num}_{encoder}.csv', columns=['Image_Label', 'EncodedPixels'], index=False)
Beispiel #29
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--encoder', type=str, default='efficientnet-b0')
    parser.add_argument('--model', type=str, default='unet')
    parser.add_argument('--loc', type=str)
    parser.add_argument('--data_folder', type=str, default='../input/')
    parser.add_argument('--batch_size', type=int, default=2)
    parser.add_argument('--optimize', type=bool, default=False)
    parser.add_argument('--tta_pre', type=bool, default=False)
    parser.add_argument('--tta_post', type=bool, default=False)
    parser.add_argument('--merge', type=str, default='mean')
    parser.add_argument('--min_size', type=int, default=10000)
    parser.add_argument('--thresh', type=float, default=0.5)
    parser.add_argument('--name', type=str)

    args = parser.parse_args()
    encoder = args.encoder
    model = args.model
    loc = args.loc
    data_folder = args.data_folder
    bs = args.batch_size
    optimize = args.optimize
    tta_pre = args.tta_pre
    tta_post = args.tta_post
    merge = args.merge
    min_size = args.min_size
    thresh = args.thresh
    name = args.name

    if model == 'unet':
        model = smp.Unet(encoder_name=encoder,
                         encoder_weights='imagenet',
                         classes=4,
                         activation=None)
    if model == 'fpn':
        model = smp.FPN(
            encoder_name=encoder,
            encoder_weights='imagenet',
            classes=4,
            activation=None,
        )
    if model == 'pspnet':
        model = smp.PSPNet(
            encoder_name=encoder,
            encoder_weights='imagenet',
            classes=4,
            activation=None,
        )
    if model == 'linknet':
        model = smp.Linknet(
            encoder_name=encoder,
            encoder_weights='imagenet',
            classes=4,
            activation=None,
        )

    preprocessing_fn = smp.encoders.get_preprocessing_fn(encoder, 'imagenet')

    test_df = get_dataset(train=False)
    test_df = prepare_dataset(test_df)
    test_ids = test_df['Image_Label'].apply(
        lambda x: x.split('_')[0]).drop_duplicates().values
    test_dataset = CloudDataset(
        df=test_df,
        datatype='test',
        img_ids=test_ids,
        transforms=valid1(),
        preprocessing=get_preprocessing(preprocessing_fn))
    test_loader = DataLoader(test_dataset, batch_size=bs, shuffle=False)

    val_df = get_dataset(train=True)
    val_df = prepare_dataset(val_df)
    _, val_ids = get_train_test(val_df)
    valid_dataset = CloudDataset(
        df=val_df,
        datatype='train',
        img_ids=val_ids,
        transforms=valid1(),
        preprocessing=get_preprocessing(preprocessing_fn))
    valid_loader = DataLoader(valid_dataset, batch_size=bs, shuffle=False)

    model.load_state_dict(torch.load(loc)['model_state_dict'])

    class_params = {
        0: (thresh, min_size),
        1: (thresh, min_size),
        2: (thresh, min_size),
        3: (thresh, min_size)
    }

    if optimize:
        print("OPTIMIZING")
        print(tta_pre)
        if tta_pre:
            opt_model = tta.SegmentationTTAWrapper(
                model,
                tta.Compose([
                    tta.HorizontalFlip(),
                    tta.VerticalFlip(),
                    tta.Rotate90(angles=[0, 180])
                ]),
                merge_mode=merge)
        else:
            opt_model = model
        tta_runner = SupervisedRunner()
        print("INFERRING ON VALID")
        tta_runner.infer(
            model=opt_model,
            loaders={'valid': valid_loader},
            callbacks=[InferCallback()],
            verbose=True,
        )

        valid_masks = []
        probabilities = np.zeros((4 * len(valid_dataset), 350, 525))
        for i, (batch, output) in enumerate(
                tqdm(
                    zip(valid_dataset,
                        tta_runner.callbacks[0].predictions["logits"]))):
            _, mask = batch
            for m in mask:
                if m.shape != (350, 525):
                    m = cv2.resize(m,
                                   dsize=(525, 350),
                                   interpolation=cv2.INTER_LINEAR)
                valid_masks.append(m)

            for j, probability in enumerate(output):
                if probability.shape != (350, 525):
                    probability = cv2.resize(probability,
                                             dsize=(525, 350),
                                             interpolation=cv2.INTER_LINEAR)
                probabilities[(i * 4) + j, :, :] = probability

        print("RUNNING GRID SEARCH")
        for class_id in range(4):
            print(class_id)
            attempts = []
            for t in range(30, 70, 5):
                t /= 100
                for ms in [7500, 10000, 12500, 15000, 175000]:
                    masks = []
                    for i in range(class_id, len(probabilities), 4):
                        probability = probabilities[i]
                        predict, num_predict = post_process(
                            sigmoid(probability), t, ms)
                        masks.append(predict)

                    d = []
                    for i, j in zip(masks, valid_masks[class_id::4]):
                        if (i.sum() == 0) & (j.sum() == 0):
                            d.append(1)
                        else:
                            d.append(dice(i, j))

                    attempts.append((t, ms, np.mean(d)))

            attempts_df = pd.DataFrame(attempts,
                                       columns=['threshold', 'size', 'dice'])

            attempts_df = attempts_df.sort_values('dice', ascending=False)
            print(attempts_df.head())
            best_threshold = attempts_df['threshold'].values[0]
            best_size = attempts_df['size'].values[0]

            class_params[class_id] = (best_threshold, best_size)

        del opt_model
        del tta_runner
        del valid_masks
        del probabilities
    gc.collect()

    if tta_post:
        model = tta.SegmentationTTAWrapper(model,
                                           tta.Compose([
                                               tta.HorizontalFlip(),
                                               tta.VerticalFlip(),
                                               tta.Rotate90(angles=[0, 180])
                                           ]),
                                           merge_mode=merge)
    else:
        model = model
    print(tta_post)

    runner = SupervisedRunner()
    runner.infer(
        model=model,
        loaders={'test': test_loader},
        callbacks=[InferCallback()],
        verbose=True,
    )

    encoded_pixels = []
    image_id = 0

    for i, image in enumerate(tqdm(runner.callbacks[0].predictions['logits'])):
        for i, prob in enumerate(image):
            if prob.shape != (350, 525):
                prob = cv2.resize(prob,
                                  dsize=(525, 350),
                                  interpolation=cv2.INTER_LINEAR)
            predict, num_predict = post_process(sigmoid(prob),
                                                class_params[image_id % 4][0],
                                                class_params[image_id % 4][1])
            if num_predict == 0:
                encoded_pixels.append('')
            else:
                r = mask2rle(predict)
                encoded_pixels.append(r)
            image_id += 1

    test_df['EncodedPixels'] = encoded_pixels
    test_df.to_csv(name, columns=['Image_Label', 'EncodedPixels'], index=False)
Beispiel #30
0
def infer_kfold(
    model_name,
    logbase_dir,
    dataset,
    infer_csv,
    infer_root,
    image_size,
    image_key,
    label_key,
    n_class,
):
    for fold in [0, 1, 2, 3, 4, 5]:
        log_dir = f"{logbase_dir}/{model_name}_320/fold_{fold}/"

        model_function = getattr(models, 'finetune')
        params = {'arch': model_name, 'n_class': n_class}
        model = model_function(params)
        use_tta = False

        loaders = OrderedDict()
        if infer_csv:
            transforms = infer_tta_aug(image_size)
            if use_tta:
                for i, transform in enumerate(transforms):
                    inferset = CsvDataset(csv_file=infer_csv,
                                          root=infer_root,
                                          image_key=image_key,
                                          label_key=label_key,
                                          transform=transform,
                                          mode='infer')

                    infer_loader = DataLoader(dataset=inferset,
                                              num_workers=4,
                                              shuffle=False,
                                              batch_size=32)

                    loaders[f'infer_{i}'] = infer_loader

            else:
                inferset = CsvDataset(csv_file=infer_csv,
                                      root=infer_root,
                                      image_key=image_key,
                                      label_key=label_key,
                                      transform=transforms[0],
                                      mode='infer')

                infer_loader = DataLoader(dataset=inferset,
                                          num_workers=4,
                                          shuffle=False,
                                          batch_size=32)

                loaders[f'{dataset}'] = infer_loader

        all_checkpoints = f"{log_dir}/checkpoints/best.pth"

        callbacks = [
            CheckpointCallback(resume=all_checkpoints),
            InferCallback(out_dir=log_dir, out_prefix="/predicts/")
        ]

        runner = ModelRunner()
        runner.infer(
            model,
            loaders,
            callbacks,
            verbose=True,
        )