示例#1
0
    def register_names_to_del(self, names):
        """
        Register names of fields that should not be pickled.

        Parameters
        ----------
        names : iterable
            A collection of strings indicating names of fields on ts
            object that should not be pickled.

        Notes
        -----
        All names registered will be deleted from the dictionary returned
        by the model's `__getstate__` method (unless a particular model
        overrides this method).
        """
        if isinstance(names, six.string_types):
            names = [names]
        try:
            assert all(isinstance(n, six.string_types) for n in iter(names))
        except (TypeError, AssertionError):
            reraise_as(ValueError('Invalid names argument'))
        # Quick check in case __init__ was never called, e.g. by a derived
        # class.
        if not hasattr(self, 'names_to_del'):
            self.names_to_del = set()
        self.names_to_del = self.names_to_del.union(names)
示例#2
0
def write(f, mat):
    """ Write a ndarray to tensorfile.

    Parameters
    ----------
    f : file
        Open file to write into
    mat : ndarray
        Array to save
    """
    def _write_int32(f, i):
        i_array = numpy.asarray(i, dtype='int32')
        if 0:
            logger.debug('writing int32 {0} {1}'.format(i, i_array))
        i_array.tofile(f)

    try:
        _write_int32(f, _dtype_magic[str(mat.dtype)])
    except KeyError:
        reraise_as(TypeError('Invalid ndarray dtype for filetensor format', mat.dtype))

    _write_int32(f, len(mat.shape))
    shape = mat.shape
    if len(shape) < 3:
        shape = list(shape) + [1] * (3 - len(shape))
    if 0:
        logger.debug('writing shape = {0}'.format(shape))
    for sh in shape:
        _write_int32(f, sh)
    mat.tofile(f)
示例#3
0
def construct_mapping(node, deep=False):
    # This is a modified version of yaml.BaseConstructor.construct_mapping
    # in which a repeated key raises a ConstructorError
    if not isinstance(node, yaml.nodes.MappingNode):
        const = yaml.constructor
        message = "expected a mapping node, but found"
        raise const.ConstructorError(None, None, "%s %s " % (message, node.id),
                                     node.start_mark)
    mapping = {}
    constructor = yaml.constructor.BaseConstructor()
    for key_node, value_node in node.value:
        key = constructor.construct_object(key_node, deep=False)
        try:
            hash(key)
        except TypeError as exc:
            const = yaml.constructor
            reraise_as(
                const.ConstructorError(
                    "while constructing a mapping", node.start_mark,
                    "found unacceptable key (%s)" %
                    (exc, key_node.start_mark)))
        if key in mapping:
            const = yaml.constructor
            raise const.ConstructorError("while constructing a mapping",
                                         node.start_mark,
                                         "found duplicate key (%s)" % key)
        value = constructor.construct_object(value_node, deep=False)
        mapping[key] = value
    return mapping
示例#4
0
    def wrapped_func(*args, **kwargs):
        """
        .. todo::

            WRITEME
        """
        try:
            func(*args, **kwargs)
        except TypeError:
            argnames, varargs, keywords, defaults = inspect.getargspec(func)
            posargs = dict(zip(argnames, args))
            bad_keywords = []
            for keyword in kwargs:
                if keyword not in argnames:
                    bad_keywords.append(keyword)

            if len(bad_keywords) > 0:
                bad = ', '.join(bad_keywords)
                reraise_as(TypeError('%s() does not support the following '
                                     'keywords: %s' % (str(func.func_name),
                                                       bad)))
            allargsgot = set(list(kwargs.keys()) + list(posargs.keys()))
            numrequired = len(argnames) - len(defaults)
            diff = list(set(argnames[:numrequired]) - allargsgot)
            if len(diff) > 0:
                reraise_as(TypeError('%s() did not get required args: %s' %
                                     (str(func.func_name), ', '.join(diff))))
            raise
示例#5
0
    def get_monitoring_channels(self, model, data, **kwargs):
        self.get_data_specs(model)[0].validate(data)
        rval = OrderedDict()
        composite_specs, mapping = self.get_composite_specs_and_mapping(model)
        nested_data = mapping.nest(data)

        for i, cost in enumerate(self.costs):
            cost_data = nested_data[i]
            try:
                channels = cost.get_monitoring_channels(
                    model, cost_data, **kwargs)
                rval.update(channels)
            except TypeError:
                reraise_as(
                    Exception('SumOfCosts.get_monitoring_channels '
                              'encountered TypeError while calling {0}'
                              '.get_monitoring_channels'.format(type(cost))))

            value = cost.expr(model, cost_data, **kwargs)
            if value is not None:
                name = ''
                if hasattr(value, 'name') and value.name is not None:
                    name = '_' + value.name
                rval['term_' + str(i) + name] = value

        return rval
示例#6
0
    def wrapped_func(*args, **kwargs):
        """
        .. todo::

            WRITEME
        """
        try:
            func(*args, **kwargs)
        except TypeError:
            argnames, varargs, keywords, defaults = inspect.getargspec(func)
            posargs = dict(zip(argnames, args))
            bad_keywords = []
            for keyword in kwargs:
                if keyword not in argnames:
                    bad_keywords.append(keyword)

            if len(bad_keywords) > 0:
                bad = ', '.join(bad_keywords)
                reraise_as(
                    TypeError('%s() does not support the following '
                              'keywords: %s' % (str(func.func_name), bad)))
            allargsgot = set(list(kwargs.keys()) + list(posargs.keys()))
            numrequired = len(argnames) - len(defaults)
            diff = list(set(argnames[:numrequired]) - allargsgot)
            if len(diff) > 0:
                reraise_as(
                    TypeError('%s() did not get required args: %s' %
                              (str(func.func_name), ', '.join(diff))))
            raise
示例#7
0
def write(f, mat):
    """ Write a ndarray to tensorfile.

    Parameters
    ----------
    f : file
        Open file to write into
    mat : ndarray
        Array to save
    """
    def _write_int32(f, i):
        i_array = numpy.asarray(i, dtype='int32')
        if 0:
            logger.debug('writing int32 {0} {1}'.format(i, i_array))
        i_array.tofile(f)

    try:
        _write_int32(f, _dtype_magic[str(mat.dtype)])
    except KeyError:
        reraise_as(TypeError('Invalid ndarray dtype for filetensor format', mat.dtype))

    _write_int32(f, len(mat.shape))
    shape = mat.shape
    if len(shape) < 3:
        shape = list(shape) + [1] * (3 - len(shape))
    if 0:
        logger.debug('writing shape = {0}'.format(shape))
    for sh in shape:
        _write_int32(f, sh)
    mat.tofile(f)
示例#8
0
    def get_gradients(self, model, data, **kwargs):
        """
        Provides the gradients of the cost function with respect to the model
        parameters.

        These are not necessarily those obtained by theano.tensor.grad
        --you may wish to use approximate or even intentionally incorrect
        gradients in some cases.

        Parameters
        ----------
        model : a pylearn2 Model instance
        data : a batch in cost.get_data_specs() form
        kwargs : dict
            Optional extra arguments, not used by the base class.

        Returns
        -------
        gradients : OrderedDict
            a dictionary mapping from the model's parameters
            to their gradients
            The default implementation is to compute the gradients
            using T.grad applied to the value returned by expr.
            However, subclasses may return other values for the gradient.
            For example, an intractable cost may return a sampling-based
            approximation to its gradient.
        updates : OrderedDict
            a dictionary mapping shared variables to updates that must
            be applied to them each time these gradients are computed.
            This is to facilitate computation of sampling-based approximate
            gradients.
            The parameters should never appear in the updates dictionary.
            This would imply that computing their gradient changes
            their value, thus making the gradient value outdated.
        """

        try:
            cost = self.expr(model=model, data=data, **kwargs)
        except TypeError:
            # If anybody knows how to add type(self) to the exception message
            # but still preserve the stack trace, please do so
            # The current code does neither
            message = "Error while calling " + str(type(self)) + ".expr"
            reraise_as(TypeError(message))

        if cost is None:
            raise NotImplementedError(
                str(type(self)) + " represents an intractable cost and "
                "does not provide a gradient "
                "approximation scheme.")

        params = list(model.get_params())

        grads = T.grad(cost, params, disconnected_inputs='ignore')

        gradients = OrderedDict(izip(params, grads))

        updates = OrderedDict()

        return gradients, updates
示例#9
0
def construct_mapping(node, deep=False):
    # This is a modified version of yaml.BaseConstructor.construct_mapping
    # in which a repeated key raises a ConstructorError
    if not isinstance(node, yaml.nodes.MappingNode):
        const = yaml.constructor
        message = "expected a mapping node, but found"
        raise const.ConstructorError(None, None,
                                     "%s %s " % (message, node.id),
                                     node.start_mark)
    mapping = {}
    constructor = yaml.constructor.BaseConstructor()
    for key_node, value_node in node.value:
        key = constructor.construct_object(key_node, deep=False)
        try:
            hash(key)
        except TypeError, exc:
            const = yaml.constructor
            reraise_as(const.ConstructorError("while constructing a mapping",
                                              node.start_mark,
                                              "found unacceptable key (%s)" %
                                              (exc, key_node.start_mark)))
        if key in mapping:
            const = yaml.constructor
            raise const.ConstructorError("while constructing a mapping",
                                         node.start_mark,
                                         "found duplicate key (%s)" % key)
        value = constructor.construct_object(value_node, deep=False)
        mapping[key] = value
示例#10
0
文件: cost.py 项目: nitbix/pylearn2
    def get_monitoring_channels(self, model, data, ** kwargs):
        self.get_data_specs(model)[0].validate(data)
        rval = OrderedDict()
        composite_specs, mapping = self.get_composite_specs_and_mapping(model)
        nested_data = mapping.nest(data)

        for i, cost in enumerate(self.costs):
            cost_data = nested_data[i]
            try:
                channels = cost.get_monitoring_channels(model, cost_data,
                                                        **kwargs)
                rval.update(channels)
            except TypeError:
                reraise_as(Exception('SumOfCosts.get_monitoring_channels '
                                     'encountered TypeError while calling {0}'
                                     '.get_monitoring_channels'.format(
                                         type(cost))))

            value = cost.expr(model, cost_data, ** kwargs)
            if value is not None:
                name = ''
                if hasattr(value, 'name') and value.name is not None:
                    name = '_' + value.name
                rval['term_' + str(i) + name] = value

        return rval
示例#11
0
    def register_names_to_del(self, names):
        """
        Register names of fields that should not be pickled.

        Parameters
        ----------
        names : iterable
            A collection of strings indicating names of fields on ts
            object that should not be pickled.

        Notes
        -----
        All names registered will be deleted from the dictionary returned
        by the model's `__getstate__` method (unless a particular model
        overrides this method).
        """
        if isinstance(names, six.string_types):
            names = [names]
        try:
            assert all(isinstance(n, six.string_types) for n in iter(names))
        except (TypeError, AssertionError):
            reraise_as(ValueError('Invalid names argument'))
        # Quick check in case __init__ was never called, e.g. by a derived
        # class.
        if not hasattr(self, 'names_to_del'):
            self.names_to_del = set()
        self.names_to_del = self.names_to_del.union(names)
示例#12
0
文件: cost.py 项目: nitbix/pylearn2
    def get_gradients(self, model, data, ** kwargs):
        """
        Provides the gradients of the cost function with respect to the model
        parameters.

        These are not necessarily those obtained by theano.tensor.grad
        --you may wish to use approximate or even intentionally incorrect
        gradients in some cases.

        Parameters
        ----------
        model : a pylearn2 Model instance
        data : a batch in cost.get_data_specs() form
        kwargs : dict
            Optional extra arguments, not used by the base class.

        Returns
        -------
        gradients : OrderedDict
            a dictionary mapping from the model's parameters
            to their gradients
            The default implementation is to compute the gradients
            using T.grad applied to the value returned by expr.
            However, subclasses may return other values for the gradient.
            For example, an intractable cost may return a sampling-based
            approximation to its gradient.
        updates : OrderedDict
            a dictionary mapping shared variables to updates that must
            be applied to them each time these gradients are computed.
            This is to facilitate computation of sampling-based approximate
            gradients.
            The parameters should never appear in the updates dictionary.
            This would imply that computing their gradient changes
            their value, thus making the gradient value outdated.
        """

        try:
            cost,mask = self.expr(model=model, data=data, **kwargs)
        except TypeError:
            # If anybody knows how to add type(self) to the exception message
            # but still preserve the stack trace, please do so
            # The current code does neither
            message = "Error while calling " + str(type(self)) + ".expr"
            reraise_as(TypeError(message))

        if cost is None:
            raise NotImplementedError(str(type(self)) +
                                      " represents an intractable cost and "
                                      "does not provide a gradient "
                                      "approximation scheme.")

        params = list(model.get_params())

        grads = T.grad(cost, params, disconnected_inputs='ignore')

        gradients = OrderedDict(izip(params, grads))

        updates = OrderedDict()

        return gradients, updates
示例#13
0
 def fn(batch, dspace=dspace, sp=sp):
     try:
           return dspace.np_format_as(batch, sp)
     except ValueError as e:
         msg = str(e) + '\nMake sure that the model and '\
                        'dataset have been initialized with '\
                        'correct values.'
         reraise_as(ValueError(msg))
示例#14
0
 def fn(batch, dspace=dspace, sp=sp):
     try:
           return dspace.np_format_as(batch, sp)
     except ValueError as e:
         msg = str(e) + '\nMake sure that the model and '\
                        'dataset have been initialized with '\
                        'correct values.'
         reraise_as(ValueError(msg))
示例#15
0
 def wrapped_layer_cost(layer, coef):
     try:
         return layer.get_weight_decay(coeff)
     except NotImplementedError:
         if coef==0.:
             return 0.
         else:
             reraise_as(NotImplementedError(str(type(layer)) +
                        " does not implement get_weight_decay."))
示例#16
0
 def _validate_shape(shape, param_name):
     try:
         shape = tuple(shape)
         [int(val) for val in shape]
     except (ValueError, TypeError):
         try:
             shape = (int(shape), )
         except TypeError:
             reraise_as(
                 TypeError("%s must be int or int tuple" % param_name))
     return shape
示例#17
0
 def _validate_shape(shape, param_name):
     try:
         shape = tuple(shape)
         [int(val) for val in shape]
     except (ValueError, TypeError):
         try:
             shape = (int(shape),)
         except TypeError:
             reraise_as(TypeError("%s must be int or int tuple"
                                  % param_name))
     return shape
示例#18
0
    def next(self):
        """
        .. todo::

            WRITEME
        """
        indx = self.subset_iterator.next()
        try:
            mini_batch = self.X[indx]
        except IndexError, e:
            reraise_as(ValueError("Index out of range"+str(e)))
示例#19
0
    def next(self):
        """
        .. todo::

            WRITEME
        """
        indx = self.subset_iterator.next()
        try:
            mini_batch = self.X[indx]
        except IndexError, e:
            reraise_as(ValueError("Index out of range" + str(e)))
示例#20
0
 def wrapped_layer_cost(layer, coeff):
     try:
         return layer.get_weight_decay(coeff)
     except NotImplementedError:
         if coeff == 0.:
             return 0.
         else:
             reraise_as(
                 NotImplementedError(
                     str(type(layer)) + " does not implement "
                     "get_weight_decay."))
示例#21
0
def resolve(d):
    """ given a dictionary d, returns the object described by the dictionary """

    tag = get_tag(d)

    try:
        resolver = resolvers[tag]
    except KeyError:
        reraise_as(TypeError('config does not know of any object type "'+tag+'"'))

    return resolver(d)
示例#22
0
文件: sgd.py 项目: ballasn/facedet
    def on_monitor(self, model, dataset, algorithm):
        """
        Adjusts the learning rate based on the contents of model.monitor

        Parameters
        ----------
        model : a Model instance
        dataset : Dataset
        algorithm : WRITEME
        """
        model = algorithm.model
        lr = algorithm.learning_rate
        current_learning_rate = lr.get_value()
        assert hasattr(model, 'monitor'), ("no monitor associated with "
                                           + str(model))
        monitor = model.monitor
        monitor_channel_specified = True

        try:
            v = monitor.channels[self.channel_name].val_record
        except KeyError:
            err_input = ''
            err_input = 'The channel_name \'' + str(
                self.channel_name) + '\' is not valid.'
            err_message = 'There is no monitoring channel named \'' + \
                str(self.channel_name) + '\'. You probably need to ' + \
                'specify a valid monitoring channel by using either ' + \
                'dataset_name or channel_name in the ' + \
                'MonitorBasedLRDecay constructor. ' + err_input
            reraise_as(ValueError(err_message))

        if len(v) == 1:
            #only the initial monitoring has happened
            #no learning has happened, so we can't adjust the learning rate yet
            #just do nothing
            self._min_v = v[0]
            return

        rval = current_learning_rate
        log.info("monitoring channel is {0}".format(self.channel_name))

        if v[-1] < self._min_v:
            self._min_v = v[-1]
            self._count = 0
        else:
            self._count += 1

        if self._count > self.nb_epoch:
            self._count = 0
            rval = self.shrink_lr * rval

        rval = max(self.min_lr, rval)
        lr.set_value(np.cast[lr.dtype](rval))
示例#23
0
def resolve(d):
    """ given a dictionary d, returns the object described by the dictionary """

    tag = get_tag(d)

    try:
        resolver = resolvers[tag]
    except KeyError:
        reraise_as(
            TypeError('config does not know of any object type "' + tag + '"'))

    return resolver(d)
示例#24
0
def load(filepath, rescale_image=True, dtype='float64'):
    """
    .. todo::

        WRITEME
    """
    assert type(filepath) == str

    if rescale_image == False and dtype == 'uint8':
        ensure_Image()
        rval = np.asarray(Image.open(filepath))
        # print 'image.load: ' + str((rval.min(), rval.max()))
        assert rval.dtype == 'uint8'
        return rval

    s = 1.0
    if rescale_image:
        s = 255.
    try:
        ensure_Image()
        rval = Image.open(filepath)
    except Exception:
        reraise_as(Exception("Could not open " + filepath))

    numpy_rval = np.array(rval)

    if numpy_rval.ndim not in [2,3]:
        logger.error(dir(rval))
        logger.error(rval)
        logger.error(rval.size)
        rval.show()
        raise AssertionError("Tried to load an image, got an array with " +
                str(numpy_rval.ndim)+" dimensions. Expected 2 or 3."
                "This may indicate a mildly corrupted image file. Try "
                "converting it to a different image format with a different "
                "editor like gimp or imagemagic. Sometimes these programs are "
                "more robust to minor corruption than PIL and will emit a "
                "correctly formatted image in the new format."
                )
    rval = numpy_rval

    rval = np.cast[dtype](rval) / s

    if rval.ndim == 2:
        rval = rval.reshape(rval.shape[0], rval.shape[1], 1)

    if rval.ndim != 3:
        raise AssertionError("Something went wrong opening " +
                             filepath + '. Resulting shape is ' +
                             str(rval.shape) +
                             " (it's meant to have 3 dimensions by now)")

    return rval
示例#25
0
文件: sgd.py 项目: baucheng/facedet
    def on_monitor(self, model, dataset, algorithm):
        """
        Adjusts the learning rate based on the contents of model.monitor

        Parameters
        ----------
        model : a Model instance
        dataset : Dataset
        algorithm : WRITEME
        """
        model = algorithm.model
        lr = algorithm.learning_rate
        current_learning_rate = lr.get_value()
        assert hasattr(model,
                       'monitor'), ("no monitor associated with " + str(model))
        monitor = model.monitor
        monitor_channel_specified = True

        try:
            v = monitor.channels[self.channel_name].val_record
        except KeyError:
            err_input = ''
            err_input = 'The channel_name \'' + str(
                self.channel_name) + '\' is not valid.'
            err_message = 'There is no monitoring channel named \'' + \
                str(self.channel_name) + '\'. You probably need to ' + \
                'specify a valid monitoring channel by using either ' + \
                'dataset_name or channel_name in the ' + \
                'MonitorBasedLRDecay constructor. ' + err_input
            reraise_as(ValueError(err_message))

        if len(v) == 1:
            #only the initial monitoring has happened
            #no learning has happened, so we can't adjust the learning rate yet
            #just do nothing
            self._min_v = v[0]
            return

        rval = current_learning_rate
        log.info("monitoring channel is {0}".format(self.channel_name))

        if v[-1] < self._min_v:
            self._min_v = v[-1]
            self._count = 0
        else:
            self._count += 1

        if self._count > self.nb_epoch:
            self._count = 0
            rval = self.shrink_lr * rval

        rval = max(self.min_lr, rval)
        lr.set_value(np.cast[lr.dtype](rval))
示例#26
0
    def next(self):
        """
        .. todo::

            WRITEME
        """
        indx = self.subset_iterator.next()
        try:
            mini_batch = self.X[indx]
        except IndexError as e:
            reraise_as(ValueError("Index out of range" + str(e)))
            # the ind of minibatch goes beyond the boundary
        return mini_batch
示例#27
0
def test_multi_constructor_obj():
    """
    Tests whether multi_constructor_obj throws an exception when
    the keys in mapping are None.
    """
    try:
        load("a: !obj:decimal.Decimal { 1 }")
    except TypeError as e:
        assert str(e) == "Received non string object (1) as key in mapping."
        pass
    except Exception as e:
        error_msg = "Got the unexpected error: %s" % (e)
        reraise_as(ValueError(error_msg))
示例#28
0
def test_multi_constructor_obj():
    """
    Tests whether multi_constructor_obj throws an exception when
    the keys in mapping are None.
    """
    try:
        load("a: !obj:decimal.Decimal { 1 }")
    except TypeError as e:
        assert str(e) == "Received non string object (1) as key in mapping."
        pass
    except Exception as e:
        error_msg = "Got the unexpected error: %s" % (e)
        reraise_as(ValueError(error_msg))
示例#29
0
    def next(self):
        """
        .. todo::

            WRITEME
        """
        indx = self.subset_iterator.next()
        try:
            mini_batch = self.X[indx]
        except IndexError as e:
            reraise_as(ValueError("Index out of range" + str(e)))
            # the ind of minibatch goes beyond the boundary
        return mini_batch
示例#30
0
def read_bin_lush_matrix(filepath):
    """
    Reads a binary matrix saved by the lush library.

    Parameters
    ----------
    filepath : str
        The path to the file.

    Returns
    -------
    matrix : ndarray
        A NumPy version of the stored matrix.
    """
    f = open(filepath, 'rb')
    try:
        magic = read_int(f)
    except ValueError:
        reraise_as("Couldn't read magic number")
    ndim = read_int(f)

    if ndim == 0:
        shape = ()
    else:
        shape = read_int(f, max(3, ndim))

    total_elems = 1
    for dim in shape:
        total_elems *= dim

    try:
        dtype = lush_magic[magic]
    except KeyError:
        reraise_as(ValueError('Unrecognized lush magic number ' + str(magic)))

    rval = np.fromfile(file=f, dtype=dtype, count=total_elems)

    excess = f.read(-1)

    if excess:
        raise ValueError(str(len(excess)) +
                         ' extra bytes found at end of file.'
                         ' This indicates  mismatch between header '
                         'and content')

    rval = rval.reshape(*shape)

    f.close()

    return rval
示例#31
0
def read_bin_lush_matrix(filepath):
    """
    Reads a binary matrix saved by the lush library.

    Parameters
    ----------
    filepath : str
        The path to the file.

    Returns
    -------
    matrix : ndarray
        A NumPy version of the stored matrix.
    """
    f = open(filepath, 'rb')
    try:
        magic = read_int(f)
    except ValueError:
        reraise_as("Couldn't read magic number")
    ndim = read_int(f)

    if ndim == 0:
        shape = ()
    else:
        shape = read_int(f, max(3, ndim))

    total_elems = 1
    for dim in shape:
        total_elems *= dim

    try:
        dtype = lush_magic[magic]
    except KeyError:
        reraise_as(ValueError('Unrecognized lush magic number ' + str(magic)))

    rval = np.fromfile(file=f, dtype=dtype, count=total_elems)

    excess = f.read(-1)

    if excess:
        raise ValueError(
            str(len(excess)) + ' extra bytes found at end of file.'
            ' This indicates  mismatch between header '
            'and content')

    rval = rval.reshape(*shape)

    f.close()

    return rval
示例#32
0
    def get(self, sources, indexes):
        """
        Retrieves the requested elements from the dataset.

        Parameter
        ---------
        sources : tuple
            A tuple of source identifiers
        indexes : slice or list
            A slice or a list of indexes

        Return
        ------
        rval : tuple
            A tuple of batches, one for each source
        """
        assert isinstance(sources, (tuple, list)) and len(sources) > 0, (
            'sources should be an instance of tuple and not empty')
        assert all([isinstance(el, string_types)
                    for el in sources]), ('sources elements should be strings')
        assert isinstance(indexes, (tuple, list, slice, py_integer_types)), (
            'indexes should be either an int, a slice or a tuple/list of ints')
        if isinstance(indexes, (tuple, list)):
            assert len(indexes) > 0 and all(
                [isinstance(i, py_integer_types)
                 for i in indexes]), ('indexes elements should be ints')

        rval = []
        for s in sources:
            try:
                sdata = self.data[s]
            except ValueError as e:
                reraise_as(
                    ValueError(
                        'The requested source %s is not part of the dataset' %
                        sources[s], *e.args))
            if (isinstance(indexes, (slice, py_integer_types))
                    or len(indexes) == 1):
                rval.append(sdata[indexes])
            else:
                warnings.warn('Accessing non sequential elements of an '
                              'HDF5 file will be at best VERY slow. Avoid '
                              'using iteration schemes that access '
                              'random/shuffled data with hdf5 datasets!!')
                val = []
                [val.append(sdata[idx]) for idx in indexes]
                rval.append(val)
        return tuple(rval)
示例#33
0
    def get(self, sources, indexes):
        """
        Retrieves the requested elements from the dataset.

        Parameter
        ---------
        sources : tuple
            A tuple of source identifiers
        indexes : slice or list
            A slice or a list of indexes

        Return
        ------
        rval : tuple
            A tuple of batches, one for each source
        """
        assert (
            isinstance(sources, (tuple, list)) and len(sources) > 0
        ), "sources should be an instance of tuple and not empty"
        assert all([isinstance(el, string_types) for el in sources]), "sources elements should be strings"
        assert isinstance(
            indexes, (tuple, list, slice, py_integer_types)
        ), "indexes should be either an int, a slice or a tuple/list of ints"
        if isinstance(indexes, (tuple, list)):
            assert len(indexes) > 0 and all(
                [isinstance(i, py_integer_types) for i in indexes]
            ), "indexes elements should be ints"

        rval = []
        for s in sources:
            try:
                sdata = self.data[s]
            except ValueError as e:
                reraise_as(ValueError("The requested source %s is not part of the dataset" % sources[s], *e.args))
            if isinstance(indexes, (slice, py_integer_types)) or len(indexes) == 1:
                rval.append(sdata[indexes])
            else:
                warnings.warn(
                    "Accessing non sequential elements of an "
                    "HDF5 file will be at best VERY slow. Avoid "
                    "using iteration schemes that access "
                    "random/shuffled data with hdf5 datasets!!"
                )
                val = []
                [val.append(sdata[idx]) for idx in indexes]
                rval.append(val)
        return tuple(rval)
示例#34
0
def try_to_import(tag_suffix):
    """
    .. todo::

        WRITEME
    """
    components = tag_suffix.split('.')
    modulename = '.'.join(components[:-1])
    try:
        exec('import %s' % modulename)
    except ImportError, e:
        # We know it's an ImportError, but is it an ImportError related to
        # this path,
        # or did the module we're importing have an unrelated ImportError?
        # and yes, this test can still have false positives, feel free to
        # improve it
        pieces = modulename.split('.')
        str_e = str(e)
        found = True in [piece.find(str(e)) != -1 for piece in pieces]

        if found:
            # The yaml file is probably to blame.
            # Report the problem with the full module path from the YAML
            # file
            reraise_as(
                ImportError("Could not import %s; ImportError was %s" %
                            (modulename, str_e)))
        else:

            pcomponents = components[:-1]
            assert len(pcomponents) >= 1
            j = 1
            while j <= len(pcomponents):
                modulename = '.'.join(pcomponents[:j])
                try:
                    exec('import %s' % modulename)
                except Exception:
                    base_msg = 'Could not import %s' % modulename
                    if j > 1:
                        modulename = '.'.join(pcomponents[:j - 1])
                        base_msg += ' but could import %s' % modulename
                    reraise_as(
                        ImportError(base_msg + '. Original exception: ' +
                                    str(e)))
                j += 1
示例#35
0
def try_to_import(tag_suffix):
    """
    .. todo::

        WRITEME
    """
    components = tag_suffix.split('.')
    modulename = '.'.join(components[:-1])
    try:
        exec('import %s' % modulename)
    except ImportError, e:
        # We know it's an ImportError, but is it an ImportError related to
        # this path,
        # or did the module we're importing have an unrelated ImportError?
        # and yes, this test can still have false positives, feel free to
        # improve it
        pieces = modulename.split('.')
        str_e = str(e)
        found = True in [piece.find(str(e)) != -1 for piece in pieces]

        if found:
            # The yaml file is probably to blame.
            # Report the problem with the full module path from the YAML
            # file
            reraise_as(ImportError("Could not import %s; ImportError was %s" %
                                   (modulename, str_e)))
        else:

            pcomponents = components[:-1]
            assert len(pcomponents) >= 1
            j = 1
            while j <= len(pcomponents):
                modulename = '.'.join(pcomponents[:j])
                try:
                    exec('import %s' % modulename)
                except Exception:
                    base_msg = 'Could not import %s' % modulename
                    if j > 1:
                        modulename = '.'.join(pcomponents[:j - 1])
                        base_msg += ' but could import %s' % modulename
                    reraise_as(ImportError(base_msg + '. Original exception: '
                                           + str(e)))
                j += 1
示例#36
0
def resolve(d):
    """
    .. todo::

        WRITEME
    """
    tag = pylearn2.config.get_tag(d)

    if tag != 'dataset':
        raise TypeError('pylearn2.datasets.config asked to resolve a config dictionary with tag "'+tag+'"')

    t = pylearn2.config.get_str(d,'typename')

    try:
        resolver = resolvers[t]
    except KeyError:
        reraise_as(TypeError('pylearn2.datasets does not know of a dataset type "'+t+'"'))

    return resolver(d)
示例#37
0
 def __init__(self, max_labels, dtype=None):
     """
     Initializes the formatter given the number of max labels.
     """
     try:
         np.empty(max_labels)
     except (ValueError, TypeError):
         reraise_as(ValueError("%s got bad max_labels argument '%s'" %
                               (self.__class__.__name__, str(max_labels))))
     self._max_labels = max_labels
     if dtype is None:
         self._dtype = config.floatX
     else:
         try:
             np.dtype(dtype)
         except TypeError:
             reraise_as(TypeError("%s got bad dtype identifier %s" %
                                  (self.__class__.__name__, str(dtype))))
         self._dtype = dtype
示例#38
0
 def __init__(self, max_labels, dtype=None):
     """
     Initializes the formatter given the number of max labels.
     """
     try:
         np.empty(max_labels)
     except (ValueError, TypeError):
         reraise_as(ValueError("%s got bad max_labels argument '%s'" %
                               (self.__class__.__name__, str(max_labels))))
     self._max_labels = max_labels
     if dtype is None:
         self._dtype = config.floatX
     else:
         try:
             np.dtype(dtype)
         except TypeError:
             reraise_as(TypeError("%s got bad dtype identifier %s" %
                                  (self.__class__.__name__, str(dtype))))
         self._dtype = dtype
示例#39
0
    def get_batch_design(self, batch_size, include_labels=False):

        try:
            idx = self.rng.randint(self.X.shape[0] - batch_size + 1)
        except ValueError:
            if batch_size > self.X.shape[0]:
                reraise_as(ValueError("Requested %d examples from a dataset "
                                      "containing only %d." %
                                      (batch_size, self.X.shape[0])))
            raise
        rx = self.X[idx:idx + batch_size, :]
        if include_labels:
            if self.y is None:
                return rx, None
            ry = self.y[idx:idx + batch_size]
            rlatent = self.latent[idx:idx + batch_size]
            return rx, ry,rlatent
        rx = np.cast[config.floatX](rx)
        return rx
示例#40
0
文件: serial.py 项目: w1kke/pylearn2
def read_bin_lush_matrix(filepath):
    """
    .. todo::

        WRITEME
    """
    f = open(filepath, 'rb')
    try:
        magic = read_int(f)
    except ValueError:
        reraise_as("Couldn't read magic number")
    ndim = read_int(f)

    if ndim == 0:
        shape = ()
    else:
        shape = read_int(f, max(3, ndim))

    total_elems = 1
    for dim in shape:
        total_elems *= dim

    try:
        dtype = lush_magic[magic]
    except KeyError:
        reraise_as(ValueError('Unrecognized lush magic number ' + str(magic)))

    rval = np.fromfile(file=f, dtype=dtype, count=total_elems)

    excess = f.read(-1)

    if excess:
        raise ValueError(
            str(len(excess)) + ' extra bytes found at end of file.'
            ' This indicates  mismatch between header and content')

    rval = rval.reshape(*shape)

    f.close()

    return rval
示例#41
0
def resolve(d):
    """
    .. todo::

        WRITEME
    """
    tag = pylearn2.config.get_tag(d)

    if tag != 'dataset':
        raise TypeError('pylearn2.datasets.config asked to resolve a config '
                        'dictionary with tag "%s"' % tag)

    typename = pylearn2.config.get_str(d, 'typename')

    try:
        resolver = resolvers[typename]
    except KeyError:
        reraise_as(TypeError('pylearn2.datasets does not know of a dataset '
                             'type "%s"' % typename))

    return resolver(d)
示例#42
0
def preprocess(string, environ=None):
    """
    Preprocesses a string, by replacing `${VARNAME}` with
    `os.environ['VARNAME']` and ~ with the path to the user's
    home directory

    Parameters
    ----------
    string : str
        String object to preprocess
    environ : dict, optional
        If supplied, preferentially accept values from
        this dictionary as well as `os.environ`. That is,
        if a key appears in both, this dictionary takes
        precedence.

    Returns
    -------
    rval : str
        The preprocessed string
    """
    if environ is None:
        environ = {}

    split = string.split('${')

    rval = [split[0]]

    for candidate in split[1:]:
        subsplit = candidate.split('}')

        if len(subsplit) < 2:
            raise ValueError('Open ${ not followed by } before '
                             'end of string or next ${ in "' + string + '"')

        varname = subsplit[0]
        try:
            val = (environ[varname] if varname in environ
                   else os.environ[varname])
        except KeyError:
            if varname == 'PYLEARN2_DATA_PATH':
                reraise_as(NoDataPathError())
            if varname == 'PYLEARN2_VIEWER_COMMAND':
                reraise_as(EnvironmentVariableError(
                    viewer_command_error_essay + environment_variable_essay)
                )

            reraise_as(ValueError('Unrecognized environment variable "' +
                                  varname + '". Did you mean ' +
                                  match(varname, os.environ.keys()) + '?'))

        rval.append(val)

        rval.append('}'.join(subsplit[1:]))

    rval = ''.join(rval)

    string = os.path.expanduser(string)

    return rval
示例#43
0
def read_bin_lush_matrix(filepath):
    """
    .. todo::

        WRITEME
    """
    f = open(filepath,'rb')
    try:
        magic = read_int(f)
    except ValueError:
        reraise_as(ValueError("Couldn't read magic number"))
    ndim = read_int(f)

    if ndim == 0:
        shape = ()
    else:
        shape = read_int(f, max(3, ndim))

    total_elems = 1
    for dim in shape:
        total_elems *= dim

    try:
        dtype = lush_magic[magic]
    except KeyError:
        reraise_as(ValueError('Unrecognized lush magic number '+str(magic)))

    rval = np.fromfile(file = f, dtype = dtype, count = total_elems)

    excess = f.read(-1)

    if excess != '':
        raise ValueError(str(len(excess))+' extra bytes found at end of file.'
                ' This indicates  mismatch between header and content')

    rval = rval.reshape(*shape)

    f.close()

    return rval
示例#44
0
    def on_monitor(self, model, dataset, algorithm):
        """
        Adjusts the learning rate based on the contents of model.monitor

        Parameters
        ----------
        model : a Model instance
        dataset : Dataset
        algorithm : WRITEME
        """
        model = algorithm.model
        lr = algorithm.learning_rate
        current_learning_rate = lr.get_value()
        assert hasattr(model, 'monitor'), ("no monitor associated with "
                + str(model))
        monitor = model.monitor
        monitor_channel_specified = True

        if self.channel_name is None:
            monitor_channel_specified = False
            channels = [elem for elem in monitor.channels
                    if elem.endswith("objective")]
            if len(channels) < 1:
                raise ValueError("There are no monitoring channels that end "
                        "with \"objective\". Please specify either "
                        "channel_name or dataset_name.")
            elif len(channels) > 1:
                datasets = algorithm.monitoring_dataset.keys()
                raise ValueError("There are multiple monitoring channels that"
                        "end with \"_objective\". The list of available "
                        "datasets are: " +
                                str(datasets) + " . Please specify either "
                                "channel_name or dataset_name in the "
                                "MonitorBasedLRAdjuster constructor to "
                                'disambiguate.')
            else:
                self.channel_name = channels[0]
                warnings.warn('The channel that has been chosen for '
                        'monitoring is: ' +
                              str(self.channel_name) + '.')

        try:
            v = monitor.channels[self.channel_name].val_record
        except KeyError:
            err_input = ''
            if monitor_channel_specified:
                if self.dataset_name:
                    err_input = 'The dataset_name \'' + str(
                            self.dataset_name) + '\' is not valid.'
                else:
                    err_input = 'The channel_name \'' + str(
                            self.channel_name) + '\' is not valid.'
            err_message = 'There is no monitoring channel named \'' + \
                    str(self.channel_name) + '\'. You probably need to ' + \
                    'specify a valid monitoring channel by using either ' + \
                    'dataset_name or channel_name in the ' + \
                    'MonitorBasedLRAdjuster constructor. ' + err_input
            reraise_as(ValueError(err_message))

        if len(v) < 1:
            if monitor.dataset is None:
                assert len(v) == 0
                raise ValueError("You're trying to use a monitor-based "
                        "learning rate adjustor but the monitor has no "
                        "entries because you didn't specify a "
                        "monitoring dataset.")

            raise ValueError("For some reason there are no monitor entries"
                                 "yet the MonitorBasedLRAdjuster has been "
                                 "called. This should never happen. The Train"
                                 " object should call the monitor once on "
                                 "initialization, then call the callbacks. "
                                 "It seems you are either calling the "
                                 "callback manually rather than as part of a "
                                 "training algorithm, or there is a problem "
                                "with the Train object.")
        if len(v) == 1:
            #only the initial monitoring has happened
            #no learning has happened, so we can't adjust the learning rate yet
            #just do nothing
            return

        rval = current_learning_rate

        log.info("monitoring channel is {0}".format(self.channel_name))

        if v[-1] > self.high_trigger * v[-2]:
            rval *= self.shrink_amt
            log.info("shrinking learning rate to %f" % rval)
        elif v[-1] > self.low_trigger * v[-2]:
            rval *= self.grow_amt
            log.info("growing learning rate to %f" % rval)

        rval = max(self.min_lr, rval)
        rval = min(self.max_lr, rval)

        lr.set_value(np.cast[lr.dtype](rval))
示例#45
0
def get_weights_report(model_path=None,
                       model=None,
                       rescale='individual',
                       border=False,
                       norm_sort=False,
                       dataset=None):
    """
    Returns a PatchViewer displaying a grid of filter weights

    Parameters
    ----------
    model_path : str
        Filepath of the model to make the report on.
    rescale : str
        A string specifying how to rescale the filter images:
            - 'individual' (default) : scale each filter so that it
                  uses as much as possible of the dynamic range
                  of the display under the constraint that 0
                  is gray and no value gets clipped
            - 'global' : scale the whole ensemble of weights
            - 'none' :   don't rescale
    dataset : pylearn2.datasets.dataset.Dataset
        Dataset object to do view conversion for displaying the weights. If
        not provided one will be loaded from the model's dataset_yaml_src.

    Returns
    -------
    WRITEME
    """

    if model is None:
        logger.info('making weights report')
        logger.info('loading model')
        model = serial.load(model_path)
        logger.info('loading done')
    else:
        assert model_path is None
    assert model is not None

    if rescale == 'none':
        global_rescale = False
        patch_rescale = False
    elif rescale == 'global':
        global_rescale = True
        patch_rescale = False
    elif rescale == 'individual':
        global_rescale = False
        patch_rescale = True
    else:
        raise ValueError('rescale=' + rescale +
                         ", must be 'none', 'global', or 'individual'")


    if isinstance(model, dict):
        #assume this was a saved matlab dictionary
        del model['__version__']
        del model['__header__']
        del model['__globals__']
        keys = [key for key in model \
                if hasattr(model[key], 'ndim') and model[key].ndim == 2]
        if len(keys) > 2:
            key = None
            while key not in keys:
                logger.info('Which is the weights?')
                for key in keys:
                    logger.info('\t{0}'.format(key))
                key = input()
        else:
            key, = keys
        weights = model[key]

        norms = np.sqrt(np.square(weights).sum(axis=1))
        logger.info('min norm: {0}'.format(norms.min()))
        logger.info('mean norm: {0}'.format(norms.mean()))
        logger.info('max norm: {0}'.format(norms.max()))

        return patch_viewer.make_viewer(weights,
                                        is_color=weights.shape[1] % 3 == 0)

    weights_view = None
    W = None

    try:
        weights_view = model.get_weights_topo()
        h = weights_view.shape[0]
    except NotImplementedError:

        if dataset is None:
            logger.info('loading dataset...')
            control.push_load_data(False)
            dataset = yaml_parse.load(model.dataset_yaml_src)
            control.pop_load_data()
            logger.info('...done')

        try:
            W = model.get_weights()
        except AttributeError as e:
            reraise_as(AttributeError("""
Encountered an AttributeError while trying to call get_weights on a model.
This probably means you need to implement get_weights for this model class,
but look at the original exception to be sure.
If this is an older model class, it may have weights stored as weightsShared,
etc.
Original exception: """+str(e)))

    if W is None and weights_view is None:
        raise ValueError("model doesn't support any weights interfaces")

    if weights_view is None:
        weights_format = model.get_weights_format()
        assert hasattr(weights_format,'__iter__')
        assert len(weights_format) == 2
        assert weights_format[0] in ['v','h']
        assert weights_format[1] in ['v','h']
        assert weights_format[0] != weights_format[1]

        if weights_format[0] == 'v':
            W = W.T
        h = W.shape[0]

        if norm_sort:
            norms = np.sqrt(1e-8+np.square(W).sum(axis=1))
            norm_prop = norms / norms.max()


        weights_view = dataset.get_weights_view(W)
        assert weights_view.shape[0] == h
    try:
        hr, hc = model.get_weights_view_shape()
    except NotImplementedError:
        hr = int(np.ceil(np.sqrt(h)))
        hc = hr
        if 'hidShape' in dir(model):
            hr, hc = model.hidShape

    pv = patch_viewer.PatchViewer(grid_shape=(hr, hc),
                                  patch_shape=weights_view.shape[1:3],
            is_color = weights_view.shape[-1] == 3)

    if global_rescale:
        weights_view /= np.abs(weights_view).max()

    if norm_sort:
        logger.info('sorting weights by decreasing norm')
        idx = sorted( range(h), key=lambda l : - norm_prop[l] )
    else:
        idx = range(h)

    if border:
        act = 0
    else:
        act = None

    for i in range(0,h):
        patch = weights_view[idx[i],...]
        pv.add_patch(patch, rescale=patch_rescale, activation=act)

    abs_weights = np.abs(weights_view)
    logger.info('smallest enc weight magnitude: {0}'.format(abs_weights.min()))
    logger.info('mean enc weight magnitude: {0}'.format(abs_weights.mean()))
    logger.info('max enc weight magnitude: {0}'.format(abs_weights.max()))


    if W is not None:
        norms = np.sqrt(np.square(W).sum(axis=1))
        assert norms.shape == (h,)
        logger.info('min norm: {0}'.format(norms.min()))
        logger.info('mean norm: {0}'.format(norms.mean()))
        logger.info('max norm: {0}'.format(norms.max()))

    return pv
示例#46
0
def test_bad_monitoring_input_in_monitor_based_lr():
    # tests that the class MonitorBasedLRAdjuster in sgd.py avoids wrong
    # settings of channel_name or dataset_name in the constructor.

    dim = 3
    m = 10

    rng = np.random.RandomState([06, 02, 2014])

    X = rng.randn(m, dim)

    learning_rate = 1e-2
    batch_size = 5

    # We need to include this so the test actually stops running at some point
    epoch_num = 2

    dataset = DenseDesignMatrix(X=X)

    # including a monitoring datasets lets us test that
    # the monitor works with supervised data
    monitoring_dataset = DenseDesignMatrix(X=X)

    cost = DummyCost()

    model = SoftmaxModel(dim)

    termination_criterion = EpochCounter(epoch_num)

    algorithm = SGD(learning_rate,
                    cost,
                    batch_size=batch_size,
                    monitoring_batches=2,
                    monitoring_dataset=monitoring_dataset,
                    termination_criterion=termination_criterion,
                    update_callbacks=None,
                    init_momentum=None,
                    set_batch_size=False)

    # testing for bad dataset_name input
    dummy = 'void'

    monitor_lr = MonitorBasedLRAdjuster(dataset_name=dummy)

    train = Train(dataset,
                  model,
                  algorithm,
                  save_path=None,
                  save_freq=0,
                  extensions=[monitor_lr])
    try:
        train.main_loop()
    except ValueError as e:
        pass
    except Exception:
        reraise_as(AssertionError("MonitorBasedLRAdjuster takes dataset_name "
                                  "that is invalid "))

    # testing for bad channel_name input
    monitor_lr2 = MonitorBasedLRAdjuster(channel_name=dummy)

    model2 = SoftmaxModel(dim)
    train2 = Train(dataset,
                   model2,
                   algorithm,
                   save_path=None,
                   save_freq=0,
                   extensions=[monitor_lr2])

    try:
        train2.main_loop()
    except ValueError as e:
        pass
    except Exception:
        reraise_as(AssertionError("MonitorBasedLRAdjuster takes channel_name "
                                  "that is invalid "))

    return
示例#47
0
文件: image.py 项目: yeahq/pylearn2
def show(image):
    """
    .. todo::

        WRITEME

    Parameters
    ----------
    image : PIL Image object or ndarray
        If ndarray, integer formats are assumed to use 0-255
        and float formats are assumed to use 0-1
    """
    if hasattr(image, '__array__'):
        #do some shape checking because PIL just raises a tuple indexing error
        #that doesn't make it very clear what the problem is
        if len(image.shape) < 2 or len(image.shape) > 3:
            raise ValueError('image must have either 2 or 3 dimensions but its'
                             ' shape is ' + str(image.shape))

        if image.dtype == 'int8':
            image = np.cast['uint8'](image)
        elif str(image.dtype).startswith('float'):
            #don't use *=, we don't want to modify the input array
            image = image * 255.
            image = np.cast['uint8'](image)

        #PIL is too stupid to handle single-channel arrays
        if len(image.shape) == 3 and image.shape[2] == 1:
            image = image[:, :, 0]

        try:
            ensure_Image()
            image = Image.fromarray(image)
        except TypeError:
            reraise_as(
                TypeError("PIL issued TypeError on ndarray of shape " +
                          str(image.shape) + " and dtype " + str(image.dtype)))

    # Create a temporary file with the suffix '.png'.
    fd, name = mkstemp(suffix='.png')
    os.close(fd)

    # Note:
    #   Although we can use tempfile.NamedTemporaryFile() to create
    #   a temporary file, the function should be used with care.
    #
    #   In Python earlier than 2.7, a temporary file created by the
    #   function will be deleted just after the file is closed.
    #   We can re-use the name of the temporary file, but there is an
    #   instant where a file with the name does not exist in the file
    #   system before we re-use the name. This may cause a race
    #   condition.
    #
    #   In Python 2.7 or later, tempfile.NamedTemporaryFile() has
    #   the 'delete' argument which can control whether a temporary
    #   file will be automatically deleted or not. With the argument,
    #   the above race condition can be avoided.
    #

    image.save(name)
    viewer_command = string.preprocess('${PYLEARN2_VIEWER_COMMAND}')
    if os.name == 'nt':
        subprocess.Popen(viewer_command + ' ' + name + ' && del ' + name,
                         shell=True)
    else:
        subprocess.Popen(viewer_command + ' ' + name + ' ; rm ' + name,
                         shell=True)
示例#48
0
def update_wrapper(wrapper,
                   wrapped,
                   assigned=WRAPPER_ASSIGNMENTS,
                   concatenated=WRAPPER_CONCATENATIONS,
                   append=False,
                   updated=WRAPPER_UPDATES,
                   replace_before=None):
    """
    A Python decorator which acts like `functools.update_wrapper` but
    also has the ability to concatenate attributes.

    Parameters
    ----------
    wrapper : function
        Function to be updated
    wrapped : function
        Original function
    assigned : tuple, optional
        Tuple naming the attributes assigned directly from the wrapped
        function to the wrapper function.
        Defaults to `utils.WRAPPER_ASSIGNMENTS`.
    concatenated : tuple, optional
        Tuple naming the attributes from the wrapped function
        concatenated with the ones from the wrapper function.
        Defaults to `utils.WRAPPER_CONCATENATIONS`.
    append : bool, optional
        If True, appends wrapped attributes to wrapper attributes
        instead of prepending them. Defaults to False.
    updated : tuple, optional
        Tuple naming the attributes of the wrapper that are updated
        with the corresponding attribute from the wrapped function.
        Defaults to `functools.WRAPPER_UPDATES`.
    replace_before : str, optional
        If `append` is `False` (meaning we are prepending), delete
        docstring lines occurring before the first line equal to this
        string (the docstring line is stripped of leading/trailing
        whitespace before comparison). The newline of the line preceding
        this string is preserved.

    Returns
    -------
    wrapper : function
        Updated wrapper function

    Notes
    -----
    This can be used to concatenate the wrapper's docstring with the
    wrapped's docstring and should help reduce the ammount of
    documentation to write: one can use this decorator on child
    classes' functions when their implementation is similar to the one
    of the parent class. Conversely, if a function defined in a child
    class departs from its parent's implementation, one can simply
    explain the differences in a 'Notes' section without re-writing the
    whole docstring.
    """
    assert not (append and replace_before), ("replace_before cannot "
                                             "be used with append")
    for attr in assigned:
        setattr(wrapper, attr, getattr(wrapped, attr))
    for attr in concatenated:
        # Make sure attributes are not None
        if getattr(wrapped, attr) is None:
            setattr(wrapped, attr, "")
        if getattr(wrapper, attr) is None:
            setattr(wrapper, attr, "")
        if append:
            setattr(wrapper,
                    attr,
                    getattr(wrapped, attr) + getattr(wrapper, attr))
        else:
            if replace_before:
                assert replace_before.strip() == replace_before, (
                    'value for replace_before "%s" contains leading/'
                    'trailing whitespace'
                )
                split = getattr(wrapped, attr).split("\n")
                # Potentially wasting time/memory by stripping everything
                # and duplicating it but probably not enough to worry about.
                split_stripped = [line.strip() for line in split]
                try:
                    index = split_stripped.index(replace_before.strip())
                except ValueError:
                    reraise_as(ValueError('no line equal to "%s" in wrapped '
                                          'function\'s attribute %s' %
                                          (replace_before, attr)))
                wrapped_val = '\n' + '\n'.join(split[index:])
            else:
                wrapped_val = getattr(wrapped, attr)
            setattr(wrapper,
                    attr,
                    getattr(wrapper, attr) + wrapped_val)
    for attr in updated:
        getattr(wrapper, attr).update(getattr(wrapped, attr, {}))
    # Return the wrapper so this can be used as a decorator via partial()
    return wrapper
示例#49
0
def _load(filepath, recurse_depth=0, retry=True):
    """
    Recursively tries to load a file until success or maximum number of
    attempts.

    Parameters
    ----------
    filepath : str
        A path to a file to load. Should be a pickle, Matlab, or NumPy
        file; or a .txt or .amat file that numpy.loadtxt can load.
    recurse_depth : int, optional
        End users should not use this argument. It is used by the function
        itself to implement the `retry` option recursively.
    retry : bool, optional
        If True, will make a handful of attempts to load the file before
        giving up. This can be useful if you are for example calling
        show_weights.py on a file that is actively being written to by a
        training script--sometimes the load attempt might fail if the
        training script writes at the same time show_weights tries to
        read, but if you try again after a few seconds you should be able
        to open the file.

    Returns
    -------
    loaded_object : object
        The object that was stored in the file.
    """
    try:
        import joblib
        joblib_available = True
    except ImportError:
        joblib_available = False
    if recurse_depth == 0:
        filepath = preprocess(filepath)

    if filepath.endswith('.npy') or filepath.endswith('.npz'):
        return np.load(filepath)

    if filepath.endswith('.amat') or filepath.endswith('txt'):
        try:
            return np.loadtxt(filepath)
        except Exception:
            reraise_as("{0} cannot be loaded by serial.load (trying "
                       "to use np.loadtxt)".format(filepath))

    if filepath.endswith('.mat'):
        global io
        if io is None:
            import scipy.io
            io = scipy.io
        try:
            return io.loadmat(filepath)
        except NotImplementedError as nei:
            if str(nei).find('HDF reader') != -1:
                global hdf_reader
                if hdf_reader is None:
                    import h5py
                    hdf_reader = h5py
                return hdf_reader.File(filepath, 'r')
            else:
                raise
        # this code should never be reached
        assert False

    # for loading PY2 pickle in PY3
    encoding = {'encoding': 'latin-1'} if six.PY3 else {}

    def exponential_backoff():
        if recurse_depth > 9:
            logger.info('Max number of tries exceeded while trying to open '
                        '{0}'.format(filepath))
            logger.info('attempting to open via reading string')
            with open(filepath, 'rb') as f:
                content = f.read()
            return cPickle.loads(content, **encoding)
        else:
            nsec = 0.5 * (2.0 ** float(recurse_depth))
            logger.info("Waiting {0} seconds and trying again".format(nsec))
            time.sleep(nsec)
            return _load(filepath, recurse_depth + 1, retry)

    try:
        if not joblib_available:
            with open(filepath, 'rb') as f:
                obj = cPickle.load(f, **encoding)
        else:
            try:
                obj = joblib.load(filepath)
            except Exception as e:
                if os.path.exists(filepath) and not os.path.isdir(filepath):
                    raise
                raise_cannot_open(filepath)
    except MemoryError as e:
        # We want to explicitly catch this exception because for MemoryError
        # __str__ returns the empty string, so some of our default printouts
        # below don't make a lot of sense.
        # Also, a lot of users assume any exception is a bug in the library,
        # so we can cut down on mail to pylearn-users by adding a message
        # that makes it clear this exception is caused by their machine not
        # meeting requirements.
        if os.path.splitext(filepath)[1] == ".pkl":
            improve_memory_error_message(e,
                                         ("You do not have enough memory to "
                                          "open %s \n"
                                          " + Try using numpy.{save,load} "
                                          "(file with extension '.npy') "
                                          "to save your file. It uses less "
                                          "memory when reading and "
                                          "writing files than pickled files.")
                                         % filepath)
        else:
            improve_memory_error_message(e,
                                         "You do not have enough memory to "
                                         "open %s" % filepath)

    except (BadPickleGet, EOFError, KeyError) as e:
        if not retry:
            reraise_as(e.__class__('Failed to open {0}'.format(filepath)))
        obj = exponential_backoff()
    except ValueError:
        logger.exception

        if not retry:
            reraise_as(ValueError('Failed to open {0}'.format(filepath)))
        obj = exponential_backoff()
    except Exception:
        # assert False
        reraise_as("Couldn't open {0}".format(filepath))

    # if the object has no yaml_src, we give it one that just says it
    # came from this file. could cause trouble if you save obj again
    # to a different location
    if not hasattr(obj, 'yaml_src'):
        try:
            obj.yaml_src = '!pkl: "' + os.path.abspath(filepath) + '"'
        except Exception:
            pass

    return obj
示例#50
0
def _save(filepath, obj):
    """
    .. todo::

        WRITEME
    """
    try:
        import joblib
        joblib_available = True
    except ImportError:
        joblib_available = False
    if filepath.endswith('.npy'):
        np.save(filepath, obj)
        return
    # This is dumb
    # assert filepath.endswith('.pkl')
    save_dir = os.path.dirname(filepath)
    # Handle current working directory case.
    if save_dir == '':
        save_dir = '.'
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    if os.path.exists(save_dir) and not os.path.isdir(save_dir):
        raise IOError("save path %s exists, not a directory" % save_dir)
    elif not os.access(save_dir, os.W_OK):
        raise IOError("permission error creating %s" % filepath)
    try:
        if joblib_available and filepath.endswith('.joblib'):
            joblib.dump(obj, filepath)
        else:
            if filepath.endswith('.joblib'):
                warnings.warn('Warning: .joblib suffix specified but joblib '
                              'unavailable. Using ordinary pickle.')
            with open(filepath, 'wb') as filehandle:
                cPickle.dump(obj, filehandle, get_pickle_protocol())
    except Exception as e:
        logger.exception("cPickle has failed to write an object to "
                         "{0}".format(filepath))
        if str(e).find('maximum recursion depth exceeded') != -1:
            raise
        try:
            logger.info('retrying with pickle')
            with open(filepath, "wb") as f:
                pickle.dump(obj, f)
        except Exception as e2:
            if str(e) == '' and str(e2) == '':
                logger.exception('neither cPickle nor pickle could write to '
                                 '{0}'.format(filepath))
                logger.exception(
                    'moreover, neither of them raised an exception that '
                    'can be converted to a string'
                )
                logger.exception(
                    'now re-attempting to write with cPickle outside the '
                    'try/catch loop so you can see if it prints anything '
                    'when it dies'
                )
                with open(filepath, 'wb') as f:
                    cPickle.dump(obj, f, get_pickle_protocol())
                logger.info('Somehow or other, the file write worked once '
                            'we quit using the try/catch.')
            else:
                if str(e2) == 'env':
                    raise

                import pdb
                tb = pdb.traceback.format_exc()
                reraise_as(IOError(str(obj) +
                                   ' could not be written to ' +
                                   str(filepath) +
                           ' by cPickle due to ' + str(e) +
                                   ' nor by pickle due to ' + str(e2) +
                                   '. \nTraceback ' + tb))
        logger.warning('{0} was written by pickle instead of cPickle, due to '
                       '{1} (perhaps your object'
                       ' is really big?)'.format(filepath, e))
示例#51
0
        try:
            # Try to figure out what the wrong field name was
            # If we fail to do it, just fall back to giving the usual
            # attribute error
            pieces = tag_suffix.split('.')
            module = '.'.join(pieces[:-1])
            field = pieces[-1]
            candidates = dir(eval(module))

            msg = ('Could not evaluate %s. ' % tag_suffix +
                   'Did you mean ' + match(field, candidates) + '? ' +
                   'Original error was ' + str(e))

        except Exception:
            warnings.warn("Attempt to decipher AttributeError failed")
            reraise_as(AttributeError('Could not evaluate %s. ' % tag_suffix +
                                      'Original error was ' + str(e)))
        reraise_as(AttributeError(msg))
    return obj


def initialize():
    """
    Initialize the configuration system by installing YAML handlers.
    Automatically done on first call to load() specified in this file.
    """
    global is_initialized

    # Add the custom multi-constructor
    yaml.add_multi_constructor('!obj:', multi_constructor_obj)
    yaml.add_multi_constructor('!pkl:', multi_constructor_pkl)
    yaml.add_multi_constructor('!import:', multi_constructor_import)
示例#52
0
文件: iteration.py 项目: dwf/pylearn2
    def __init__(self, dataset, subset_iterator, data_specs=None,
                 return_tuple=False, convert=None):
        self._data_specs = data_specs
        self._dataset = dataset
        self._subset_iterator = subset_iterator
        self._return_tuple = return_tuple

        # Keep only the needed sources in self._raw_data.
        # Remember what source they correspond to in self._source
        assert is_flat_specs(data_specs)

        dataset_space, dataset_source = self._dataset.get_data_specs()
        assert is_flat_specs((dataset_space, dataset_source))

        # the dataset's data spec is either a single (space, source) pair,
        # or a pair of (non-nested CompositeSpace, non-nested tuple).
        # We could build a mapping and call flatten(..., return_tuple=True)
        # but simply putting spaces, sources and data in tuples is simpler.
        if not isinstance(dataset_source, (tuple, list)):
            dataset_source = (dataset_source,)

        if not isinstance(dataset_space, CompositeSpace):
            dataset_sub_spaces = (dataset_space,)
        else:
            dataset_sub_spaces = dataset_space.components
        assert len(dataset_source) == len(dataset_sub_spaces)

        space, source = data_specs
        if not isinstance(source, tuple):
            source = (source,)
        if not isinstance(space, CompositeSpace):
            sub_spaces = (space,)
        else:
            sub_spaces = space.components
        assert len(source) == len(sub_spaces)

        # If `dataset` is incompatible with the new interface, fall back to the
        # old interface
        if not hasattr(self._dataset, 'get'):
            all_data = self._dataset.get_data()
            if not isinstance(all_data, tuple):
                all_data = (all_data,)
            raw_data = []
            for s in source:
                try:
                    raw_data.append(all_data[dataset_source.index(s)])
                except ValueError as e:
                    msg = str(e) + '\nThe dataset does not provide '\
                                   'a source with name: ' + s + '.'
                    reraise_as(ValueError(msg))
            self._raw_data = tuple(raw_data)

        self._source = source
        self._space = sub_spaces

        if convert is None:
            self._convert = [None for s in source]
        else:
            assert len(convert) == len(source)
            self._convert = convert

        for i, (so, sp) in enumerate(safe_izip(source, sub_spaces)):
            try:
                idx = dataset_source.index(so)
            except ValueError as e:
                msg = str(e) + '\nThe dataset does not provide '\
                               'a source with name: ' + so + '.'
                reraise_as(ValueError(msg))
            dspace = dataset_sub_spaces[idx]

            fn = self._convert[i]

            # If there is a fn, it is supposed to take care of the formatting,
            # and it should be an error if it does not. If there was no fn,
            # then the iterator will try to format using the generic
            # space-formatting functions.
            if fn is None:
                # "dspace" and "sp" have to be passed as parameters
                # to lambda, in order to capture their current value,
                # otherwise they would change in the next iteration
                # of the loop.
                fn = (lambda batch, dspace=dspace, sp=sp:
                      dspace.np_format_as(batch, sp))

            self._convert[i] = fn
示例#53
0
""" TrainExtensions for doing random spatial windowing and flipping of an
    image dataset on every epoch. TODO: fill out properly."""

import warnings
import numpy
from . import TrainExtension
from pylearn2.datasets.preprocessing import CentralWindow
from pylearn2.utils.exc import reraise_as
from pylearn2.utils.rng import make_np_rng

try:
    from ..utils._window_flip import random_window_and_flip_c01b
    from ..utils._window_flip import random_window_and_flip_b01c
except ImportError:
    reraise_as(
        ImportError("Import of Cython module failed. Please make sure "
                    "you have run 'python setup.py develop' in the "
                    "pylearn2 directory"))

__authors__ = "David Warde-Farley"
__copyright__ = "Copyright 2010-2012, Universite de Montreal"
__credits__ = ["David Warde-Farley"]
__license__ = "3-clause BSD"
__maintainer__ = "David Warde-Farley"
__email__ = "wardefar@iro"


def _zero_pad(array, amount, axes=(1, 2)):
    """
    Returns a copy of <array> with zero-filled padding around the margins.

    The new array has the same dimensions as the input array, except for
示例#54
0
def raise_cannot_open(path):
    """
    Raise an exception saying we can't open `path`.

    Parameters
    ----------
    path : str
        The path we cannot open
    """
    pieces = path.split('/')
    for i in xrange(1, len(pieces) + 1):
        so_far = '/'.join(pieces[0:i])
        if not os.path.exists(so_far):
            if i == 1:
                if so_far == '':
                    continue
                reraise_as(IOError('Cannot open ' + path + ' (' + so_far +
                           ' does not exist)'))
            parent = '/'.join(pieces[0:i - 1])
            bad = pieces[i - 1]

            if not os.path.isdir(parent):
                reraise_as(IOError("Cannot open " + path + " because " +
                           parent + " is not a directory."))

            candidates = os.listdir(parent)

            if len(candidates) == 0:
                reraise_as(IOError("Cannot open " + path + " because " +
                           parent + " is empty."))

            if len(candidates) > 100:
                # Don't attempt to guess the right name if the directory is
                # huge
                reraise_as(IOError("Cannot open " + path + " but can open " +
                                   parent + "."))

            if os.path.islink(path):
                reraise_as(IOError(path + " appears to be a symlink to a "
                                   "non-existent file"))
            reraise_as(IOError("Cannot open " + path + " but can open " +
                       parent + ". Did you mean " + match(bad, candidates) +
                       " instead of " + bad + "?"))
        # end if
    # end for
    assert False
示例#55
0
    def __init__(self,
                 dataset,
                 subset_iterator,
                 data_specs=None,
                 return_tuple=False,
                 convert=None):
        self._data_specs = data_specs
        self._dataset = dataset
        self._subset_iterator = subset_iterator
        self._return_tuple = return_tuple

        # Keep only the needed sources in self._raw_data.
        # Remember what source they correspond to in self._source
        assert is_flat_specs(data_specs)

        dataset_space, dataset_source = self._dataset.get_data_specs()
        assert is_flat_specs((dataset_space, dataset_source))

        # the dataset's data spec is either a single (space, source) pair,
        # or a pair of (non-nested CompositeSpace, non-nested tuple).
        # We could build a mapping and call flatten(..., return_tuple=True)
        # but simply putting spaces, sources and data in tuples is simpler.
        if not isinstance(dataset_source, (tuple, list)):
            dataset_source = (dataset_source, )

        if not isinstance(dataset_space, CompositeSpace):
            dataset_sub_spaces = (dataset_space, )
        else:
            dataset_sub_spaces = dataset_space.components
        assert len(dataset_source) == len(dataset_sub_spaces)

        space, source = data_specs
        if not isinstance(source, tuple):
            source = (source, )
        if not isinstance(space, CompositeSpace):
            sub_spaces = (space, )
        else:
            sub_spaces = space.components
        assert len(source) == len(sub_spaces)

        # If `dataset` is incompatible with the new interface, fall back to the
        # old interface
        if not hasattr(self._dataset, 'get'):
            all_data = self._dataset.get_data()
            if not isinstance(all_data, tuple):
                all_data = (all_data, )
            raw_data = []
            for s in source:
                try:
                    raw_data.append(all_data[dataset_source.index(s)])
                except ValueError as e:
                    msg = str(e) + '\nThe dataset does not provide '\
                                   'a source with name: ' + s + '.'
                    reraise_as(ValueError(msg))
            self._raw_data = tuple(raw_data)

        self._source = source
        self._space = sub_spaces

        if convert is None:
            self._convert = [None for s in source]
        else:
            assert len(convert) == len(source)
            self._convert = convert

        for i, (so, sp) in enumerate(safe_izip(source, sub_spaces)):
            try:
                idx = dataset_source.index(so)
            except ValueError as e:
                msg = str(e) + '\nThe dataset does not provide '\
                               'a source with name: ' + so + '.'
                reraise_as(ValueError(msg))
            dspace = dataset_sub_spaces[idx]

            fn = self._convert[i]

            # If there is a fn, it is supposed to take care of the formatting,
            # and it should be an error if it does not. If there was no fn,
            # then the iterator will try to format using the generic
            # space-formatting functions.
            if fn is None:
                # "dspace" and "sp" have to be passed as parameters
                # to lambda, in order to capture their current value,
                # otherwise they would change in the next iteration
                # of the loop.
                fn = (lambda batch, dspace=dspace, sp=sp: dspace.np_format_as(
                    batch, sp))

            self._convert[i] = fn