Exemplo n.º 1
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('output', help='output director')
    args = parser.parse_args()

    output = Path(args.output)
    output.mkdir(exist_ok=True)
    poly_stats = {}
    for im_id in sorted(utils.get_wkt_data()):
        print(im_id)
        im_data = utils.load_image(im_id, rgb_only=True)
        im_data = utils.scale_percentile(im_data)
        cv2.imwrite(str(output.joinpath('{}.jpg'.format(im_id))),
                    255 * im_data)
        im_size = im_data.shape[:2]
        poly_by_type = utils.load_polygons(im_id, im_size)
        for poly_type, poly in sorted(poly_by_type.items()):
            cls = poly_type - 1
            mask = utils.mask_for_polygons(im_size, poly)
            cv2.imwrite(
                str(output.joinpath('{}_mask_{}.png'.format(im_id, cls))),
                255 * mask)
            poly_stats.setdefault(im_id, {})[cls] = {
                'area': poly.area / (im_size[0] * im_size[1]),
                'perimeter': int(poly.length),
                'number': len(poly),
            }

    output.joinpath('stats.json').write_text(json.dumps(poly_stats))

    for key in ['number', 'perimeter', 'area']:
        if key == 'area':
            fmt = '{:.4%}'.format
        else:
            fmt = lambda x: x
        print('\n{}'.format(key))
        print(
            tabulate.tabulate(
                [[im_id] + [fmt(s[cls][key]) for cls in range(10)]
                 for im_id, s in sorted(poly_stats.items())],
                headers=['im_id'] + list(range(10))))
Exemplo n.º 2
0
def get_poly_data(im_id, *, store: Path, classes: List[int], epsilon: float,
                  min_area: float, min_small_area: float, validation: str,
                  to_fix: Set[str], hps: HyperParams, valid_polygons: bool,
                  buffer: float):
    train_polygons = utils.get_wkt_data().get(im_id)
    jaccard_stats = []
    path = mask_path(store, im_id)
    if path.exists():
        logger.info(im_id)
        with gzip.open(str(path), 'rb') as f:
            try:
                masks = np.load(f)  # type: np.ndarray
            except Exception:
                logger.error('Error loading mask {}'.format(path))
                raise
            if validation == 'square':
                masks = square(masks, hps)
        rows = []
        if validation:
            im_data = utils.load_image(im_id, rgb_only=True)
            im_size = im_data.shape[:2]
            if validation == 'square':
                im_data = square(im_data, hps)
            cv2.imwrite(str(store / '{}_image.png'.format(im_id)),
                        255 * utils.scale_percentile(im_data))
        for cls, mask in zip(classes, masks):
            poly_type = cls + 1
            if train_polygons and not validation:
                rows.append((im_id, str(poly_type), 'MULTIPOLYGON EMPTY'))
            else:
                unscaled, pred_poly = get_polygons(
                    im_id,
                    mask,
                    epsilon,
                    min_area=min_small_area if cls in {1, 8, 9} else min_area,
                    fix='{}_{}'.format(im_id, poly_type) in to_fix,
                    buffer=buffer,
                )
                rows.append((im_id, str(poly_type),
                             shapely.wkt.dumps(pred_poly,
                                               rounding_precision=8)))
                if validation:
                    poly_mask = utils.mask_for_polygons(mask.shape, unscaled)
                    train_poly = shapely.wkt.loads(train_polygons[poly_type])
                    scaled_train_poly = utils.scale_to_mask(
                        im_id, im_size, train_poly)
                    true_mask = utils.mask_for_polygons(
                        im_size, scaled_train_poly)
                    if validation == 'square':
                        true_mask = square(true_mask, hps)
                    write_mask = lambda m, name: cv2.imwrite(
                        str(store / '{}_{}_{}.png'.format(im_id, cls, name)),
                        255 * m)
                    write_mask(true_mask, 'true_mask')
                    write_mask(mask, 'pixel_mask')
                    write_mask(poly_mask, 'poly_mask')
                    jaccard_stats.append(
                        log_jaccard(im_id,
                                    cls,
                                    true_mask,
                                    mask,
                                    poly_mask,
                                    scaled_train_poly,
                                    unscaled,
                                    valid_polygons=valid_polygons))
    else:
        logger.info('{} empty'.format(im_id))
        rows = [(im_id, str(cls + 1), 'MULTIPOLYGON EMPTY') for cls in classes]
    return rows, jaccard_stats
Exemplo n.º 3
0
def main():
    parser = argparse.ArgumentParser()
    arg = parser.add_argument
    arg('logdir', type=Path, help='Path to log directory')
    arg('output', type=str, help='Submission csv')
    arg('--only', help='Only predict these image ids (comma-separated)')
    arg('--threshold', type=float, default=0.5)
    arg('--epsilon', type=float, default=2.0, help='smoothing')
    arg('--min-area', type=float, default=50.0)
    arg('--min-small-area', type=float, default=10.0)
    arg('--masks-only', action='store_true', help='Do only mask prediction')
    arg('--model-path',
        type=Path,
        help='Path to a specific model (if the last is not desired)')
    arg('--processes', type=int, default=30)
    arg('--validation',
        choices=['square', 'custom'],
        help='only validation images, check jaccard, '
        'save masks and polygons as png')
    arg('--valid-polygons',
        action='store_true',
        help='validation via polygons')
    arg('--fix', nargs='+', help='{im_id}_{poly_type} format, e.g 6100_1_1_10')
    arg('--force-predict', action='store_true')
    arg('--no-edges', action='store_true', help='disable prediction on edges')
    arg('--buffer', type=float, help='do .buffer(x) on pred polygons')
    args = parser.parse_args()
    to_fix = set(args.fix or [])
    with open('/train_log_dir/checkpoint-folder/hps.json') as jsonfile:
        hps = HyperParams(**json.load(jsonfile))

    only = set(args.only.split(',')) if args.only else set()
    with open('sample_submission.csv') as f:
        reader = csv.reader(f)
        header = next(reader)
        image_ids = [im_id for im_id, cls, _ in reader if cls == '1']

    store = args.logdir  # type: Path
    store.mkdir(exist_ok=True, parents=True)

    train_ids = set(utils.get_wkt_data())
    if only:
        to_predict = only
    elif args.validation:
        if args.validation == 'custom':
            to_predict = [
                '6140_3_1', '6110_1_2', '6160_2_1', '6170_0_4', '6100_2_2'
            ]
        else:
            to_predict = set(train_ids)
    else:
        to_predict = set(image_ids) | set(train_ids)
    if not args.force_predict:
        to_predict_masks = [
            im_id for im_id in to_predict
            if not mask_path(store, im_id).exists()
        ]
    else:
        to_predict_masks = to_predict

    if to_predict_masks:
        predict_masks(args,
                      hps,
                      store,
                      to_predict_masks,
                      args.threshold,
                      validation=args.validation,
                      no_edges=args.no_edges)
    if args.masks_only:
        logger.info('Was building masks only, done.')
        return

    logger.info('Building polygons')
    opener = gzip.open if args.output.endswith('.gz') else open
    with opener(args.output, 'wt') as f:
        writer = csv.writer(f)
        writer.writerow(header)
        to_output = to_predict if args.validation else (only or image_ids)
        jaccard_stats = [[] for _ in hps.classes]
        sizes = [0 for _ in hps.classes]
        with Pool(processes=args.processes) as pool:
            for rows, js in pool.imap(
                    partial(
                        get_poly_data,
                        store=store,
                        classes=hps.classes,
                        epsilon=args.epsilon,
                        min_area=args.min_area,
                        min_small_area=args.min_small_area,
                        validation=args.validation,
                        to_fix=to_fix,
                        hps=hps,
                        valid_polygons=args.valid_polygons,
                        buffer=args.buffer,
                    ), to_output):
                assert len(rows) == hps.n_classes
                writer.writerows(rows)
                for cls_jss, cls_js in zip(jaccard_stats, js):
                    cls_jss.append(cls_js)
                for idx, (_, _, poly) in enumerate(rows):
                    sizes[idx] += len(poly)
        if args.validation:
            pixel_jaccards, poly_jaccards = [], []
            for cls, cls_js in zip(hps.classes, jaccard_stats):
                pixel_jc, poly_jc = [
                    np.array([0, 0, 0], dtype=np.float32) for _ in range(2)
                ]
                for _pixel_jc, _poly_jc in cls_js:
                    pixel_jc += _pixel_jc
                    poly_jc += _poly_jc
                logger.info(
                    'cls-{}: pixel jaccard: {:.5f}, polygon jaccard: {:.5f}'.
                    format(cls, jaccard(pixel_jc), jaccard(poly_jc)))
                pixel_jaccards.append(jaccard(pixel_jc))
                poly_jaccards.append(jaccard(poly_jc))
            logger.info(
                'Mean pixel jaccard: {:.5f}, polygon jaccard: {:.5f}'.format(
                    np.mean(pixel_jaccards), np.mean(poly_jaccards)))
        for cls, size in zip(hps.classes, sizes):
            logger.info('cls-{} size: {:,} bytes'.format(cls, size))
Exemplo n.º 4
0
def main():
    parser = argparse.ArgumentParser()
    arg = parser.add_argument
    arg('--logdir',
        type=str,
        default="/data/data_gWkkHSkq/kaggle-dstl/log",
        help='Path to log directory')
    arg('--hps', help='Change hyperparameters in k1=v1,k2=v2 format')
    arg('--all',
        action='store_true',
        help='Train on all images without validation')
    arg('--validation',
        choices=['random', 'stratified', 'square', 'custom'],
        default='custom',
        help='validation strategy')
    arg('--valid-only', action='store_true')
    arg('--only',
        help='Train on this image ids only (comma-separated) without validation'
        )
    arg('--clean', action='store_true', help='Clean logdir')
    arg('--no-mp', action='store_true', help='Disable multiprocessing')
    arg('--model-path', type=Path)
    args = parser.parse_args()

    logdir = Path(args.logdir)
    logdir.mkdir(exist_ok=True, parents=True)
    if args.clean:
        for p in logdir.iterdir():
            p.unlink()

    if args.hps == 'load':
        hps = HyperParams.from_dir(logdir)
    else:
        hps = HyperParams()
        hps.update(args.hps)
        logdir.joinpath('hps.json').write_text(
            json.dumps(attr.asdict(hps), indent=True, sort_keys=True))
    pprint(attr.asdict(hps))
    print(args)

    model = Model(hps=hps)
    all_im_ids = list(utils.get_wkt_data())
    mask_stats = json.loads(Path('cls-stats.json').read_text())
    im_area = [
        (im_id,
         np.mean([mask_stats[im_id][str(cls)]['area'] for cls in hps.classes]))
        for im_id in all_im_ids
    ]
    area_by_id = dict(im_area)
    valid_ids = []

    if args.only:
        train_ids = args.only.split(',')
    elif args.all:
        train_ids = all_im_ids
    elif args.validation == 'stratified':
        train_ids, valid_ids = [], []
        for idx, (im_id, _) in enumerate(
                sorted(im_area, key=lambda x: (x[1], x[0]), reverse=True)):
            (valid_ids if (idx % 4 == 1) else train_ids).append(im_id)
    elif args.validation == 'square':
        train_ids = valid_ids = all_im_ids
    elif args.validation == 'random':
        forced_train_ids = {'6070_2_3', '6120_2_2', '6110_4_0'}
        other_ids = list(set(all_im_ids) - forced_train_ids)
        train_ids, valid_ids = [[other_ids[idx] for idx in g] for g in next(
            ShuffleSplit(random_state=1, n_splits=4).split(other_ids))]
        train_ids.extend(forced_train_ids)
    elif args.validation == 'custom':
        valid_ids = [
            '6140_3_1', '6110_1_2', '6160_2_1', '6170_0_4', '6100_2_2'
        ]
        train_ids = [im_id for im_id in all_im_ids if im_id not in valid_ids]
    else:
        raise ValueError('Unexpected validation kind: {}'.format(
            args.validation))

    if args.valid_only:
        train_ids = []

    train_area_by_class, valid_area_by_class = [{
        cls: np.mean([mask_stats[im_id][str(cls)]['area'] for im_id in im_ids])
        for cls in hps.classes
    } for im_ids in [train_ids, valid_ids]]

    logger.info('Train: {}'.format(' '.join(sorted(train_ids))))
    logger.info('Valid: {}'.format(' '.join(sorted(valid_ids))))
    logger.info('Train area mean: {:.6f}'.format(
        np.mean([area_by_id[im_id] for im_id in valid_ids])))
    logger.info('Train area by class: {}'.format(' '.join(
        '{}: {:.6f}'.format(cls, train_area_by_class[cls])
        for cls in hps.classes)))
    logger.info('Valid area mean: {:.6f}'.format(
        np.mean([area_by_id[im_id] for im_id in train_ids])))
    logger.info('Valid area by class: {}'.format(' '.join(
        'cls-{}: {:.6f}'.format(cls, valid_area_by_class[cls])
        for cls in hps.classes)))

    model.train(logdir=logdir,
                train_ids=train_ids,
                valid_ids=valid_ids,
                validation=args.validation,
                no_mp=args.no_mp,
                valid_only=args.valid_only,
                model_path=args.model_path)