コード例 #1
0
ファイル: test_train.py プロジェクト: 123fengye741/pylearn2
def test_train_cmd():
    """
    Calls the train.py script with a short YAML file
    to see if it trains without error
    """
    train(os.path.join(pylearn2.__path__[0],
                       "scripts/autoencoder_example/dae.yaml"))
コード例 #2
0
def train_mlp():
    """
    Function that trains an MLP for testing the Live Monitoring extension.
    """
    train(os.path.join(
        pylearn2.__path__[0],
        'train_extensions/tests/live_monitor_test.yaml'
    ))
コード例 #3
0
 def train(self):
     """
     See module-level docstring of /pylearn2/scripts/train.py for a more details.
     """
     print 'make argument parser'
     parser = pylearn2train.make_argument_parser()
     print 'train network'
     pylearn2train.train(self.yaml_path)
コード例 #4
0
if __name__ == "__main__":
    import tempfile

    weights_file = 'sparserf_example.pkl'
    params = (10, [[3, 0], [0, 3]], weights_file)

    # Create the dataset
    from de.datasets import VanHateren
    VanHateren.create_datasets()

    # Create the yaml file.
    _, config_fn = tempfile.mkstemp()
    with open("sparserf_template.yaml") as fp:
        # create yaml from templtes + params
        config_yaml = "".join(fp.readlines()) % params
    with open(config_fn, 'w') as config_fp:
        config_fp.write(config_yaml)

    # Train the network
    from pylearn2.scripts.train import train
    train(config=config_fn)

    # Visualize the weights
    from pylearn2.scripts.show_weights import show_weights
    show_weights(model_path=weights_file, border=True)

    # Visualize the reconstruction
    from de.compare_reconstruct import compare_reconstruction
    compare_reconstruction(model_path=weights_file)
コード例 #5
0
                i += 1

            currHiddenUnit += 1

        return connectionMatrix

    @functools.wraps(Model._modify_updates)
    def _modify_updates(self, updates):
        W = self.weights
        if W in updates:
            updates[W] = updates[W] * self.mask
        return super(SparseRFAutoencoder, self)._modify_updates(updates)

if __name__ == "__main__":

    # Create the dataset
    from .datasets import VanHateren
    VanHateren.create_datasets()

    # Train the network.
    from pylearn2.scripts.train import train
    train(config="sparserf.yaml")

    # Visualize the weights
    from pylearn2.scripts.show_weights import show_weights
    show_weights(model_path="sparserf.pkl", border=True)

    # Visualize the reconstruction
    from .compare_reconstruct import compare_reconstruction
    compare_reconstruction(model_path="sparserf.pkl")
コード例 #6
0
# coding: UTF-8

import os

os.environ["PYLEARN2_DATA_PATH"] = os.path.dirname(os.getcwd()) + "/data"
# os.environ["THEANO_FLAGS"] = "mode=FAST_RUN,device=gpu,floatX=float32"

# 参考
# http://qiita.com/fetaro/items/448407a6964d307e8840

import codecs
def ccc(name):
    if name.lower() == 'windows-31j':
        return codecs.lookup('utf-8')
codecs.register(ccc)

# handle = os.open("C:\\Users\\jgpua_000\\ml\\pylearn2_test\\data\\mnist\\train-images-idx3-ubyte", os.O_RDONLY)

from pylearn2.scripts.train import train
# train(os.path.join(pylearn2.__path__[0],"scripts/autoencoder_example/dae.yaml"))
train("mnist.yaml")
コード例 #7
0
 def train(self):
     """
     See module-level docstring of /pylearn2/scripts/train.py for a more details.
     """
     print 'train network'
     pylearn2train.train(self.yaml_path)