def main(_): bst_acc = 0.0 with tf.train.MonitoredTrainingSession() as sess: sess.run(init_op) start_time = time.time() for tn_batch in range(tn_num_batch): tn_image_np, tn_label_np = cifar.next_batch(sess) feed_dict = { tn_dis.image_ph: tn_image_np, tn_dis.hard_label_ph: tn_label_np, } sess.run(tn_dis.pre_train, feed_dict=feed_dict) if (tn_batch + 1) % eval_interval != 0 and (tn_batch + 1) != tn_num_batch: continue acc = cifar.compute_acc(sess, vd_dis) bst_acc = max(acc, bst_acc) end_time = time.time() duration = end_time - start_time avg_time = duration / (tn_batch + 1) print('#batch=%d acc=%.4f time=%.4fs/batch est=%.4fh' % (tn_batch + 1, bst_acc, avg_time, avg_time * tn_num_batch / 3600)) if acc < bst_acc: continue tn_dis.saver.save(utils.get_session(sess), flags.dis_model_ckpt) print('#cifar=%d final=%.4f' % (flags.train_size, bst_acc))
def main(_): bst_acc = 0.0 writer = tf.summary.FileWriter(config.logs_dir, graph=tf.get_default_graph()) with tf.train.MonitoredTrainingSession() as sess: sess.run(init_op) start = time.time() for tn_batch in range(tn_num_batch): tn_image_np, tn_label_np = mnist.train.next_batch(flags.batch_size) feed_dict = { tn_dis.image_ph:tn_image_np, tn_dis.hard_label_ph:tn_label_np, } _, summary = sess.run([tn_dis.pre_update, summary_op], feed_dict=feed_dict) writer.add_summary(summary, tn_batch) if (tn_batch + 1) % eval_interval != 0: continue feed_dict = { vd_dis.image_ph:mnist.test.images, vd_dis.hard_label_ph:mnist.test.labels, } acc = sess.run(vd_dis.accuracy, feed_dict=feed_dict) bst_acc = max(acc, bst_acc) tot_time = time.time() - start global_step = sess.run(tn_dis.global_step) avg_time = (tot_time / global_step) * (mnist.train.num_examples / flags.batch_size) print('#%08d curacc=%.4f curbst=%.4f tot=%.0fs avg=%.2fs/epoch' % (tn_batch, acc, bst_acc, tot_time, avg_time)) if acc < bst_acc: continue tn_dis.saver.save(utils.get_session(sess), flags.dis_model_ckpt) tot_time = time.time() - start print('#mnist=%d bstacc=%.4f et=%.0fs' % (tn_size, bst_acc, tot_time))
def main(_): best_prec = 0.0 writer = tf.summary.FileWriter(config.logs_dir, graph=tf.get_default_graph()) with tf.train.MonitoredTrainingSession() as sess: sess.run(init_op) start = time.time() for tn_batch in range(tn_num_batch): tn_image_np, _, tn_label_np = yfccdata.next_batch(flags, sess) feed_dict = { tn_gen.image_ph: tn_image_np, tn_gen.hard_label_ph: tn_label_np } _, summary = sess.run([tn_gen.pre_update, summary_op], feed_dict=feed_dict) writer.add_summary(summary, tn_batch) if (tn_batch + 1) % eval_interval != 0: continue prec = yfcceval.compute_prec(flags, sess, vd_gen) best_prec = max(prec, best_prec) tot_time = time.time() - start global_step = sess.run(tn_gen.global_step) avg_time = (tot_time / global_step) * (tn_size / flags.batch_size) print('#%08d prec@%d=%.4f best=%.4f tot=%.0fs avg=%.2fs/epoch' % (global_step, flags.cutoff, prec, best_prec, tot_time, avg_time)) if prec < best_prec: continue tn_gen.saver.save(utils.get_session(sess), flags.gen_model_ckpt) tot_time = time.time() - start print('best@%d=%.4f et=%.0fs' % (flags.cutoff, best_prec, tot_time))
def main(_): bst_acc = 0.0 acc_list = [] writer = tf.summary.FileWriter(config.logs_dir, graph=tf.get_default_graph()) with tf.train.MonitoredTrainingSession() as sess: sess.run(init_op) start = time.time() for tn_batch in range(tn_num_batch): tn_image_np, tn_label_np = mnist.train.next_batch(flags.batch_size) feed_dict = {tn_gen.image_ph:tn_image_np, tn_gen.hard_label_ph:tn_label_np} _, summary = sess.run([tn_gen.pre_update, summary_op], feed_dict=feed_dict) writer.add_summary(summary, tn_batch) if flags.log_accuracy: feed_dict = { vd_gen.image_ph:mnist.test.images, vd_gen.hard_label_ph:mnist.test.labels, } acc = sess.run(vd_gen.accuracy, feed_dict=feed_dict) acc_list.append(acc) if (tn_batch + 1) % eval_interval != 0: continue else: if (tn_batch + 1) % eval_interval != 0: continue feed_dict = { vd_gen.image_ph:mnist.test.images, vd_gen.hard_label_ph:mnist.test.labels, } acc = sess.run(vd_gen.accuracy, feed_dict=feed_dict) bst_acc = max(acc, bst_acc) tot_time = time.time() - start global_step = sess.run(tn_gen.global_step) avg_time = (tot_time / global_step) * (tn_size / flags.batch_size) print('#%08d curacc=%.4f curbst=%.4f tot=%.0fs avg=%.2fs/epoch' % (tn_batch, acc, bst_acc, tot_time, avg_time)) if acc < bst_acc: continue tn_gen.saver.save(utils.get_session(sess), flags.gen_model_ckpt) tot_time = time.time() - start print('#mnist=%d bstacc=%.4f et=%.0fs' % (tn_size, bst_acc, tot_time)) if flags.log_accuracy: utils.create_pardir(flags.all_learning_curve_p) pickle.dump(acc_list, open(flags.all_learning_curve_p, 'wb'))