コード例 #1
0
ファイル: train.py プロジェクト: canghaiyunfan/kaggle-dstl
 def _log_im(self, xs: np.ndarray, ys: np.ndarray, dist_ys: np.ndarray,
             pred_ys: np.ndarray):
     b = self.hps.patch_border
     s = self.hps.patch_inner
     border = np.zeros([b * 2 + s, b * 2 + s, 3], dtype=np.float32)
     border[b, b:-b, :] = border[-b, b:-b, :] = 1
     border[b:-b, b, :] = border[b:-b, -b, :] = 1
     border[-b, -b, :] = 1
     for i, (x, y, p) in enumerate(zip(xs, ys, pred_ys)):
         fname = lambda s: str(self.logdir / ('{:0>3}_{}.png'.format(i, s)))
         x = utils.scale_percentile(x.transpose(1, 2, 0))
         channels = [x[:, :, :3]]  # RGB
         if x.shape[-1] == 12:
             channels.extend([
                 x[:, :, 4:7],  # M
                 x[:, :, 3:4],  # P (will be shown below RGB)
                 # 7 and 8 from M are skipped
                 x[:, :, 9:12],  # M
             ])
         elif x.shape[-1] == 20:
             channels.extend([
                 x[:, :, 4:7],  # M
                 x[:, :, 6:9],  # M (overlap)
                 x[:, :, 9:12],  # M
                 x[:, :, 3:4],  # P (will be shown below RGB)
                 x[:, :, 12:15],  # A (overlap)
                 x[:, :, 14:17],  # A
                 x[:, :, 17:],  # A
             ])
         channels = [np.maximum(border, ch) for ch in channels]
         if len(channels) >= 4:
             n = len(channels) // 2
             img = np.concatenate([
                 np.concatenate(channels[:n], 1),
                 np.concatenate(channels[n:], 1)
             ], 0)
         else:
             img = np.concatenate(channels, axis=1)
         cv2.imwrite(fname('-x'), img * 255)
         for j, (cls, c_y, c_p) in enumerate(zip(self.hps.classes, y, p)):
             cv2.imwrite(fname('{}-y'.format(cls)), c_y * 255)
             cv2.imwrite(fname('{}-z'.format(cls)), c_p * 255)
             if dist_ys.shape[0]:
                 cv2.imwrite(fname('{}-d'.format(cls)), dist_ys[i, j] * 255)
コード例 #2
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('output', help='output director')
    args = parser.parse_args()

    output = Path(args.output)
    output.mkdir(exist_ok=True)
    poly_stats = {}
    for im_id in sorted(utils.get_wkt_data()):
        print(im_id)
        im_data = utils.load_image(im_id, rgb_only=True)
        im_data = utils.scale_percentile(im_data)
        cv2.imwrite(str(output.joinpath('{}.jpg'.format(im_id))),
                    255 * im_data)
        im_size = im_data.shape[:2]
        poly_by_type = utils.load_polygons(im_id, im_size)
        for poly_type, poly in sorted(poly_by_type.items()):
            cls = poly_type - 1
            mask = utils.mask_for_polygons(im_size, poly)
            cv2.imwrite(
                str(output.joinpath('{}_mask_{}.png'.format(im_id, cls))),
                255 * mask)
            poly_stats.setdefault(im_id, {})[cls] = {
                'area': poly.area / (im_size[0] * im_size[1]),
                'perimeter': int(poly.length),
                'number': len(poly),
            }

    output.joinpath('stats.json').write_text(json.dumps(poly_stats))

    for key in ['number', 'perimeter', 'area']:
        if key == 'area':
            fmt = '{:.4%}'.format
        else:
            fmt = lambda x: x
        print('\n{}'.format(key))
        print(
            tabulate.tabulate(
                [[im_id] + [fmt(s[cls][key]) for cls in range(10)]
                 for im_id, s in sorted(poly_stats.items())],
                headers=['im_id'] + list(range(10))))
コード例 #3
0
def get_poly_data(im_id, *, store: Path, classes: List[int], epsilon: float,
                  min_area: float, min_small_area: float, validation: str,
                  to_fix: Set[str], hps: HyperParams, valid_polygons: bool,
                  buffer: float):
    train_polygons = utils.get_wkt_data().get(im_id)
    jaccard_stats = []
    path = mask_path(store, im_id)
    if path.exists():
        logger.info(im_id)
        with gzip.open(str(path), 'rb') as f:
            try:
                masks = np.load(f)  # type: np.ndarray
            except Exception:
                logger.error('Error loading mask {}'.format(path))
                raise
            if validation == 'square':
                masks = square(masks, hps)
        rows = []
        if validation:
            im_data = utils.load_image(im_id, rgb_only=True)
            im_size = im_data.shape[:2]
            if validation == 'square':
                im_data = square(im_data, hps)
            cv2.imwrite(str(store / '{}_image.png'.format(im_id)),
                        255 * utils.scale_percentile(im_data))
        for cls, mask in zip(classes, masks):
            poly_type = cls + 1
            if train_polygons and not validation:
                rows.append((im_id, str(poly_type), 'MULTIPOLYGON EMPTY'))
            else:
                unscaled, pred_poly = get_polygons(
                    im_id,
                    mask,
                    epsilon,
                    min_area=min_small_area if cls in {1, 8, 9} else min_area,
                    fix='{}_{}'.format(im_id, poly_type) in to_fix,
                    buffer=buffer,
                )
                rows.append((im_id, str(poly_type),
                             shapely.wkt.dumps(pred_poly,
                                               rounding_precision=8)))
                if validation:
                    poly_mask = utils.mask_for_polygons(mask.shape, unscaled)
                    train_poly = shapely.wkt.loads(train_polygons[poly_type])
                    scaled_train_poly = utils.scale_to_mask(
                        im_id, im_size, train_poly)
                    true_mask = utils.mask_for_polygons(
                        im_size, scaled_train_poly)
                    if validation == 'square':
                        true_mask = square(true_mask, hps)
                    write_mask = lambda m, name: cv2.imwrite(
                        str(store / '{}_{}_{}.png'.format(im_id, cls, name)),
                        255 * m)
                    write_mask(true_mask, 'true_mask')
                    write_mask(mask, 'pixel_mask')
                    write_mask(poly_mask, 'poly_mask')
                    jaccard_stats.append(
                        log_jaccard(im_id,
                                    cls,
                                    true_mask,
                                    mask,
                                    poly_mask,
                                    scaled_train_poly,
                                    unscaled,
                                    valid_polygons=valid_polygons))
    else:
        logger.info('{} empty'.format(im_id))
        rows = [(im_id, str(cls + 1), 'MULTIPOLYGON EMPTY') for cls in classes]
    return rows, jaccard_stats