Beispiel #1
0
 def remove_training_directory(self, train_dir):
     try:
         logging.info("%s: Removing existing train directory.",
                      task_as_string(self.task))
         gfile.DeleteRecursively(train_dir)
     except:
         logging.error(
             "%s: Failed to delete directory " + train_dir +
             " when starting a new model. Please delete it manually and" +
             " try again.", task_as_string(self.task))
Beispiel #2
0
 def remove_training_directory(self, train_dir):
     """Removes the training directory."""
     if tf.gfile.Exists(train_dir):
         try:
             logging.info("{}: Removing existing train dir.".format(
                 task_as_string(self.task)))
             gfile.DeleteRecursively(train_dir)
         except:
             logging.error(
                 "{}: Failed to delete dir {} when starting a new model. Delete it manually and try again."
                 .format(task_as_string(self.task), train_dir))
Beispiel #3
0
def create_captcha_dataset(size=100,
                           data_dir='./data/',
                           height=60,
                           width=160,
                           image_format='.png'):
    if gfile.Exists(data_dir):
        gfile.DeleteRecursively(data_dir)
    gfile.MakeDirs(data_dir)
    captcha = ImageCaptcha(width=width, height=height)
    for _ in range(size):
        text = gen_random_text(CAPTCHA_CHARSET, CAPTCHA_LENGTH)
        captcha.write(text, data_dir + text + image_format)
    return None
def create_captcha_dataset(size=100,
                           data_dir='./data/',
                           height=60,
                           width=160,
                           image_format='.png'):
    """创建并保存验证码数据集"""
    # 清空存储目录并重新创建
    if gfile.Exists(data_dir):
        gfile.DeleteRecursively(data_dir)
    gfile.MakeDirs(data_dir)

    # 创建ImageCaptcha实例
    captcha = ImageCaptcha(width=width, height=height)

    for _ in range(size):
        # 生成随机验证码
        text = gen_random_text(CAPTCHA_CHARSET, CAPTCHA_LEN)
        captcha.write(text, data_dir + text + image_format)
Beispiel #5
0
def recover_session(self):
  # Recover session
  saver = None
  latest_checkpoint = tf.train.latest_checkpoint(self.train_dir)
  if self.config.start_new_model:
    logging.info("'start_new_model' flag is set. Removing existing train dir.")
    try:
      gfile.DeleteRecursively(self.train_dir)
    except:
      logging.error(
          "Failed to delete directory " + self.train_dir +
          " when starting a new model. Please delete it manually and" +
          " try again.")
  elif not latest_checkpoint:
    logging.info("No checkpoint file found. Building a new model.")
  else:
    meta_filename = latest_checkpoint + ".meta"
    if not gfile.Exists(meta_filename):
      logging.info("No meta graph file found. Building a new model.")
    else:
      logging.info("Restoring from meta graph file %s", meta_filename)
      saver = tf.train.import_meta_graph(meta_filename)
  return saver
Beispiel #6
0
def main(unused_argv):
  logging.set_verbosity(tf.logging.INFO)
  print("tensorflow version: %s" % tf.__version__)
  is_chief = (FLAGS.task == 0)

  # Recover session
  saver = None
  latest_checkpoint = tf.train.latest_checkpoint(FLAGS.train_dir)
  if FLAGS.start_new_model:
    logging.info("'start_new_model' flag is set. Removing existing train dir.")
    try:
      gfile.DeleteRecursively(FLAGS.train_dir)
    except:
      logging.error(
          "Failed to delete directory " + FLAGS.train_dir +
          " when starting a new model. Please delete it manually and" +
          " try again.")
  elif not latest_checkpoint:
    logging.info("No checkpoint file found. Building a new model.")
  else:
    meta_filename = latest_checkpoint + ".meta"
    if not gfile.Exists(meta_filename):
      logging.info("No meta graph file found. Building a new model.")
    else:
      logging.info("Restoring from meta graph file %s", meta_filename)
      saver = tf.train.import_meta_graph(meta_filename)

  if not saver:
    # convert feature_names and feature_sizes to lists of values
    feature_names, feature_sizes = utils.GetListOfFeatureNamesAndSizes(
        FLAGS.feature_names, FLAGS.feature_sizes)

    if FLAGS.frame_features:
      reader = readers.YT8MFrameFeatureReader(
          feature_names=feature_names,
          feature_sizes=feature_sizes)
    else:
      reader = readers.YT8MAggregatedFeatureReader(
          feature_names=feature_names,
          feature_sizes=feature_sizes)

    model = find_class_by_name(FLAGS.model,
        [frame_level_models, video_level_models])()
    label_loss_fn = find_class_by_name(FLAGS.label_loss, [losses])()
    optimizer_class = find_class_by_name(FLAGS.optimizer, [tf.train])
    build_graph(reader=reader,
                model=model,
                optimizer_class=optimizer_class,
                clip_gradient_norm=FLAGS.clip_gradient_norm,
                train_data_pattern=FLAGS.train_data_pattern,
                label_loss_fn=label_loss_fn,
                base_learning_rate=FLAGS.base_learning_rate,
                learning_rate_decay=FLAGS.learning_rate_decay,
                learning_rate_decay_examples=FLAGS.learning_rate_decay_examples,
                regularization_penalty=FLAGS.regularization_penalty,
                num_readers=FLAGS.num_readers,
                batch_size=FLAGS.batch_size,
                num_epochs=FLAGS.num_epochs)
    logging.info("built graph")
    saver = tf.train.Saver(max_to_keep=0, keep_checkpoint_every_n_hours=0.25)

  train_loop(is_chief=is_chief,
             train_dir=FLAGS.train_dir,
             saver=saver,
             master=FLAGS.master)
plt.plot(history.history['val_loss'])
plt.title('Model Loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper right')
plt.tight_layout()

plt.show()

import os
import tensorflow.gfile as gfile

save_dir = "./mnist/model/"

if gfile.Exists(save_dir):
    gfile.DeleteRecursively(save_dir)
gfile.MakeDirs(save_dir)

model_name = 'keras_mnist.h5'
model_path = os.path.join(save_dir, model_name)
model.save(model_path)
print('Saved trained model at %s ' % model_path)

from keras.models import load_model

mnist_model = load_model(model_path)

loss_and_metrics = mnist_model.evaluate(X_test, Y_test, verbose=2)

print("Test Loss: {}".format(loss_and_metrics[0]))
print("Test Accuracy: {}%".format(loss_and_metrics[1] * 100))
# 数据可视化
fig = plt.figure(),
plt.subplot(2, 1, 1),
plt.plot(history.history['acc']),
plt.plot(history.history['val_acc']),
plt.title('Model Accuracy'),
plt.ylabel('accuracy'),
plt.xlabel('epoch'),
plt.legend(['train', 'test'], loc='lower right'),
plt.subplot(2, 1, 2),
plt.plot(history.history['loss']),
plt.plot(history.history['val_loss']),
plt.title('Model Loss'),
plt.ylabel('loss'),
plt.xlabel('epoch'),
plt.legend(['train', 'test'], loc='upper right'),
plt.tight_layout(),
plt.show()

# 保存模型
save_path = './model/'

if gfile.Exists(save_path):
    gfile.DeleteRecursively(save_path)
gfile.MakeDirs(save_path)

model_name = 'softmax.h5'
model_path = os.path.join(save_path, model_name)
model.save(model_path)