示例#1
0
 def run(self):
     print("Because Pylearn2 is under heavy development, we generally do "
           "not advice using the `setup.py install` command. Please "
           "consider using the `setup.py develop` command instead for the "
           "following reasons:\n\n1. Using `setup.py install` creates a "
           "copy of the Pylearn2 source code in your Python installation "
           "path. In order to update Pylearn2 afterwards you will need to "
           "rerun `setup.py install` (!). Simply using `git pull` to "
           "update your local copy of Pylearn2 code will not suffice. \n\n"
           "2. When using `sudo` to install Pylearn2, all files, "
           "including the tutorials, will be copied to a directory owned "
           "by root. Not only is running tutorials as root unsafe, it "
           "also means that all Pylearn2-related environment variables "
           "which were defined for the user will be unavailable.\n\n"
           "Pressing enter will continue the installation of Pylearn2 in "
           "`develop` mode instead. Note that this means that you need to "
           "keep this folder with the Pylearn2 code in its current "
           "location. If you know what you are doing, and are very sure "
           "that you want to install Pylearn2 using the `install` "
           "command instead, please type `install`.\n")
     mode = None
     while mode not in ['', 'install', 'develop', 'cancel']:
         if mode is not None:
             print("Please try again")
         mode = input("Installation mode: [develop]/install/cancel: ")
     if mode in ['', 'develop']:
         self.distribution.run_command('develop')
     if mode == 'install':
         return install.run(self)
示例#2
0
def get_choice(choice_to_explanation):
    """
    .. todo::

        WRITEME

    Parameters
    ----------
    choice_to_explanation : dict
        Dictionary mapping possible user responses to strings describing
        what that response will cause the script to do

    Returns
    -------
    WRITEME
    """
    d = choice_to_explanation

    for key in d:
        logger.info('\t{0}: {1}'.format(key, d[key]))
    prompt = '/'.join(d.keys())+'? '

    first = True
    choice = ''
    while first or choice not in d.keys():
        if not first:
            warnings.warn('unrecognized choice')
        first = False
        choice = input(prompt)
    return choice
示例#3
0
 def run(self):
     print("Because Pylearn2 is under heavy development, we generally do "
           "not advice using the `setup.py install` command. Please "
           "consider using the `setup.py develop` command instead for the "
           "following reasons:\n\n1. Using `setup.py install` creates a "
           "copy of the Pylearn2 source code in your Python installation "
           "path. In order to update Pylearn2 afterwards you will need to "
           "rerun `setup.py install` (!). Simply using `git pull` to "
           "update your local copy of Pylearn2 code will not suffice. \n\n"
           "2. When using `sudo` to install Pylearn2, all files, "
           "including the tutorials, will be copied to a directory owned "
           "by root. Not only is running tutorials as root unsafe, it "
           "also means that all Pylearn2-related environment variables "
           "which were defined for the user will be unavailable.\n\n"
           "Pressing enter will continue the installation of Pylearn2 in "
           "`develop` mode instead. Note that this means that you need to "
           "keep this folder with the Pylearn2 code in its current "
           "location. If you know what you are doing, and are very sure "
           "that you want to install Pylearn2 using the `install` "
           "command instead, please type `install`.\n")
     mode = None
     while mode not in ['', 'install', 'develop', 'cancel']:
         if mode is not None:
             print("Please try again")
         mode = input("Installation mode: [develop]/install/cancel: ")
     if mode in ['', 'develop']:
         self.distribution.run_command('develop')
     if mode == 'install':
         return install.run(self)
示例#4
0
def get_choice(choice_to_explanation):
    """
    .. todo::

        WRITEME

    Parameters
    ----------
    choice_to_explanation : dict
        Dictionary mapping possible user responses to strings describing
        what that response will cause the script to do

    Returns
    -------
    WRITEME
    """
    d = choice_to_explanation

    for key in d:
        logger.info('\t{0}: {1}'.format(key, d[key]))
    prompt = '/'.join(d.keys()) + '? '

    first = True
    choice = ''
    while first or choice not in d.keys():
        if not first:
            warnings.warn('unrecognized choice')
        first = False
        choice = input(prompt)
    return choice
示例#5
0
def remove_packages( packages_to_remove ):
    """
    Uninstall packages, whether or not they
    are found in the source.lst (so it can
    remove datasets installed from file).

    :param packages_to_remove: list of package names
    :raises: IOErrors
    """

    if packages_to_remove==[]:
        raise RuntimeError("[rm] fatal: need packages names to remove.")

    packages_really_to_remove=[]
    for this_package in packages_to_remove:
        if this_package in packages_sources:

            #this_data_set_location=os.path.join(dataset_data_path,this_package)

            # check if in the installed.lst
            # then if directory actually exists
            # then if you have rights to remove it
            if this_package in installed_packages_list:

                this_data_set_location=os.path.join( installed_packages_list[this_package].where,
                                                     installed_packages_list[this_package].name )

                if os.path.exists(this_data_set_location):
                    if (file_access_rights(this_data_set_location,os.W_OK)):
                        # ok, you may have rights to delete it
                        packages_really_to_remove.append(this_package)
                    else:
                        logger.warning("[rm] insufficient rights "
                                       "to remove '{0}'".format(this_package))
                else:
                    logger.warning("[rm] package '{0}' found in config file "
                                   "but not installed".format(this_package))
            else:
                logger.warning("[rm] package '{0}' "
                               "not installed".format(this_package))
        else:
            logger.warning("[rm] unknown package '{0}'".format(this_package))

    if packages_really_to_remove!=[]:
        logger.info("[rm] the following packages will be removed permanently:")
        packages = []
        for this_package in packages_really_to_remove:
            packages.append(this_package)
        logger.info(' '.join(packages))

        r = input("Proceed? [yes/N] ")
        if r=='y' or r=='yes':
            for  this_package in packages_really_to_remove:
                remove_package( installed_packages_list[this_package], dataset_data_path )
        else:
            logger.info("[up] Taking '{0}' for no, so there.".format(r))
    else:
        # ok, nothing to remove, filenames where bad.
        pass
示例#6
0
def install_packages_from_file(packages_to_install):
    """
    (Force)Installs packages from files, but does
    not update installed.lst files.

    caveat: not as tested as everything else.

    :param packages_to_install: list of files to install
    :raises: IOErrors
    """
    if packages_to_install == []:
        raise RuntimeError("[in] fatal: need packages names to install.")

    packages_really_to_install = []
    for this_package in packages_to_install:
        if os.path.exists(this_package):
            packages_really_to_install.append(this_package)
        else:
            logger.warning("[in] package '{0}' not found".format(this_package))

    if packages_really_to_install != []:
        logger.info("[in] The following package(s) will be installed:")
        packages = []
        for this_package in packages_really_to_install:
            packages.append(corename(this_package))
        logger.info(' '.join(packages))

        r = input("Proceed? [yes/N] ")
        if r == 'y' or r == 'yes':
            for this_package in packages_really_to_install:
                #install_upgrade( this_package, upgrade=False, progress_hook=hook )
                if os.path.exists(dataset_data_path + corename(this_package)):
                    r = input(
                        "[in] '%s' already installed, overwrite? [yes/N] " %
                        corename(this_package))

                    if r != 'y' and r != 'yes':
                        logger.info("[in] skipping package "
                                    "'{0}'".format(corename(this_package)))
                        continue
                install_package(corename(this_package), this_package,
                                dataset_data_path)
                #update_installed_list("i",(make a package object here),dataset_data_path)

        else:
            logger.info("[in] Taking '{0}' for no, so there.".format(r))
示例#7
0
def install_packages_from_file( packages_to_install ):
    """
    (Force)Installs packages from files, but does
    not update installed.lst files.

    caveat: not as tested as everything else.

    :param packages_to_install: list of files to install
    :raises: IOErrors
    """
    if packages_to_install==[]:
        raise RuntimeError("[in] fatal: need packages names to install.")

    packages_really_to_install=[]
    for this_package in packages_to_install:
        if os.path.exists(this_package):
            packages_really_to_install.append(this_package)
        else:
            logger.warning("[in] package '{0}' not found".format(this_package))

    if packages_really_to_install!=[]:
        logger.info("[in] The following package(s) will be installed:")
        packages = []
        for this_package in packages_really_to_install:
            packages.append(corename(this_package))
        logger.info(' '.join(packages))

        r = input("Proceed? [yes/N] ")
        if r=='y' or r=='yes':
            for  this_package in packages_really_to_install:
                #install_upgrade( this_package, upgrade=False, progress_hook=hook )
                if os.path.exists(dataset_data_path+corename(this_package)):
                    r = input("[in] '%s' already installed, overwrite? [yes/N] " % corename(this_package))

                    if r!='y' and r!='yes':
                        logger.info("[in] skipping package "
                                    "'{0}'".format(corename(this_package)))
                        continue
                install_package( corename(this_package), this_package, dataset_data_path)
                #update_installed_list("i",(make a package object here),dataset_data_path)

        else:
            logger.info("[in] Taking '{0}' for no, so there.".format(r))
def show_reconstructions(m, model_path):
    """
    Show reconstructions of a given DBM model.

    Parameters
    ----------
    m: int
        rows * cols
    model_path: str
        Path of the model.
    """
    model = load_model(model_path, m)

    x = input('use test set? (y/n) ')
    dataset = load_dataset(model.dataset_yaml_src, x)
    vis_batch = dataset.get_batch_topo(m)
    pv = init_viewer(dataset, rows, cols, vis_batch)

    batch = model.visible_layer.space.make_theano_batch()
    reconstruction = model.reconstruct(batch)
    recons_func = function([batch], reconstruction)

    if hasattr(model.visible_layer, 'beta'):
        beta = model.visible_layer.beta.get_value()
        print('beta: ', (beta.min(), beta.mean(), beta.max()))

    while True:
        update_viewer(dataset, batch, rows, cols, pv, recons_func, vis_batch)
        pv.show()
        print('Displaying reconstructions. (q to quit, ENTER = show more)')
        while True:
            x = input()
            if x == 'q':
                quit()
            if x == '':
                x = 1
                break
            else:
                print('Invalid input, try again')

        vis_batch = dataset.get_batch_topo(m)
def show_reconstructions(m, model_path):
    """
    Show reconstructions of a given DBM model.

    Parameters
    ----------
    m: int
        rows * cols
    model_path: str
        Path of the model.
    """
    model = load_model(model_path, m)

    x = input('use test set? (y/n) ')
    dataset = load_dataset(model.dataset_yaml_src, x)
    vis_batch = dataset.get_batch_topo(m)
    pv = init_viewer(dataset, rows, cols, vis_batch)

    batch = model.visible_layer.space.make_theano_batch()
    reconstruction = model.reconstruct(batch)
    recons_func = function([batch], reconstruction)

    if hasattr(model.visible_layer, 'beta'):
        beta = model.visible_layer.beta.get_value()
        print('beta: ', (beta.min(), beta.mean(), beta.max()))

    while True:
        update_viewer(dataset, batch, rows, cols, pv, recons_func, vis_batch)
        pv.show()
        print('Displaying reconstructions. (q to quit, ENTER = show more)')
        while True:
            x = input()
            if x == 'q':
                quit()
            if x == '':
                x = 1
                break
            else:
                print('Invalid input, try again')

        vis_batch = dataset.get_batch_topo(m)
示例#10
0
def install_packages(packages_to_install, force_install=False, hook=None):
    """
    Installs the packages, possibly forcing installs.

    :param packages_to_install: list of package names
    :param force_install: if True, re-installs even if installed.
    :param hook: download progress hook
    :raises: IOErrors
    """

    if packages_to_install == []:
        raise RuntimeError("[in] fatal: need packages names to install.")

    if force_install:
        logger.warning("[in] using the force")

    packages_really_to_install = []
    for this_package in packages_to_install:
        if this_package in packages_sources:

            if force_install or not this_package in installed_packages_list:
                packages_really_to_install.append(this_package)
            else:
                logger.warning("[in] package '{0}' "
                               "is already installed".format(this_package))
        else:
            logger.warning("[in] unknown package '{0}'".format(this_package))

    if packages_really_to_install != []:
        logger.info("[in] The following package(s) will be installed:")
        for this_package in packages_really_to_install:
            readable_size = packages_sources[this_package].readable_size
            logger.info("{0} ({1})".format(this_package, readable_size))

        r = input("Proceed? [yes/N] ")
        if r == 'y' or r == 'yes':
            for this_package in packages_really_to_install:
                install_upgrade(packages_sources[this_package],
                                upgrade=False,
                                progress_hook=hook)
        else:
            logger.info("[in] Taking '{0}' for no, so there.".format(r))
    else:
        # ok, nothing to upgrade,
        # move along.
        pass
示例#11
0
def install_packages( packages_to_install, force_install=False, hook=None ):
    """
    Installs the packages, possibly forcing installs.

    :param packages_to_install: list of package names
    :param force_install: if True, re-installs even if installed.
    :param hook: download progress hook
    :raises: IOErrors
    """

    if packages_to_install==[]:
        raise RuntimeError("[in] fatal: need packages names to install.")

    if force_install:
        logger.warning("[in] using the force")

    packages_really_to_install=[]
    for this_package in packages_to_install:
        if this_package in packages_sources:

            if force_install or not this_package in installed_packages_list:
                packages_really_to_install.append(this_package)
            else:
                logger.warning("[in] package '{0}' "
                               "is already installed".format(this_package))
        else:
            logger.warning("[in] unknown package '{0}'".format(this_package))

    if packages_really_to_install!=[]:
        logger.info("[in] The following package(s) will be installed:")
        for this_package in packages_really_to_install:
            readable_size = packages_sources[this_package].readable_size
            logger.info("{0} ({1})".format(this_package, readable_size))

        r = input("Proceed? [yes/N] ")
        if r=='y' or r=='yes':
            for  this_package in packages_really_to_install:
                install_upgrade( packages_sources[this_package], upgrade=False, progress_hook=hook )
        else:
            logger.info("[in] Taking '{0}' for no, so there.".format(r))
    else:
        # ok, nothing to upgrade,
        # move along.
        pass
示例#12
0
def create_archive(source, archive_name):

    if os.path.exists(archive_name):
        r = input("'%s' exists, overwrite? [yes/N] " % archive_name)
        if (r != "y") and (r != "yes"):
            logger.info("taking '{0}' for no, so there.".format(r))
            #bail out
            return

    try:
        tar = tarfile.open(archive_name, mode="w:bz2")
    except Exception as e:
        logger.exception(e)
        return
    else:
        for root, dirs, files in os.walk(source):
            for filename in files:
                this_file = os.path.join(root, filename)
                logger.info("adding '{0}'".format(this_file))
                tar.add(this_file)
        tar.close()
示例#13
0
def create_archive( source, archive_name ):

    if os.path.exists(archive_name):
        r = input("'%s' exists, overwrite? [yes/N] " % archive_name)
        if (r!="y") and (r!="yes"):
            logger.info("taking '{0}' for no, so there.".format(r))
            #bail out
            return

    try:
        tar=tarfile.open(archive_name,mode="w:bz2")
    except Exception as e:
        logger.exception(e)
        return
    else:
        for root, dirs, files in os.walk(source):
            for filename in files:
                this_file = os.path.join(root,filename)
                logger.info("adding '{0}'".format(this_file))
                tar.add(this_file)
        tar.close()
cols = 10
m = rows * cols

_, model_path = sys.argv

print('Loading model...')
model = serial.load(model_path)
model.set_batch_size(m)


dataset_yaml_src = model.dataset_yaml_src

print('Loading data...')
dataset = yaml_parse.load(dataset_yaml_src)

x = input('use test set? (y/n) ')

if x == 'y':
    dataset = dataset.get_test_set()
else:
    assert x == 'n'

vis_batch = dataset.get_batch_topo(m)

_, patch_rows, patch_cols, channels = vis_batch.shape

assert _ == m

mapback = hasattr(dataset, 'mapback_for_viewer')

actual_cols = 2 * cols * (1 + mapback) * (1 + (channels == 2))
示例#15
0
pv = PatchViewer((1, 2), (r, c), is_color=False)

i = 0
while True:
    patch = topo[i, :, :, :]
    patch = patch / np.abs(patch).max()

    pv.add_patch(patch[:, :, 1], rescale=False)
    pv.add_patch(patch[:, :, 0], rescale=False)

    pv.show()

    print(dataset.y[i])

    choices = {'g': 'goto image', 'q': 'quit'}

    if i + 1 < b:
        choices['n'] = 'next image'

    choice = get_choice(choices)

    if choice == 'q':
        quit()

    if choice == 'n':
        i += 1

    if choice == 'g':
        i = int(input('index: '))
示例#16
0
def upgrade_packages(packages_to_upgrade, hook=None ):
    """
    Upgrades packages.

    If no packages are supplied, it will perform
    an "update-all" operation, finding all packages
    that are out of date.

    If packages names are supplied, only those
    are checked for upgrade (and upgraded if out
    of date)

    :param packages_to_upgrade: list of package names.
    :raises: IOErrors (from downloads/rights)
    """

    # get names only
    if packages_to_upgrade==[]:
        packages_to_upgrade=installed_packages_list.keys() # all installed!
        all_packages=True
    else:
        all_packages=False

    # check what packages are in the list,
    # and really to be upgraded.
    #
    packages_really_to_upgrade=[]
    for this_package in packages_to_upgrade:
        if this_package in installed_packages_list:

            # check if there's a date
            installed_date=installed_packages_list[this_package].timestamp

            if this_package in packages_sources:
                repo_date=packages_sources[this_package].timestamp

                if installed_date < repo_date:
                    # ok, there's a newer version
                    logger.info(this_package)
                    packages_really_to_upgrade.append(this_package)
                else:
                    # no newer version, nothing to update
                    pass
            else:
                logger.warning("[up] '{0}' is unknown "
                               "(installed from file?).".format(this_package))
        else:
            # not installed?
            if not all_packages:
                logger.warning("[up] '{0}' is not installed, "
                               "cannot upgrade.".format(this_package))
                pass


    # once we have determined which packages
    # are to be updated, we show them to the
    # user for him to confirm
    #
    if packages_really_to_upgrade!=[]:
        logger.info("[up] the following package(s) will be upgraded:")
        for this_package in packages_really_to_upgrade:
            readable_size = packages_sources[this_package].readable_size
            logger.info("{0} ({1})".format(this_package, readable_size))

        r = input("Proceed? [yes/N] ")
        if r=='y' or r=='yes':
            for  this_package in packages_really_to_upgrade:
                install_upgrade( packages_sources[this_package], upgrade=True, progress_hook=hook )
        else:
            logger.info("[up] Taking '{0}' for no, so there.".format(r))
    else:
        # ok, nothing to upgrade,
        # move along.
        pass
示例#17
0
def get_weights_report(model_path=None,
                       model=None,
                       rescale='individual',
                       border=False,
                       norm_sort=False,
                       dataset=None):
    """
    Returns a PatchViewer displaying a grid of filter weights

    Parameters
    ----------
    model_path : str
        Filepath of the model to make the report on.
    rescale : str
        A string specifying how to rescale the filter images:
            - 'individual' (default) : scale each filter so that it
                  uses as much as possible of the dynamic range
                  of the display under the constraint that 0
                  is gray and no value gets clipped
            - 'global' : scale the whole ensemble of weights
            - 'none' :   don't rescale
    dataset : pylearn2.datasets.dataset.Dataset
        Dataset object to do view conversion for displaying the weights. If
        not provided one will be loaded from the model's dataset_yaml_src.

    Returns
    -------
    WRITEME
    """

    if model is None:
        logger.info('making weights report')
        logger.info('loading model')
        model = serial.load(model_path)
        logger.info('loading done')
    else:
        assert model_path is None
    assert model is not None

    if rescale == 'none':
        global_rescale = False
        patch_rescale = False
    elif rescale == 'global':
        global_rescale = True
        patch_rescale = False
    elif rescale == 'individual':
        global_rescale = False
        patch_rescale = True
    else:
        raise ValueError('rescale=' + rescale +
                         ", must be 'none', 'global', or 'individual'")


    if isinstance(model, dict):
        #assume this was a saved matlab dictionary
        del model['__version__']
        del model['__header__']
        del model['__globals__']
        keys = [key for key in model \
                if hasattr(model[key], 'ndim') and model[key].ndim == 2]
        if len(keys) > 2:
            key = None
            while key not in keys:
                logger.info('Which is the weights?')
                for key in keys:
                    logger.info('\t{0}'.format(key))
                key = input()
        else:
            key, = keys
        weights = model[key]

        norms = np.sqrt(np.square(weights).sum(axis=1))
        logger.info('min norm: {0}'.format(norms.min()))
        logger.info('mean norm: {0}'.format(norms.mean()))
        logger.info('max norm: {0}'.format(norms.max()))

        return patch_viewer.make_viewer(weights,
                                        is_color=weights.shape[1] % 3 == 0)

    weights_view = None
    W = None

    try:
        weights_view = model.get_weights_topo()
        h = weights_view.shape[0]
    except NotImplementedError:

        if dataset is None:
            logger.info('loading dataset...')
            control.push_load_data(False)
            dataset = yaml_parse.load(model.dataset_yaml_src)
            control.pop_load_data()
            logger.info('...done')

        try:
            W = model.get_weights()
        except AttributeError as e:
            reraise_as(AttributeError("""
Encountered an AttributeError while trying to call get_weights on a model.
This probably means you need to implement get_weights for this model class,
but look at the original exception to be sure.
If this is an older model class, it may have weights stored as weightsShared,
etc.
Original exception: """+str(e)))

    if W is None and weights_view is None:
        raise ValueError("model doesn't support any weights interfaces")

    if weights_view is None:
        weights_format = model.get_weights_format()
        assert hasattr(weights_format,'__iter__')
        assert len(weights_format) == 2
        assert weights_format[0] in ['v','h']
        assert weights_format[1] in ['v','h']
        assert weights_format[0] != weights_format[1]

        if weights_format[0] == 'v':
            W = W.T
        h = W.shape[0]

        if norm_sort:
            norms = np.sqrt(1e-8+np.square(W).sum(axis=1))
            norm_prop = norms / norms.max()


        weights_view = dataset.get_weights_view(W)
        assert weights_view.shape[0] == h
    try:
        hr, hc = model.get_weights_view_shape()
    except NotImplementedError:
        hr = int(np.ceil(np.sqrt(h)))
        hc = hr
        if 'hidShape' in dir(model):
            hr, hc = model.hidShape

    pv = patch_viewer.PatchViewer(grid_shape=(hr, hc),
                                  patch_shape=weights_view.shape[1:3],
            is_color = weights_view.shape[-1] == 3)

    if global_rescale:
        weights_view /= np.abs(weights_view).max()

    if norm_sort:
        logger.info('sorting weights by decreasing norm')
        idx = sorted( range(h), key=lambda l : - norm_prop[l] )
    else:
        idx = range(h)

    if border:
        act = 0
    else:
        act = None

    for i in range(0,h):
        patch = weights_view[idx[i],...]
        pv.add_patch(patch, rescale=patch_rescale, activation=act)

    abs_weights = np.abs(weights_view)
    logger.info('smallest enc weight magnitude: {0}'.format(abs_weights.min()))
    logger.info('mean enc weight magnitude: {0}'.format(abs_weights.mean()))
    logger.info('max enc weight magnitude: {0}'.format(abs_weights.max()))


    if W is not None:
        norms = np.sqrt(np.square(W).sum(axis=1))
        assert norms.shape == (h,)
        logger.info('min norm: {0}'.format(norms.min()))
        logger.info('mean norm: {0}'.format(norms.mean()))
        logger.info('max norm: {0}'.format(norms.max()))

    return pv
示例#18
0
def main():
    """
    .. todo::

        WRITEME
    """
    parser = argparse.ArgumentParser()
    parser.add_argument("--out")
    parser.add_argument("model_paths", nargs='+')
    parser.add_argument("--yrange",
                        help='The y-range to be used for plotting, e.g.  0:1')

    options = parser.parse_args()
    model_paths = options.model_paths

    if options.out is not None:
        import matplotlib
        matplotlib.use('Agg')
    import matplotlib.pyplot as plt

    print('generating names...')
    model_names = [
        model_path.replace('.pkl', '!') for model_path in model_paths
    ]
    model_names = unique_substrings(model_names, min_size=10)
    model_names = [model_name.replace('!', '') for model_name in model_names]
    print('...done')

    for i, arg in enumerate(model_paths):
        try:
            model = serial.load(arg)
        except Exception:
            if arg.endswith('.yaml'):
                print(sys.stderr,
                      arg + " is a yaml config file," +
                      "you need to load a trained model.",
                      file=sys.stderr)
                quit(-1)
            raise
        this_model_channels = model.monitor.channels

        if len(sys.argv) > 2:
            postfix = ":" + model_names[i]
        else:
            postfix = ""

        for channel in this_model_channels:
            channels[channel + postfix] = this_model_channels[channel]
        del model
        gc.collect()

    while True:
        # Make a list of short codes for each channel so user can specify them
        # easily
        tag_generator = _TagGenerator()
        codebook = {}
        sorted_codes = []
        for channel_name in sorted(channels,
                                   key=number_aware_alphabetical_key):
            code = tag_generator.get_tag()
            codebook[code] = channel_name
            codebook['<' + channel_name + '>'] = channel_name
            sorted_codes.append(code)

        x_axis = 'example'
        print('set x_axis to example')

        if len(channels.values()) == 0:
            print("there are no channels to plot")
            break

        # If there is more than one channel in the monitor ask which ones to
        # plot
        prompt = len(channels.values()) > 1

        if prompt:

            # Display the codebook
            for code in sorted_codes:
                print(code + '. ' + codebook[code])

            print()

            print("Put e, b, s or h in the list somewhere to plot " +
                  "epochs, batches, seconds, or hours, respectively.")
            response = input('Enter a list of channels to plot ' + \
                    '(example: A, C,F-G, h, <test_err>) or q to quit' + \
                    ' or o for options: ')

            if response == 'o':
                print('1: smooth all channels')
                print('any other response: do nothing, go back to plotting')
                response = input('Enter your choice: ')
                if response == '1':
                    for channel in channels.values():
                        k = 5
                        new_val_record = []
                        for i in xrange(len(channel.val_record)):
                            new_val = 0.
                            count = 0.
                            for j in xrange(max(0, i - k), i + 1):
                                new_val += channel.val_record[j]
                                count += 1.
                            new_val_record.append(new_val / count)
                        channel.val_record = new_val_record
                continue

            if response == 'q':
                break

            #Remove spaces
            response = response.replace(' ', '')

            #Split into list
            codes = response.split(',')

            final_codes = set([])

            for code in codes:
                if code == 'e':
                    x_axis = 'epoch'
                    continue
                elif code == 'b':
                    x_axis = 'batche'
                elif code == 's':
                    x_axis = 'second'
                elif code == 'h':
                    x_axis = 'hour'
                elif code.startswith('<'):
                    assert code.endswith('>')
                    final_codes.add(code)
                elif code.find('-') != -1:
                    #The current list element is a range of codes

                    rng = code.split('-')

                    if len(rng) != 2:
                        print("Input not understood: " + code)
                        quit(-1)

                    found = False
                    for i in xrange(len(sorted_codes)):
                        if sorted_codes[i] == rng[0]:
                            found = True
                            break

                    if not found:
                        print("Invalid code: " + rng[0])
                        quit(-1)

                    found = False
                    for j in xrange(i, len(sorted_codes)):
                        if sorted_codes[j] == rng[1]:
                            found = True
                            break

                    if not found:
                        print("Invalid code: " + rng[1])
                        quit(-1)

                    final_codes = final_codes.union(set(sorted_codes[i:j + 1]))
                else:
                    #The current list element is just a single code
                    final_codes = final_codes.union(set([code]))
            # end for code in codes
        else:
            final_codes, = set(codebook.keys())

        colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k']
        styles = list(colors)
        styles += [color + '--' for color in colors]
        styles += [color + ':' for color in colors]

        fig = plt.figure()
        ax = plt.subplot(1, 1, 1)

        # plot the requested channels
        for idx, code in enumerate(sorted(final_codes)):

            channel_name = codebook[code]
            channel = channels[channel_name]

            y = np.asarray(channel.val_record)

            if contains_nan(y):
                print(channel_name + ' contains NaNs')

            if contains_inf(y):
                print(channel_name + 'contains infinite values')

            if x_axis == 'example':
                x = np.asarray(channel.example_record)
            elif x_axis == 'batche':
                x = np.asarray(channel.batch_record)
            elif x_axis == 'epoch':
                try:
                    x = np.asarray(channel.epoch_record)
                except AttributeError:
                    # older saved monitors won't have epoch_record
                    x = np.arange(len(channel.batch_record))
            elif x_axis == 'second':
                x = np.asarray(channel.time_record)
            elif x_axis == 'hour':
                x = np.asarray(channel.time_record) / 3600.
            else:
                assert False

            ax.plot(
                x,
                y,
                styles[idx % len(styles)],
                marker='.',  # add point margers to lines
                label=channel_name)

        plt.xlabel('# ' + x_axis + 's')
        ax.ticklabel_format(scilimits=(-3, 3), axis='both')

        handles, labels = ax.get_legend_handles_labels()
        lgd = ax.legend(handles,
                        labels,
                        loc='upper center',
                        bbox_to_anchor=(0.5, -0.1))
        # 0.046 is the size of 1 legend box
        fig.subplots_adjust(bottom=0.11 + 0.046 * len(final_codes))

        if (options.yrange is not None):
            ymin, ymax = map(float, options.yrange.split(':'))
            plt.ylim(ymin, ymax)

        if options.out is None:
            plt.show()
        else:
            plt.savefig(options.out)

        if not prompt:
            break
示例#19
0
def main():
    """
    .. todo::

        WRITEME
    """
    parser = argparse.ArgumentParser()
    parser.add_argument("--out")
    parser.add_argument("model_paths", nargs='+')
    options = parser.parse_args()
    model_paths = options.model_paths

    if options.out is not None:
      import matplotlib
      matplotlib.use('Agg')
    import matplotlib.pyplot as plt

    print('generating names...')
    model_names = [model_path.replace('.pkl', '!') for model_path in
            model_paths]
    model_names = unique_substrings(model_names, min_size=10)
    model_names = [model_name.replace('!','') for model_name in
            model_names]
    print('...done')

    for i, arg in enumerate(model_paths):
        try:
            model = serial.load(arg)
        except Exception:
            if arg.endswith('.yaml'):
                print(sys.stderr, arg + " is a yaml config file," + 
                      "you need to load a trained model.", file=sys.stderr)
                quit(-1)
            raise
        this_model_channels = model.monitor.channels

        if len(sys.argv) > 2:
            postfix = ":" + model_names[i]
        else:
            postfix = ""

        for channel in this_model_channels:
            channels[channel+postfix] = this_model_channels[channel]
        del model
        gc.collect()


    while True:
        # Make a list of short codes for each channel so user can specify them
        # easily
        tag_generator = _TagGenerator()
        codebook = {}
        sorted_codes = []
        for channel_name in sorted(channels,
                key = number_aware_alphabetical_key):
            code = tag_generator.get_tag()
            codebook[code] = channel_name
            codebook['<'+channel_name+'>'] = channel_name
            sorted_codes.append(code)

        x_axis = 'example'
        print('set x_axis to example')

        if len(channels.values()) == 0:
            print("there are no channels to plot")
            break

        # If there is more than one channel in the monitor ask which ones to
        # plot
        prompt = len(channels.values()) > 1

        if prompt:

            # Display the codebook
            for code in sorted_codes:
                print(code + '. ' + codebook[code])

            print()

            print("Put e, b, s or h in the list somewhere to plot " + 
                    "epochs, batches, seconds, or hours, respectively.")
            response = input('Enter a list of channels to plot ' + \
                    '(example: A, C,F-G, h, <test_err>) or q to quit' + \
                    ' or o for options: ')

            if response == 'o':
                print('1: smooth all channels')
                print('any other response: do nothing, go back to plotting')
                response = input('Enter your choice: ')
                if response == '1':
                    for channel in channels.values():
                        k = 5
                        new_val_record = []
                        for i in xrange(len(channel.val_record)):
                            new_val = 0.
                            count = 0.
                            for j in xrange(max(0, i-k), i+1):
                                new_val += channel.val_record[j]
                                count += 1.
                            new_val_record.append(new_val / count)
                        channel.val_record = new_val_record
                continue

            if response == 'q':
                break

            #Remove spaces
            response = response.replace(' ','')

            #Split into list
            codes = response.split(',')

            final_codes = set([])

            for code in codes:
                if code == 'e':
                    x_axis = 'epoch'
                    continue
                elif code == 'b':
                    x_axis = 'batche'
                elif code == 's':
                    x_axis = 'second'
                elif code == 'h':
                    x_axis = 'hour'
                elif code.startswith('<'):
                    assert code.endswith('>')
                    final_codes.add(code)
                elif code.find('-') != -1:
                    #The current list element is a range of codes

                    rng = code.split('-')

                    if len(rng) != 2:
                        print("Input not understood: "+code)
                        quit(-1)

                    found = False
                    for i in xrange(len(sorted_codes)):
                        if sorted_codes[i] == rng[0]:
                            found = True
                            break

                    if not found:
                        print("Invalid code: "+rng[0])
                        quit(-1)

                    found = False
                    for j in xrange(i,len(sorted_codes)):
                        if sorted_codes[j] == rng[1]:
                            found = True
                            break

                    if not found:
                        print("Invalid code: "+rng[1])
                        quit(-1)

                    final_codes = final_codes.union(set(sorted_codes[i:j+1]))
                else:
                    #The current list element is just a single code
                    final_codes = final_codes.union(set([code]))
            # end for code in codes
        else:
            final_codes ,= set(codebook.keys())

        colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k']
        styles = list(colors)
        styles += [color+'--' for color in colors]
        styles += [color+':' for color in colors]

        fig = plt.figure()
        ax = plt.subplot(1,1,1)

        # plot the requested channels
        for idx, code in enumerate(sorted(final_codes)):

            channel_name= codebook[code]
            channel = channels[channel_name]

            y = np.asarray(channel.val_record)

            if contains_nan(y):
                print(channel_name + ' contains NaNs')

            if contains_inf(y):
                print(channel_name + 'contains infinite values')

            if x_axis == 'example':
                x = np.asarray(channel.example_record)
            elif x_axis == 'batche':
                x = np.asarray(channel.batch_record)
            elif x_axis == 'epoch':
                try:
                    x = np.asarray(channel.epoch_record)
                except AttributeError:
                    # older saved monitors won't have epoch_record
                    x = np.arange(len(channel.batch_record))
            elif x_axis == 'second':
                x = np.asarray(channel.time_record)
            elif x_axis == 'hour':
                x = np.asarray(channel.time_record) / 3600.
            else:
                assert False


            ax.plot( x,
                      y,
                      styles[idx % len(styles)],
                      marker = '.', # add point margers to lines
                      label = channel_name)

        plt.xlabel('# '+x_axis+'s')
        ax.ticklabel_format( scilimits = (-3,3), axis = 'both')

        handles, labels = ax.get_legend_handles_labels()
        lgd = ax.legend(handles, labels, loc='upper center',
                bbox_to_anchor=(0.5,-0.1))
        # 0.046 is the size of 1 legend box
        fig.subplots_adjust(bottom=0.11 + 0.046 * len(final_codes))

        if options.out is None:
          plt.show()
        else:
          plt.savefig(options.out)

        if not prompt:
            break
示例#20
0
def show_samples(m, model_path):
    """
    Show samples given a DBM model.

    Parameters
    ----------
    m: int
        rows * cols
    model_path: str
        Path of the model.
    """
    model = load_model(model_path, m)

    print('Loading data (used for setting up visualization '
          'and seeding gibbs chain) ...')
    dataset_yaml_src = model.dataset_yaml_src
    dataset = yaml_parse.load(dataset_yaml_src)

    pv = init_viewer(dataset, rows, cols)

    if hasattr(model.visible_layer, 'beta'):
        beta = model.visible_layer.beta.get_value()
        print('beta: ', (beta.min(), beta.mean(), beta.max()))

    print('showing seed data...')
    vis_batch = dataset.get_batch_topo(m)
    update_viewer(dataset, pv, vis_batch, rows, cols)
    pv.show()

    print('How many Gibbs steps should I run with the seed data clamped?'
          '(negative = ignore seed data)')
    x = int(input())

    # Make shared variables representing the sampling state of the model
    layer_to_state = model.make_layer_to_state(m)
    # Seed the sampling with the data batch
    vis_sample = layer_to_state[model.visible_layer]

    validate_all_samples(model, layer_to_state)

    if x >= 0:
        if vis_sample.ndim == 4:
            vis_sample.set_value(vis_batch)
        else:
            design_matrix = dataset.get_design_matrix(vis_batch)
            vis_sample.set_value(design_matrix)

    validate_all_samples(model, layer_to_state)

    sample_func = get_sample_func(model, layer_to_state, x)

    while True:
        print('Displaying samples. '
              'How many steps to take next? (q to quit, ENTER=1)')
        while True:
            x = input()
            if x == 'q':
                quit()
            if x == '':
                x = 1
                break
            else:
                try:
                    x = int(x)
                    break
                except ValueError:
                    print('Invalid input, try again')

        for i in xrange(x):
            print(i)
            sample_func()

        validate_all_samples(model, layer_to_state)

        vis_batch = vis_sample.get_value()
        update_viewer(dataset, pv, vis_batch, rows, cols)
        pv.show()

        if 'Softmax' in str(type(model.hidden_layers[-1])):
            state = layer_to_state[model.hidden_layers[-1]]
            value = state.get_value()
            y = np.argmax(value, axis=1)
            assert y.ndim == 1
            for i in xrange(0, y.shape[0], cols):
                print(y[i:i+cols])
示例#21
0
            pv.add_patch(display_batch[row_start+j,:,:,:], rescale = False)
            if mapback:
                pv.add_patch(mapped_batch[row_start+j,:,:,:], rescale = False)
    pv.show()


if hasattr(model.visible_layer, 'beta'):
    beta = model.visible_layer.beta.get_value()
#model.visible_layer.beta.set_value(beta * 100.)
    print('beta: ',(beta.min(), beta.mean(), beta.max()))

print('showing seed data...')
show()

print('How many Gibbs steps should I run with the seed data clamped? (negative = ignore seed data) ')
x = int(input())


# Make shared variables representing the sampling state of the model
layer_to_state = model.make_layer_to_state(m)
# Seed the sampling with the data batch
vis_sample = layer_to_state[model.visible_layer]

def validate_all_samples():
    # Run some checks on the samples, this should help catch any bugs
    layers = [ model.visible_layer ] + model.hidden_layers

    def check_batch_size(l):
        if isinstance(l, (list, tuple)):
            map(check_batch_size, l)
        else:
示例#22
0
rows = 5
cols = 10
m = rows * cols

_, model_path = sys.argv

print('Loading model...')
model = serial.load(model_path)
model.set_batch_size(m)

dataset_yaml_src = model.dataset_yaml_src

print('Loading data...')
dataset = yaml_parse.load(dataset_yaml_src)

x = input('use test set? (y/n) ')

if x == 'y':
    dataset = dataset.get_test_set()
else:
    assert x == 'n'

vis_batch = dataset.get_batch_topo(m)

_, patch_rows, patch_cols, channels = vis_batch.shape

assert _ == m

mapback = hasattr(dataset, 'mapback_for_viewer')

actual_cols = 2 * cols * (1 + mapback) * (1 + (channels == 2))
示例#23
0
def remove_packages(packages_to_remove):
    """
    Uninstall packages, whether or not they
    are found in the source.lst (so it can
    remove datasets installed from file).

    :param packages_to_remove: list of package names
    :raises: IOErrors
    """

    if packages_to_remove == []:
        raise RuntimeError("[rm] fatal: need packages names to remove.")

    packages_really_to_remove = []
    for this_package in packages_to_remove:
        if this_package in packages_sources:

            #this_data_set_location=os.path.join(dataset_data_path,this_package)

            # check if in the installed.lst
            # then if directory actually exists
            # then if you have rights to remove it
            if this_package in installed_packages_list:

                this_data_set_location = os.path.join(
                    installed_packages_list[this_package].where,
                    installed_packages_list[this_package].name)

                if os.path.exists(this_data_set_location):
                    if (file_access_rights(this_data_set_location, os.W_OK)):
                        # ok, you may have rights to delete it
                        packages_really_to_remove.append(this_package)
                    else:
                        logger.warning("[rm] insufficient rights "
                                       "to remove '{0}'".format(this_package))
                else:
                    logger.warning("[rm] package '{0}' found in config file "
                                   "but not installed".format(this_package))
            else:
                logger.warning("[rm] package '{0}' "
                               "not installed".format(this_package))
        else:
            logger.warning("[rm] unknown package '{0}'".format(this_package))

    if packages_really_to_remove != []:
        logger.info("[rm] the following packages will be removed permanently:")
        packages = []
        for this_package in packages_really_to_remove:
            packages.append(this_package)
        logger.info(' '.join(packages))

        r = input("Proceed? [yes/N] ")
        if r == 'y' or r == 'yes':
            for this_package in packages_really_to_remove:
                remove_package(installed_packages_list[this_package],
                               dataset_data_path)
        else:
            logger.info("[up] Taking '{0}' for no, so there.".format(r))
    else:
        # ok, nothing to remove, filenames where bad.
        pass
示例#24
0
def show_samples(m, model_path):
    """
    Show samples given a DBM model.

    Parameters
    ----------
    m: int
        rows * cols
    model_path: str
        Path of the model.
    """
    model = load_model(model_path, m)

    print('Loading data (used for setting up visualization '
          'and seeding gibbs chain) ...')
    dataset_yaml_src = model.dataset_yaml_src
    dataset = yaml_parse.load(dataset_yaml_src)

    pv = init_viewer(dataset, rows, cols)

    if hasattr(model.visible_layer, 'beta'):
        beta = model.visible_layer.beta.get_value()
        print('beta: ', (beta.min(), beta.mean(), beta.max()))

    print('showing seed data...')
    vis_batch = dataset.get_batch_topo(m)
    update_viewer(dataset, pv, vis_batch, rows, cols)
    pv.show()

    print('How many Gibbs steps should I run with the seed data clamped?'
          '(negative = ignore seed data)')
    x = int(input())

    # Make shared variables representing the sampling state of the model
    layer_to_state = model.make_layer_to_state(m)
    # Seed the sampling with the data batch
    vis_sample = layer_to_state[model.visible_layer]

    validate_all_samples(model, layer_to_state)

    if x >= 0:
        if vis_sample.ndim == 4:
            vis_sample.set_value(vis_batch)
        else:
            design_matrix = dataset.get_design_matrix(vis_batch)
            vis_sample.set_value(design_matrix)

    validate_all_samples(model, layer_to_state)

    sample_func = get_sample_func(model, layer_to_state, x)

    while True:
        print('Displaying samples. '
              'How many steps to take next? (q to quit, ENTER=1)')
        while True:
            x = input()
            if x == 'q':
                quit()
            if x == '':
                x = 1
                break
            else:
                try:
                    x = int(x)
                    break
                except ValueError:
                    print('Invalid input, try again')

        for i in xrange(x):
            print(i)
            sample_func()

        validate_all_samples(model, layer_to_state)

        vis_batch = vis_sample.get_value()
        update_viewer(dataset, pv, vis_batch, rows, cols)
        pv.show()

        if 'Softmax' in str(type(model.hidden_layers[-1])):
            state = layer_to_state[model.hidden_layers[-1]]
            value = state.get_value()
            y = np.argmax(value, axis=1)
            assert y.ndim == 1
            for i in xrange(0, y.shape[0], cols):
                print(y[i:i + cols])
示例#25
0
def upgrade_packages(packages_to_upgrade, hook=None):
    """
    Upgrades packages.

    If no packages are supplied, it will perform
    an "update-all" operation, finding all packages
    that are out of date.

    If packages names are supplied, only those
    are checked for upgrade (and upgraded if out
    of date)

    :param packages_to_upgrade: list of package names.
    :raises: IOErrors (from downloads/rights)
    """

    # get names only
    if packages_to_upgrade == []:
        packages_to_upgrade = installed_packages_list.keys()  # all installed!
        all_packages = True
    else:
        all_packages = False

    # check what packages are in the list,
    # and really to be upgraded.
    #
    packages_really_to_upgrade = []
    for this_package in packages_to_upgrade:
        if this_package in installed_packages_list:

            # check if there's a date
            installed_date = installed_packages_list[this_package].timestamp

            if this_package in packages_sources:
                repo_date = packages_sources[this_package].timestamp

                if installed_date < repo_date:
                    # ok, there's a newer version
                    logger.info(this_package)
                    packages_really_to_upgrade.append(this_package)
                else:
                    # no newer version, nothing to update
                    pass
            else:
                logger.warning("[up] '{0}' is unknown "
                               "(installed from file?).".format(this_package))
        else:
            # not installed?
            if not all_packages:
                logger.warning("[up] '{0}' is not installed, "
                               "cannot upgrade.".format(this_package))
                pass

    # once we have determined which packages
    # are to be updated, we show them to the
    # user for him to confirm
    #
    if packages_really_to_upgrade != []:
        logger.info("[up] the following package(s) will be upgraded:")
        for this_package in packages_really_to_upgrade:
            readable_size = packages_sources[this_package].readable_size
            logger.info("{0} ({1})".format(this_package, readable_size))

        r = input("Proceed? [yes/N] ")
        if r == 'y' or r == 'yes':
            for this_package in packages_really_to_upgrade:
                install_upgrade(packages_sources[this_package],
                                upgrade=True,
                                progress_hook=hook)
        else:
            logger.info("[up] Taking '{0}' for no, so there.".format(r))
    else:
        # ok, nothing to upgrade,
        # move along.
        pass
示例#26
0
def get_weights_report(model_path=None,
                       model=None,
                       rescale='individual',
                       border=False,
                       norm_sort=False,
                       dataset=None):
    """
    Returns a PatchViewer displaying a grid of filter weights

    Parameters
    ----------
    model_path : str
        Filepath of the model to make the report on.
    rescale : str
        A string specifying how to rescale the filter images: \
            'individual' (default): scale each filter so that it \
                uses as much as possible of the dynamic range \
                of the display under the constraint that 0 \
                is gray and no value gets clipped \
            'global' : scale the whole ensemble of weights \
            'none' :   don't rescale
    dataset: pylearn2.datasets.dataset.Dataset
        Dataset object to do view conversion for displaying the weights. If \
        not provided one will be loaded from the model's dataset_yaml_src.

    Returns
    -------
    WRITEME
    """

    if model is None:
        print 'making weights report'
        print 'loading model'
        model = serial.load(model_path)
        print 'loading done'
    else:
        assert model_path is None
    assert model is not None

    if rescale == 'none':
        global_rescale = False
        patch_rescale = False
    elif rescale == 'global':
        global_rescale = True
        patch_rescale = False
    elif rescale == 'individual':
        global_rescale = False
        patch_rescale = True
    else:
        raise ValueError('rescale=' + rescale +
                         ", must be 'none', 'global', or 'individual'")

    if hasattr(model, 'layers'):
        if isinstance(model.layers[0], mlp_models.PretrainedLayer):
            model = model.layers[0].layer_content

    if isinstance(model, dict):
        #assume this was a saved matlab dictionary
        del model['__version__']
        del model['__header__']
        del model['__globals__']
        keys = [key for key in model \
                if hasattr(model[key], 'ndim') and model[key].ndim == 2]
        if len(keys) > 2:
            key = None
            while key not in keys:
                logger.info('Which is the weights?')
                for key in keys:
                    logger.info('\t{0}'.format(key))
                key = input()
        else:
            key, = keys
        weights = model[key]

        norms = np.sqrt(np.square(weights).sum(axis=1))
        print 'min norm: ',norms.min()
        print 'mean norm: ',norms.mean()
        print 'max norm: ',norms.max()

        return patch_viewer.make_viewer(weights,
                                        is_color=weights.shape[1] % 3 == 0)

    weights_view = None
    W = None

    try:
        weights_view = model.get_weights_topo()
        h = weights_view.shape[0]
    except NotImplementedError:

        if dataset is None:
            print 'loading dataset...'
            control.push_load_data(False)
            dataset = yaml_parse.load(model.dataset_yaml_src)
            control.pop_load_data()
            print '...done'

        try:
            W = model.get_weights()
示例#27
0
            #print 'recons: '
            #for ch in xrange(3):
            #    chv = r[:,:,ch]
            #    print '\t',ch,(chv.min(),chv.mean(),chv.max())
            if mapback:
                pv.add_patch(dataset.adjust_to_be_viewed_with(
                    mapped_r_batch[row_start+j,:,:,:].copy(),
                    mapped_batch[row_start+j,:,:,:].copy()),rescale = False)
    pv.show()


if hasattr(grbm.visible_layer, 'beta'):
    beta = grbm.visible_layer.beta.get_value()
    #model.visible_layer.beta.set_value(beta * 100.)
    print('beta: ',(beta.min(), beta.mean(), beta.max()))

while True:
    show()
    print('Displaying reconstructions. (q to quit, ENTER = show more)')
    while True:
        x = input()
        if x == 'q':
            quit()
        if x == '':
            x = 1
            break
        else:
            print('Invalid input, try again')

    vis_batch = dataset.get_batch_topo(m)