def evaluate(args): graph = tf.Graph() with graph.as_default() as g: images = tf.placeholder(tf.float32, [None, 224, 224, 3]) model = VggNetModel(num_classes=1) y_hat = model.inference(images, False) y_hat = tf.reshape(y_hat, [ -1, ]) saver = tf.train.Saver() with tf.Session(graph=graph) as sess: sess.run(tf.global_variables_initializer()) ckpt = tf.train.get_checkpoint_state( args.ckpt_dir) # load up level directory :+'/iqa_model_final.ckpt' if ckpt and ckpt.model_checkpoint_path: logger.info('loading checkpoint:' + ckpt.model_checkpoint_path) saver.restore(sess, ckpt.model_checkpoint_path) else: logger.info("please loading checkpoint!") test_image_paths, test_scores = get_image_list(args) score_set = [] label_set = [] for i in range(len(test_image_paths)): image_tensor, score = parse_test_data(str(test_image_paths[i]), float(test_scores[i])) image = sess.run(image_tensor) predict_score = sess.run(y_hat, feed_dict={images: image}) label_set.append(score) score_set.append(predict_score[0]) if i % 50 == 0: logger.info("image:{}/{}, true score:{}".format( i, len(test_image_paths), score)) logger.info("image:{}/{}, predict_score:{}".format( i, len(test_image_paths), predict_score[0])) srocc, krocc, plcc, rmse, mse = evaluate_metric(label_set, score_set) logger.info( "SROCC_v: %.3f\t KROCC: %.3f\t PLCC_v: %.3f\t RMSE_v: %.3f\t mse: %.3f\n" % (srocc, krocc, plcc, rmse, mse)) logger.info("Test finish!")
def train(args): graph = tf.Graph() with graph.as_default(): global_step = tf.train.create_global_step() # # placeholders for training data imgs = tf.placeholder(tf.float32, [None, args.crop_height, args.crop_width, 3]) scores = tf.placeholder(tf.float32, [None]) dropout_keep_prob = tf.placeholder(tf.float32, []) lr = tf.placeholder(tf.float32, []) with tf.name_scope("create_models"): model = VggNetModel(num_classes=1, dropout_keep_prob=dropout_keep_prob) y_hat = model.inference(imgs, True) y_hat = tf.reshape(y_hat, [ -1, ]) with tf.name_scope("create_loss"): reg_loss = mes(y_hat, scores) with tf.name_scope("create_optimize"): # optimizer = tf.train.GradientDescentOptimizer(learning_rate=lr).minimize(loss) # not converge ?? var_list = [v for v in tf.trainable_variables()] optimizer = tf.train.AdamOptimizer(learning_rate=lr).minimize( reg_loss, var_list=var_list) saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10) tf.summary.scalar('learning_rate', lr) tf.summary.scalar('reg_loss', reg_loss) # Build the summary Tensor based on the TF collection of Summaries. summary_op = tf.summary.merge_all() # Instantiate a SummaryWriter to output summaries and the Graph. timestamp = datetime.fromtimestamp( time.time()).strftime('%Y%m%d-%H:%M') summary_writer = tf.summary.FileWriter(os.path.join( args.logs_dir, 'train/{}-{}'.format(args.exp_name, timestamp)), filename_suffix=args.exp_name) summary_test = tf.summary.FileWriter(os.path.join( args.logs_dir, 'test/{}-{}'.format(args.exp_name, timestamp)), filename_suffix=args.exp_name) train_image_paths, train_scores, test_image_paths, test_scores = get_image_list( args) train_loader = train_generator(train_image_paths, train_scores) train_num_batchs = len(train_image_paths) // args.batch_size + 1 test_loader = val_generator(test_image_paths, test_scores, args.batch_size) test_num_batchs = len(test_image_paths) // args.batch_size + 1 with tf.Session(graph=graph) as sess: sess.run(tf.global_variables_initializer()) ckpt = tf.train.get_checkpoint_state(args.ckpt_dir) counter = 0 if ckpt and ckpt.model_checkpoint_path: counter = __load__(saver, sess, args.ckpt_dir) else: load(saver, sess, args.pretrain_models_path) start_time = time.time() start_step = counter # if counter is not None else 0 base_lr = args.learning_rate for step, (images, targets) in enumerate(train_loader, start_step): if step <= 500: base_lr = args.start_lr + (args.learning_rate - args.start_lr) * step / float(500) else: if (step + 1) % (0.5 * args.iter_max) == 0: base_lr = base_lr / 5 if (step + 1) % (0.8 * args.iter_max) == 0: base_lr = base_lr / 5 # base_lr=(base_lr-base_lr*0.001)/args.iter_max*(args) # other learning rate modify loss_, y_hat_, _ = sess.run( [reg_loss, y_hat, optimizer], feed_dict={ imgs: images, scores: targets, lr: base_lr, dropout_keep_prob: args.dropout_keep_prob }) if (step + 1) % args.summary_step == 0: # logger.info("targets labels is : {}".format(targets)) # logger.info("predict lables is : {}".format(y_hat_)) logger.info( "step %d/%d,reg loss is %f, time %f,learning rate: %.8f" % (step, args.iter_max, loss_, (time.time() - start_time), base_lr)) summary_str = sess.run(summary_op, feed_dict={ imgs: images, scores: targets, lr: base_lr, dropout_keep_prob: args.dropout_keep_prob }) summary_writer.add_summary(summary_str, step) # summary_writer.flush() if (step + 1) % args.test_step == 0: if args.save_ckpt_file: # saver.save(sess, args.checkpoint_dir + 'iteration_' + str(step) + '.ckpt',write_meta_graph=False) save(saver, sess, args.ckpt_dir, step) test_loss = 0 scores_set = np.array([]) lables_set = np.array([]) # for step, (images, targets) in enumerate(test_loader): for i in range(test_num_batchs): images, targets = next(test_loader) loss_, y_hat_ = sess.run( [reg_loss, y_hat], feed_dict={ imgs: images, scores: targets, lr: base_lr, dropout_keep_prob: args.dropout_keep_prob }) test_loss += loss_ scores_set = np.append(scores_set, y_hat_) lables_set = np.append(lables_set, targets) logger.info( 'test_loader step/len(test_loader) :{}/{}'.format( i, test_num_batchs)) # print(type(scores_set), type(lables_set)) # logger.info("scores_set:{}, lables_set:{}.".format(scores_set,lables_set.shape)) srocc, krocc, plcc, rmse, mse = evaluate_metric( lables_set, scores_set) test_loss /= test_num_batchs logger.info( "SROCC_v: %.3f\t KROCC: %.3f\t PLCC_v: %.3f\t RMSE_v: %.3f\t mse: %.3f\t test loss: %.3f\n" % (srocc, krocc, plcc, rmse, mse, test_loss)) s1 = tf.Summary(value=[ tf.Summary.Value(tag='test_loss', simple_value=test_loss) ]) s2 = tf.Summary(value=[ tf.Summary.Value(tag='test_srocc', simple_value=srocc) ]) summary_test.add_summary(s1, step) summary_test.add_summary(s2, step) if step == args.iter_max: saver.save(sess, args.ckpt_dir + '/final_model_' + timestamp + '.ckpt', write_meta_graph=False) logger.info( 'save train_iqa final models max_iter: {}...'.format( args.iter_max)) break logger.info("Optimization finish!")