Ejemplo n.º 1
0
    def test_subclass_save_model(self):
        num_classes = 10
        # Input size, e.g. image
        batch_size = None
        input_shape = (32, 32, 3)

        model = model_util.SimpleConvTestModel(num_classes)
        self.assertFalse(model.built, 'Model should not have been built')
        self.assertFalse(model.weights,
                         ('Model should have no weights since it '
                          'has not been built.'))
        model.build(input_shape=tensor_shape.TensorShape((batch_size, ) +
                                                         input_shape))
        self.assertTrue(model.weights,
                        ('Model should have weights now that it '
                         'has been properly built.'))
        self.assertTrue(model.built,
                        'Model should be built after calling `build`.')
        weights = model.get_weights()

        tf_format_name = os.path.join(self.get_temp_dir(), 'ckpt')
        model.save_weights(tf_format_name)
        if h5py is not None:
            hdf5_format_name = os.path.join(self.get_temp_dir(), 'weights.h5')
            model.save_weights(hdf5_format_name)

        model = model_util.SimpleConvTestModel(num_classes)
        model.build(input_shape=tensor_shape.TensorShape((batch_size, ) +
                                                         input_shape))
        if h5py is not None:
            model.load_weights(hdf5_format_name)
            self.assertAllClose(weights, model.get_weights())
        model.load_weights(tf_format_name)
        self.assertAllClose(weights, model.get_weights())
Ejemplo n.º 2
0
  def test_tensorshape_io_subclass_build(self):
    num_classes = 10
    # Input size, e.g. image
    batch_size = None
    input_shape = (32, 32, 3)

    model = model_util.SimpleConvTestModel(num_classes)
    self.assertFalse(model.built, 'Model should not have been built')
    self.assertFalse(model.weights, ('Model should have no weights since it '
                                     'has not been built.'))
    model.build(
        input_shape=tensor_shape.TensorShape((batch_size,) + input_shape))
    self.assertTrue(model.weights, ('Model should have weights now that it '
                                    'has been properly built.'))
    self.assertTrue(model.built, 'Model should be built after calling `build`.')

    model(array_ops.ones((32,) + input_shape))