def main(output):
  m = ModelFromCaffe()
  images, labels = fgo.inputs()
  m.graph(images, labels)

  init = tf.initialize_all_variables()
  saver = tf.train.Saver()
  with tf.Session() as sess:
    sess.run(init)
    path = saver.save(sess, output)
    print("saved variables to %s" % path)
def setup(param):
  env = Inference()

  env.images_op, _ = fgo.inputs()
  env.logits_op = fgo.FGO16().graph(env.images_op, n_classes=61, wd=5e-4)
  # prob_op = fgo.inference(logits_op)

  saver = tf.train.Saver()

  if not param.gpu:
    print("no gpu")
    config = tf.ConfigProto(device_count={'GPU': 0})
  else:
    print("with gpu")
    config = tf.ConfigProto()

  env.sess = tf.Session(config=config)
  with env.sess.as_default():
    tf.initialize_all_variables().run()
    saver.restore(env.sess, param.saved)

  return env
示例#3
0
  parser.add_argument('--batch', default=64, type=int)
  parser.add_argument('--steps', default=None, type=int)
  parser.add_argument('--save_steps', default=100, type=int)
  parser.add_argument('--eval_size', default=0, type=int)
  parser.add_argument('--gpu', default=1, type=bool)
  parser.add_argument('--eval_steps', default=50, type=int)
  param = parser.parse_args()

  train_set, validation_set, test_set = input_data.make_split(param.data)
  print("Train set (%d), validation set (%d), test set (%d)" % (len(train_set), len(validation_set), len(test_set)))
  train_batches = input_data.load_batches(train_set, param.batch, finite=False, shuffle=True, randflip=True, randshift=True, randcrop=True)

  if param.eval_size > 0:
    validation_set = validation_set[:param.eval_size]

  images_op, labels_op = fgo.inputs()
  logits_op = fgo.FGO16().graph(images_op, n_classes=61, wd=5e-4)
  prob_op = fgo.inference(logits_op)
  loss_op = fgo.cost(logits_op, labels_op, param.batch)

  saver = tf.train.Saver()

  train_op, global_step_op, lr_op = fgo.training(loss_op)
  accuracy_op = fgo.accuracy(prob_op, labels_op)
  summary_op = fgo.summaries(images_op, loss_op)

  accuracies = []

  def eval_accuracy(sess, file_set):
    batches = input_data.load_batches(file_set, param.batch, finite=True, shuffle=False, randflip=False, randshift=False, randcrop=False)
    return accuracy(sess, images_op, labels_op, accuracy_op, loss_op, batches, len(file_set))