예제 #1
0
 def eval_accuracy(sess, file_set):
   batches = input_data.load_batches(file_set, param.batch, finite=True, shuffle=False, randflip=False, randshift=False, randcrop=False)
   return accuracy(sess, images_op, labels_op, accuracy_op, loss_op, batches, len(file_set))
예제 #2
0
  parser = argparse.ArgumentParser()
  parser.add_argument('data')
  parser.add_argument('saved')
  parser.add_argument('--save', default=None)
  parser.add_argument('--logdir', default=None)
  parser.add_argument('--batch', default=64, type=int)
  parser.add_argument('--steps', default=None, type=int)
  parser.add_argument('--save_steps', default=100, type=int)
  parser.add_argument('--eval_size', default=0, type=int)
  parser.add_argument('--gpu', default=1, type=bool)
  parser.add_argument('--eval_steps', default=50, type=int)
  param = parser.parse_args()

  train_set, validation_set, test_set = input_data.make_split(param.data)
  print("Train set (%d), validation set (%d), test set (%d)" % (len(train_set), len(validation_set), len(test_set)))
  train_batches = input_data.load_batches(train_set, param.batch, finite=False, shuffle=True, randflip=True, randshift=True, randcrop=True)

  if param.eval_size > 0:
    validation_set = validation_set[:param.eval_size]

  images_op, labels_op = fgo.inputs()
  logits_op = fgo.FGO16().graph(images_op, n_classes=61, wd=5e-4)
  prob_op = fgo.inference(logits_op)
  loss_op = fgo.cost(logits_op, labels_op, param.batch)

  saver = tf.train.Saver()

  train_op, global_step_op, lr_op = fgo.training(loss_op)
  accuracy_op = fgo.accuracy(prob_op, labels_op)
  summary_op = fgo.summaries(images_op, loss_op)