Пример #1
0
    def test_image_dtype(self):
        expected_dtypes = ('uint8', 'float32')
        norbs = (NORB(which_set='train', which_norb='small'),
                 NORB(which_set='train',
                      which_norb='small',
                      image_dtype='float32'))

        for norb, expected_dtype in safe_zip(norbs, expected_dtypes):
            assert str(norb.X.dtype) == expected_dtype
Пример #2
0
    def test_label_to_value_funcs(self):
        def test_impl(norb):
            label_to_value_maps = (
                # category
                {
                    0: 'animal',
                    1: 'human',
                    2: 'airplane',
                    3: 'truck',
                    4: 'car',
                    5: 'blank'
                },

                # instance
                dict(safe_zip(range(10), range(10))),

                # elevation
                dict(safe_zip(range(9),
                              numpy.arange(9) * 5 + 30)),

                # azimuth
                dict(safe_zip(range(0, 36, 2), numpy.arange(0, 360, 20))),

                # lighting
                dict(safe_zip(range(5), range(5))),

                # horizontal shift
                dict(safe_zip(range(-5, 6), range(-5, 6))),

                # vertical shift
                dict(safe_zip(range(-5, 6), range(-5, 6))),

                # lumination change
                dict(safe_zip(range(-19, 20), range(-19, 20))),

                # contrast change
                dict(safe_zip(range(2), (0.8, 1.3))))

            # Use of zip rather than safe_zip intentional;
            # norb.label_to_value_funcs will be shorter than
            # label_to_value_maps if norb is small NORB.
            for (label_to_value_map,
                 label_to_value_func) in zip(label_to_value_maps,
                                             norb.label_to_value_funcs):
                for label, expected_value in six.iteritems(label_to_value_map):
                    actual_value = label_to_value_func(label)
                    assert expected_value == actual_value

        test_impl(NORB(which_set='test', which_norb='small'))
        test_impl(NORB(which_set='test', which_norb='big'))
Пример #3
0
    def test_get_topological_view(self):

        def test_impl(norb):
            # Get a topological view as a single "(b, s, 0 1, c)" tensor.
            topo_tensor = norb.get_topological_view(single_tensor=True)
            shape = ((norb.X.shape[0], 2) +
                     SmallNORB.original_image_shape +
                     (1, ))
            expected_topo_tensor = norb.X.reshape(shape)
            # We loop to lower the peak memory usage
            for i in range(topo_tensor.shape[0]):
                assert numpy.all(topo_tensor[i] == expected_topo_tensor[i])

            # Get a topological view as two "(b, 0, 1, c)" tensors
            topo_tensors = norb.get_topological_view(single_tensor=False)
            expected_topo_tensors = tuple(expected_topo_tensor[:, i, ...]
                                          for i in range(2))

            for topo_tensor, expected_topo_tensor in safe_zip(
                    topo_tensors, expected_topo_tensors):
                assert numpy.all(topo_tensor == expected_topo_tensor)

        # Use stop parameter for SmallNORB; otherwise the buildbot uses close
        # to 10G of RAM.
        for norb in (SmallNORB('train', stop=1000),
                     NORB(which_norb='small', which_set='train')):
            test_impl(norb)
Пример #4
0
def main():
    """Top-level function."""

    args = _parse_args()

    dataset = NORB(args.which_norb, args.which_set)
    # Indexes into the first 5 labels, which live on a 5-D grid.
    grid_indices = [
        0,
    ] * 5

    grid_to_short_label = _make_grid_to_short_label(dataset)

    # Maps 5-D label vector to a list of row indices for dataset.X, dataset.y
    # that have those labels.
    label_to_row_indices = _make_label_to_row_indices(dataset.y)

    # Indexes into the row index lists returned by label_to_row_indices.
    object_image_index = [
        0,
    ]
    blank_image_index = [
        0,
    ]
    blank_label = _get_blank_label(dataset)

    # Index into grid_indices currently being edited
    grid_dimension = [
        0,
    ]

    figure, all_axes = pyplot.subplots(1, 3, squeeze=True, figsize=(10, 3.5))

    figure.canvas.set_window_title("NORB dataset (%sing set)" % args.which_set)

    label_text = figure.suptitle(
        'Up/down arrows choose label, '
        'left/right arrows change it',
        x=0.1,
        horizontalalignment="left")

    # Hides axes' tick marks
    for axes in all_axes:
        axes.get_xaxis().set_visible(False)
        axes.get_yaxis().set_visible(False)

    text_axes, image_axes = (all_axes[0], all_axes[1:])
    image_captions = ('left', 'right')

    if args.stereo_viewer:
        image_captions = tuple(reversed(image_captions))

    for image_ax, caption in safe_zip(image_axes, image_captions):
        image_ax.set_title(caption)

    text_axes.set_frame_on(False)  # Hides background of text_axes

    def is_blank(grid_indices):
        assert len(grid_indices) == 5
        assert all(x >= 0 for x in grid_indices)

        ci = dataset.label_name_to_index['category']  # category index
        category = grid_to_short_label[ci][grid_indices[ci]]
        category_name = dataset.label_to_value_funcs[ci](category)
        return category_name == 'blank'

    def get_short_label(grid_indices):
        """
        Returns the first 5 elements of the label vector pointed to by
        grid_indices. We use the first 5, since they're the labels used by
        both the 'big' and Small NORB datasets.
        """

        # Need to special-case the 'blank' category, since it lies outside of
        # the grid.
        if is_blank(grid_indices):  # won't happen with SmallNORB
            return tuple(blank_label[:5])
        else:
            return tuple(grid_to_short_label[i][g]
                         for i, g in enumerate(grid_indices))

    def get_row_indices(grid_indices):
        short_label = get_short_label(grid_indices)
        return label_to_row_indices.get(short_label, None)

    def redraw(redraw_text, redraw_images):
        row_indices = get_row_indices(grid_indices)

        if row_indices is None:
            row_index = None
            image_index = 0
            num_images = 0
        else:
            image_index = (blank_image_index if is_blank(grid_indices) else
                           object_image_index)[0]
            row_index = row_indices[image_index]
            num_images = len(row_indices)

        def draw_text():
            if row_indices is None:
                padding_length = dataset.y.shape[1] - len(grid_indices)
                current_label = (tuple(get_short_label(grid_indices)) +
                                 (0, ) * padding_length)
            else:
                current_label = dataset.y[row_index, :]

            label_names = dataset.label_index_to_name

            label_values = [
                label_to_value(label) for label_to_value, label in safe_zip(
                    dataset.label_to_value_funcs, current_label)
            ]

            lines = [
                '%s: %s' % (t, v)
                for t, v in safe_zip(label_names, label_values)
            ]

            if dataset.y.shape[1] > 5:
                # Inserts image number & blank line between editable and
                # fixed labels.
                lines = (lines[:5] + [
                    'No such image' if num_images == 0 else 'image: %d of %d' %
                    (image_index + 1, num_images), '\n'
                ] + lines[5:])

            # prepends the current index's line with an arrow.
            lines[grid_dimension[0]] = '==> ' + lines[grid_dimension[0]]

            text_axes.clear()

            # "transAxes": 0, 0 = bottom-left, 1, 1 at upper-right.
            text_axes.text(
                0,
                0.5,  # coords
                '\n'.join(lines),
                verticalalignment='center',
                transform=text_axes.transAxes)

        def draw_images():
            if row_indices is None:
                for axis in image_axes:
                    axis.clear()
            else:
                data_row = dataset.X[row_index:row_index + 1, :]
                image_pair = dataset.get_topological_view(mat=data_row,
                                                          single_tensor=True)

                # Shaves off the singleton dimensions (batch # and channel #).
                image_pair = tuple(image_pair[0, :, :, :, 0])

                if args.stereo_viewer:
                    image_pair = tuple(reversed(image_pair))

                for axis, image in safe_zip(image_axes, image_pair):
                    axis.imshow(image, cmap='gray')

        if redraw_text:
            draw_text()

        if redraw_images:
            draw_images()

        figure.canvas.draw()

    def on_key_press(event):
        def add_mod(arg, step, size):
            return (arg + size + step) % size

        def incr_index_type(step):
            num_dimensions = len(grid_indices)
            if dataset.y.shape[1] > 5:
                # If dataset is big NORB, add one for the image index
                num_dimensions += 1

            grid_dimension[0] = add_mod(grid_dimension[0], step,
                                        num_dimensions)

        def incr_index(step):
            assert step in (0, -1, 1), ("Step was %d" % step)

            image_index = (blank_image_index
                           if is_blank(grid_indices) else object_image_index)

            if grid_dimension[0] == 5:  # i.e. the image index
                row_indices = get_row_indices(grid_indices)
                if row_indices is None:
                    image_index[0] = 0
                else:
                    # increment the image index
                    image_index[0] = add_mod(image_index[0], step,
                                             len(row_indices))
            else:
                # increment one of the grid indices
                gd = grid_dimension[0]
                grid_indices[gd] = add_mod(grid_indices[gd], step,
                                           len(grid_to_short_label[gd]))

                row_indices = get_row_indices(grid_indices)
                if row_indices is None:
                    image_index[0] = 0
                else:
                    # some grid indices have 2 images instead of 3.
                    image_index[0] = min(image_index[0], len(row_indices))

        # Disables left/right key if we're currently showing a blank,
        # and the current index type is neither 'category' (0) nor
        # 'image number' (5)
        disable_left_right = (is_blank(grid_indices)
                              and not (grid_dimension[0] in (0, 5)))

        if event.key == 'up':
            incr_index_type(-1)
            redraw(True, False)
        elif event.key == 'down':
            incr_index_type(1)
            redraw(True, False)
        elif event.key == 'q':
            sys.exit(0)
        elif not disable_left_right:
            if event.key == 'left':
                incr_index(-1)
                redraw(True, True)
            elif event.key == 'right':
                incr_index(1)
                redraw(True, True)

    figure.canvas.mpl_connect('key_press_event', on_key_press)
    redraw(True, True)

    pyplot.show()