def main(): args = get_args() image_paths = sorted(args.input_path.glob("*.jpg")) config = py2cfg(args.config_path) args.output_path.mkdir(exist_ok=True, parents=True) if args.visualize: vis_output_path = Path(str(args.output_path) + "_vis") vis_output_path.mkdir(exist_ok=True, parents=True) test_aug = config.test_augmentations model = config.model checkpoint = load_checkpoint(args.checkpoint, { "model.0.": "", "model.": "" }) utils.unpack_checkpoint(checkpoint, model=model) model = nn.Sequential(model, ApplySoftmaxToLogits()) model, _, _, _, device = utils.process_components(model=model) if args.tta == "lr": model = TTAWrapper(model, fliplr_image2mask) elif args.tta == "d4": model = TTAWrapper(model, d4_image2mask) runner = SupervisedRunner(model=model, device=device) with torch.no_grad(): test_loader = DataLoader( TestSegmentationDataset(image_paths, test_aug, factor=config.pad_factor, imread_lib=config.imread_library), batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=True, drop_last=False, ) for input_images in tqdm(test_loader): raw_predictions = runner.predict_batch( {"features": input_images["features"].cuda()})["logits"] image_height, image_width = input_images["features"].shape[2:] pads = input_images["pads"].cpu().numpy() image_ids = input_images["image_id"] _, predictions = raw_predictions.max(1) for i in range(raw_predictions.shape[0]): unpadded_mask = predictions[i].cpu().numpy() if unpadded_mask.shape != (image_height, image_width): unpadded_mask = cv2.resize(unpadded_mask, (image_width, image_height), interpolation=cv2.INTER_NEAREST) mask = unpad(unpadded_mask, pads[i]).astype(np.uint8) mask_name = image_ids[i] + ".png" cv2.imwrite(str(args.output_path / mask_name), mask) if args.visualize: factor = 255 // config.num_classes cv2.imwrite(str(vis_output_path / mask_name), mask * factor)
class BERTClassificationModel: def __init__(self, model_name="cl-tohoku/bert-base-japanese-whole-word-masking", checkpoints_dir=None): """ Text classification model based on Japanese BERT Model. Attributes ---------- model_name : str The BERT model file checkpoints_dir : str The path of trained BERT model dir ------- fit() Train a text classification model. eval() Evaluate the trained model. predict() Predict a label. """ self.runner = SupervisedRunner( input_key=("features", "attention_mask") ) if checkpoints_dir: config_file = f"{checkpoints_dir}/checkpoints/config.pkl" if os.path.exists(config_file): with open(config_file, "rb") as f: self.label2id, self.config = pickle.load(f) self.id2label = {v: k for k, v in self.label2id.items()} num_labels = len(self.label2id) self.max_seq_length = self.config["max_seq_length"] self.batch_size = self.config["batch_size"] self.model_name = self.config["model_name"] self.elapsed_time = self.config["elapsed_time"] self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) self.model = BERTBaseJapaneseModel(self.model_name, num_labels) self.data_for_predict = ClassificationDataset( tokenizer=self.tokenizer, label2id=self.label2id, max_seq_length=self.max_seq_length, texts=["checkpoints"] ) temporary_data = { "temporary": DataLoader( dataset=self.data_for_predict, batch_size=self.batch_size, shuffle=False ) } # Load the trained BERT model self.runner.infer( model=self.model, loaders=temporary_data, resume=f"{checkpoints_dir}/checkpoints/best.pth" ) else: self.model_name = model_name self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.pad_vid = self.tokenizer.vocab["[PAD]"] self.data_for_predict = None def fit(self, train_df, dev_df, batch_size=16, max_seq_length=256, learning_rate=5e-5, epochs=1, log_dir=None, verbose=False): start = time.time() config = { "model_name": self.model_name, "batch_size": batch_size, "max_seq_length": max_seq_length, "learning_rate": learning_rate, "epochs": epochs, "log_dir": log_dir } train_y = train_df[0] train_X = train_df[1] label2id = dict( zip(sorted(set(train_y)), range(len(set(train_y)))) ) self.id2label = {v: k for k, v in label2id.items()} num_labels = len(label2id) self.train_data = ClassificationDataset( tokenizer=self.tokenizer, label2id=label2id, max_seq_length=max_seq_length, texts=train_X, labels=train_y ) dev_y = dev_df[0] dev_X = dev_df[1] self.dev_data = ClassificationDataset( tokenizer=self.tokenizer, label2id=label2id, max_seq_length=max_seq_length, texts=dev_X, labels=dev_y ) train_dev_loaders = { "train": DataLoader( dataset=self.train_data, batch_size=batch_size, shuffle=True ), "valid": DataLoader( dataset=self.dev_data, batch_size=batch_size, shuffle=False ) } model = BERTBaseJapaneseModel(self.model_name, num_labels) criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer) self.runner.train( model=model, criterion=criterion, optimizer=optimizer, scheduler=scheduler, loaders=train_dev_loaders, callbacks=[ AccuracyCallback(num_classes=num_labels), ], fp16=None, logdir=log_dir, num_epochs=epochs, verbose=verbose ) self.elapsed_time = time.time() - start config["elapsed_time"] = self.elapsed_time if os.path.exists(f"{log_dir}/checkpoints"): filename = f"{log_dir}/checkpoints/config.pkl" with open(filename, "wb") as f: pickle.dump([label2id, config], f) def predict(self, text): if self.data_for_predict: x = self.data_for_predict._from_text(text) else: x = self.train_data._from_text(text) x["features"] = x["features"].reshape(1, -1) x["attention_mask"] = x["attention_mask"].reshape(1, -1) logits = self.runner.predict_batch(x)['logits'] pred_id = logits.argmax(axis=1) pred_y = self.id2label[int(pred_id)] return pred_y def eval(self, test_df): test_Y = test_df[0] pred_Y = [self.predict(text) for text in test_df[1]] accuracy = accuracy_score(test_Y, pred_Y) macro_f1 = f1_score(test_Y, pred_Y, average="macro") cr = classification_report(test_Y, pred_Y) eval_metrics = classifiers.EvaluationMetrics( accuracy, macro_f1, cr, self.elapsed_time ) return eval_metrics
tta_runner = SupervisedRunner(model=tta_model, device=utils.get_device(), input_key="image") # In[ ]: infer_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=num_workers) batch = next(iter(infer_loader)) # predict_batch will automatically move the batch to the Runner's device tta_predictions = tta_runner.predict_batch(batch) # shape is `batch_size x channels x height x width` print(tta_predictions["logits"].shape) # Let's see our mask after TTA # In[ ]: threshold = 0.5 image = utils.tensor_to_ndimage(batch["image"][0]) mask_ = tta_predictions["logits"][0, 0].sigmoid() mask = utils.detach(mask_ > threshold).astype("float")