Esempio n. 1
0
 def testIgnoreSaveCounter(self):
     checkpoint_directory = self.get_temp_dir()
     checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
     with self.cached_session() as session:
         # Create and save a model using Saver() before using a Checkpoint.
         # This generates a snapshot without the Checkpoint's `save_counter`.
         model = sequential.Sequential()
         model.add(reshaping.Flatten(input_shape=(1, )))
         model.add(core.Dense(1))
         name_saver = tf.compat.v1.train.Saver(model.trainable_variables)
         save_path = name_saver.save(sess=session,
                                     save_path=checkpoint_prefix,
                                     global_step=1)
         # Checkpoint.restore must successfully load that checkpoint.
         ckpt = tf.train.Checkpoint(model=model)
         status = ckpt.restore(save_path)
         status.assert_existing_objects_matched()
         # It should, however, refuse to load a checkpoint where an unrelated
         # `save_counter` variable is missing.
         model.layers[1].var = tf.Variable(0.0, name="save_counter")
         status = ckpt.restore(save_path)
         with self.assertRaises(AssertionError):
             status.assert_existing_objects_matched()
Esempio n. 2
0
def mnist_model(input_shape, enable_histograms=True):
    """Creates a MNIST model."""
    model = sequential_model_lib.Sequential()

    # Adding custom pass-through layer to visualize input images.
    model.add(LayerForImageSummary())

    model.add(
        conv_layer_lib.Conv2D(32,
                              kernel_size=(3, 3),
                              activation="relu",
                              input_shape=input_shape))
    model.add(conv_layer_lib.Conv2D(64, (3, 3), activation="relu"))
    model.add(pool_layer_lib.MaxPooling2D(pool_size=(2, 2)))
    model.add(regularization_layer_lib.Dropout(0.25))
    model.add(reshaping_layer_lib.Flatten())
    model.add(layer_lib.Dense(128, activation="relu"))
    model.add(regularization_layer_lib.Dropout(0.5))
    model.add(layer_lib.Dense(NUM_CLASSES, activation="softmax"))

    # Adding custom pass-through layer for summary recording.
    if enable_histograms:
        model.add(LayerForHistogramSummary())
    return model