예제 #1
0
def test_serialization():
    # Create a simple brick with two parameters
    mlp = MLP(activations=[None, None],
              dims=[10, 10, 10],
              weights_init=Constant(1.),
              use_bias=False)
    mlp.initialize()
    W = mlp.linear_transformations[1].W
    W.set_value(W.get_value() * 2)

    # Check the data using numpy.load
    with NamedTemporaryFile(delete=False) as f:
        dump(mlp, f)
    numpy_data = numpy.load(f.name)
    assert set(numpy_data.keys()) == \
        set(['mlp-linear_0.W', 'mlp-linear_1.W', 'pkl'])
    assert_allclose(numpy_data['mlp-linear_0.W'], numpy.ones((10, 10)))
    assert numpy_data['mlp-linear_0.W'].dtype == theano.config.floatX

    # Ensure that it can be unpickled
    mlp = load(f.name)
    assert_allclose(mlp.linear_transformations[1].W.get_value(),
                    numpy.ones((10, 10)) * 2)

    # Ensure that only parameters are saved as NPY files
    mlp.random_data = numpy.random.rand(10)
    with NamedTemporaryFile(delete=False) as f:
        dump(mlp, f)
    numpy_data = numpy.load(f.name)
    assert set(numpy_data.keys()) == \
        set(['mlp-linear_0.W', 'mlp-linear_1.W', 'pkl'])

    # Ensure that parameters can be loaded with correct names
    parameter_values = load_parameter_values(f.name)
    assert set(parameter_values.keys()) == \
        set(['/mlp/linear_0.W', '/mlp/linear_1.W'])

    # Ensure that duplicate names are dealt with
    for child in mlp.children:
        child.name = 'linear'
    with NamedTemporaryFile(delete=False) as f:
        dump(mlp, f)
    numpy_data = numpy.load(f.name)
    assert set(numpy_data.keys()) == \
        set(['mlp-linear.W', 'mlp-linear.W_2', 'pkl'])

    # Ensure warnings are raised when __main__ namespace objects are dumped
    foo.__module__ = '__main__'
    import __main__
    __main__.__dict__['foo'] = foo
    mlp.foo = foo
    with NamedTemporaryFile(delete=False) as f:
        with warnings.catch_warnings(record=True) as w:
            dump(mlp, f)
            assert len(w) == 1
            assert '__main__' in str(w[-1].message)
예제 #2
0
def test_serialization():
    # Create a simple brick with two parameters
    mlp = MLP(activations=[None, None], dims=[10, 10, 10],
              weights_init=Constant(1.), use_bias=False)
    mlp.initialize()
    W = mlp.linear_transformations[1].W
    W.set_value(W.get_value() * 2)

    # Check the data using numpy.load
    with NamedTemporaryFile(delete=False) as f:
        dump(mlp, f)
    numpy_data = numpy.load(f.name)
    assert set(numpy_data.keys()) == \
        set(['mlp-linear_0.W', 'mlp-linear_1.W', 'pkl'])
    assert_allclose(numpy_data['mlp-linear_0.W'], numpy.ones((10, 10)))
    assert numpy_data['mlp-linear_0.W'].dtype == theano.config.floatX

    # Ensure that it can be unpickled
    mlp = load(f.name)
    assert_allclose(mlp.linear_transformations[1].W.get_value(),
                    numpy.ones((10, 10)) * 2)

    # Ensure that only parameters are saved as NPY files
    mlp.random_data = numpy.random.rand(10)
    with NamedTemporaryFile(delete=False) as f:
        dump(mlp, f)
    numpy_data = numpy.load(f.name)
    assert set(numpy_data.keys()) == \
        set(['mlp-linear_0.W', 'mlp-linear_1.W', 'pkl'])

    # Ensure that parameters can be loaded with correct names
    parameter_values = load_parameter_values(f.name)
    assert set(parameter_values.keys()) == \
        set(['/mlp/linear_0.W', '/mlp/linear_1.W'])

    # Ensure that duplicate names are dealt with
    for child in mlp.children:
        child.name = 'linear'
    with NamedTemporaryFile(delete=False) as f:
        dump(mlp, f)
    numpy_data = numpy.load(f.name)
    assert set(numpy_data.keys()) == \
        set(['mlp-linear.W', 'mlp-linear.W_2', 'pkl'])

    # Ensure warnings are raised when __main__ namespace objects are dumped
    foo.__module__ = '__main__'
    import __main__
    __main__.__dict__['foo'] = foo
    mlp.foo = foo
    with NamedTemporaryFile(delete=False) as f:
        with warnings.catch_warnings(record=True) as w:
            dump(mlp, f)
            assert len(w) == 1
            assert '__main__' in str(w[-1].message)
예제 #3
0
def test_serialization():

    # Create a simple MLP to dump.
    mlp = MLP(activations=[None, None],
              dims=[10, 10, 10],
              weights_init=Constant(1.),
              use_bias=False)
    mlp.initialize()
    W = mlp.linear_transformations[1].W
    W.set_value(W.get_value() * 2)

    # Ensure warnings are raised when __main__ namespace objects are dumped.
    foo.__module__ = '__main__'
    import __main__
    __main__.__dict__['foo'] = foo
    mlp.foo = foo
    with NamedTemporaryFile(delete=False) as f:
        with warnings.catch_warnings(record=True) as w:
            dump(mlp.foo, f)
            assert len(w) == 1
            assert '__main__' in str(w[-1].message)

    # Check the parameters.
    with NamedTemporaryFile(delete=False) as f:
        dump(mlp, f, parameters=[mlp.children[0].W, mlp.children[1].W])
    with open(f.name, 'rb') as ff:
        numpy_data = load_parameters(ff)
    assert set(numpy_data.keys()) == \
        set(['/mlp/linear_0.W', '/mlp/linear_1.W'])
    assert_allclose(numpy_data['/mlp/linear_0.W'], numpy.ones((10, 10)))
    assert numpy_data['/mlp/linear_0.W'].dtype == theano.config.floatX

    # Ensure that it can be unpickled.
    with open(f.name, 'rb') as ff:
        mlp = load(ff)
    assert_allclose(mlp.linear_transformations[1].W.get_value(),
                    numpy.ones((10, 10)) * 2)

    # Ensure that duplicate names are dealt with.
    for child in mlp.children:
        child.name = 'linear'
    with NamedTemporaryFile(delete=False) as f:
        dump(mlp, f, parameters=[mlp.children[0].W, mlp.children[1].W])
    with open(f.name, 'rb') as ff:
        numpy_data = load_parameters(ff)
    assert set(numpy_data.keys()) == \
        set(['/mlp/linear.W', '/mlp/linear.W_2'])

    # Check when we don't dump the main object.
    with NamedTemporaryFile(delete=False) as f:
        dump(None, f, parameters=[mlp.children[0].W, mlp.children[1].W])
    with tarfile.open(f.name, 'r') as tarball:
        assert set(tarball.getnames()) == set(['_parameters'])
예제 #4
0
def test_serialization():

    # Create a simple MLP to dump.
    mlp = MLP(activations=[None, None], dims=[10, 10, 10],
              weights_init=Constant(1.), use_bias=False)
    mlp.initialize()
    W = mlp.linear_transformations[1].W
    W.set_value(W.get_value() * 2)

    # Ensure warnings are raised when __main__ namespace objects are dumped.
    foo.__module__ = '__main__'
    import __main__
    __main__.__dict__['foo'] = foo
    mlp.foo = foo
    with NamedTemporaryFile(delete=False) as f:
        with warnings.catch_warnings(record=True) as w:
            dump(mlp.foo, f)
            assert len(w) == 1
            assert '__main__' in str(w[-1].message)

    # Check the parameters.
    with NamedTemporaryFile(delete=False) as f:
        dump(mlp, f, parameters=[mlp.children[0].W, mlp.children[1].W])
    with open(f.name, 'rb') as ff:
        numpy_data = load_parameters(ff)
    assert set(numpy_data.keys()) == \
        set(['/mlp/linear_0.W', '/mlp/linear_1.W'])
    assert_allclose(numpy_data['/mlp/linear_0.W'], numpy.ones((10, 10)))
    assert numpy_data['/mlp/linear_0.W'].dtype == theano.config.floatX

    # Ensure that it can be unpickled.
    with open(f.name, 'rb') as ff:
        mlp = load(ff)
    assert_allclose(mlp.linear_transformations[1].W.get_value(),
                    numpy.ones((10, 10)) * 2)

    # Ensure that duplicate names are dealt with.
    for child in mlp.children:
        child.name = 'linear'
    with NamedTemporaryFile(delete=False) as f:
        dump(mlp, f, parameters=[mlp.children[0].W, mlp.children[1].W])
    with open(f.name, 'rb') as ff:
        numpy_data = load_parameters(ff)
    assert set(numpy_data.keys()) == \
        set(['/mlp/linear.W', '/mlp/linear.W_2'])

    # Check when we don't dump the main object.
    with NamedTemporaryFile(delete=False) as f:
        dump(None, f, parameters=[mlp.children[0].W, mlp.children[1].W])
    with tarfile.open(f.name, 'r') as tarball:
        assert set(tarball.getnames()) == set(['_parameters'])