Example #1
0
File: plot.py Project: lewfish/mlx
    def make_debug_plots(self,
                         dataloader,
                         model,
                         classes,
                         output_dir,
                         max_plots=25,
                         score_thresh=0.3):
        preds_dir = join(output_dir, 'preds')
        zip_path = join(output_dir, 'preds.zip')
        make_dir(preds_dir, force_empty=True)

        model.eval()
        for batch_x, batch_y in dataloader:
            with torch.no_grad():
                device = list(model.parameters())[0].device
                batch_x = batch_x.to(device=device)
                batch_sz = batch_x.shape[0]
                batch_boxlist = model(batch_x)

            for img_ind in range(batch_sz):
                x = batch_x[img_ind].cpu()
                y = batch_y[img_ind].cpu()
                boxlist = batch_boxlist[img_ind].score_filter(
                    score_thresh).cpu()

                # Plot image, ground truth, and predictions
                fig = self.plot_image_preds(x, y, boxlist, classes)
                plt.savefig(join(preds_dir, '{}-images.png'.format(img_ind)),
                            bbox_inches='tight')
                plt.close(fig)
            break

        zipdir(preds_dir, zip_path)
        shutil.rmtree(preds_dir)
Example #2
0
File: data.py Project: lewfish/mlx
def setup_output_dir(cfg, tmp_dir):
    output_uri = cfg.output_uri
    if not output_uri.startswith('s3://'):
        return output_uri

    output_dir = get_local_path(output_uri, tmp_dir)
    make_dir(output_dir, force_empty=True)
    sync_from_dir(output_uri, output_dir)
    return output_dir
Example #3
0
def main(test, s3_data, batch, debug):
    """Train a semantic segmentation FPN model on the CamVid-Tiramisu dataset."""
    if batch:
        run_on_batch(test, debug)

    # Setup options
    batch_sz = 8
    num_workers = 4
    num_epochs = 20
    lr = 1e-4
    backbone_arch = 'resnet18'
    sample_pct = 1.0

    if test:
        batch_sz = 1
        num_workers = 0
        num_epochs = 2
        sample_pct = 0.01

    # Setup data
    tmp_dir_obj = tempfile.TemporaryDirectory()
    tmp_dir = tmp_dir_obj.name
    output_dir = local_output_uri
    make_dir(output_dir)

    data_dir = download_data(s3_data, tmp_dir)
    data = get_databunch(data_dir,
                         sample_pct=sample_pct,
                         batch_sz=batch_sz,
                         num_workers=num_workers)
    print(data)
    plot_data(data, output_dir)

    # Setup and train model
    num_classes = data.c
    model = SegmentationFPN(backbone_arch, num_classes)
    metrics = [acc_camvid]
    learn = Learner(data,
                    model,
                    metrics=metrics,
                    loss_func=SegmentationFPN.loss,
                    path=output_dir)
    learn.unfreeze()

    callbacks = [
        SaveModelCallback(learn, monitor='valid_loss'),
        CSVLogger(learn, filename='log'),
    ]

    learn.fit_one_cycle(num_epochs, lr, callbacks=callbacks)

    # Plot predictions and sync
    plot_preds(data, learn, output_dir)

    if s3_data:
        sync_to_dir(output_dir, remote_output_uri)
Example #4
0
File: data.py Project: lewfish/mlx
def setup_output_dir(cfg, tmp_dir):
    if not cfg.output_uri.startswith('s3://'):
        make_dir(cfg.output_uri)
        return cfg.output_uri

    output_dir = get_local_path(cfg.output_uri, tmp_dir)
    make_dir(output_dir, force_empty=True)
    if not cfg.overfit_mode:
        sync_from_dir(cfg.output_uri, output_dir)
    return output_dir
Example #5
0
    def plot_dataloader(self, dataloader, output_path):
        x, y = next(iter(dataloader))
        batch_sz = x.shape[0]

        ncols = nrows = math.ceil(math.sqrt(batch_sz))
        fig = plt.figure(constrained_layout=True, figsize=(3 * ncols, 3 * nrows))
        grid = gridspec.GridSpec(ncols=ncols, nrows=nrows, figure=fig)

        for i in range(batch_sz):
            ax = fig.add_subplot(grid[i])
            plot_xyz(ax, x[i], y[i], self.label_names)

        make_dir(output_path, use_dirname=True)
        plt.savefig(output_path)
        plt.close()
Example #6
0
    def _plot_data(split):
        debug_chips_dir = join(output_dir, '{}-debug-chips'.format(split))
        zip_path = join(output_dir, '{}-debug-chips.zip'.format(split))
        make_dir(debug_chips_dir, force_empty=True)

        ds = data.train_ds if split == 'train' else data.valid_ds
        for i, (x, y) in enumerate(ds):
            if i == max_per_split:
                break
            x.show(y=y)
            plt.savefig(join(debug_chips_dir, '{}.png'.format(i)),
                        figsize=(3, 3))
            plt.close()
        zipdir(debug_chips_dir, zip_path)
        shutil.rmtree(debug_chips_dir)
Example #7
0
def plot_preds(data, learn, output_dir, max_plots=50):
    preds_dir = join(output_dir, 'preds')
    zip_path = join(output_dir, 'preds.zip')
    make_dir(preds_dir, force_empty=True)

    ds = data.valid_ds
    for i, (x, y) in enumerate(ds):
        if i == max_plots:
            break
        z = learn.predict(x)
        x.show(y=z[0])
        plt.savefig(join(preds_dir, '{}.png'.format(i)), figsize=(3, 3))
        plt.close()

    zipdir(preds_dir, zip_path)
    shutil.rmtree(preds_dir)
Example #8
0
    def make_split(split, split_size):
        nonlocal im_id
        nonlocal ann_id

        image_dir = join(output_dir, 'train')
        make_dir(image_dir)

        images = []
        annotations = []
        for _ in range(split_size):
            img, boxes = make_scene(img_size, max_boxes)
            img = np.transpose(img, (1, 2, 0))
            file_name = '{}.png'.format(im_id)
            Image.fromarray(img).save(join(image_dir, file_name))
            images.append({
                'id': im_id,
                'height': img_size,
                'width': img_size,
                'file_name': file_name
            })
            for box in boxes:
                annotations.append({
                    'id':
                    ann_id,
                    'image_id':
                    im_id,
                    'category_id':
                    1,
                    'area': (box[2] - box[0]) * (box[3] - box[1]),
                    'bbox': [box[1], box[0], box[3] - box[1], box[2] - box[0]]
                })
                ann_id += 1
            im_id += 1

        categories = [{'id': 1, 'name': 'rectangle'}]
        labels = {
            'images': images,
            'annotations': annotations,
            'categories': categories
        }
        json_to_file(labels, join(output_dir, '{}.json'.format(split)))
Example #9
0
File: plot.py Project: lewfish/mlx
    def make_debug_plots(self,
                         dataset,
                         model,
                         classes,
                         output_dir,
                         max_plots=25,
                         score_thresh=0.25):
        preds_dir = join(output_dir, 'preds')
        zip_path = join(output_dir, 'preds.zip')
        make_dir(preds_dir, force_empty=True)

        model.eval()
        for img_id, (x, y) in enumerate(dataset):
            if img_id == max_plots:
                break

            # Get predictions
            boxlist, head_out = self.get_pred(x, model, score_thresh)

            # Plot image, ground truth, and predictions
            fig = self.plot_image_preds(x, y, boxlist, classes)
            plt.savefig(join(preds_dir, '{}.png'.format(img_id)),
                        dpi=200,
                        bbox_inches='tight')
            plt.close(fig)

            # Plot raw output of network at each level.
            for level, level_out in enumerate(head_out):
                stride = model.fpn.strides[level]
                reg_arr, label_arr, center_arr = level_out

                # Plot label_arr
                label_arr = label_arr[0].detach().cpu()
                label_probs = torch.sigmoid(label_arr)
                fig = self.plot_label_arr(label_probs, classes, stride)
                plt.savefig(join(preds_dir,
                                 '{}-{}-label-arr.png'.format(img_id, stride)),
                            dpi=100,
                            bbox_inches='tight')
                plt.close(fig)

                # Plot top, left, bottom, right from reg_arr and center_arr.
                reg_arr = reg_arr[0].detach().cpu()
                center_arr = center_arr[0][0].detach().cpu()
                center_probs = torch.sigmoid(center_arr)
                fig = plot_reg_center_arr(reg_arr, center_probs, stride)
                plt.savefig(join(
                    preds_dir,
                    '{}-{}-reg-center-arr.png'.format(img_id, stride)),
                            dpi=100,
                            bbox_inches='tight')
                plt.close(fig)

            # Get encoding of ground truth targets.
            h, w = x.shape[1:]
            targets = encode_single_targets(y.boxes, y.get_field('labels'),
                                            model.pyramid_shape,
                                            model.num_labels)

            # Plot encoding of ground truth at each level.
            for level, level_targets in enumerate(targets):
                stride = model.fpn.strides[level]
                reg_arr, label_arr, center_arr = level_targets

                # Plot label_arr
                label_probs = label_arr.detach().cpu()
                fig = self.plot_label_arr(label_probs, classes, stride)
                plt.savefig(join(
                    preds_dir, '{}-{}-label-arr-gt.png'.format(img_id,
                                                               stride)),
                            dpi=100,
                            bbox_inches='tight')
                plt.close(fig)

                # Plot top, left, bottom, right from reg_arr and center_arr.
                reg_arr = reg_arr.detach().cpu()
                center_arr = center_arr[0].detach().cpu()
                center_probs = center_arr
                fig = plot_reg_center_arr(reg_arr, center_probs, stride)
                plt.savefig(join(
                    preds_dir,
                    '{}-{}-reg-center-arr-gt.png'.format(img_id, stride)),
                            dpi=100,
                            bbox_inches='tight')
                plt.close(fig)

        zipdir(preds_dir, zip_path)
        shutil.rmtree(preds_dir)
Example #10
0
def main(test, s3_data, batch, debug):
    if batch:
        run_on_batch(test, debug)

    # Setup options
    bs = 16
    size = 256
    num_workers = 4
    num_epochs = 100
    lr = 1e-4
    # for size 256
    # Subtract 2 because there's no padding on final convolution
    grid_sz = 8 - 2

    if test:
        bs = 8
        size = 128
        num_debug_images = 32
        num_workers = 0
        num_epochs = 1
        # for size 128
        grid_sz = 4 - 2

    # Setup data
    make_dir(output_dir)

    data_dir = untar_data(URLs.PASCAL_2007, dest='/opt/data/pascal2007/data')
    img_path = data_dir/'train/'
    trn_path = data_dir/'train.json'
    trn_images, trn_lbl_bbox = get_annotations(trn_path)
    val_path = data_dir/'valid.json'
    val_images, val_lbl_bbox = get_annotations(val_path)

    images, lbl_bbox = trn_images+val_images, trn_lbl_bbox+val_lbl_bbox
    img2bbox = dict(zip(images, lbl_bbox))
    get_y_func = lambda o: img2bbox[o.name]

    with open(trn_path) as f:
        d = json.load(f)
        classes = sorted(d['categories'], key=lambda x: x['id'])
        classes = [x['name'] for x in classes]
        classes = ['background'] + classes
        num_classes = len(classes)

    anc_sizes = torch.tensor([
        [1, 1],
        [2, 2],
        [3, 3],
        [3, 1],
        [1, 3]], dtype=torch.float32)
    grid = ObjectDetectionGrid(grid_sz, anc_sizes, num_classes)
    score_thresh = 0.1
    iou_thresh = 0.8

    class MyObjectCategoryList(ObjectCategoryList):
        def analyze_pred(self, pred):
            boxes, labels, _ = grid.get_preds(
                pred.unsqueeze(0), score_thresh=score_thresh,
                iou_thresh=iou_thresh)
            return (boxes[0], labels[0])

    class MyObjectItemList(ObjectItemList):
        _label_cls = MyObjectCategoryList

    def get_data(bs, size, ):
        src = MyObjectItemList.from_folder(img_path)
        if test:
            src = src[0:num_debug_images]
        src = src.split_by_files(val_images)
        src = src.label_from_func(get_y_func, classes=classes)
        src = src.transform(get_transforms(), size=size, tfm_y=True)
        return src.databunch(path=data_dir, bs=bs, collate_fn=bb_pad_collate,
                             num_workers=num_workers)

    data = get_data(bs, size)
    print(data)
    plot_data(data, output_dir)

    # Setup model
    model = ObjectDetectionModel(grid)

    def loss(out, gt_boxes, gt_classes):
        gt = model.grid.encode(gt_boxes, gt_classes)
        box_loss, class_loss = model.grid.compute_losses(out, gt)
        return box_loss + class_loss

    metrics = [F1(grid, score_thresh=score_thresh, iou_thresh=iou_thresh)]
    learn = Learner(data, model, metrics=metrics, loss_func=loss,
                    path=output_dir)
    callbacks = [
        CSVLogger(learn, filename='log')
    ]
    # model.freeze_body()
    learn.fit_one_cycle(num_epochs, lr, callbacks=callbacks)

    plot_preds(data, learn, output_dir)

    if s3_data:
        sync_to_dir(output_dir, output_uri)
Example #11
0
File: main.py Project: lewfish/mlx
def main(config_path, opts):
    # Load config and setup output_dir.
    torch_cache_dir = '/opt/data/torch-cache'
    os.environ['TORCH_HOME'] = torch_cache_dir
    tmp_dir_obj = tempfile.TemporaryDirectory()
    tmp_dir = tmp_dir_obj.name

    cfg = load_config(config_path, opts)
    if cfg.output_uri.startswith('s3://'):
        output_dir = get_local_path(cfg.output_uri, tmp_dir)
        make_dir(output_dir, force_empty=True)
        if not cfg.overfit_mode:
            sync_from_dir(cfg.output_uri, output_dir)
    else:
        output_dir = cfg.output_uri
        make_dir(output_dir)
    shutil.copyfile(config_path, join(output_dir, 'config.yml'))

    print(cfg)
    print()

    # Setup databunch and plot.
    databunch = build_databunch(cfg)
    print(databunch)
    print()
    if not cfg.predict_mode:
        databunch.plot_dataloaders(output_dir)

    # Setup learner.
    num_labels = len(databunch.label_names)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    model = build_model(cfg)
    model.to(device)
    if cfg.model.init_weights:
        model.load_state_dict(
            torch.load(cfg.model.init_weights, map_location=device))

    opt = build_optimizer(cfg, model)
    loss_fn = torch.nn.CrossEntropyLoss()
    start_epoch = 0
    train_state_path = join(output_dir, 'train_state.json')
    if isfile(train_state_path):
        train_state = file_to_json(train_state_path)
        start_epoch = train_state['epoch'] + 1
    num_samples = len(databunch.train_ds)
    step_scheduler, epoch_scheduler = build_scheduler(cfg, num_samples, opt,
                                                      start_epoch)
    learner = Learner(cfg, databunch, output_dir, model, loss_fn, opt, device,
                      epoch_scheduler, step_scheduler)

    # Train
    if not cfg.predict_mode:
        if cfg.overfit_mode:
            learner.overfit()
        else:
            learner.train()

    # Evaluate on test set and plot.
    if cfg.eval_train:
        print('\nEvaluating on train set...')
        metrics = learner.validate_epoch(databunch.train_dl)
        print('train metrics: {}'.format(metrics))
        json_to_file(metrics, join(output_dir, 'train_metrics.json'))

        print('\nPlotting training set predictions...')
        learner.plot_preds(databunch.train_dl,
                           join(output_dir, 'train_preds.png'))

    print('\nEvaluating on test set...')
    metrics = learner.validate(databunch.test_dl)
    print('test metrics: {}'.format(metrics))
    json_to_file(metrics, join(output_dir, 'test_metrics.json'))

    print('\nPlotting test set predictions...')
    learner.plot_preds(databunch.test_dl, join(output_dir, 'test_preds.png'))

    if cfg.output_uri.startswith('s3://'):
        sync_to_dir(output_dir, cfg.output_uri)