Ejemplo n.º 1
0
def write_h5(full_dataset, args, rng):
    partition_masks = get_partition_masks(full_dataset, args.ratio, rng)
    partition_sizes = [numpy.count_nonzero(mask) for mask in partition_masks]

    basename = os.path.splitext(os.path.split(args.input)[1])[0]

    output_path = os.path.join(args.output_dir,
                               "{}_split_{}.h5".format(basename, args.ratio))

    h5_file = make_h5_file(output_path,
                           args.partition_names,
                           partition_sizes,
                           full_dataset.names,
                           full_dataset.formats)

    for partition_name, partition_mask in safe_izip(args.partition_names,
                                                    partition_masks):
        partition = h5_file['partitions'][partition_name]

        for full_tensor, name, fmt in safe_izip(full_dataset.tensors,
                                                full_dataset.names,
                                                full_dataset.formats):
            partition_tensor = partition[name]
            batch_slice = get_batch_slice(fmt, partition_mask)
            partition_tensor[...] = full_tensor[batch_slice]
Ejemplo n.º 2
0
def write_memmaps(full_dataset, args, rng):

    def get_partition_path(output_dir, input_path, partition_name):
        basename = os.path.splitext(os.path.split(input_path)[1])[0]
        return os.path.join(output_dir,
                            "{}_split_{}_{}.npy".format(basename,
                                                        args.ratio,
                                                        partition_name))

    partition_masks = get_partition_masks(full_dataset, args.ratio, rng)

    for partition_name, partition_mask in safe_izip(args.partition_names,
                                                    partition_masks):
        partition_path = get_partition_path(args.output_dir,
                                            args.input,
                                            partition_name)

        memmap = make_memmap_file(partition_path,
                                  numpy.count_nonzero(partition_mask),
                                  full_dataset.names,
                                  full_dataset.formats)


        for full_tensor, name, fmt in safe_izip(full_dataset.tensors,
                                                full_dataset.names,
                                                full_dataset.formats):
            partition_tensor = memmap[name]
            batch_slice = get_batch_slice(fmt, partition_mask)
            partition_tensor[...] = full_tensor[batch_slice]
Ejemplo n.º 3
0
def main():
    args = parse_args()

    h5_datasets = load_h5_dataset(args.input)

    def get_output_filepath(input_name, partition_name):
        dirname, filename = os.path.split(input_name)
        basename, extension = os.path.splitext(filename)
        assert_equal(extension, '.h5')

        return os.path.join(dirname, "{}_{}{}".format(basename,
                                                      partition_name,
                                                      '.npy'))

    for dataset, partition_name in safe_izip(h5_datasets, h5_datasets._fields):
        output_filepath = get_output_filepath(args.input, partition_name)
        memmap = make_memmap_file(output_filepath,
                                  dataset.num_examples(),
                                  dataset.names,
                                  dataset.formats)

        memmap_tensors = [memmap[name] for name in dataset.names]
        for in_tensor, out_tensor in safe_izip(dataset.tensors,
                                               memmap_tensors):
            assert_equal(out_tensor.shape, in_tensor.shape)
            out_tensor[...] = in_tensor


        print("Wrote {}".format(output_filepath))
Ejemplo n.º 4
0
def test_mnist():
    '''
    Tests load_mnist().

    Checks test & train sets' formats and sizes, but not content.
    '''
    train_set, test_set = load_mnist()

    for mnist, expected_size in safe_izip((train_set, test_set),
                                          (60000, 10000)):
        assert_equal(mnist.num_examples(), expected_size)

    expected_formats = [DenseFormat(shape=[-1, 28, 28],
                                    axes=['b', '0', '1'],
                                    dtype='uint8'),
                        DenseFormat(shape=[-1],
                                    axes=['b'],
                                    dtype='uint8')]
    expected_names = [u'images', u'labels']
    expected_sizes = [60000, 10000]

    for dataset, expected_size in safe_izip((train_set, test_set),
                                            expected_sizes):
        assert_all_equal(dataset.names, expected_names)
        assert_all_equal(dataset.formats, expected_formats)

        for tensor, fmt in safe_izip(dataset.tensors, dataset.formats):
            fmt.check(tensor)
            assert_equal(tensor.shape[0], expected_size)

        labels = dataset.tensors[dataset.names.index('labels')]
        assert_true(numpy.logical_and(labels[...] >= 0,
                                      labels[...] < 10).all())
Ejemplo n.º 5
0
def test_pickle_h5_dataset():
    '''
    Tests pickling and unpickling of H5Dataset.
    '''

    # Path for the pickle file, not the .h5 file.
    file_path = '/tmp/test_mnist_test_pickle_hdf5_data.pkl'

    def make_pickle(file_path):
        '''
        Pickles the MNIST dataset.
        '''
        hdf5_data = load_mnist()
        with open(file_path, 'wb') as pickle_file:
            cPickle.dump(hdf5_data,
                         pickle_file,
                         protocol=cPickle.HIGHEST_PROTOCOL)

    make_pickle(file_path)
    assert_less(_file_size_in_bytes(file_path), 1024 * 5)

    def load_pickle(file_path):
        '''
        Loads the MNIST dataset pickled above.
        '''
        with open(file_path, 'rb') as pickle_file:
            return cPickle.load(pickle_file)

    mnist_datasets = load_mnist()
    pickled_mnist_datasets = load_pickle(file_path)

    for (mnist_dataset,
         pickled_mnist_dataset) in safe_izip(mnist_datasets,
                                             pickled_mnist_datasets):
        for (name,
             expected_name,
             fmt,
             expected_fmt,
             tensor,
             expected_tensor) in safe_izip(pickled_mnist_dataset.names,
                                           mnist_dataset.names,
                                           pickled_mnist_dataset.formats,
                                           mnist_dataset.formats,
                                           pickled_mnist_dataset.tensors,
                                           mnist_dataset.tensors):
            assert_equal(name, expected_name)
            assert_equal(fmt, expected_fmt)
            assert_array_equal(tensor, expected_tensor)
Ejemplo n.º 6
0
    def draw_menu():
        lines = []

        rows, rows_index = label_index_map.index_to_rows(index)

        if rows is not None:
            row = rows[rows_index[0]]
            label = dataset_labels[row]

            for name, value, converter in safe_izip(label_names[:5],
                                                    label[:5],
                                                    converters[:5]):
                lines.append('{}: {}'.format(name, converter(value)))

            lines.append("image: {} of {}".format(rows_index[0] + 1,
                                                  len(rows)))
            lines.append('')

            if dataset_labels.shape[1] == 11:
                for name, value, converter in safe_izip(label_names[5:],
                                                        label[5:],
                                                        converters[5:]):
                    lines.append('{}: {}'.format(name, converter(value)))
        else:
            label_5d = label_index_map.index_to_label_5d(index)

            for name, value, converter in safe_izip(label_names[:5],
                                                    label_5d,
                                                    converters[:5]):
                lines.append('{}: {}'.format(name, converter(value)))

            lines.append("image: (no such image)")
            lines.append('')

            if dataset_labels.shape[1] == 11:
                for name in label_names[5:]:
                    lines.append('{}: N/A'.format(name))

        lines[index_dim[0]] = "==> " + lines[index_dim[0]]

        # "transAxes": 0, 0 = bottom-left, 1, 1 at upper-right.
        text_axes = all_axes[0]
        text_axes.clear()
        text_axes.text(0.1,  # x
                       0.5,  # y
                       '\n'.join(lines),
                       verticalalignment='center',
                       transform=text_axes.transAxes)
Ejemplo n.º 7
0
    def __getitem__(self, arg):
        '''
        Returns a tuple of slices along the batch axis.

        Examples:

          self[slice(3, 10)] returns a 6-element batch from each tensor.

          self[3] is equivalent to self[slice(3, 4)] (batch axis is intact)


        Parameters
        ----------
        arg: integer or slice

        Returns
        -------
        rval: tuple
          A tuple of views into each tensor.
        '''

        batch_slice = self._getitem_arg_to_slice(arg)

        def get_slice_tuple(fmt, batch_slice):
            '''
            Returns a tuple for slicing a tensor along its batch axis.
            '''
            return tuple(batch_slice if axis == 'b'
                         else slice(None)
                         for axis in fmt.axes)

        return tuple(tensor[get_slice_tuple(fmt, batch_slice)]
                     for tensor, fmt
                     in safe_izip(self.tensors, self.formats))
Ejemplo n.º 8
0
    def draw_images():
        def draw_image_impl(image, axes):
            if args.lcn:
                axes.imshow(image, cmap='gray')
            else:
                axes.imshow(image,
                            cmap='gray',
                            norm=matplotlib.colors.NoNorm())


        rows, rows_index = label_index_map.index_to_rows(index)

        image_axes = all_axes[1:]  # could be 1 or two axes

        if rows is None:
            for axes in image_axes:
                axes.clear()
        else:
            row = rows[rows_index[0]]
            image = dataset_images[row]
            if args.lcn:
                image = lcn(image)

            if 's' in image_format.axes:
                assert_equal(image_format.axes.index('s'), 1)
                for sub_image, axes in safe_izip(image, image_axes):
                    draw_image_impl(sub_image, axes)
            else:
                draw_image_impl(image, all_axes[1])
Ejemplo n.º 9
0
    def __init__(self, names, tensors, formats):
        if len(names) != len(formats) or len(names) != len(tensors):
            raise ValueError("Names, formats, and tensors must all have the "
                             "same length, but got %d, %d, and %d "
                             "respectively." %
                             tuple(len(names), len(formats), len(tensors)))

        for name in names:
            if not isinstance(name, basestring):
                raise TypeError("names must be strings, not %s." % type(name))

        for tensor, fmt in safe_izip(tensors, formats):
            if not isinstance(fmt, Format):
                raise TypeError("formats must be Formats, not %s." %
                                type(fmt))

            if 'b' not in fmt.axes:
                raise ValueError("Expected format to contain a 'b' axis "
                                 "(batch axis).")

            if fmt.dtype is None:
                raise ValueError("Expected all formats to specify a dtype.")

            fmt.check(tensor)

        self.names = tuple(names)
        self.formats = tuple(formats)
        self.tensors = tuple(tensors)
Ejemplo n.º 10
0
    def label5d_to_index(self, labels):
        assert_in(len(labels.shape), (1, 2))

        was_1D = False
        if len(labels.shape) == 1:
            labels = labels[numpy.newaxis, :]
            was_1D = True

        labels = numpy.asarray(labels[:, :5])

        indices = numpy.empty_like(labels)
        indices[...] = -1

        for (label5d_values,
             label_column,
             index_column) in safe_izip(self.label5d_values,
                                        labels.T,
                                        indices.T):

            for ind, value in enumerate(label5d_values):
                mask = (label_column == value)
                index_column[mask] = ind

        assert_false((indices == -1).any())

        if was_1D:
            assert_equal(indices.shape[0], 1)
            return indices[0, :]
        else:
            return indices
Ejemplo n.º 11
0
    def _next(self):

        batch_indices = self._next_batch_indices()
        # pdb.set_trace()

        # sanity-check output of _next_batch_indices()
        if not isinstance(batch_indices, slice):
            assert_all_integer(batch_indices)

            if isinstance(batch_indices, numpy.ndarray):
                # Workaround to a bug in h5py.Dataset where indexing by a
                # length-1 ndarray is treated like indexing with the integer it
                # contains.
                if len(batch_indices) == 1:
                    batch_indices = tuple(batch_indices)
            else:
                assert_is_instance(batch_indices, collections.Sequence)

        result = tuple(self._get_batches(self.dataset.tensors,
                                         self.dataset.formats,
                                         batch_indices))

        # sanity-check size of batches
        for batch, fmt in safe_izip(result, self.dataset.formats):
            assert_equal(batch.shape[fmt.axes.index('b')], self.batch_size)

        return result
Ejemplo n.º 12
0
    def _get_batches(self, tensors, formats, batch_indices):
        '''
        Extracts batches from tensors, given batch_indices.

        Parameters
        ----------
        tensors: Iterable of numpy.ndarray, or similar
          The tensors to select a batch from. Usually self.dataset.tensors.

        fmt: simplelearn.format.DenseFormat
          The formats corresponding to <tensors>. Usually self.dataset.formats.

        batch_indices: Sequence
          The output of _get_batch_indices.
        '''

        def get_batch(tensor, fmt):

            # h5py has a bug where if len(index_tuple) == 1,
            # Dataset.__getitem__ treats it the same as just
            # index_tuple[0]. Adding a gratuitous Ellipsis element to the end
            # prevents this.
            #
            # See h5py bug report: https://github.com/h5py/h5py/issues/586
            index_tuple = tuple(batch_indices if axis == 'b' else slice(None)
                                for axis in fmt.axes) + (Ellipsis, )

            return tensor[index_tuple]

        return tuple(get_batch(tensor, fmt) for tensor, fmt
                     in safe_izip(tensors, formats))
Ejemplo n.º 13
0
    def index_to_label_5d(self, indices):
        assert_in(len(indices.shape), (1, 2))

        indices = numpy.asarray(indices)

        was_1D = False
        if len(indices.shape) == 1:
            indices = indices[numpy.newaxis, :]
            was_1D = True

        assert_equal(indices.shape[1], 5)

        labels = numpy.zeros_like(indices)

        for (index_column,
             label_column,
             label5d_values) in safe_izip(indices.T,
                                          labels.T,
                                          self.label5d_values):
            label_column[:] = label5d_values[index_column]

        if was_1D:
            return labels[0, :]
        else:
            return labels
Ejemplo n.º 14
0
    def __init__(self, path):
        assert_is_instance(path, basestring)

        path = os.path.abspath(path)

        assert_true(path.startswith(simplelearn.data.data_path),
                    ("{} is not a subdirectory of simplelearn.data.data_path "
                     "{}").format(path, simplelearn.data.data_path))

        # pylint can't see memmap members
        # pylint: disable=no-member
        self.memmap = numpy.lib.format.open_memmap(path, mode='r')
        num_examples = self.memmap.shape[0]

        names = self.memmap.dtype.names
        tensors = [self.memmap[name] for name in names]
        axes_list = [field[2] for field
                     in self.memmap.dtype.fields.itervalues()]

        def replace_element(arg, index, new_value):
            assert_is_instance(arg, tuple)
            result = list(arg)
            result[index] = new_value
            return tuple(result)

        formats = [DenseFormat(axes=axes,
                               shape=replace_element(tensor.shape, 0, -1),
                               dtype=tensor.dtype)
                   for axes, tensor in safe_izip(axes_list, tensors)]

        super(MemmapDataset, self).__init__(names=names,
                                            formats=formats,
                                            tensors=tensors)
Ejemplo n.º 15
0
        def make_id_elev_azim_to_example(example_dataset):
            assert_equal(example_dataset.formats[0].axes.index('c'), 3)
            assert_equal(example_dataset.formats[0].axes.index('b'), 0)
            assert_equal(example_dataset.formats[1].axes.index('b'), 0)

            images = example_dataset.tensors[0][..., :3]  # cut off alpha
            labels = example_dataset.tensors[1]

            assert_in(labels.shape[1], (5, 11))

            # Arbitrarily restrict attention to images that use
            # lighting setup 0.
            row_mask = labels[:, 4] == 0
            images = images[row_mask, :]
            labels = labels[row_mask, :]

            ids = self.label_to_id(labels)
            ids_elevs_azims = numpy.hstack((ids[:, numpy.newaxis],
                                            labels[:, 2:4]))

            result = dict(safe_izip((tuple(t) for t in ids_elevs_azims),
                                    images))

            assert_equal(len(result), ids_elevs_azims.shape[0])

            return result
Ejemplo n.º 16
0
    def draw_indices():
        '''
        Draws the current indices and their corresponding values.
        '''

        values = get_current_values()

        def rad_to_deg(radians):
            return radians / numpy.pi * 180.0

        values[1] = int(rad_to_deg(values[1]))
        values[2] = int(rad_to_deg(values[2]))

        index_names = ['id', 'elev', 'azim', 'light']

        lines = ['{}: {}'.format(index_name, value)
                 for index_name, value
                 in safe_izip(index_names, values)]

        lines[dimension_to_edit[0]] = '==> ' + lines[dimension_to_edit[0]]

        if fg_baseline_scales is not None:
            lines.append('fg baseline scale: {:0.2f}'.format(
                fg_baseline_scales[values[0]]))

        text_axes.clear()

        # "transAxes": 0, 0 = bottom-left, 1, 1 at upper-right.
        text_axes.text(0.1,  # x
                       0.5,  # y
                       '\n'.join(['fg image:', ''] +lines),
                       verticalalignment='center',
                       transform=text_axes.transAxes)
Ejemplo n.º 17
0
    def __init__(self,
                 inputs,
                 input_iterator,
                 parameters_updaters,
                 parameters,
                 callbacks,
                 theano_function_mode=None):

        '''
        Parameters
        ----------

        inputs: sequence of Nodes.
          Symbols for the outputs of the input_iterator.
          These should come from input_iterator.make_input_nodes()

        input_iterator: simplelearn.data.DataIterator
          Yields tuples of training set batches, such as (values, labels).

        callbacks: Sequence of EpochCallbacks
          This includes subclasses like IterationCallback &
          ParameterUpdater. One of these callbacks must throw a StopTraining
          exception for the training to halt.

        theano_function_mode: theano.compile.Mode
          Optional. The 'mode' argument to pass to theano.function().
          An example: pylearn2.devtools.nan_guard.NanGuard()
        '''

        #
        # sanity-checks the arguments.
        #

        assert_all_is_instance(inputs, Node)
        assert_is_instance(input_iterator, DataIterator)
        assert_true(input_iterator.next_is_new_epoch())

        for (input,
             iterator_input) in safe_izip(inputs,
                                          input_iterator.make_input_nodes()):
            assert_equal(input.output_format, iterator_input.output_format)

        assert_equal(len(callbacks),
                     len(frozenset(callbacks)),
                     "There were duplicate callbacks.")

        assert_all_is_instance(callbacks, EpochCallback)

        #
        # Sets members
        #

        self._inputs = inputs
        self._input_iterator = input_iterator
        self._theano_function_mode = theano_function_mode
        self.epoch_callbacks = list(callbacks)
        self._train_called = False
        self.parameter_updaters = parameters_updaters
        self.parameters = parameters
Ejemplo n.º 18
0
def add_conv_layers(input_node,
                    yaml_dict,
                    use_dropout,
                    numpy_rng,
                    theano_rng,
                    output_list):
    __check_arg_types(input_node,
                      yaml_dict,
                      use_dropout,
                      numpy_rng,
                      theano_rng,
                      output_list)

    iranges = yaml_dict['iranges']
    if not isinstance(iranges, Sequence):
        iranges = [iranges] * len(yaml_dict['filter_counts'])

    layers = [input_node]

    for (filter_size,
         filter_count,
         filter_stride,
         filter_pad,
         pool_size,
         pool_stride,
         pool_pad,
         irange) in safe_izip(yaml_dict['filter_sizes'],
                              yaml_dict['filter_counts'],
                              yaml_dict['filter_strides'],
                              yaml_dict['filter_pads'],
                              yaml_dict['pool_sizes'],
                              yaml_dict['pool_strides'],
                              yaml_dict['pool_pads'],
                              iranges):

        layers.append(
            Conv2dLayer(layers[-1],
                        filter_shape=[filter_size, filter_size],
                        num_filters=filter_count,
                        filter_strides=[filter_stride,
                                        filter_stride],
                        conv_pads=[filter_pad, filter_pad],
                        pool_window_shape=[pool_size, pool_size],
                        pool_strides=[pool_stride, pool_stride],
                        pool_pads=(pool_pad
                                   if isinstance(pool_pad, basestring)
                                   else [pool_pad, pool_pad])))

        _init_params(numpy_rng,
                     irange,
                     layers[-1].conv2d_node.filters)

        if use_dropout:
            include_rate = .8 if len(layers) == 2 else .5
            layers.append(Dropout(layers[-1],
                                  include_rate,
                                  theano_rng))

    output_list.extend(layers[1:])
    def __init__(self,
                inputs,
                input_iterator,
                parameters,
                old_parameters,
                parameter_updaters,
                iterator_full_gradient,
                epoch_callbacks,
                theano_function_mode=None):


        #
        # sanity-checks the arguments.
        #

        assert_all_is_instance(inputs, Node)
        assert_is_instance(input_iterator, DataIterator)
        assert_true(input_iterator.next_is_new_epoch())

        for (input,
             iterator_input) in safe_izip(inputs,
                                          input_iterator.make_input_nodes()):
            assert_equal(input.output_format, iterator_input.output_format)

        assert_equal(len(epoch_callbacks),
                     len(frozenset(epoch_callbacks)),
                     "There were duplicate callbacks.")

        assert_all_is_instance(epoch_callbacks, EpochCallback)


        #
        # Sets members
        #

        self._input_iterator = input_iterator
        self._parameters = tuple(parameters)
        self._old_parameters = tuple(old_parameters)
        self._parameter_updaters = tuple(parameter_updaters)
        self._theano_function_mode = theano_function_mode
        self._inputs = tuple(inputs)

        input_symbols = [i.output_symbol for i in self._inputs]

        self.epoch_callbacks = tuple(epoch_callbacks)

        self._train_called = False

        self.new_epoch = True
        self.method = self._parameter_updaters[0].method
        self.update_function = self._compile_update_function(input_symbols)
        self.full_gradient_function = self._compile_full_gradient_update_function(input_symbols)

        self.full_gradient_iterator = iterator_full_gradient
        total_size_dataset = self.full_gradient_iterator.dataset.tensors[0].shape[0]
        batch_size = self.full_gradient_iterator.batch_size
        self.batches_in_epoch_full = total_size_dataset/batch_size
Ejemplo n.º 20
0
    def allocate_memmap_file(args):
        '''
        Creates a properly shaped & typed h5py Dataset file.

        Returns
        -------
        rval: numpy.memmap
          A recordarray memmap.
        '''
        for input_path in args.inputs:
            assert_true(os.path.isfile(input_path))

        output_filepath = get_output_filepath(args)
        assert_equal(os.path.splitext(output_filepath)[1], '.npy')

        fmt = None

        # Get format of each file, make sure they're all the same.
        for (input_path,
             subsample,
             scale,
             rotation,
             grid) in safe_izip(args.inputs,
                                args.subsamples,
                                args.scales,
                                args.rotations,
                                args.grids):
            iterator = FrameIterator(input_path)
            frame = iterator.next()
            cell_batch = transform_image(frame, scale, rotation, grid)
            cells_per_frame = cell_batch.shape[0]
            this_fmt = DenseFormat(axes=('b', '0', '1', 'c'),
                                   shape=((-1, ) + cell_batch.shape[1:]),
                                   dtype=cell_batch.dtype)
            if fmt is None:
                fmt = this_fmt
            else:
                assert_equal(this_fmt,
                             fmt,
                             "All video files (after their respective grid, "
                             "scale, and rotation transforms, must yield the "
                             "same image format")

        num_cells = 0

        # Counts and sets num_cells
        for input_path in args.inputs:
            for frame_number, frame in enumerate(FrameIterator(input_path)):
                if frame_number % subsample == 0:
                    num_cells += cells_per_frame

        return make_memmap_file(path=output_filepath,
                                num_examples=num_cells,
                                tensor_names=['images'],
                                tensor_formats=[fmt])
Ejemplo n.º 21
0
def main():
    def parse_args():
        parser = argparse.ArgumentParser(
            description=("Shuffles a .h5 dataset file in-place. \n"
                         "\n"
                         "Due to the limitations of numpy.shuffle, this can "
                         "only shuffle datasets where the batch axis is the "
                         "first axis for all tensors. Will throw an error and "
                         "do nothing if this is not the case.\n"
                         "\n"
                         "Shuffling on hard drives is very slow. Therefore "
                         "it's recommended that you first copy the dataset to "
                         "a SSD or other high-speed storage, if available."))

        def path_to_h5_file(arg):
            assert_true(os.path.isfile(arg))
            assert_equal(os.path.splitext(arg)[1], '.h5')

        parser.add_argument('-i',
                            '--input',
                            type=path_to_h5_file,
                            required=True,
                            help=("The .h5 file to shuffle in-place. It "
                                  "should have been created by "
                                  "simplelearn.data.make_h5_file()."))

        parser.add_argument('-s',
                            '--seed',
                            type=int,
                            default=1234,
                            help="RNG seed to use for shuffling")

        return parser.parse_args()

    args = parse_args()

    partitions = load_h5_dataset(args.input, mode='r+')

    for partition_index, partition in enumerate(partitions):
        for (tensor_index,
             fmt,
             tensor) in safe_izip(range(len(partition.tensors)),
                                  partition.formats,
                                  partition.tensors):
            if fmt.axes[0] != 'b':
                raise ValueError("Can't shuffle this dataset. Partition {}, "
                                 "tensor {} 's first axis is not the batch "
                                 "axis (axes = {}).".format(partition_index,
                                                            tensor_index,
                                                            str(fmt.axes)))

            rng = numpy.random.RandomState(args.seed)
            print("Shuffling partition {} tensor {}.".format(partition_index,
                                                             tensor_index))
            rng.shuffle(tensor)
Ejemplo n.º 22
0
    def on_iteration(self, computed_values):
        # sanity-check formats of computed_values
        for value, node in safe_izip(computed_values,
                                     self.nodes_to_compute):
            node.output_format.check(value)

        rval = self._on_iteration(computed_values)
        assert_is(rval,
                  None,
                  ("{}._on_iteration implemented incorrectly. It "
                   "shouldn't return anything.".format(type(self))))
Ejemplo n.º 23
0
def test_flatten():
    nested_list = [xrange(3),
                   3,
                   ((4, 5), (6, ()), 7),
                   8,
                   (9, ),
                   10,
                   ()]

    for value, expected_value in safe_izip(flatten(nested_list), xrange(11)):
        assert_equal(value, expected_value)
Ejemplo n.º 24
0
            def make_element(indices, axes):
                indices = numpy.array(indices)
                axes = numpy.array(tuple(ord(a) for a in axes))
                result = 0

                for index, axis in safe_izip(indices, axes):
                    index = numpy.mod(index, 10)
                    axis = numpy.mod(int(axis), 5)
                    result += index * axis

                return result
Ejemplo n.º 25
0
    def load_rng_state(rng_state_h5):
        result = list()
        result.append(rng_state_h5.attrs['elem_0'])
        result.append(rng_state_h5['elem_1'][...])

        for i in range(2, 5):
            result.append(rng_state_h5.attrs['elem_{}'.format(i)])

        expected_types = (basestring, numpy.ndarray, int, int, float)
        for element, expected_type in safe_izip(result, expected_types):
            assert_is_instance(element, expected_type)

        return tuple(result)
Ejemplo n.º 26
0
    def expected_loss_function(arg0, arg1, arg_format):
        diff = arg0 - arg1
        feature_size = numpy.prod([size for size, axis in
                                   safe_izip(arg_format.shape, arg_format.axes)
                                   if axis != 'b'])
        bf_format = DenseFormat(axes=('b', 'f'),
                                shape=(-1, feature_size),
                                dtype=None)
        non_b_axes = tuple(axis for axis in arg_format.axes if axis != 'b')
        axis_map = {non_b_axes: 'f'}
        diff = arg_format.convert(diff, bf_format, axis_map=axis_map)

        return (diff * diff).sum(axis=1)
Ejemplo n.º 27
0
        def get_norb_filename_lists_gz(which_norb, which_set):
            '''
            Returns the list of files that contain the given NORB dataset.

            The filenames use the 'mat.gz' suffix. Does not check whether the
            files exist.

            Parameters
            ----------
            which_norb: 'big' or 'small'
            which_set: 'test' or 'train'

            Returns
            -------
            rval: tuple
              (dat_files, cat_files, info_files), where each is a
              list of filenames.
            '''
            if which_norb == 'small':
                prefix = 'smallnorb'
                dim = 96
                file_num_strs = ['']
            else:
                prefix = 'norb'
                dim = 108
                file_nums = range(1, 3 if which_set == 'test' else 11)
                file_num_strs = ['-{:02d}'.format(n) for n in file_nums]

            instances = '01235' if which_set == 'test' else '46789'
            suffix = 'mat.gz'

            cat_files = []
            dat_files = []
            info_files = []

            for (file_type,
                 file_list) in safe_izip(('cat', 'dat', 'info'),
                                         (cat_files, dat_files, info_files)):
                for file_num_str in file_num_strs:
                    filename = ('{prefix}-5x{instances}x9x18x6x2x{dim}x{dim}'
                                '-{which_set}ing{file_num_str}-{file_type}.'
                                '{suffix}').format(prefix=prefix,
                                                   instances=instances,
                                                   dim=dim,
                                                   which_set=which_set,
                                                   file_num_str=file_num_str,
                                                   file_type=file_type,
                                                   suffix=suffix)
                    file_list.append(filename)

            return cat_files, dat_files, info_files
Ejemplo n.º 28
0
    def num_examples(self):
        '''
        The number of examples contained in this Dataset.

        Throws a RuntimeException if this Dataset contains no tensors.
        '''
        if len(self.tensors) == 0:
            raise RuntimeError("This dataset has no tensors, so its "
                               "'size' is undefined.")

        sizes = [t.shape[f.axes.index('b')]
                 for t, f in safe_izip(self.tensors, self.formats)]
        assert_all_equal(sizes)
        return sizes[0]
Ejemplo n.º 29
0
def _get_num_examples(dataset):
    '''
    Returns the number of examples in a Dataset.
    '''
    assert_is_instance(dataset, Dataset)

    example_counts = tuple(tensor.shape[fmt.axes.index('b')]
                           for tensor, fmt
                           in safe_izip(dataset.tensors, dataset.formats))

    if not all(sc == example_counts[0] for sc in example_counts[1:]):
        raise ValueError("Expected all tensors to have the same number of "
                         "samples, but got {}.".format(example_counts))

    return example_counts[0]
Ejemplo n.º 30
0
    def draw_label(labels, axes):
        assert_in(len(labels), (5, 11))
        axes.clear()

        label_names = ['category',
                       'instance',
                       'elevation',
                       'azimuth',
                       'lighting']

        categories = ['animal', 'human', 'plane', 'truck', 'car']

        converters = [lambda x: categories[x],
                      lambda x: x,
                      lambda x: 30 + x * 5,
                      lambda x: x * 20,
                      lambda x: x]

        if len(labels) == 11:
            label_names.extend(['horiz. shift',
                                'vert. shift',
                                'lumination change',
                                'contrast',
                                'object scale',
                                'rotation (deg)'])
            categories.append('blank')
            contrasts = (0.8, 1.3)
            scales = (0.78, 1.0)

            converters.extend([lambda x: x,
                               lambda x: x,
                               lambda x: x,
                               lambda x: contrasts[x],
                               lambda x: scales[x],
                               lambda x: x])

        lines = ['{}: {}'.format(name, converter(label))
                 for name, converter, label
                 in safe_izip(label_names, converters, labels)]

        # "transAxes": 0, 0 = bottom-left, 1, 1 at upper-right.
        axes.text(0.1,  # x
                  0.5,  # y
                  '\n'.join(lines),
                  verticalalignment='center',
                  transform=axes.transAxes)