예제 #1
0
파일: plot.py 프로젝트: lewfish/mlx
    def make_debug_plots(self,
                         dataloader,
                         model,
                         classes,
                         output_dir,
                         max_plots=25,
                         score_thresh=0.3):
        preds_dir = join(output_dir, 'preds')
        zip_path = join(output_dir, 'preds.zip')
        make_dir(preds_dir, force_empty=True)

        model.eval()
        for batch_x, batch_y in dataloader:
            with torch.no_grad():
                device = list(model.parameters())[0].device
                batch_x = batch_x.to(device=device)
                batch_sz = batch_x.shape[0]
                batch_boxlist = model(batch_x)

            for img_ind in range(batch_sz):
                x = batch_x[img_ind].cpu()
                y = batch_y[img_ind].cpu()
                boxlist = batch_boxlist[img_ind].score_filter(
                    score_thresh).cpu()

                # Plot image, ground truth, and predictions
                fig = self.plot_image_preds(x, y, boxlist, classes)
                plt.savefig(join(preds_dir, '{}-images.png'.format(img_ind)),
                            bbox_inches='tight')
                plt.close(fig)
            break

        zipdir(preds_dir, zip_path)
        shutil.rmtree(preds_dir)
예제 #2
0
    def _plot_data(split):
        debug_chips_dir = join(output_dir, '{}-debug-chips'.format(split))
        zip_path = join(output_dir, '{}-debug-chips.zip'.format(split))
        make_dir(debug_chips_dir, force_empty=True)

        ds = data.train_ds if split == 'train' else data.valid_ds
        for i, (x, y) in enumerate(ds):
            if i == max_per_split:
                break
            x.show(y=y)
            plt.savefig(join(debug_chips_dir, '{}.png'.format(i)),
                        figsize=(3, 3))
            plt.close()
        zipdir(debug_chips_dir, zip_path)
        shutil.rmtree(debug_chips_dir)
예제 #3
0
def plot_preds(data, learn, output_dir, max_plots=50):
    preds_dir = join(output_dir, 'preds')
    zip_path = join(output_dir, 'preds.zip')
    make_dir(preds_dir, force_empty=True)

    ds = data.valid_ds
    for i, (x, y) in enumerate(ds):
        if i == max_plots:
            break
        z = learn.predict(x)
        x.show(y=z[0])
        plt.savefig(join(preds_dir, '{}.png'.format(i)), figsize=(3, 3))
        plt.close()

    zipdir(preds_dir, zip_path)
    shutil.rmtree(preds_dir)
예제 #4
0
파일: plot.py 프로젝트: lewfish/mlx
    def make_debug_plots(self,
                         dataset,
                         model,
                         classes,
                         output_dir,
                         max_plots=25,
                         score_thresh=0.25):
        preds_dir = join(output_dir, 'preds')
        zip_path = join(output_dir, 'preds.zip')
        make_dir(preds_dir, force_empty=True)

        model.eval()
        for img_id, (x, y) in enumerate(dataset):
            if img_id == max_plots:
                break

            # Get predictions
            boxlist, head_out = self.get_pred(x, model, score_thresh)

            # Plot image, ground truth, and predictions
            fig = self.plot_image_preds(x, y, boxlist, classes)
            plt.savefig(join(preds_dir, '{}.png'.format(img_id)),
                        dpi=200,
                        bbox_inches='tight')
            plt.close(fig)

            # Plot raw output of network at each level.
            for level, level_out in enumerate(head_out):
                stride = model.fpn.strides[level]
                reg_arr, label_arr, center_arr = level_out

                # Plot label_arr
                label_arr = label_arr[0].detach().cpu()
                label_probs = torch.sigmoid(label_arr)
                fig = self.plot_label_arr(label_probs, classes, stride)
                plt.savefig(join(preds_dir,
                                 '{}-{}-label-arr.png'.format(img_id, stride)),
                            dpi=100,
                            bbox_inches='tight')
                plt.close(fig)

                # Plot top, left, bottom, right from reg_arr and center_arr.
                reg_arr = reg_arr[0].detach().cpu()
                center_arr = center_arr[0][0].detach().cpu()
                center_probs = torch.sigmoid(center_arr)
                fig = plot_reg_center_arr(reg_arr, center_probs, stride)
                plt.savefig(join(
                    preds_dir,
                    '{}-{}-reg-center-arr.png'.format(img_id, stride)),
                            dpi=100,
                            bbox_inches='tight')
                plt.close(fig)

            # Get encoding of ground truth targets.
            h, w = x.shape[1:]
            targets = encode_single_targets(y.boxes, y.get_field('labels'),
                                            model.pyramid_shape,
                                            model.num_labels)

            # Plot encoding of ground truth at each level.
            for level, level_targets in enumerate(targets):
                stride = model.fpn.strides[level]
                reg_arr, label_arr, center_arr = level_targets

                # Plot label_arr
                label_probs = label_arr.detach().cpu()
                fig = self.plot_label_arr(label_probs, classes, stride)
                plt.savefig(join(
                    preds_dir, '{}-{}-label-arr-gt.png'.format(img_id,
                                                               stride)),
                            dpi=100,
                            bbox_inches='tight')
                plt.close(fig)

                # Plot top, left, bottom, right from reg_arr and center_arr.
                reg_arr = reg_arr.detach().cpu()
                center_arr = center_arr[0].detach().cpu()
                center_probs = center_arr
                fig = plot_reg_center_arr(reg_arr, center_probs, stride)
                plt.savefig(join(
                    preds_dir,
                    '{}-{}-reg-center-arr-gt.png'.format(img_id, stride)),
                            dpi=100,
                            bbox_inches='tight')
                plt.close(fig)

        zipdir(preds_dir, zip_path)
        shutil.rmtree(preds_dir)
예제 #5
0
    def make_debug_plots(self,
                         dataloader,
                         model,
                         classes,
                         output_path,
                         max_plots=25,
                         score_thresh=0.3):
        with tempfile.TemporaryDirectory() as preds_dir:
            model.eval()
            for batch_x, batch_y in dataloader:
                with torch.no_grad():
                    device = list(model.parameters())[0].device
                    batch_x = batch_x.to(device=device)
                    batch_sz = batch_x.shape[0]
                    batch_boxlist, batch_head_out = model(batch_x,
                                                          get_head_out=True)

                for img_ind in range(batch_sz):
                    x = batch_x[img_ind].cpu()
                    y = batch_y[img_ind].cpu()
                    boxlist = batch_boxlist[img_ind].score_filter(
                        score_thresh).cpu()
                    head_out = (batch_head_out[0][img_ind],
                                batch_head_out[1][img_ind])

                    # Plot image, ground truth, and predictions
                    fig = self.plot_image_preds(x, y, boxlist, classes)
                    fig.savefig(join(preds_dir,
                                     '{}-images.png'.format(img_ind)),
                                bbox_inches='tight')
                    plt.close(fig)

                    # Plot raw output of network.
                    keypoint, reg = head_out
                    keypoint, reg = keypoint.cpu(), reg.cpu()
                    stride = model.stride

                    fig = plot_encoded(boxlist,
                                       stride,
                                       keypoint,
                                       reg,
                                       classes=classes)
                    fig.savefig(join(preds_dir,
                                     '{}-output.png'.format(img_ind)),
                                bbox_inches='tight')
                    plt.close(fig)

                    # Plot encoding of ground truth targets.
                    h, w = x.shape[1:]
                    positions = get_positions(h, w, stride, y.boxes.device)
                    keypoint, reg = encode([y], positions, stride,
                                           len(classes), self.cfg)
                    fig = plot_encoded(y,
                                       stride,
                                       keypoint[0],
                                       reg[0],
                                       classes=classes)
                    fig.savefig(join(preds_dir,
                                     '{}-targets.png'.format(img_ind)),
                                bbox_inches='tight')
                    plt.close(fig)
                break

            zipdir(preds_dir, output_path)