def main(argv):
    model_file = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)

    if model_file is None:
        print('No model found')
        sys.exit()

    set_log_level(logging.DEBUG)

    sess = tf.Session()
    with sess.as_default():

        model = make_wresnet()
        saver = tf.train.Saver()
        # Restore the checkpoint
        saver.restore(sess, model_file)
        SCOPE = "cifar10_challenge"
        model2 = make_wresnet(scope=SCOPE)
        assert len(model.get_vars()) == len(model2.get_vars())
        found = [False] * len(model2.get_vars())
        for var1 in model.get_vars():
            var1_found = False
            var2_name = SCOPE + "/" + var1.name
            for idx, var2 in enumerate(model2.get_vars()):
                if var2.name == var2_name:
                    var1_found = True
                    found[idx] = True
                    sess.run(tf.assign(var2, var1))
                    break
            assert var1_found, var1.name
        assert all(found)

        model2.dataset_factory = Factory(CIFAR, {"max_val": 255})

        serial.save("model.joblib", model2)
 def __init__(self, n_classes=10):
     self.W_conv1 = self._weight_variable([5, 5, 1, 32])
     self.b_conv1 = self._bias_variable([32])
     self.W_conv2 = self._weight_variable([5, 5, 32, 64])
     self.b_conv2 = self._bias_variable([64])
     self.W_fc1 = self._weight_variable([7 * 7 * 64, 1024])
     self.b_fc1 = self._bias_variable([1024])
     self.W_fc2 = self._weight_variable([1024, n_classes])
     self.b_fc2 = self._bias_variable([n_classes])
     super().__init__(self, '', n_classes, {})
     self.dataset_factory = Factory(MNIST, {"center": False})
Beispiel #3
0
 def __init__(self, nb_classes=10):
     # NOTE: for compatibility with Madry Lab downloadable checkpoints,
     # we cannot use scopes, give these variables names, etc.
     self.W_conv1 = self._weight_variable([5, 5, 1, 32])
     self.b_conv1 = self._bias_variable([32])
     self.W_conv2 = self._weight_variable([5, 5, 32, 64])
     self.b_conv2 = self._bias_variable([64])
     self.W_fc1 = self._weight_variable([7 * 7 * 64, 1024])
     self.b_fc1 = self._bias_variable([1024])
     self.W_fc2 = self._weight_variable([1024, nb_classes])
     self.b_fc2 = self._bias_variable([nb_classes])
     Model.__init__(self, "", nb_classes, {})
     self.dataset_factory = Factory(MNIST, {"center": False})
    def __init__(self, nb_classes=10):
        # NOTE: for compatibility with Madry Lab downloadable checkpoints,
        # we cannot use scopes, give these variables names, etc.
        """
        self.conv1 = tf.layers.Conv2D(32, (5, 5), activation='relu', padding='same', name='conv1')
        self.pool1 = tf.layers.MaxPooling2D((2, 2), (2, 2), padding='same')
        self.conv2 = tf.layers.Conv2D(64, (5, 5), activation='relu', padding='same', name='conv2')
        self.pool2 = tf.layers.MaxPooling2D((2, 2), (2, 2), padding='same')
        self.fc1 = tf.layers.Dense(1024, activation='relu', name='fc1')
        self.fc2 = tf.layers.Dense(10, name='fc2')
        """

        keras_model = tf.keras.Sequential()
        keras_model.add(
            tf.keras.layers.Conv2D(32, (5, 5),
                                   activation='relu',
                                   padding='same',
                                   name='conv1',
                                   input_shape=(28, 28, 1)))
        keras_model.add(
            tf.keras.layers.MaxPooling2D((2, 2), (2, 2), padding='same'))
        keras_model.add(
            tf.keras.layers.Conv2D(64, (5, 5),
                                   activation='relu',
                                   padding='same',
                                   name='conv2'))
        keras_model.add(
            tf.keras.layers.MaxPooling2D((2, 2), (2, 2), padding='same'))
        keras_model.add(tf.keras.layers.Flatten())
        keras_model.add(
            tf.keras.layers.Dense(1024, activation='relu', name='fc1'))
        keras_model.add(tf.keras.layers.Dense(10, name='fc2'))

        self.keras_model = keras_model
        Model.__init__(self, '', nb_classes, {})
        self.dataset_factory = Factory(MNIST, {"center": False})