Beispiel #1
0
    def on_batch_end(self):
        super().on_batch_end()
        # log cls & box loss. I'm using a very dirty way to do it by accessing
        # protected attributes of loss class
        self.cls_loss_meter.update(
            utils.to_numpy(self.state.criterion._cls_loss))
        self.box_loss_meter.update(
            utils.to_numpy(self.state.criterion._box_loss))

        # skip creating coco res for train
        if self.state.is_train:
            return

        cls_out, box_out = self.state.output
        res = decode(cls_out, box_out, self.anchors, **self.decode_params)

        # rescale to image size. don't really need to clip after that
        res[..., :4] *= self.batch_ratios.view(-1, 1, 1).to(res)
        # xyxy -> xywh
        res[..., 2:4] = res[..., 2:4] - res[..., :2]

        for batch in range(res.size(0)):
            for one_res in res[batch]:
                if one_res[4].tolist(
                ) < 0.001:  # stop when below this threshold, scores in descending order
                    break
                coco_result = dict(
                    image_id=self.batch_img_ids[batch, 0].tolist(),
                    bbox=one_res[:4].tolist(),
                    score=one_res[4].tolist(),
                    category_id=int(one_res[5].tolist()),
                )
                self.all_results.append(coco_result)
Beispiel #2
0
 def on_batch_end(self):
     _, target = self.state.input
     output = self.state.output
     with amp.autocast(self.state.use_fp16):
         for metric, name in zip(self.metrics, self.metric_names):
             self.state.metric_meters[name].update(
                 utils.to_numpy(metric(output, target).squeeze()))
Beispiel #3
0
    def _make_step(self):
        data, target = self.state.input

        with amp.autocast(self.state.use_fp16):
            output = self.state.model(data)
            loss = self.state.criterion(output, target)
        self.state.output = output

        if self.state.is_train:
            # backward for every batch
            self.state.grad_scaler.scale(loss /
                                         self.accumulate_steps).backward()
            # everything else only before making step
            if self.state.step % self.accumulate_steps == 0:

                if self.gradient_clip_val > 0:
                    self.state.grad_scaler.unscale_(self.state.optimizer)
                    torch.nn.utils.clip_grad_norm_(
                        self.state.model.parameters(), self.gradient_clip_val)

                self.state.grad_scaler.step(self.state.optimizer)
                self.state.grad_scaler.update()
                self.state.optimizer.zero_grad()
            torch.cuda.synchronize()

        # Update loss
        self.state.loss_meter.update(to_numpy(loss))
Beispiel #4
0
    def on_loader_end(self):

        target = torch.cat(self.target)
        output = torch.cat(self.output)
        with amp.autocast(self.state.use_fp16):
            for metric, name in zip(self.metrics, self.metric_names):
                self.state.metric_meters[name].update(
                    utils.to_numpy(metric(output, target).squeeze()))
Beispiel #5
0
    def on_loader_end(self):
        # Reduce collected features
        prediction_features = torch.cat(self.prediction_features).squeeze()
        target_features = torch.cat(self.target_features).squeeze()

        for metric, name in zip(self.metrics, self.metric_names):
            value = to_numpy(metric(prediction_features, target_features))
            self.state.metric_meters[name].update(value)
Beispiel #6
0
def main():
    parser = get_parser()
    parser.add_argument("--no_val", action="store_true", help="Disable validation")
    parser.add_argument("--no_test", action="store_true", help="Disable prediction on test")
    parser.add_argument("--short_predict", default=0, type=int, help="Number of first images to show predict for")
    parser.add_argument("--thr", default=0.5, type=float, help="Threshold for cutting")
    parser.add_argument("--tta", action="store_true", help="Enables TTA")
    FLAGS = parser.parse_args()
    assert os.path.exists(FLAGS.outdir), "You have to pass config after training to inference script"
    # get model
    print("Loading model")
    model = MODEL_FROM_NAME[FLAGS.segm_arch](FLAGS.arch, **FLAGS.model_params)  # .cuda()
    sd = torch.load(os.path.join(FLAGS.outdir, "model.chpn"))["state_dict"]
    model.load_state_dict(sd)
    model = model.cuda().eval()
    if FLAGS.tta:
        model = pt.tta_wrapper.TTA(
            model, segm=True, h_flip=True, rotation=[90], merge="gmean", activation="sigmoid"
        )
    model = apex.amp.initialize(model, verbosity=0)
    print("Loaded model")
    # get validation dataloaders
    val_aug = albu.Compose([albu.CenterCrop(FLAGS.size, FLAGS.size), albu.Normalize(), ToTensor(),])
    val_dtst = OpenCitiesDataset(split="val", transform=val_aug, buildings_only=True)
    val_dtld = DataLoader(val_dtst, batch_size=FLAGS.bs, shuffle=False, num_workers=8)
    val_dtld = ToCudaLoader(val_dtld)

    if not FLAGS.no_val:
        runner = pt.fit_wrapper.Runner(
            model, 
            None, 
            TargetWrapper(pt.losses.JaccardLoss(), "mask"), 
            [
                TargetWrapper(pt.metrics.JaccardScore(), "mask"), 
                TargetWrapper(ThrJaccardScore(thr=FLAGS.thr), "mask"),
            ],
        )
        _, (jacc_score, thr_jacc_score) = runner.evaluate(val_dtld)
        print(f"Validation Jacc Score: {thr_jacc_score:.4f}")

    if FLAGS.no_test:
        return

    # Predict on test
    # for now simply resize it to proper size
    test_aug = get_aug("test", size=FLAGS.size) 

    test_dataset = OpenCitiesTestDataset(transform=test_aug)
    test_loader = DataLoader(test_dataset, batch_size=FLAGS.bs, shuffle=False, num_workers=8, )

    global PREDS_PATH
    global THR
    THR = FLAGS.thr
    PREDS_PATH = Path("data/preds")
    preds_preview_path = Path(FLAGS.outdir) / "preds_preview"
    shutil.rmtree(preds_preview_path, ignore_errors=True)
    PREDS_PATH.mkdir(exist_ok=True)
    preds_preview_path.mkdir(exist_ok=True)
    workers_pool = pool.Pool()
    cnt = 0
    for imgs, aug_imgs, idxs in tqdm(test_loader):
        # aug_img = aug_img.view(1, *aug_img.shape)  # add batch dimension
        preds = model(aug_imgs.cuda())
        if not FLAGS.tta:
            preds = preds.sigmoid()
        preds = to_numpy(preds).squeeze()
        workers_pool.map(save_pred, zip(preds, idxs))

        if FLAGS.short_predict:
            for img, idx, pred in zip(imgs, idxs, preds):
                img2 = to_numpy(img).copy()
                pred = cv2.resize(pred, (1024, 1024))
                img2[(pred > THR).astype(bool)] = [255, 0, 0]
                combined = cv2.cvtColor(np.hstack([img, img2]), cv2.COLOR_RGB2BGR)
                cv2.imwrite(str(preds_preview_path / (idx + ".jpg")), combined)
            if cnt < FLAGS.short_predict:
                cnt += FLAGS.bs
            else:
                break
        # pred = cv2.resize(pred, (1024, 1024))
        # pred = (pred > FLAGS.thr).astype(np.uint8)
        # make copy of the image with houses in red and save them both together to check that it's valid
        # img2 = img.copy()
        # img2[pred.astype(bool)] = [255, 0, 0]
        # combined = cv2.cvtColor(np.hstack([img, img2]), cv2.COLOR_RGB2BGR)
        # only save preview with --short_predict. only save predicts for full test run.
        # if FLAGS.short_predict:
        #     cv2.imwrite(str(preds_preview_path / (idx + ".jpg")), combined)
        #     if imgs_count > 30:
        #         break
        # else:
            # cv2.imwrite(str(preds_path / (idx + ".tif")), pred)
    workers_pool.close()