コード例 #1
0
dis_summary_op = tf.summary.merge([
  tf.summary.scalar(tn_dis.learning_rate.name, tn_dis.learning_rate),
  tf.summary.scalar(tn_dis.gan_loss.name, tn_dis.gan_loss),
])
gen_summary_op = tf.summary.merge([
  tf.summary.scalar(tn_gen.learning_rate.name, tn_gen.learning_rate),
  tf.summary.scalar(tn_gen.kdgan_loss.name, tn_gen.kdgan_loss),
])
tch_summary_op = tf.summary.merge([
  tf.summary.scalar(tn_tch.learning_rate.name, tn_tch.learning_rate),
  tf.summary.scalar(tn_tch.kdgan_loss.name, tn_tch.kdgan_loss),
])
init_op = tf.global_variables_initializer()

yfccdata_d = data_utils.YFCCDATA(flags)
yfccdata_g = data_utils.YFCCDATA(flags)
yfccdata_t = data_utils.YFCCDATA(flags)
yfcceval = data_utils.YFCCEVAL(flags)

def main(_):
  best_prec, bst_epk = 0.0, 0
  epk_score_list = []
  writer = tf.summary.FileWriter(config.logs_dir, graph=tf.get_default_graph())
  with tf.train.MonitoredTrainingSession() as sess:
    sess.run(init_op)
    tn_dis.saver.restore(sess, flags.dis_model_ckpt)
    tn_gen.saver.restore(sess, flags.gen_model_ckpt)
    tn_tch.saver.restore(sess, flags.tch_model_ckpt)
    start = time.time()
コード例 #2
0
scope = tf.get_variable_scope()
scope.reuse_variables()
vd_gen = GEN(flags, is_training=False)

tf.summary.scalar(tn_gen.learning_rate.name, tn_gen.learning_rate)
tf.summary.scalar(tn_gen.pre_loss.name, tn_gen.pre_loss)
summary_op = tf.summary.merge_all()
init_op = tf.global_variables_initializer()

for variable in tf.trainable_variables():
    num_params = 1
    for dim in variable.shape:
        num_params *= dim.value
    print('%-50s (%d params)' % (variable.name, num_params))

yfccdata = data_utils.YFCCDATA(flags)
yfcceval = data_utils.YFCCEVAL(flags)


def main(_):
    best_prec = 0.0
    writer = tf.summary.FileWriter(config.logs_dir,
                                   graph=tf.get_default_graph())
    with tf.train.MonitoredTrainingSession() as sess:
        sess.run(init_op)
        start = time.time()
        for tn_batch in range(tn_num_batch):
            tn_image_np, _, tn_label_np = yfccdata.next_batch(flags, sess)
            feed_dict = {
                tn_gen.image_ph: tn_image_np,
                tn_gen.hard_label_ph: tn_label_np