def test_subclass_save_model(self): num_classes = 10 # Input size, e.g. image batch_size = None input_shape = (32, 32, 3) model = model_util.SimpleConvTestModel(num_classes) self.assertFalse(model.built, 'Model should not have been built') self.assertFalse(model.weights, ('Model should have no weights since it ' 'has not been built.')) model.build(input_shape=tensor_shape.TensorShape((batch_size, ) + input_shape)) self.assertTrue(model.weights, ('Model should have weights now that it ' 'has been properly built.')) self.assertTrue(model.built, 'Model should be built after calling `build`.') weights = model.get_weights() tf_format_name = os.path.join(self.get_temp_dir(), 'ckpt') model.save_weights(tf_format_name) if h5py is not None: hdf5_format_name = os.path.join(self.get_temp_dir(), 'weights.h5') model.save_weights(hdf5_format_name) model = model_util.SimpleConvTestModel(num_classes) model.build(input_shape=tensor_shape.TensorShape((batch_size, ) + input_shape)) if h5py is not None: model.load_weights(hdf5_format_name) self.assertAllClose(weights, model.get_weights()) model.load_weights(tf_format_name) self.assertAllClose(weights, model.get_weights())
def test_tensorshape_io_subclass_build(self): num_classes = 10 # Input size, e.g. image batch_size = None input_shape = (32, 32, 3) model = model_util.SimpleConvTestModel(num_classes) self.assertFalse(model.built, 'Model should not have been built') self.assertFalse(model.weights, ('Model should have no weights since it ' 'has not been built.')) model.build( input_shape=tensor_shape.TensorShape((batch_size,) + input_shape)) self.assertTrue(model.weights, ('Model should have weights now that it ' 'has been properly built.')) self.assertTrue(model.built, 'Model should be built after calling `build`.') model(array_ops.ones((32,) + input_shape))