def test_subclass_save_model(self): num_classes = 10 # Input size, e.g. image batch_size = None input_shape = (32, 32, 3) model = model_util.SimpleConvTestModel(num_classes) self.assertFalse(model.built, 'Model should not have been built') self.assertFalse(model.weights, ('Model should have no weights since it ' 'has not been built.')) model.build(input_shape=tf.TensorShape((batch_size, ) + input_shape)) self.assertTrue(model.weights, ('Model should have weights now that it ' 'has been properly built.')) self.assertTrue(model.built, 'Model should be built after calling `build`.') weights = model.get_weights() tf_format_name = os.path.join(self.get_temp_dir(), 'ckpt') model.save_weights(tf_format_name) if h5py is not None: hdf5_format_name = os.path.join(self.get_temp_dir(), 'weights.h5') model.save_weights(hdf5_format_name) model = model_util.SimpleConvTestModel(num_classes) model.build(input_shape=tf.TensorShape((batch_size, ) + input_shape)) if h5py is not None: model.load_weights(hdf5_format_name) self.assertAllClose(weights, model.get_weights()) model.load_weights(tf_format_name) self.assertAllClose(weights, model.get_weights())
def test_tensorshape_io_subclass_build(self): num_classes = 10 # Input size, e.g. image batch_size = None input_shape = (32, 32, 3) model = model_util.SimpleConvTestModel(num_classes) self.assertFalse(model.built, "Model should not have been built") self.assertFalse( model.weights, ("Model should have no weights since it " "has not been built."), ) model.build(input_shape=tf.TensorShape((batch_size,) + input_shape)) self.assertTrue( model.weights, ( "Model should have weights now that it " "has been properly built." ), ) self.assertTrue( model.built, "Model should be built after calling `build`." ) model(tf.ones((32,) + input_shape))
def test_multidim_io_subclass_build(self): num_classes = 10 # Input size, e.g. image batch_size = 32 input_shape = (32, 32, 3) model = model_util.SimpleConvTestModel(num_classes) self.assertFalse(model.built, 'Model should not have been built') self.assertFalse(model.weights, ('Model should have no weights since it ' 'has not been built.')) batch_input_shape = (batch_size,) + input_shape model.build(input_shape=batch_input_shape) self.assertTrue(model.weights, ('Model should have weights now that it ' 'has been properly built.')) self.assertTrue(model.built, 'Model should be built after calling `build`.') model(tf.ones(batch_input_shape))