def __new_from_description__(cls, description): net = Network.from_architecture(description['architecture']) net.set_handler(create_from_description(description['handler'])) net.initialize(create_from_description(description['initializers'])) net.set_gradient_modifiers( create_from_description(description['gradient_modifiers'])) net.set_weight_modifiers( create_from_description(description['weight_modifiers'])) net.output_name = description.get('output_name') return net
def test_describe_numpy_handler(): nh = NumpyHandler(np.float32) d = get_description(nh) assert d == {'@type': 'NumpyHandler', 'dtype': 'float32'} nh2 = create_from_description(d) assert isinstance(nh2, NumpyHandler) assert nh2.dtype == np.float32
def test_describe_pycuda_handler(): from brainstorm.handlers.pycuda_handler import PyCudaHandler pch = PyCudaHandler() d = get_description(pch) assert d == {'@type': 'PyCudaHandler'} pch2 = create_from_description(d) assert isinstance(pch2, PyCudaHandler)
def test_create_from_description(): assert create_from_description(13) == 13 assert create_from_description(0xff) == 0xff assert create_from_description(23.5) == 23.5 assert create_from_description(1.3e-7) == 1.3e-7 assert create_from_description(True) is True assert create_from_description(False) is False assert create_from_description('foo') == 'foo' assert create_from_description(None) is None
def test_seedable_initializes_from_description1(): class Foo2(Seedable, Describable): pass f = create_from_description({'@type': 'Foo2'}) assert hasattr(f, 'rnd') assert isinstance(f.rnd, RandomState) f.rnd.randint(100) # assert no throw
def test_seedable_initializes_from_description2(): class Foo3(Seedable, Describable): def __init_from_description__(self, description): super(Foo3, self).__init_from_description__(description) f = create_from_description({'@type': 'Foo3'}) assert hasattr(f, 'rnd') assert isinstance(f.rnd, RandomState) f.rnd.randint(100) # assert no throw
def test_get_network_from_description(): net = bs.Network.from_architecture(arch) net.initialize(1) d = get_description(net) net2 = create_from_description(d) assert isinstance(net2, bs.Network) assert net2.layers.keys() == net.layers.keys() assert isinstance(net2.handler, NumpyHandler) assert net2.handler.dtype == np.float32 assert np.all(net2.buffer.parameters == net.buffer.parameters)
def test_recreate_trainer_from_description(): tr = bs.Trainer(bs.training.SgdStepper(learning_rate=0.7), verbose=False) tr.add_hook(bs.hooks.StopAfterEpoch(23)) tr.add_hook(bs.hooks.StopOnNan()) d = get_description(tr) tr2 = create_from_description(d) assert isinstance(tr2, bs.Trainer) assert tr2.verbose is False assert list(tr2.hooks.keys()) == ['StopAfterEpoch', 'StopOnNan'] assert tr2.hooks['StopAfterEpoch'].max_epochs == 23 assert isinstance(tr2.stepper, bs.training.SgdStepper) assert tr2.stepper.learning_rate == 0.7
def from_hdf5(cls, filename): """ Load network from HDF5 file. Args: filename (str): Name of the file that the network should be loaded from. Returns: Network: The loaded network. See Also: :meth:`.save_as_hdf5` """ with h5py.File(filename, 'r') as f: description = json.loads(f['description'].value.decode()) net = create_from_description(description) net.handler.set_from_numpy(net.buffer.parameters, f['parameters'].value) return net
def test_create_from_description_describable(): class Foo22(Describable): pass f = create_from_description({'@type': 'Foo22'}) assert isinstance(f, Foo22)
def test_create_from_description_list(): assert create_from_description([]) == [] assert create_from_description([1, 2, 3]) == [1, 2, 3] assert create_from_description([1, 'b']) == [1, 'b']
def test_create_from_description_dict(): assert create_from_description({}) == {} assert create_from_description({'a': 1, 'b': 'ar'}) == {'a': 1, 'b': 'ar'}
def test_create_from_description_with_invalid_description_raises(): with pytest.raises(TypeError) as excinfo: create_from_description(pytest) assert excinfo.value.args[0].find('module') > -1
def test_create_from_description_with_undescribable_raises(): with pytest.raises(TypeError) as excinfo: create_from_description({'@type': 'Unknown66'}) assert excinfo.value.args[0].find("Unknown66") > -1