def test_image_dtype(self): expected_dtypes = ('uint8', 'float32') norbs = (NORB(which_set='train', which_norb='small'), NORB(which_set='train', which_norb='small', image_dtype='float32')) for norb, expected_dtype in safe_zip(norbs, expected_dtypes): assert str(norb.X.dtype) == expected_dtype
def test_label_to_value_funcs(self): def test_impl(norb): label_to_value_maps = ( # category { 0: 'animal', 1: 'human', 2: 'airplane', 3: 'truck', 4: 'car', 5: 'blank' }, # instance dict(safe_zip(range(10), range(10))), # elevation dict(safe_zip(range(9), numpy.arange(9) * 5 + 30)), # azimuth dict(safe_zip(range(0, 36, 2), numpy.arange(0, 360, 20))), # lighting dict(safe_zip(range(5), range(5))), # horizontal shift dict(safe_zip(range(-5, 6), range(-5, 6))), # vertical shift dict(safe_zip(range(-5, 6), range(-5, 6))), # lumination change dict(safe_zip(range(-19, 20), range(-19, 20))), # contrast change dict(safe_zip(range(2), (0.8, 1.3)))) # Use of zip rather than safe_zip intentional; # norb.label_to_value_funcs will be shorter than # label_to_value_maps if norb is small NORB. for (label_to_value_map, label_to_value_func) in zip(label_to_value_maps, norb.label_to_value_funcs): for label, expected_value in six.iteritems(label_to_value_map): actual_value = label_to_value_func(label) assert expected_value == actual_value test_impl(NORB(which_set='test', which_norb='small')) test_impl(NORB(which_set='test', which_norb='big'))
def test_get_topological_view(self): def test_impl(norb): # Get a topological view as a single "(b, s, 0 1, c)" tensor. topo_tensor = norb.get_topological_view(single_tensor=True) shape = ((norb.X.shape[0], 2) + SmallNORB.original_image_shape + (1, )) expected_topo_tensor = norb.X.reshape(shape) # We loop to lower the peak memory usage for i in range(topo_tensor.shape[0]): assert numpy.all(topo_tensor[i] == expected_topo_tensor[i]) # Get a topological view as two "(b, 0, 1, c)" tensors topo_tensors = norb.get_topological_view(single_tensor=False) expected_topo_tensors = tuple(expected_topo_tensor[:, i, ...] for i in range(2)) for topo_tensor, expected_topo_tensor in safe_zip( topo_tensors, expected_topo_tensors): assert numpy.all(topo_tensor == expected_topo_tensor) # Use stop parameter for SmallNORB; otherwise the buildbot uses close # to 10G of RAM. for norb in (SmallNORB('train', stop=1000), NORB(which_norb='small', which_set='train')): test_impl(norb)
def main(): """Top-level function.""" args = _parse_args() dataset = NORB(args.which_norb, args.which_set) # Indexes into the first 5 labels, which live on a 5-D grid. grid_indices = [ 0, ] * 5 grid_to_short_label = _make_grid_to_short_label(dataset) # Maps 5-D label vector to a list of row indices for dataset.X, dataset.y # that have those labels. label_to_row_indices = _make_label_to_row_indices(dataset.y) # Indexes into the row index lists returned by label_to_row_indices. object_image_index = [ 0, ] blank_image_index = [ 0, ] blank_label = _get_blank_label(dataset) # Index into grid_indices currently being edited grid_dimension = [ 0, ] figure, all_axes = pyplot.subplots(1, 3, squeeze=True, figsize=(10, 3.5)) figure.canvas.set_window_title("NORB dataset (%sing set)" % args.which_set) label_text = figure.suptitle( 'Up/down arrows choose label, ' 'left/right arrows change it', x=0.1, horizontalalignment="left") # Hides axes' tick marks for axes in all_axes: axes.get_xaxis().set_visible(False) axes.get_yaxis().set_visible(False) text_axes, image_axes = (all_axes[0], all_axes[1:]) image_captions = ('left', 'right') if args.stereo_viewer: image_captions = tuple(reversed(image_captions)) for image_ax, caption in safe_zip(image_axes, image_captions): image_ax.set_title(caption) text_axes.set_frame_on(False) # Hides background of text_axes def is_blank(grid_indices): assert len(grid_indices) == 5 assert all(x >= 0 for x in grid_indices) ci = dataset.label_name_to_index['category'] # category index category = grid_to_short_label[ci][grid_indices[ci]] category_name = dataset.label_to_value_funcs[ci](category) return category_name == 'blank' def get_short_label(grid_indices): """ Returns the first 5 elements of the label vector pointed to by grid_indices. We use the first 5, since they're the labels used by both the 'big' and Small NORB datasets. """ # Need to special-case the 'blank' category, since it lies outside of # the grid. if is_blank(grid_indices): # won't happen with SmallNORB return tuple(blank_label[:5]) else: return tuple(grid_to_short_label[i][g] for i, g in enumerate(grid_indices)) def get_row_indices(grid_indices): short_label = get_short_label(grid_indices) return label_to_row_indices.get(short_label, None) def redraw(redraw_text, redraw_images): row_indices = get_row_indices(grid_indices) if row_indices is None: row_index = None image_index = 0 num_images = 0 else: image_index = (blank_image_index if is_blank(grid_indices) else object_image_index)[0] row_index = row_indices[image_index] num_images = len(row_indices) def draw_text(): if row_indices is None: padding_length = dataset.y.shape[1] - len(grid_indices) current_label = (tuple(get_short_label(grid_indices)) + (0, ) * padding_length) else: current_label = dataset.y[row_index, :] label_names = dataset.label_index_to_name label_values = [ label_to_value(label) for label_to_value, label in safe_zip( dataset.label_to_value_funcs, current_label) ] lines = [ '%s: %s' % (t, v) for t, v in safe_zip(label_names, label_values) ] if dataset.y.shape[1] > 5: # Inserts image number & blank line between editable and # fixed labels. lines = (lines[:5] + [ 'No such image' if num_images == 0 else 'image: %d of %d' % (image_index + 1, num_images), '\n' ] + lines[5:]) # prepends the current index's line with an arrow. lines[grid_dimension[0]] = '==> ' + lines[grid_dimension[0]] text_axes.clear() # "transAxes": 0, 0 = bottom-left, 1, 1 at upper-right. text_axes.text( 0, 0.5, # coords '\n'.join(lines), verticalalignment='center', transform=text_axes.transAxes) def draw_images(): if row_indices is None: for axis in image_axes: axis.clear() else: data_row = dataset.X[row_index:row_index + 1, :] image_pair = dataset.get_topological_view(mat=data_row, single_tensor=True) # Shaves off the singleton dimensions (batch # and channel #). image_pair = tuple(image_pair[0, :, :, :, 0]) if args.stereo_viewer: image_pair = tuple(reversed(image_pair)) for axis, image in safe_zip(image_axes, image_pair): axis.imshow(image, cmap='gray') if redraw_text: draw_text() if redraw_images: draw_images() figure.canvas.draw() def on_key_press(event): def add_mod(arg, step, size): return (arg + size + step) % size def incr_index_type(step): num_dimensions = len(grid_indices) if dataset.y.shape[1] > 5: # If dataset is big NORB, add one for the image index num_dimensions += 1 grid_dimension[0] = add_mod(grid_dimension[0], step, num_dimensions) def incr_index(step): assert step in (0, -1, 1), ("Step was %d" % step) image_index = (blank_image_index if is_blank(grid_indices) else object_image_index) if grid_dimension[0] == 5: # i.e. the image index row_indices = get_row_indices(grid_indices) if row_indices is None: image_index[0] = 0 else: # increment the image index image_index[0] = add_mod(image_index[0], step, len(row_indices)) else: # increment one of the grid indices gd = grid_dimension[0] grid_indices[gd] = add_mod(grid_indices[gd], step, len(grid_to_short_label[gd])) row_indices = get_row_indices(grid_indices) if row_indices is None: image_index[0] = 0 else: # some grid indices have 2 images instead of 3. image_index[0] = min(image_index[0], len(row_indices)) # Disables left/right key if we're currently showing a blank, # and the current index type is neither 'category' (0) nor # 'image number' (5) disable_left_right = (is_blank(grid_indices) and not (grid_dimension[0] in (0, 5))) if event.key == 'up': incr_index_type(-1) redraw(True, False) elif event.key == 'down': incr_index_type(1) redraw(True, False) elif event.key == 'q': sys.exit(0) elif not disable_left_right: if event.key == 'left': incr_index(-1) redraw(True, True) elif event.key == 'right': incr_index(1) redraw(True, True) figure.canvas.mpl_connect('key_press_event', on_key_press) redraw(True, True) pyplot.show()