class StereoViewConverter(object): """ Converts stereo image data between two formats: A) A dense design matrix, one stereo pair per row (VectorSpace) B) An image pair (CompositeSpace of two Conv2DSpaces) Parameters ---------- shape : tuple See doc for __init__'s <shape> parameter. """ def __init__(self, shape, axes=None): """ The arguments describe how the data is laid out in the design matrix. Parameters ---------- shape : tuple A tuple of 4 ints, describing the shape of each datum. This is the size of each axis in <axes>, excluding the 'b' axis. axes : tuple A tuple of the following elements in any order: 'b' batch axis 's' stereo axis 0 image axis 0 (row) 1 image axis 1 (column) 'c' channel axis """ shape = tuple(shape) if not all(isinstance(s, int) for s in shape): raise TypeError("Shape must be a tuple/list of ints") if len(shape) != 4: raise ValueError("Shape array needs to be of length 4, got %s." % shape) datum_axes = list(axes) datum_axes.remove('b') if shape[datum_axes.index('s')] != 2: raise ValueError("Expected 's' axis to have size 2, got %d.\n" " axes: %s\n" " shape: %s" % (shape[datum_axes.index('s')], axes, shape)) self.shape = shape self.set_axes(axes) def make_conv2d_space(shape, axes): shape_axes = list(axes) shape_axes.remove('b') image_shape = tuple(shape[shape_axes.index(axis)] for axis in (0, 1)) conv2d_axes = list(axes) conv2d_axes.remove('s') return Conv2DSpace(shape=image_shape, num_channels=shape[shape_axes.index('c')], axes=conv2d_axes, dtype=None) conv2d_space = make_conv2d_space(shape, axes) self.topo_space = CompositeSpace((conv2d_space, conv2d_space)) self.storage_space = VectorSpace(dim=numpy.prod(shape)) def get_formatted_batch(self, batch, space): """ Returns a batch formatted to a space. Parameters ---------- batch : ndarray The batch to format space : a pylearn2.space.Space The target space to format to. """ return self.storage_space.np_format_as(batch, space) def design_mat_to_topo_view(self, design_mat): """ Called by DenseDesignMatrix.get_formatted_view(), get_batch_topo() Parameters ---------- design_mat : ndarray """ return self.storage_space.np_format_as(design_mat, self.topo_space) def design_mat_to_weights_view(self, design_mat): """ Called by DenseDesignMatrix.get_weights_view() Parameters ---------- design_mat : ndarray """ return self.design_mat_to_topo_view(design_mat) def topo_view_to_design_mat(self, topo_batch): """ Used by DenseDesignMatrix.set_topological_view(), .get_design_mat() Parameters ---------- topo_batch : ndarray """ return self.topo_space.np_format_as(topo_batch, self.storage_space) def view_shape(self): """ TODO: write documentation. """ return self.shape def weights_view_shape(self): """ TODO: write documentation. """ return self.view_shape() def set_axes(self, axes): """ Change the order of the axes. Parameters ---------- axes : tuple Must have length 5, must contain 'b', 's', 0, 1, 'c'. """ axes = tuple(axes) if len(axes) != 5: raise ValueError("Axes must have 5 elements; got %s" % str(axes)) for required_axis in ('b', 's', 0, 1, 'c'): if required_axis not in axes: raise ValueError("Axes must contain 'b', 's', 0, 1, and 'c'. " "Got %s." % str(axes)) if axes.index('b') != 0: raise ValueError("The 'b' axis must come first (axes = %s)." % str(axes)) def remove_b_axis(axes): axes = list(axes) axes.remove('b') return tuple(axes) if hasattr(self, 'axes'): # Reorders the shape vector to match the new axis ordering. assert hasattr(self, 'shape') old_axes = remove_b_axis(self.axes) # pylint: disable-msg=E0203 new_axes = remove_b_axis(axes) new_shape = tuple(self.shape[old_axes.index(a)] for a in new_axes) self.shape = new_shape self.axes = axes
class StereoViewConverter(object): """ Converts stereo image data between two formats: #. A dense design matrix, one stereo pair per row (`VectorSpace`) #. An image pair (`CompositeSpace` of two `Conv2DSpace`) The arguments describe how the data is laid out in the design matrix. Parameters ---------- shape: tuple A tuple of 4 ints, describing the shape of each datum. This is the size of each axis in `<axes>`, excluding the `b` axis. axes : tuple Tuple of the following elements in any order: * 'b' : batch axis * 's' : stereo axis * 0 : image axis 0 (row) * 1 : image axis 1 (column) * 'c' : channel axis """ def __init__(self, shape, axes=None): shape = tuple(shape) if not all(isinstance(s, int) for s in shape): raise TypeError("Shape must be a tuple/list of ints") if len(shape) != 4: raise ValueError("Shape array needs to be of length 4, got %s." % shape) datum_axes = list(axes) datum_axes.remove('b') if shape[datum_axes.index('s')] != 2: raise ValueError("Expected 's' axis to have size 2, got %d.\n" " axes: %s\n" " shape: %s" % (shape[datum_axes.index('s')], axes, shape)) self.shape = shape self.set_axes(axes) def make_conv2d_space(shape, axes): shape_axes = list(axes) shape_axes.remove('b') image_shape = tuple(shape[shape_axes.index(axis)] for axis in (0, 1)) conv2d_axes = list(axes) conv2d_axes.remove('s') return Conv2DSpace(shape=image_shape, num_channels=shape[shape_axes.index('c')], axes=conv2d_axes) conv2d_space = make_conv2d_space(shape, axes) self.topo_space = CompositeSpace((conv2d_space, conv2d_space)) self.storage_space = VectorSpace(dim=numpy.prod(shape)) def get_formatted_batch(self, batch, space): """ .. todo:: WRITEME """ return self.storage_space.np_format_as(batch, space) def design_mat_to_topo_view(self, design_mat): """ Called by DenseDesignMatrix.get_formatted_view(), get_batch_topo() """ return self.storage_space.np_format_as(design_mat, self.topo_space) def design_mat_to_weights_view(self, design_mat): """ Called by DenseDesignMatrix.get_weights_view() """ return self.design_mat_to_topo_view(design_mat) def topo_view_to_design_mat(self, topo_batch): """ Used by `DenseDesignMatrix.set_topological_view()` and `DenseDesignMatrix.get_design_mat()`. """ return self.topo_space.np_format_as(topo_batch, self.storage_space) def view_shape(self): """ .. todo:: WRITEME """ return self.shape def weights_view_shape(self): """ .. todo:: WRITEME """ return self.view_shape() def set_axes(self, axes): """ .. todo:: WRITEME """ axes = tuple(axes) if len(axes) != 5: raise ValueError("Axes must have 5 elements; got %s" % str(axes)) for required_axis in ('b', 's', 0, 1, 'c'): if required_axis not in axes: raise ValueError("Axes must contain 'b', 's', 0, 1, and 'c'. " "Got %s." % str(axes)) if axes.index('b') != 0: raise ValueError("The 'b' axis must come first (axes = %s)." % str(axes)) def get_batchless_axes(axes): axes = list(axes) axes.remove('b') return tuple(axes) if hasattr(self, 'axes'): # Reorders the shape vector to match the new axis ordering. assert hasattr(self, 'shape') old_axes = get_batchless_axes(self.axes) new_axes = get_batchless_axes(axes) new_shape = tuple(self.shape[old_axes.index(a)] for a in new_axes) self.shape = new_shape self.axes = axes
class StereoViewConverter(object): """ Converts stereo image data between two formats: #. A dense design matrix, one stereo pair per row (`VectorSpace`) #. An image pair (`CompositeSpace` of two `Conv2DSpace`) The arguments describe how the data is laid out in the design matrix. Parameters ---------- shape: tuple A tuple of 4 ints, describing the shape of each datum. This is the size of each axis in `<axes>`, excluding the `b` axis. axes : tuple Tuple of the following elements in any order: * 'b' : batch axis * 's' : stereo axis * 0 : image axis 0 (row) * 1 : image axis 1 (column) * 'c' : channel axis """ def __init__(self, shape, axes=None): shape = tuple(shape) if not all(isinstance(s, int) for s in shape): raise TypeError("Shape must be a tuple/list of ints") if len(shape) != 4: raise ValueError("Shape array needs to be of length 4, got %s." % shape) datum_axes = list(axes) datum_axes.remove('b') if shape[datum_axes.index('s')] != 2: raise ValueError("Expected 's' axis to have size 2, got %d.\n" " axes: %s\n" " shape: %s" % (shape[datum_axes.index('s')], axes, shape)) self.shape = shape self.set_axes(axes) def make_conv2d_space(shape, axes): shape_axes = list(axes) shape_axes.remove('b') image_shape = tuple(shape[shape_axes.index(axis)] for axis in (0, 1)) conv2d_axes = list(axes) conv2d_axes.remove('s') return Conv2DSpace(shape=image_shape, num_channels=shape[shape_axes.index('c')], axes=conv2d_axes) conv2d_space = make_conv2d_space(shape, axes) self.topo_space = CompositeSpace((conv2d_space, conv2d_space)) self.storage_space = VectorSpace(dim=numpy.prod(shape)) def get_formatted_batch(self, batch, space): return self.storage_space.np_format_as(batch, space) def design_mat_to_topo_view(self, design_mat): """ Called by DenseDesignMatrix.get_formatted_view(), get_batch_topo() """ return self.storage_space.np_format_as(design_mat, self.topo_space) def design_mat_to_weights_view(self, design_mat): """ Called by DenseDesignMatrix.get_weights_view() """ return self.design_mat_to_topo_view(design_mat) def topo_view_to_design_mat(self, topo_batch): """ Used by `DenseDesignMatrix.set_topological_view()` and `DenseDesignMatrix.get_design_mat()`. """ return self.topo_space.np_format_as(topo_batch, self.storage_space) def view_shape(self): return self.shape def weights_view_shape(self): return self.view_shape() def set_axes(self, axes): axes = tuple(axes) if len(axes) != 5: raise ValueError("Axes must have 5 elements; got %s" % str(axes)) for required_axis in ('b', 's', 0, 1, 'c'): if required_axis not in axes: raise ValueError("Axes must contain 'b', 's', 0, 1, and 'c'. " "Got %s." % str(axes)) if axes.index('b') != 0: raise ValueError("The 'b' axis must come first (axes = %s)." % str(axes)) def get_batchless_axes(axes): axes = list(axes) axes.remove('b') return tuple(axes) if hasattr(self, 'axes'): # Reorders the shape vector to match the new axis ordering. assert hasattr(self, 'shape') old_axes = get_batchless_axes(self.axes) new_axes = get_batchless_axes(axes) new_shape = tuple(self.shape[old_axes.index(a)] for a in new_axes) self.shape = new_shape self.axes = axes