def main():
    args = get_args()
    image_paths = sorted(args.input_path.glob("*.jpg"))

    config = py2cfg(args.config_path)
    args.output_path.mkdir(exist_ok=True, parents=True)

    if args.visualize:
        vis_output_path = Path(str(args.output_path) + "_vis")
        vis_output_path.mkdir(exist_ok=True, parents=True)

    test_aug = config.test_augmentations

    model = config.model

    checkpoint = load_checkpoint(args.checkpoint, {
        "model.0.": "",
        "model.": ""
    })
    utils.unpack_checkpoint(checkpoint, model=model)

    model = nn.Sequential(model, ApplySoftmaxToLogits())

    model, _, _, _, device = utils.process_components(model=model)

    if args.tta == "lr":
        model = TTAWrapper(model, fliplr_image2mask)
    elif args.tta == "d4":
        model = TTAWrapper(model, d4_image2mask)

    runner = SupervisedRunner(model=model, device=device)

    with torch.no_grad():
        test_loader = DataLoader(
            TestSegmentationDataset(image_paths,
                                    test_aug,
                                    factor=config.pad_factor,
                                    imread_lib=config.imread_library),
            batch_size=args.batch_size,
            num_workers=args.num_workers,
            pin_memory=True,
            drop_last=False,
        )

        for input_images in tqdm(test_loader):
            raw_predictions = runner.predict_batch(
                {"features": input_images["features"].cuda()})["logits"]

            image_height, image_width = input_images["features"].shape[2:]

            pads = input_images["pads"].cpu().numpy()

            image_ids = input_images["image_id"]

            _, predictions = raw_predictions.max(1)

            for i in range(raw_predictions.shape[0]):
                unpadded_mask = predictions[i].cpu().numpy()

                if unpadded_mask.shape != (image_height, image_width):
                    unpadded_mask = cv2.resize(unpadded_mask,
                                               (image_width, image_height),
                                               interpolation=cv2.INTER_NEAREST)

                mask = unpad(unpadded_mask, pads[i]).astype(np.uint8)

                mask_name = image_ids[i] + ".png"
                cv2.imwrite(str(args.output_path / mask_name), mask)
                if args.visualize:
                    factor = 255 // config.num_classes
                    cv2.imwrite(str(vis_output_path / mask_name),
                                mask * factor)
Esempio n. 2
0
class BERTClassificationModel:

    def __init__(self,
                 model_name="cl-tohoku/bert-base-japanese-whole-word-masking",
                 checkpoints_dir=None):

        """
        Text classification model based on Japanese BERT Model.

        Attributes
        ----------
        model_name : str
            The BERT model file
        checkpoints_dir : str
            The path of trained BERT model dir

        -------
        fit()
            Train a text classification model.
        eval()
            Evaluate the trained model.
        predict()
            Predict a label.
        """

        self.runner = SupervisedRunner(
            input_key=("features", "attention_mask")
        )

        if checkpoints_dir:
            config_file = f"{checkpoints_dir}/checkpoints/config.pkl"
            if os.path.exists(config_file):
                with open(config_file, "rb") as f:
                    self.label2id, self.config = pickle.load(f)
                    self.id2label = {v: k for k, v in self.label2id.items()}

                num_labels = len(self.label2id)
                self.max_seq_length = self.config["max_seq_length"]
                self.batch_size = self.config["batch_size"]
                self.model_name = self.config["model_name"]
                self.elapsed_time = self.config["elapsed_time"]
                self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
                self.model = BERTBaseJapaneseModel(self.model_name, num_labels)

                self.data_for_predict = ClassificationDataset(
                    tokenizer=self.tokenizer,
                    label2id=self.label2id,
                    max_seq_length=self.max_seq_length,
                    texts=["checkpoints"]
                )

                temporary_data = {
                    "temporary": DataLoader(
                        dataset=self.data_for_predict,
                        batch_size=self.batch_size,
                        shuffle=False
                    )
                }

                # Load the trained BERT model
                self.runner.infer(
                    model=self.model,
                    loaders=temporary_data,
                    resume=f"{checkpoints_dir}/checkpoints/best.pth"
                )

        else:
            self.model_name = model_name
            self.tokenizer = AutoTokenizer.from_pretrained(model_name)
            self.pad_vid = self.tokenizer.vocab["[PAD]"]
            self.data_for_predict = None

    def fit(self,
            train_df, dev_df,
            batch_size=16, max_seq_length=256, learning_rate=5e-5,
            epochs=1, log_dir=None, verbose=False):

            start = time.time()
            config = {
                "model_name": self.model_name,
                "batch_size": batch_size,
                "max_seq_length": max_seq_length,
                "learning_rate": learning_rate,
                "epochs": epochs,
                "log_dir": log_dir
            }

            train_y = train_df[0]
            train_X = train_df[1]
            label2id = dict(
                zip(sorted(set(train_y)), range(len(set(train_y))))
            )
            self.id2label = {v: k for k, v in label2id.items()}
            num_labels = len(label2id)

            self.train_data = ClassificationDataset(
                tokenizer=self.tokenizer,
                label2id=label2id,
                max_seq_length=max_seq_length,
                texts=train_X,
                labels=train_y
            )

            dev_y = dev_df[0]
            dev_X = dev_df[1]

            self.dev_data = ClassificationDataset(
                tokenizer=self.tokenizer,
                label2id=label2id,
                max_seq_length=max_seq_length,
                texts=dev_X,
                labels=dev_y
            )

            train_dev_loaders = {
                "train": DataLoader(
                    dataset=self.train_data,
                    batch_size=batch_size,
                    shuffle=True
                ),
                "valid": DataLoader(
                    dataset=self.dev_data,
                    batch_size=batch_size,
                    shuffle=False
                )
            }

            model = BERTBaseJapaneseModel(self.model_name, num_labels)
            criterion = torch.nn.CrossEntropyLoss()
            optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)

            self.runner.train(
                model=model,
                criterion=criterion,
                optimizer=optimizer,
                scheduler=scheduler,
                loaders=train_dev_loaders,
                callbacks=[
                    AccuracyCallback(num_classes=num_labels),
                ],
                fp16=None,
                logdir=log_dir,
                num_epochs=epochs,
                verbose=verbose
            )

            self.elapsed_time = time.time() - start
            config["elapsed_time"] = self.elapsed_time

            if os.path.exists(f"{log_dir}/checkpoints"):
                filename = f"{log_dir}/checkpoints/config.pkl"
                with open(filename, "wb") as f:
                    pickle.dump([label2id, config], f)

    def predict(self, text):
        if self.data_for_predict:
            x = self.data_for_predict._from_text(text)
        else:
            x = self.train_data._from_text(text)

        x["features"] = x["features"].reshape(1, -1)
        x["attention_mask"] = x["attention_mask"].reshape(1, -1)
        logits = self.runner.predict_batch(x)['logits']
        pred_id = logits.argmax(axis=1)
        pred_y = self.id2label[int(pred_id)]
        return pred_y

    def eval(self, test_df):
        test_Y = test_df[0]
        pred_Y = [self.predict(text) for text in test_df[1]]

        accuracy = accuracy_score(test_Y, pred_Y)
        macro_f1 = f1_score(test_Y, pred_Y, average="macro")
        cr = classification_report(test_Y, pred_Y)

        eval_metrics = classifiers.EvaluationMetrics(
            accuracy, macro_f1, cr, self.elapsed_time
        )
        return eval_metrics
tta_runner = SupervisedRunner(model=tta_model,
                              device=utils.get_device(),
                              input_key="image")

# In[ ]:

infer_loader = DataLoader(test_dataset,
                          batch_size=1,
                          shuffle=False,
                          num_workers=num_workers)

batch = next(iter(infer_loader))

# predict_batch will automatically move the batch to the Runner's device
tta_predictions = tta_runner.predict_batch(batch)

# shape is `batch_size x channels x height x width`
print(tta_predictions["logits"].shape)

# Let's see our mask after TTA

# In[ ]:

threshold = 0.5

image = utils.tensor_to_ndimage(batch["image"][0])

mask_ = tta_predictions["logits"][0, 0].sigmoid()
mask = utils.detach(mask_ > threshold).astype("float")