def register_names_to_del(self, names): """ Register names of fields that should not be pickled. Parameters ---------- names : iterable A collection of strings indicating names of fields on ts object that should not be pickled. Notes ----- All names registered will be deleted from the dictionary returned by the model's `__getstate__` method (unless a particular model overrides this method). """ if isinstance(names, six.string_types): names = [names] try: assert all(isinstance(n, six.string_types) for n in iter(names)) except (TypeError, AssertionError): reraise_as(ValueError('Invalid names argument')) # Quick check in case __init__ was never called, e.g. by a derived # class. if not hasattr(self, 'names_to_del'): self.names_to_del = set() self.names_to_del = self.names_to_del.union(names)
def write(f, mat): """ Write a ndarray to tensorfile. Parameters ---------- f : file Open file to write into mat : ndarray Array to save """ def _write_int32(f, i): i_array = numpy.asarray(i, dtype='int32') if 0: logger.debug('writing int32 {0} {1}'.format(i, i_array)) i_array.tofile(f) try: _write_int32(f, _dtype_magic[str(mat.dtype)]) except KeyError: reraise_as(TypeError('Invalid ndarray dtype for filetensor format', mat.dtype)) _write_int32(f, len(mat.shape)) shape = mat.shape if len(shape) < 3: shape = list(shape) + [1] * (3 - len(shape)) if 0: logger.debug('writing shape = {0}'.format(shape)) for sh in shape: _write_int32(f, sh) mat.tofile(f)
def construct_mapping(node, deep=False): # This is a modified version of yaml.BaseConstructor.construct_mapping # in which a repeated key raises a ConstructorError if not isinstance(node, yaml.nodes.MappingNode): const = yaml.constructor message = "expected a mapping node, but found" raise const.ConstructorError(None, None, "%s %s " % (message, node.id), node.start_mark) mapping = {} constructor = yaml.constructor.BaseConstructor() for key_node, value_node in node.value: key = constructor.construct_object(key_node, deep=False) try: hash(key) except TypeError as exc: const = yaml.constructor reraise_as( const.ConstructorError( "while constructing a mapping", node.start_mark, "found unacceptable key (%s)" % (exc, key_node.start_mark))) if key in mapping: const = yaml.constructor raise const.ConstructorError("while constructing a mapping", node.start_mark, "found duplicate key (%s)" % key) value = constructor.construct_object(value_node, deep=False) mapping[key] = value return mapping
def wrapped_func(*args, **kwargs): """ .. todo:: WRITEME """ try: func(*args, **kwargs) except TypeError: argnames, varargs, keywords, defaults = inspect.getargspec(func) posargs = dict(zip(argnames, args)) bad_keywords = [] for keyword in kwargs: if keyword not in argnames: bad_keywords.append(keyword) if len(bad_keywords) > 0: bad = ', '.join(bad_keywords) reraise_as(TypeError('%s() does not support the following ' 'keywords: %s' % (str(func.func_name), bad))) allargsgot = set(list(kwargs.keys()) + list(posargs.keys())) numrequired = len(argnames) - len(defaults) diff = list(set(argnames[:numrequired]) - allargsgot) if len(diff) > 0: reraise_as(TypeError('%s() did not get required args: %s' % (str(func.func_name), ', '.join(diff)))) raise
def get_monitoring_channels(self, model, data, **kwargs): self.get_data_specs(model)[0].validate(data) rval = OrderedDict() composite_specs, mapping = self.get_composite_specs_and_mapping(model) nested_data = mapping.nest(data) for i, cost in enumerate(self.costs): cost_data = nested_data[i] try: channels = cost.get_monitoring_channels( model, cost_data, **kwargs) rval.update(channels) except TypeError: reraise_as( Exception('SumOfCosts.get_monitoring_channels ' 'encountered TypeError while calling {0}' '.get_monitoring_channels'.format(type(cost)))) value = cost.expr(model, cost_data, **kwargs) if value is not None: name = '' if hasattr(value, 'name') and value.name is not None: name = '_' + value.name rval['term_' + str(i) + name] = value return rval
def wrapped_func(*args, **kwargs): """ .. todo:: WRITEME """ try: func(*args, **kwargs) except TypeError: argnames, varargs, keywords, defaults = inspect.getargspec(func) posargs = dict(zip(argnames, args)) bad_keywords = [] for keyword in kwargs: if keyword not in argnames: bad_keywords.append(keyword) if len(bad_keywords) > 0: bad = ', '.join(bad_keywords) reraise_as( TypeError('%s() does not support the following ' 'keywords: %s' % (str(func.func_name), bad))) allargsgot = set(list(kwargs.keys()) + list(posargs.keys())) numrequired = len(argnames) - len(defaults) diff = list(set(argnames[:numrequired]) - allargsgot) if len(diff) > 0: reraise_as( TypeError('%s() did not get required args: %s' % (str(func.func_name), ', '.join(diff)))) raise
def get_gradients(self, model, data, **kwargs): """ Provides the gradients of the cost function with respect to the model parameters. These are not necessarily those obtained by theano.tensor.grad --you may wish to use approximate or even intentionally incorrect gradients in some cases. Parameters ---------- model : a pylearn2 Model instance data : a batch in cost.get_data_specs() form kwargs : dict Optional extra arguments, not used by the base class. Returns ------- gradients : OrderedDict a dictionary mapping from the model's parameters to their gradients The default implementation is to compute the gradients using T.grad applied to the value returned by expr. However, subclasses may return other values for the gradient. For example, an intractable cost may return a sampling-based approximation to its gradient. updates : OrderedDict a dictionary mapping shared variables to updates that must be applied to them each time these gradients are computed. This is to facilitate computation of sampling-based approximate gradients. The parameters should never appear in the updates dictionary. This would imply that computing their gradient changes their value, thus making the gradient value outdated. """ try: cost = self.expr(model=model, data=data, **kwargs) except TypeError: # If anybody knows how to add type(self) to the exception message # but still preserve the stack trace, please do so # The current code does neither message = "Error while calling " + str(type(self)) + ".expr" reraise_as(TypeError(message)) if cost is None: raise NotImplementedError( str(type(self)) + " represents an intractable cost and " "does not provide a gradient " "approximation scheme.") params = list(model.get_params()) grads = T.grad(cost, params, disconnected_inputs='ignore') gradients = OrderedDict(izip(params, grads)) updates = OrderedDict() return gradients, updates
def construct_mapping(node, deep=False): # This is a modified version of yaml.BaseConstructor.construct_mapping # in which a repeated key raises a ConstructorError if not isinstance(node, yaml.nodes.MappingNode): const = yaml.constructor message = "expected a mapping node, but found" raise const.ConstructorError(None, None, "%s %s " % (message, node.id), node.start_mark) mapping = {} constructor = yaml.constructor.BaseConstructor() for key_node, value_node in node.value: key = constructor.construct_object(key_node, deep=False) try: hash(key) except TypeError, exc: const = yaml.constructor reraise_as(const.ConstructorError("while constructing a mapping", node.start_mark, "found unacceptable key (%s)" % (exc, key_node.start_mark))) if key in mapping: const = yaml.constructor raise const.ConstructorError("while constructing a mapping", node.start_mark, "found duplicate key (%s)" % key) value = constructor.construct_object(value_node, deep=False) mapping[key] = value
def get_monitoring_channels(self, model, data, ** kwargs): self.get_data_specs(model)[0].validate(data) rval = OrderedDict() composite_specs, mapping = self.get_composite_specs_and_mapping(model) nested_data = mapping.nest(data) for i, cost in enumerate(self.costs): cost_data = nested_data[i] try: channels = cost.get_monitoring_channels(model, cost_data, **kwargs) rval.update(channels) except TypeError: reraise_as(Exception('SumOfCosts.get_monitoring_channels ' 'encountered TypeError while calling {0}' '.get_monitoring_channels'.format( type(cost)))) value = cost.expr(model, cost_data, ** kwargs) if value is not None: name = '' if hasattr(value, 'name') and value.name is not None: name = '_' + value.name rval['term_' + str(i) + name] = value return rval
def get_gradients(self, model, data, ** kwargs): """ Provides the gradients of the cost function with respect to the model parameters. These are not necessarily those obtained by theano.tensor.grad --you may wish to use approximate or even intentionally incorrect gradients in some cases. Parameters ---------- model : a pylearn2 Model instance data : a batch in cost.get_data_specs() form kwargs : dict Optional extra arguments, not used by the base class. Returns ------- gradients : OrderedDict a dictionary mapping from the model's parameters to their gradients The default implementation is to compute the gradients using T.grad applied to the value returned by expr. However, subclasses may return other values for the gradient. For example, an intractable cost may return a sampling-based approximation to its gradient. updates : OrderedDict a dictionary mapping shared variables to updates that must be applied to them each time these gradients are computed. This is to facilitate computation of sampling-based approximate gradients. The parameters should never appear in the updates dictionary. This would imply that computing their gradient changes their value, thus making the gradient value outdated. """ try: cost,mask = self.expr(model=model, data=data, **kwargs) except TypeError: # If anybody knows how to add type(self) to the exception message # but still preserve the stack trace, please do so # The current code does neither message = "Error while calling " + str(type(self)) + ".expr" reraise_as(TypeError(message)) if cost is None: raise NotImplementedError(str(type(self)) + " represents an intractable cost and " "does not provide a gradient " "approximation scheme.") params = list(model.get_params()) grads = T.grad(cost, params, disconnected_inputs='ignore') gradients = OrderedDict(izip(params, grads)) updates = OrderedDict() return gradients, updates
def fn(batch, dspace=dspace, sp=sp): try: return dspace.np_format_as(batch, sp) except ValueError as e: msg = str(e) + '\nMake sure that the model and '\ 'dataset have been initialized with '\ 'correct values.' reraise_as(ValueError(msg))
def wrapped_layer_cost(layer, coef): try: return layer.get_weight_decay(coeff) except NotImplementedError: if coef==0.: return 0. else: reraise_as(NotImplementedError(str(type(layer)) + " does not implement get_weight_decay."))
def _validate_shape(shape, param_name): try: shape = tuple(shape) [int(val) for val in shape] except (ValueError, TypeError): try: shape = (int(shape), ) except TypeError: reraise_as( TypeError("%s must be int or int tuple" % param_name)) return shape
def _validate_shape(shape, param_name): try: shape = tuple(shape) [int(val) for val in shape] except (ValueError, TypeError): try: shape = (int(shape),) except TypeError: reraise_as(TypeError("%s must be int or int tuple" % param_name)) return shape
def next(self): """ .. todo:: WRITEME """ indx = self.subset_iterator.next() try: mini_batch = self.X[indx] except IndexError, e: reraise_as(ValueError("Index out of range"+str(e)))
def next(self): """ .. todo:: WRITEME """ indx = self.subset_iterator.next() try: mini_batch = self.X[indx] except IndexError, e: reraise_as(ValueError("Index out of range" + str(e)))
def wrapped_layer_cost(layer, coeff): try: return layer.get_weight_decay(coeff) except NotImplementedError: if coeff == 0.: return 0. else: reraise_as( NotImplementedError( str(type(layer)) + " does not implement " "get_weight_decay."))
def resolve(d): """ given a dictionary d, returns the object described by the dictionary """ tag = get_tag(d) try: resolver = resolvers[tag] except KeyError: reraise_as(TypeError('config does not know of any object type "'+tag+'"')) return resolver(d)
def on_monitor(self, model, dataset, algorithm): """ Adjusts the learning rate based on the contents of model.monitor Parameters ---------- model : a Model instance dataset : Dataset algorithm : WRITEME """ model = algorithm.model lr = algorithm.learning_rate current_learning_rate = lr.get_value() assert hasattr(model, 'monitor'), ("no monitor associated with " + str(model)) monitor = model.monitor monitor_channel_specified = True try: v = monitor.channels[self.channel_name].val_record except KeyError: err_input = '' err_input = 'The channel_name \'' + str( self.channel_name) + '\' is not valid.' err_message = 'There is no monitoring channel named \'' + \ str(self.channel_name) + '\'. You probably need to ' + \ 'specify a valid monitoring channel by using either ' + \ 'dataset_name or channel_name in the ' + \ 'MonitorBasedLRDecay constructor. ' + err_input reraise_as(ValueError(err_message)) if len(v) == 1: #only the initial monitoring has happened #no learning has happened, so we can't adjust the learning rate yet #just do nothing self._min_v = v[0] return rval = current_learning_rate log.info("monitoring channel is {0}".format(self.channel_name)) if v[-1] < self._min_v: self._min_v = v[-1] self._count = 0 else: self._count += 1 if self._count > self.nb_epoch: self._count = 0 rval = self.shrink_lr * rval rval = max(self.min_lr, rval) lr.set_value(np.cast[lr.dtype](rval))
def resolve(d): """ given a dictionary d, returns the object described by the dictionary """ tag = get_tag(d) try: resolver = resolvers[tag] except KeyError: reraise_as( TypeError('config does not know of any object type "' + tag + '"')) return resolver(d)
def load(filepath, rescale_image=True, dtype='float64'): """ .. todo:: WRITEME """ assert type(filepath) == str if rescale_image == False and dtype == 'uint8': ensure_Image() rval = np.asarray(Image.open(filepath)) # print 'image.load: ' + str((rval.min(), rval.max())) assert rval.dtype == 'uint8' return rval s = 1.0 if rescale_image: s = 255. try: ensure_Image() rval = Image.open(filepath) except Exception: reraise_as(Exception("Could not open " + filepath)) numpy_rval = np.array(rval) if numpy_rval.ndim not in [2,3]: logger.error(dir(rval)) logger.error(rval) logger.error(rval.size) rval.show() raise AssertionError("Tried to load an image, got an array with " + str(numpy_rval.ndim)+" dimensions. Expected 2 or 3." "This may indicate a mildly corrupted image file. Try " "converting it to a different image format with a different " "editor like gimp or imagemagic. Sometimes these programs are " "more robust to minor corruption than PIL and will emit a " "correctly formatted image in the new format." ) rval = numpy_rval rval = np.cast[dtype](rval) / s if rval.ndim == 2: rval = rval.reshape(rval.shape[0], rval.shape[1], 1) if rval.ndim != 3: raise AssertionError("Something went wrong opening " + filepath + '. Resulting shape is ' + str(rval.shape) + " (it's meant to have 3 dimensions by now)") return rval
def next(self): """ .. todo:: WRITEME """ indx = self.subset_iterator.next() try: mini_batch = self.X[indx] except IndexError as e: reraise_as(ValueError("Index out of range" + str(e))) # the ind of minibatch goes beyond the boundary return mini_batch
def test_multi_constructor_obj(): """ Tests whether multi_constructor_obj throws an exception when the keys in mapping are None. """ try: load("a: !obj:decimal.Decimal { 1 }") except TypeError as e: assert str(e) == "Received non string object (1) as key in mapping." pass except Exception as e: error_msg = "Got the unexpected error: %s" % (e) reraise_as(ValueError(error_msg))
def read_bin_lush_matrix(filepath): """ Reads a binary matrix saved by the lush library. Parameters ---------- filepath : str The path to the file. Returns ------- matrix : ndarray A NumPy version of the stored matrix. """ f = open(filepath, 'rb') try: magic = read_int(f) except ValueError: reraise_as("Couldn't read magic number") ndim = read_int(f) if ndim == 0: shape = () else: shape = read_int(f, max(3, ndim)) total_elems = 1 for dim in shape: total_elems *= dim try: dtype = lush_magic[magic] except KeyError: reraise_as(ValueError('Unrecognized lush magic number ' + str(magic))) rval = np.fromfile(file=f, dtype=dtype, count=total_elems) excess = f.read(-1) if excess: raise ValueError(str(len(excess)) + ' extra bytes found at end of file.' ' This indicates mismatch between header ' 'and content') rval = rval.reshape(*shape) f.close() return rval
def read_bin_lush_matrix(filepath): """ Reads a binary matrix saved by the lush library. Parameters ---------- filepath : str The path to the file. Returns ------- matrix : ndarray A NumPy version of the stored matrix. """ f = open(filepath, 'rb') try: magic = read_int(f) except ValueError: reraise_as("Couldn't read magic number") ndim = read_int(f) if ndim == 0: shape = () else: shape = read_int(f, max(3, ndim)) total_elems = 1 for dim in shape: total_elems *= dim try: dtype = lush_magic[magic] except KeyError: reraise_as(ValueError('Unrecognized lush magic number ' + str(magic))) rval = np.fromfile(file=f, dtype=dtype, count=total_elems) excess = f.read(-1) if excess: raise ValueError( str(len(excess)) + ' extra bytes found at end of file.' ' This indicates mismatch between header ' 'and content') rval = rval.reshape(*shape) f.close() return rval
def get(self, sources, indexes): """ Retrieves the requested elements from the dataset. Parameter --------- sources : tuple A tuple of source identifiers indexes : slice or list A slice or a list of indexes Return ------ rval : tuple A tuple of batches, one for each source """ assert isinstance(sources, (tuple, list)) and len(sources) > 0, ( 'sources should be an instance of tuple and not empty') assert all([isinstance(el, string_types) for el in sources]), ('sources elements should be strings') assert isinstance(indexes, (tuple, list, slice, py_integer_types)), ( 'indexes should be either an int, a slice or a tuple/list of ints') if isinstance(indexes, (tuple, list)): assert len(indexes) > 0 and all( [isinstance(i, py_integer_types) for i in indexes]), ('indexes elements should be ints') rval = [] for s in sources: try: sdata = self.data[s] except ValueError as e: reraise_as( ValueError( 'The requested source %s is not part of the dataset' % sources[s], *e.args)) if (isinstance(indexes, (slice, py_integer_types)) or len(indexes) == 1): rval.append(sdata[indexes]) else: warnings.warn('Accessing non sequential elements of an ' 'HDF5 file will be at best VERY slow. Avoid ' 'using iteration schemes that access ' 'random/shuffled data with hdf5 datasets!!') val = [] [val.append(sdata[idx]) for idx in indexes] rval.append(val) return tuple(rval)
def get(self, sources, indexes): """ Retrieves the requested elements from the dataset. Parameter --------- sources : tuple A tuple of source identifiers indexes : slice or list A slice or a list of indexes Return ------ rval : tuple A tuple of batches, one for each source """ assert ( isinstance(sources, (tuple, list)) and len(sources) > 0 ), "sources should be an instance of tuple and not empty" assert all([isinstance(el, string_types) for el in sources]), "sources elements should be strings" assert isinstance( indexes, (tuple, list, slice, py_integer_types) ), "indexes should be either an int, a slice or a tuple/list of ints" if isinstance(indexes, (tuple, list)): assert len(indexes) > 0 and all( [isinstance(i, py_integer_types) for i in indexes] ), "indexes elements should be ints" rval = [] for s in sources: try: sdata = self.data[s] except ValueError as e: reraise_as(ValueError("The requested source %s is not part of the dataset" % sources[s], *e.args)) if isinstance(indexes, (slice, py_integer_types)) or len(indexes) == 1: rval.append(sdata[indexes]) else: warnings.warn( "Accessing non sequential elements of an " "HDF5 file will be at best VERY slow. Avoid " "using iteration schemes that access " "random/shuffled data with hdf5 datasets!!" ) val = [] [val.append(sdata[idx]) for idx in indexes] rval.append(val) return tuple(rval)
def try_to_import(tag_suffix): """ .. todo:: WRITEME """ components = tag_suffix.split('.') modulename = '.'.join(components[:-1]) try: exec('import %s' % modulename) except ImportError, e: # We know it's an ImportError, but is it an ImportError related to # this path, # or did the module we're importing have an unrelated ImportError? # and yes, this test can still have false positives, feel free to # improve it pieces = modulename.split('.') str_e = str(e) found = True in [piece.find(str(e)) != -1 for piece in pieces] if found: # The yaml file is probably to blame. # Report the problem with the full module path from the YAML # file reraise_as( ImportError("Could not import %s; ImportError was %s" % (modulename, str_e))) else: pcomponents = components[:-1] assert len(pcomponents) >= 1 j = 1 while j <= len(pcomponents): modulename = '.'.join(pcomponents[:j]) try: exec('import %s' % modulename) except Exception: base_msg = 'Could not import %s' % modulename if j > 1: modulename = '.'.join(pcomponents[:j - 1]) base_msg += ' but could import %s' % modulename reraise_as( ImportError(base_msg + '. Original exception: ' + str(e))) j += 1
def try_to_import(tag_suffix): """ .. todo:: WRITEME """ components = tag_suffix.split('.') modulename = '.'.join(components[:-1]) try: exec('import %s' % modulename) except ImportError, e: # We know it's an ImportError, but is it an ImportError related to # this path, # or did the module we're importing have an unrelated ImportError? # and yes, this test can still have false positives, feel free to # improve it pieces = modulename.split('.') str_e = str(e) found = True in [piece.find(str(e)) != -1 for piece in pieces] if found: # The yaml file is probably to blame. # Report the problem with the full module path from the YAML # file reraise_as(ImportError("Could not import %s; ImportError was %s" % (modulename, str_e))) else: pcomponents = components[:-1] assert len(pcomponents) >= 1 j = 1 while j <= len(pcomponents): modulename = '.'.join(pcomponents[:j]) try: exec('import %s' % modulename) except Exception: base_msg = 'Could not import %s' % modulename if j > 1: modulename = '.'.join(pcomponents[:j - 1]) base_msg += ' but could import %s' % modulename reraise_as(ImportError(base_msg + '. Original exception: ' + str(e))) j += 1
def resolve(d): """ .. todo:: WRITEME """ tag = pylearn2.config.get_tag(d) if tag != 'dataset': raise TypeError('pylearn2.datasets.config asked to resolve a config dictionary with tag "'+tag+'"') t = pylearn2.config.get_str(d,'typename') try: resolver = resolvers[t] except KeyError: reraise_as(TypeError('pylearn2.datasets does not know of a dataset type "'+t+'"')) return resolver(d)
def __init__(self, max_labels, dtype=None): """ Initializes the formatter given the number of max labels. """ try: np.empty(max_labels) except (ValueError, TypeError): reraise_as(ValueError("%s got bad max_labels argument '%s'" % (self.__class__.__name__, str(max_labels)))) self._max_labels = max_labels if dtype is None: self._dtype = config.floatX else: try: np.dtype(dtype) except TypeError: reraise_as(TypeError("%s got bad dtype identifier %s" % (self.__class__.__name__, str(dtype)))) self._dtype = dtype
def get_batch_design(self, batch_size, include_labels=False): try: idx = self.rng.randint(self.X.shape[0] - batch_size + 1) except ValueError: if batch_size > self.X.shape[0]: reraise_as(ValueError("Requested %d examples from a dataset " "containing only %d." % (batch_size, self.X.shape[0]))) raise rx = self.X[idx:idx + batch_size, :] if include_labels: if self.y is None: return rx, None ry = self.y[idx:idx + batch_size] rlatent = self.latent[idx:idx + batch_size] return rx, ry,rlatent rx = np.cast[config.floatX](rx) return rx
def read_bin_lush_matrix(filepath): """ .. todo:: WRITEME """ f = open(filepath, 'rb') try: magic = read_int(f) except ValueError: reraise_as("Couldn't read magic number") ndim = read_int(f) if ndim == 0: shape = () else: shape = read_int(f, max(3, ndim)) total_elems = 1 for dim in shape: total_elems *= dim try: dtype = lush_magic[magic] except KeyError: reraise_as(ValueError('Unrecognized lush magic number ' + str(magic))) rval = np.fromfile(file=f, dtype=dtype, count=total_elems) excess = f.read(-1) if excess: raise ValueError( str(len(excess)) + ' extra bytes found at end of file.' ' This indicates mismatch between header and content') rval = rval.reshape(*shape) f.close() return rval
def resolve(d): """ .. todo:: WRITEME """ tag = pylearn2.config.get_tag(d) if tag != 'dataset': raise TypeError('pylearn2.datasets.config asked to resolve a config ' 'dictionary with tag "%s"' % tag) typename = pylearn2.config.get_str(d, 'typename') try: resolver = resolvers[typename] except KeyError: reraise_as(TypeError('pylearn2.datasets does not know of a dataset ' 'type "%s"' % typename)) return resolver(d)
def preprocess(string, environ=None): """ Preprocesses a string, by replacing `${VARNAME}` with `os.environ['VARNAME']` and ~ with the path to the user's home directory Parameters ---------- string : str String object to preprocess environ : dict, optional If supplied, preferentially accept values from this dictionary as well as `os.environ`. That is, if a key appears in both, this dictionary takes precedence. Returns ------- rval : str The preprocessed string """ if environ is None: environ = {} split = string.split('${') rval = [split[0]] for candidate in split[1:]: subsplit = candidate.split('}') if len(subsplit) < 2: raise ValueError('Open ${ not followed by } before ' 'end of string or next ${ in "' + string + '"') varname = subsplit[0] try: val = (environ[varname] if varname in environ else os.environ[varname]) except KeyError: if varname == 'PYLEARN2_DATA_PATH': reraise_as(NoDataPathError()) if varname == 'PYLEARN2_VIEWER_COMMAND': reraise_as(EnvironmentVariableError( viewer_command_error_essay + environment_variable_essay) ) reraise_as(ValueError('Unrecognized environment variable "' + varname + '". Did you mean ' + match(varname, os.environ.keys()) + '?')) rval.append(val) rval.append('}'.join(subsplit[1:])) rval = ''.join(rval) string = os.path.expanduser(string) return rval
def read_bin_lush_matrix(filepath): """ .. todo:: WRITEME """ f = open(filepath,'rb') try: magic = read_int(f) except ValueError: reraise_as(ValueError("Couldn't read magic number")) ndim = read_int(f) if ndim == 0: shape = () else: shape = read_int(f, max(3, ndim)) total_elems = 1 for dim in shape: total_elems *= dim try: dtype = lush_magic[magic] except KeyError: reraise_as(ValueError('Unrecognized lush magic number '+str(magic))) rval = np.fromfile(file = f, dtype = dtype, count = total_elems) excess = f.read(-1) if excess != '': raise ValueError(str(len(excess))+' extra bytes found at end of file.' ' This indicates mismatch between header and content') rval = rval.reshape(*shape) f.close() return rval
def on_monitor(self, model, dataset, algorithm): """ Adjusts the learning rate based on the contents of model.monitor Parameters ---------- model : a Model instance dataset : Dataset algorithm : WRITEME """ model = algorithm.model lr = algorithm.learning_rate current_learning_rate = lr.get_value() assert hasattr(model, 'monitor'), ("no monitor associated with " + str(model)) monitor = model.monitor monitor_channel_specified = True if self.channel_name is None: monitor_channel_specified = False channels = [elem for elem in monitor.channels if elem.endswith("objective")] if len(channels) < 1: raise ValueError("There are no monitoring channels that end " "with \"objective\". Please specify either " "channel_name or dataset_name.") elif len(channels) > 1: datasets = algorithm.monitoring_dataset.keys() raise ValueError("There are multiple monitoring channels that" "end with \"_objective\". The list of available " "datasets are: " + str(datasets) + " . Please specify either " "channel_name or dataset_name in the " "MonitorBasedLRAdjuster constructor to " 'disambiguate.') else: self.channel_name = channels[0] warnings.warn('The channel that has been chosen for ' 'monitoring is: ' + str(self.channel_name) + '.') try: v = monitor.channels[self.channel_name].val_record except KeyError: err_input = '' if monitor_channel_specified: if self.dataset_name: err_input = 'The dataset_name \'' + str( self.dataset_name) + '\' is not valid.' else: err_input = 'The channel_name \'' + str( self.channel_name) + '\' is not valid.' err_message = 'There is no monitoring channel named \'' + \ str(self.channel_name) + '\'. You probably need to ' + \ 'specify a valid monitoring channel by using either ' + \ 'dataset_name or channel_name in the ' + \ 'MonitorBasedLRAdjuster constructor. ' + err_input reraise_as(ValueError(err_message)) if len(v) < 1: if monitor.dataset is None: assert len(v) == 0 raise ValueError("You're trying to use a monitor-based " "learning rate adjustor but the monitor has no " "entries because you didn't specify a " "monitoring dataset.") raise ValueError("For some reason there are no monitor entries" "yet the MonitorBasedLRAdjuster has been " "called. This should never happen. The Train" " object should call the monitor once on " "initialization, then call the callbacks. " "It seems you are either calling the " "callback manually rather than as part of a " "training algorithm, or there is a problem " "with the Train object.") if len(v) == 1: #only the initial monitoring has happened #no learning has happened, so we can't adjust the learning rate yet #just do nothing return rval = current_learning_rate log.info("monitoring channel is {0}".format(self.channel_name)) if v[-1] > self.high_trigger * v[-2]: rval *= self.shrink_amt log.info("shrinking learning rate to %f" % rval) elif v[-1] > self.low_trigger * v[-2]: rval *= self.grow_amt log.info("growing learning rate to %f" % rval) rval = max(self.min_lr, rval) rval = min(self.max_lr, rval) lr.set_value(np.cast[lr.dtype](rval))
def get_weights_report(model_path=None, model=None, rescale='individual', border=False, norm_sort=False, dataset=None): """ Returns a PatchViewer displaying a grid of filter weights Parameters ---------- model_path : str Filepath of the model to make the report on. rescale : str A string specifying how to rescale the filter images: - 'individual' (default) : scale each filter so that it uses as much as possible of the dynamic range of the display under the constraint that 0 is gray and no value gets clipped - 'global' : scale the whole ensemble of weights - 'none' : don't rescale dataset : pylearn2.datasets.dataset.Dataset Dataset object to do view conversion for displaying the weights. If not provided one will be loaded from the model's dataset_yaml_src. Returns ------- WRITEME """ if model is None: logger.info('making weights report') logger.info('loading model') model = serial.load(model_path) logger.info('loading done') else: assert model_path is None assert model is not None if rescale == 'none': global_rescale = False patch_rescale = False elif rescale == 'global': global_rescale = True patch_rescale = False elif rescale == 'individual': global_rescale = False patch_rescale = True else: raise ValueError('rescale=' + rescale + ", must be 'none', 'global', or 'individual'") if isinstance(model, dict): #assume this was a saved matlab dictionary del model['__version__'] del model['__header__'] del model['__globals__'] keys = [key for key in model \ if hasattr(model[key], 'ndim') and model[key].ndim == 2] if len(keys) > 2: key = None while key not in keys: logger.info('Which is the weights?') for key in keys: logger.info('\t{0}'.format(key)) key = input() else: key, = keys weights = model[key] norms = np.sqrt(np.square(weights).sum(axis=1)) logger.info('min norm: {0}'.format(norms.min())) logger.info('mean norm: {0}'.format(norms.mean())) logger.info('max norm: {0}'.format(norms.max())) return patch_viewer.make_viewer(weights, is_color=weights.shape[1] % 3 == 0) weights_view = None W = None try: weights_view = model.get_weights_topo() h = weights_view.shape[0] except NotImplementedError: if dataset is None: logger.info('loading dataset...') control.push_load_data(False) dataset = yaml_parse.load(model.dataset_yaml_src) control.pop_load_data() logger.info('...done') try: W = model.get_weights() except AttributeError as e: reraise_as(AttributeError(""" Encountered an AttributeError while trying to call get_weights on a model. This probably means you need to implement get_weights for this model class, but look at the original exception to be sure. If this is an older model class, it may have weights stored as weightsShared, etc. Original exception: """+str(e))) if W is None and weights_view is None: raise ValueError("model doesn't support any weights interfaces") if weights_view is None: weights_format = model.get_weights_format() assert hasattr(weights_format,'__iter__') assert len(weights_format) == 2 assert weights_format[0] in ['v','h'] assert weights_format[1] in ['v','h'] assert weights_format[0] != weights_format[1] if weights_format[0] == 'v': W = W.T h = W.shape[0] if norm_sort: norms = np.sqrt(1e-8+np.square(W).sum(axis=1)) norm_prop = norms / norms.max() weights_view = dataset.get_weights_view(W) assert weights_view.shape[0] == h try: hr, hc = model.get_weights_view_shape() except NotImplementedError: hr = int(np.ceil(np.sqrt(h))) hc = hr if 'hidShape' in dir(model): hr, hc = model.hidShape pv = patch_viewer.PatchViewer(grid_shape=(hr, hc), patch_shape=weights_view.shape[1:3], is_color = weights_view.shape[-1] == 3) if global_rescale: weights_view /= np.abs(weights_view).max() if norm_sort: logger.info('sorting weights by decreasing norm') idx = sorted( range(h), key=lambda l : - norm_prop[l] ) else: idx = range(h) if border: act = 0 else: act = None for i in range(0,h): patch = weights_view[idx[i],...] pv.add_patch(patch, rescale=patch_rescale, activation=act) abs_weights = np.abs(weights_view) logger.info('smallest enc weight magnitude: {0}'.format(abs_weights.min())) logger.info('mean enc weight magnitude: {0}'.format(abs_weights.mean())) logger.info('max enc weight magnitude: {0}'.format(abs_weights.max())) if W is not None: norms = np.sqrt(np.square(W).sum(axis=1)) assert norms.shape == (h,) logger.info('min norm: {0}'.format(norms.min())) logger.info('mean norm: {0}'.format(norms.mean())) logger.info('max norm: {0}'.format(norms.max())) return pv
def test_bad_monitoring_input_in_monitor_based_lr(): # tests that the class MonitorBasedLRAdjuster in sgd.py avoids wrong # settings of channel_name or dataset_name in the constructor. dim = 3 m = 10 rng = np.random.RandomState([06, 02, 2014]) X = rng.randn(m, dim) learning_rate = 1e-2 batch_size = 5 # We need to include this so the test actually stops running at some point epoch_num = 2 dataset = DenseDesignMatrix(X=X) # including a monitoring datasets lets us test that # the monitor works with supervised data monitoring_dataset = DenseDesignMatrix(X=X) cost = DummyCost() model = SoftmaxModel(dim) termination_criterion = EpochCounter(epoch_num) algorithm = SGD(learning_rate, cost, batch_size=batch_size, monitoring_batches=2, monitoring_dataset=monitoring_dataset, termination_criterion=termination_criterion, update_callbacks=None, init_momentum=None, set_batch_size=False) # testing for bad dataset_name input dummy = 'void' monitor_lr = MonitorBasedLRAdjuster(dataset_name=dummy) train = Train(dataset, model, algorithm, save_path=None, save_freq=0, extensions=[monitor_lr]) try: train.main_loop() except ValueError as e: pass except Exception: reraise_as(AssertionError("MonitorBasedLRAdjuster takes dataset_name " "that is invalid ")) # testing for bad channel_name input monitor_lr2 = MonitorBasedLRAdjuster(channel_name=dummy) model2 = SoftmaxModel(dim) train2 = Train(dataset, model2, algorithm, save_path=None, save_freq=0, extensions=[monitor_lr2]) try: train2.main_loop() except ValueError as e: pass except Exception: reraise_as(AssertionError("MonitorBasedLRAdjuster takes channel_name " "that is invalid ")) return
def show(image): """ .. todo:: WRITEME Parameters ---------- image : PIL Image object or ndarray If ndarray, integer formats are assumed to use 0-255 and float formats are assumed to use 0-1 """ if hasattr(image, '__array__'): #do some shape checking because PIL just raises a tuple indexing error #that doesn't make it very clear what the problem is if len(image.shape) < 2 or len(image.shape) > 3: raise ValueError('image must have either 2 or 3 dimensions but its' ' shape is ' + str(image.shape)) if image.dtype == 'int8': image = np.cast['uint8'](image) elif str(image.dtype).startswith('float'): #don't use *=, we don't want to modify the input array image = image * 255. image = np.cast['uint8'](image) #PIL is too stupid to handle single-channel arrays if len(image.shape) == 3 and image.shape[2] == 1: image = image[:, :, 0] try: ensure_Image() image = Image.fromarray(image) except TypeError: reraise_as( TypeError("PIL issued TypeError on ndarray of shape " + str(image.shape) + " and dtype " + str(image.dtype))) # Create a temporary file with the suffix '.png'. fd, name = mkstemp(suffix='.png') os.close(fd) # Note: # Although we can use tempfile.NamedTemporaryFile() to create # a temporary file, the function should be used with care. # # In Python earlier than 2.7, a temporary file created by the # function will be deleted just after the file is closed. # We can re-use the name of the temporary file, but there is an # instant where a file with the name does not exist in the file # system before we re-use the name. This may cause a race # condition. # # In Python 2.7 or later, tempfile.NamedTemporaryFile() has # the 'delete' argument which can control whether a temporary # file will be automatically deleted or not. With the argument, # the above race condition can be avoided. # image.save(name) viewer_command = string.preprocess('${PYLEARN2_VIEWER_COMMAND}') if os.name == 'nt': subprocess.Popen(viewer_command + ' ' + name + ' && del ' + name, shell=True) else: subprocess.Popen(viewer_command + ' ' + name + ' ; rm ' + name, shell=True)
def update_wrapper(wrapper, wrapped, assigned=WRAPPER_ASSIGNMENTS, concatenated=WRAPPER_CONCATENATIONS, append=False, updated=WRAPPER_UPDATES, replace_before=None): """ A Python decorator which acts like `functools.update_wrapper` but also has the ability to concatenate attributes. Parameters ---------- wrapper : function Function to be updated wrapped : function Original function assigned : tuple, optional Tuple naming the attributes assigned directly from the wrapped function to the wrapper function. Defaults to `utils.WRAPPER_ASSIGNMENTS`. concatenated : tuple, optional Tuple naming the attributes from the wrapped function concatenated with the ones from the wrapper function. Defaults to `utils.WRAPPER_CONCATENATIONS`. append : bool, optional If True, appends wrapped attributes to wrapper attributes instead of prepending them. Defaults to False. updated : tuple, optional Tuple naming the attributes of the wrapper that are updated with the corresponding attribute from the wrapped function. Defaults to `functools.WRAPPER_UPDATES`. replace_before : str, optional If `append` is `False` (meaning we are prepending), delete docstring lines occurring before the first line equal to this string (the docstring line is stripped of leading/trailing whitespace before comparison). The newline of the line preceding this string is preserved. Returns ------- wrapper : function Updated wrapper function Notes ----- This can be used to concatenate the wrapper's docstring with the wrapped's docstring and should help reduce the ammount of documentation to write: one can use this decorator on child classes' functions when their implementation is similar to the one of the parent class. Conversely, if a function defined in a child class departs from its parent's implementation, one can simply explain the differences in a 'Notes' section without re-writing the whole docstring. """ assert not (append and replace_before), ("replace_before cannot " "be used with append") for attr in assigned: setattr(wrapper, attr, getattr(wrapped, attr)) for attr in concatenated: # Make sure attributes are not None if getattr(wrapped, attr) is None: setattr(wrapped, attr, "") if getattr(wrapper, attr) is None: setattr(wrapper, attr, "") if append: setattr(wrapper, attr, getattr(wrapped, attr) + getattr(wrapper, attr)) else: if replace_before: assert replace_before.strip() == replace_before, ( 'value for replace_before "%s" contains leading/' 'trailing whitespace' ) split = getattr(wrapped, attr).split("\n") # Potentially wasting time/memory by stripping everything # and duplicating it but probably not enough to worry about. split_stripped = [line.strip() for line in split] try: index = split_stripped.index(replace_before.strip()) except ValueError: reraise_as(ValueError('no line equal to "%s" in wrapped ' 'function\'s attribute %s' % (replace_before, attr))) wrapped_val = '\n' + '\n'.join(split[index:]) else: wrapped_val = getattr(wrapped, attr) setattr(wrapper, attr, getattr(wrapper, attr) + wrapped_val) for attr in updated: getattr(wrapper, attr).update(getattr(wrapped, attr, {})) # Return the wrapper so this can be used as a decorator via partial() return wrapper
def _load(filepath, recurse_depth=0, retry=True): """ Recursively tries to load a file until success or maximum number of attempts. Parameters ---------- filepath : str A path to a file to load. Should be a pickle, Matlab, or NumPy file; or a .txt or .amat file that numpy.loadtxt can load. recurse_depth : int, optional End users should not use this argument. It is used by the function itself to implement the `retry` option recursively. retry : bool, optional If True, will make a handful of attempts to load the file before giving up. This can be useful if you are for example calling show_weights.py on a file that is actively being written to by a training script--sometimes the load attempt might fail if the training script writes at the same time show_weights tries to read, but if you try again after a few seconds you should be able to open the file. Returns ------- loaded_object : object The object that was stored in the file. """ try: import joblib joblib_available = True except ImportError: joblib_available = False if recurse_depth == 0: filepath = preprocess(filepath) if filepath.endswith('.npy') or filepath.endswith('.npz'): return np.load(filepath) if filepath.endswith('.amat') or filepath.endswith('txt'): try: return np.loadtxt(filepath) except Exception: reraise_as("{0} cannot be loaded by serial.load (trying " "to use np.loadtxt)".format(filepath)) if filepath.endswith('.mat'): global io if io is None: import scipy.io io = scipy.io try: return io.loadmat(filepath) except NotImplementedError as nei: if str(nei).find('HDF reader') != -1: global hdf_reader if hdf_reader is None: import h5py hdf_reader = h5py return hdf_reader.File(filepath, 'r') else: raise # this code should never be reached assert False # for loading PY2 pickle in PY3 encoding = {'encoding': 'latin-1'} if six.PY3 else {} def exponential_backoff(): if recurse_depth > 9: logger.info('Max number of tries exceeded while trying to open ' '{0}'.format(filepath)) logger.info('attempting to open via reading string') with open(filepath, 'rb') as f: content = f.read() return cPickle.loads(content, **encoding) else: nsec = 0.5 * (2.0 ** float(recurse_depth)) logger.info("Waiting {0} seconds and trying again".format(nsec)) time.sleep(nsec) return _load(filepath, recurse_depth + 1, retry) try: if not joblib_available: with open(filepath, 'rb') as f: obj = cPickle.load(f, **encoding) else: try: obj = joblib.load(filepath) except Exception as e: if os.path.exists(filepath) and not os.path.isdir(filepath): raise raise_cannot_open(filepath) except MemoryError as e: # We want to explicitly catch this exception because for MemoryError # __str__ returns the empty string, so some of our default printouts # below don't make a lot of sense. # Also, a lot of users assume any exception is a bug in the library, # so we can cut down on mail to pylearn-users by adding a message # that makes it clear this exception is caused by their machine not # meeting requirements. if os.path.splitext(filepath)[1] == ".pkl": improve_memory_error_message(e, ("You do not have enough memory to " "open %s \n" " + Try using numpy.{save,load} " "(file with extension '.npy') " "to save your file. It uses less " "memory when reading and " "writing files than pickled files.") % filepath) else: improve_memory_error_message(e, "You do not have enough memory to " "open %s" % filepath) except (BadPickleGet, EOFError, KeyError) as e: if not retry: reraise_as(e.__class__('Failed to open {0}'.format(filepath))) obj = exponential_backoff() except ValueError: logger.exception if not retry: reraise_as(ValueError('Failed to open {0}'.format(filepath))) obj = exponential_backoff() except Exception: # assert False reraise_as("Couldn't open {0}".format(filepath)) # if the object has no yaml_src, we give it one that just says it # came from this file. could cause trouble if you save obj again # to a different location if not hasattr(obj, 'yaml_src'): try: obj.yaml_src = '!pkl: "' + os.path.abspath(filepath) + '"' except Exception: pass return obj
def _save(filepath, obj): """ .. todo:: WRITEME """ try: import joblib joblib_available = True except ImportError: joblib_available = False if filepath.endswith('.npy'): np.save(filepath, obj) return # This is dumb # assert filepath.endswith('.pkl') save_dir = os.path.dirname(filepath) # Handle current working directory case. if save_dir == '': save_dir = '.' if not os.path.exists(save_dir): os.makedirs(save_dir) if os.path.exists(save_dir) and not os.path.isdir(save_dir): raise IOError("save path %s exists, not a directory" % save_dir) elif not os.access(save_dir, os.W_OK): raise IOError("permission error creating %s" % filepath) try: if joblib_available and filepath.endswith('.joblib'): joblib.dump(obj, filepath) else: if filepath.endswith('.joblib'): warnings.warn('Warning: .joblib suffix specified but joblib ' 'unavailable. Using ordinary pickle.') with open(filepath, 'wb') as filehandle: cPickle.dump(obj, filehandle, get_pickle_protocol()) except Exception as e: logger.exception("cPickle has failed to write an object to " "{0}".format(filepath)) if str(e).find('maximum recursion depth exceeded') != -1: raise try: logger.info('retrying with pickle') with open(filepath, "wb") as f: pickle.dump(obj, f) except Exception as e2: if str(e) == '' and str(e2) == '': logger.exception('neither cPickle nor pickle could write to ' '{0}'.format(filepath)) logger.exception( 'moreover, neither of them raised an exception that ' 'can be converted to a string' ) logger.exception( 'now re-attempting to write with cPickle outside the ' 'try/catch loop so you can see if it prints anything ' 'when it dies' ) with open(filepath, 'wb') as f: cPickle.dump(obj, f, get_pickle_protocol()) logger.info('Somehow or other, the file write worked once ' 'we quit using the try/catch.') else: if str(e2) == 'env': raise import pdb tb = pdb.traceback.format_exc() reraise_as(IOError(str(obj) + ' could not be written to ' + str(filepath) + ' by cPickle due to ' + str(e) + ' nor by pickle due to ' + str(e2) + '. \nTraceback ' + tb)) logger.warning('{0} was written by pickle instead of cPickle, due to ' '{1} (perhaps your object' ' is really big?)'.format(filepath, e))
try: # Try to figure out what the wrong field name was # If we fail to do it, just fall back to giving the usual # attribute error pieces = tag_suffix.split('.') module = '.'.join(pieces[:-1]) field = pieces[-1] candidates = dir(eval(module)) msg = ('Could not evaluate %s. ' % tag_suffix + 'Did you mean ' + match(field, candidates) + '? ' + 'Original error was ' + str(e)) except Exception: warnings.warn("Attempt to decipher AttributeError failed") reraise_as(AttributeError('Could not evaluate %s. ' % tag_suffix + 'Original error was ' + str(e))) reraise_as(AttributeError(msg)) return obj def initialize(): """ Initialize the configuration system by installing YAML handlers. Automatically done on first call to load() specified in this file. """ global is_initialized # Add the custom multi-constructor yaml.add_multi_constructor('!obj:', multi_constructor_obj) yaml.add_multi_constructor('!pkl:', multi_constructor_pkl) yaml.add_multi_constructor('!import:', multi_constructor_import)
def __init__(self, dataset, subset_iterator, data_specs=None, return_tuple=False, convert=None): self._data_specs = data_specs self._dataset = dataset self._subset_iterator = subset_iterator self._return_tuple = return_tuple # Keep only the needed sources in self._raw_data. # Remember what source they correspond to in self._source assert is_flat_specs(data_specs) dataset_space, dataset_source = self._dataset.get_data_specs() assert is_flat_specs((dataset_space, dataset_source)) # the dataset's data spec is either a single (space, source) pair, # or a pair of (non-nested CompositeSpace, non-nested tuple). # We could build a mapping and call flatten(..., return_tuple=True) # but simply putting spaces, sources and data in tuples is simpler. if not isinstance(dataset_source, (tuple, list)): dataset_source = (dataset_source,) if not isinstance(dataset_space, CompositeSpace): dataset_sub_spaces = (dataset_space,) else: dataset_sub_spaces = dataset_space.components assert len(dataset_source) == len(dataset_sub_spaces) space, source = data_specs if not isinstance(source, tuple): source = (source,) if not isinstance(space, CompositeSpace): sub_spaces = (space,) else: sub_spaces = space.components assert len(source) == len(sub_spaces) # If `dataset` is incompatible with the new interface, fall back to the # old interface if not hasattr(self._dataset, 'get'): all_data = self._dataset.get_data() if not isinstance(all_data, tuple): all_data = (all_data,) raw_data = [] for s in source: try: raw_data.append(all_data[dataset_source.index(s)]) except ValueError as e: msg = str(e) + '\nThe dataset does not provide '\ 'a source with name: ' + s + '.' reraise_as(ValueError(msg)) self._raw_data = tuple(raw_data) self._source = source self._space = sub_spaces if convert is None: self._convert = [None for s in source] else: assert len(convert) == len(source) self._convert = convert for i, (so, sp) in enumerate(safe_izip(source, sub_spaces)): try: idx = dataset_source.index(so) except ValueError as e: msg = str(e) + '\nThe dataset does not provide '\ 'a source with name: ' + so + '.' reraise_as(ValueError(msg)) dspace = dataset_sub_spaces[idx] fn = self._convert[i] # If there is a fn, it is supposed to take care of the formatting, # and it should be an error if it does not. If there was no fn, # then the iterator will try to format using the generic # space-formatting functions. if fn is None: # "dspace" and "sp" have to be passed as parameters # to lambda, in order to capture their current value, # otherwise they would change in the next iteration # of the loop. fn = (lambda batch, dspace=dspace, sp=sp: dspace.np_format_as(batch, sp)) self._convert[i] = fn
""" TrainExtensions for doing random spatial windowing and flipping of an image dataset on every epoch. TODO: fill out properly.""" import warnings import numpy from . import TrainExtension from pylearn2.datasets.preprocessing import CentralWindow from pylearn2.utils.exc import reraise_as from pylearn2.utils.rng import make_np_rng try: from ..utils._window_flip import random_window_and_flip_c01b from ..utils._window_flip import random_window_and_flip_b01c except ImportError: reraise_as( ImportError("Import of Cython module failed. Please make sure " "you have run 'python setup.py develop' in the " "pylearn2 directory")) __authors__ = "David Warde-Farley" __copyright__ = "Copyright 2010-2012, Universite de Montreal" __credits__ = ["David Warde-Farley"] __license__ = "3-clause BSD" __maintainer__ = "David Warde-Farley" __email__ = "wardefar@iro" def _zero_pad(array, amount, axes=(1, 2)): """ Returns a copy of <array> with zero-filled padding around the margins. The new array has the same dimensions as the input array, except for
def raise_cannot_open(path): """ Raise an exception saying we can't open `path`. Parameters ---------- path : str The path we cannot open """ pieces = path.split('/') for i in xrange(1, len(pieces) + 1): so_far = '/'.join(pieces[0:i]) if not os.path.exists(so_far): if i == 1: if so_far == '': continue reraise_as(IOError('Cannot open ' + path + ' (' + so_far + ' does not exist)')) parent = '/'.join(pieces[0:i - 1]) bad = pieces[i - 1] if not os.path.isdir(parent): reraise_as(IOError("Cannot open " + path + " because " + parent + " is not a directory.")) candidates = os.listdir(parent) if len(candidates) == 0: reraise_as(IOError("Cannot open " + path + " because " + parent + " is empty.")) if len(candidates) > 100: # Don't attempt to guess the right name if the directory is # huge reraise_as(IOError("Cannot open " + path + " but can open " + parent + ".")) if os.path.islink(path): reraise_as(IOError(path + " appears to be a symlink to a " "non-existent file")) reraise_as(IOError("Cannot open " + path + " but can open " + parent + ". Did you mean " + match(bad, candidates) + " instead of " + bad + "?")) # end if # end for assert False
def __init__(self, dataset, subset_iterator, data_specs=None, return_tuple=False, convert=None): self._data_specs = data_specs self._dataset = dataset self._subset_iterator = subset_iterator self._return_tuple = return_tuple # Keep only the needed sources in self._raw_data. # Remember what source they correspond to in self._source assert is_flat_specs(data_specs) dataset_space, dataset_source = self._dataset.get_data_specs() assert is_flat_specs((dataset_space, dataset_source)) # the dataset's data spec is either a single (space, source) pair, # or a pair of (non-nested CompositeSpace, non-nested tuple). # We could build a mapping and call flatten(..., return_tuple=True) # but simply putting spaces, sources and data in tuples is simpler. if not isinstance(dataset_source, (tuple, list)): dataset_source = (dataset_source, ) if not isinstance(dataset_space, CompositeSpace): dataset_sub_spaces = (dataset_space, ) else: dataset_sub_spaces = dataset_space.components assert len(dataset_source) == len(dataset_sub_spaces) space, source = data_specs if not isinstance(source, tuple): source = (source, ) if not isinstance(space, CompositeSpace): sub_spaces = (space, ) else: sub_spaces = space.components assert len(source) == len(sub_spaces) # If `dataset` is incompatible with the new interface, fall back to the # old interface if not hasattr(self._dataset, 'get'): all_data = self._dataset.get_data() if not isinstance(all_data, tuple): all_data = (all_data, ) raw_data = [] for s in source: try: raw_data.append(all_data[dataset_source.index(s)]) except ValueError as e: msg = str(e) + '\nThe dataset does not provide '\ 'a source with name: ' + s + '.' reraise_as(ValueError(msg)) self._raw_data = tuple(raw_data) self._source = source self._space = sub_spaces if convert is None: self._convert = [None for s in source] else: assert len(convert) == len(source) self._convert = convert for i, (so, sp) in enumerate(safe_izip(source, sub_spaces)): try: idx = dataset_source.index(so) except ValueError as e: msg = str(e) + '\nThe dataset does not provide '\ 'a source with name: ' + so + '.' reraise_as(ValueError(msg)) dspace = dataset_sub_spaces[idx] fn = self._convert[i] # If there is a fn, it is supposed to take care of the formatting, # and it should be an error if it does not. If there was no fn, # then the iterator will try to format using the generic # space-formatting functions. if fn is None: # "dspace" and "sp" have to be passed as parameters # to lambda, in order to capture their current value, # otherwise they would change in the next iteration # of the loop. fn = (lambda batch, dspace=dspace, sp=sp: dspace.np_format_as( batch, sp)) self._convert[i] = fn