Exemplo n.º 1
0
def test():
    model_to_be_restored = MLP()
    checkpoint = tf.train.Checkpoint(myAwesomeModel=model_to_be_restored)
    checkpoint.restore(tf.train.latest_checkpoint('./save'))
    y_pred = np.argmax(model_to_be_restored.predict(data_loader.test_data),
                       axis=-1)
    print("test accuracy: %f" %
          (sum(y_pred == data_loader.test_label) / data_loader.num_test_data))
Exemplo n.º 2
0
        X, y = data_loader.get_batch(batch_size)
        with tf.GradientTape() as tape:
            y_pred = model(X)
            loss = tf.keras.losses.sparse_categorical_crossentropy(
                y_true=y, y_pred=y_pred)
            loss = tf.reduce_mean(loss)
            print("batch %d: loss %f" % (batch_index, loss.numpy()))
        grads = tape.gradient(loss, model.variables)
        optimizer.apply_gradients(grads_and_vars=zip(grads, model.variables))

    sparse_categorical_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
    num_batches = int(data_loader.num_test_data // batch_size)
    for batch_index in range(num_batches):
        start_index, end_index = batch_index * batch_size, (batch_index +
                                                            1) * batch_size
        y_pred = model.predict(data_loader.test_data[start_index:end_index])
        sparse_categorical_accuracy.update_state(
            y_true=data_loader.test_label[start_index:end_index],
            y_pred=y_pred)
    print("test accuracy: %f" % sparse_categorical_accuracy.result())
if training_loop == 'graph':
    optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)
    num_batches = int(data_loader.num_train_data // batch_size * num_epochs)
    # 建立计算图
    X_placeholder = tf.compat.v1.placeholder(name='X',
                                             shape=[None, 28, 28, 1],
                                             dtype=tf.float32)
    y_placeholder = tf.compat.v1.placeholder(name='y',
                                             shape=[None],
                                             dtype=tf.int32)
    y_pred = model(X_placeholder)