def train(model, resume_path=None): with tf.Graph().as_default(): # configure the training session config = tf.ConfigProto( allow_soft_placement=True, log_device_placement=FLAGS.log_device_placement) sess = tf.Session(config=config) global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False) num_batches_per_epoch = (FLAGS.train_size // FLAGS.batch_size) decay_steps = int(num_batches_per_epoch * FLAGS.num_epochs_per_decay) # Decay the learning rate exponentially based on the number of steps. lr = tf.train.exponential_decay(FLAGS.initial_learning_rate, global_step, decay_steps, FLAGS.learning_rate_decay_factor, staircase=True) train_data, train_labels, train_names, train_weights = reader.inputs( shuffle=True, num_epochs=FLAGS.max_epochs, dataset_partition='train') val_data, val_labels, val_names, val_weights = reader.inputs( shuffle=False, num_epochs=FLAGS.max_epochs, dataset_partition='val') with tf.variable_scope('model'): logits, loss, init_op, init_feed = model.build( train_data, train_labels, train_weights) with tf.variable_scope('model', reuse=True): logits_val, loss_val = model.build(val_data, val_labels, val_weights, is_training=False) tf.summary.scalar('learning_rate', lr) opt = tf.train.AdamOptimizer(lr) grads = opt.compute_gradients(loss) apply_gradient_op = opt.apply_gradients(grads, global_step=global_step) # Add histograms for trainable variables. for var in tf.trainable_variables(): tf.summary.histogram(var.op.name, var) # Add histograms for gradients. grad_tensors = [] for grad, var in grads: grad_tensors += [grad] if grad is not None: tf.summary.histogram(var.op.name + '/gradients', grad) variable_averages = tf.train.ExponentialMovingAverage( FLAGS.moving_average_decay, global_step) variables_averages_op = variable_averages.apply( tf.trainable_variables()) with tf.control_dependencies( [apply_gradient_op, variables_averages_op]): train_op = tf.no_op(name='train') saver = tf.train.Saver(tf.global_variables(), max_to_keep=FLAGS.max_epochs, sharded=False) if len(FLAGS.resume_path) > 0: print('\nRestoring params from:', FLAGS.resume_path) saver.restore(sess, FLAGS.resume_path) sess.run(tf.local_variables_initializer()) else: sess.run(tf.initialize_all_variables()) sess.run(tf.initialize_local_variables()) sess.run(init_op, feed_dict=init_feed) summary_op = tf.summary.merge_all() coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) summary_writer = tf.summary.FileWriter(FLAGS.train_dir, graph=sess.graph) num_batches = FLAGS.train_size // FLAGS.batch_size plot_data = {} plot_data['train_loss'] = [] plot_data['valid_loss'] = [] plot_data['train_iou'] = [] plot_data['valid_iou'] = [] plot_data['train_acc'] = [] plot_data['valid_acc'] = [] plot_data['train_prec'] = [] plot_data['valid_prec'] = [] plot_data['train_rec'] = [] plot_data['valid_rec'] = [] plot_data['lr'] = [] ex_start_time = time.time() global_step_value = 0 visualize_dir = os.path.join(FLAGS.train_dir, 'visualize') for epoch_num in range(1, FLAGS.max_epochs + 1): print('\ntensorboard --logdir=' + FLAGS.train_dir + '\n') confusion_matrix = np.zeros((FLAGS.num_classes, FLAGS.num_classes), dtype=np.uint64) avg_train_loss = 0 for step in range(1, num_batches + 1): start_time = time.time() run_ops = [train_op, logits, loss, global_step, train_labels] if global_step_value % 50 == 0: run_ops += [summary_op] ret_val = sess.run(run_ops) (_, logits_value, loss_value, global_step_value, batch_labels, summary_str) = ret_val summary_writer.add_summary(summary_str, global_step_value) else: ret_val = sess.run(run_ops) (_, logits_value, loss_value, global_step_value, batch_labels) = ret_val duration = time.time() - start_time avg_train_loss += loss_value if step >= num_batches * 4 // 5: collect_confusion(logits_value, batch_labels, confusion_matrix) acc, iou, rec, prec, size = eval_helper.compute_errors( confusion_matrix, 'train', class_names, verbose=False) if step % 5 == 0: num_examples_per_step = FLAGS.batch_size examples_per_sec = num_examples_per_step / duration sec_per_batch = float(duration) format_str = '%s: epoch %d, batch %d / %d,\n' \ 'batch loss= %.2f\n' \ 'avg loss = %.2f\n' \ 'avg_acc= %.2f\n' \ 'avg_iou= %.2f\n' \ 'avg prec=%.2f\n' \ 'avg rec=%.2f\n \ (%.1f examples/sec; %.3f sec/batch)' print(format_str % (train_helper.get_expired_time(ex_start_time), epoch_num, step, num_batches + 1, loss_value, avg_train_loss / step, acc, iou, prec, rec, examples_per_sec, sec_per_batch)) else: if step % 5 == 0: num_examples_per_step = FLAGS.batch_size examples_per_sec = num_examples_per_step / duration sec_per_batch = float(duration) format_str = '%s: epoch %d, batch %d / %d,\n' \ 'batch loss= %.2f\n'\ 'avg loss = %.2f\n \ (%.1f examples/sec; %.3f sec/batch)' print(format_str % (train_helper.get_expired_time(ex_start_time), epoch_num, step, num_batches + 1, loss_value, avg_train_loss / step, examples_per_sec, sec_per_batch)) train_loss = avg_train_loss / num_batches valid_acc, valid_iou, valid_prec, valid_rec, valid_loss = evaluate( sess, epoch_num, val_labels, logits_val, loss_val, data_size=FLAGS.val_size, name='validation') plot_data['train_loss'] += [train_loss] plot_data['valid_loss'] += [valid_loss] plot_data['train_iou'] += [iou] plot_data['valid_iou'] += [valid_iou] plot_data['train_acc'] += [acc] plot_data['valid_acc'] += [valid_acc] plot_data['train_prec'] += [prec] plot_data['valid_prec'] += [valid_prec] plot_data['train_rec'] += [rec] plot_data['valid_rec'] += [valid_rec] plot_data['lr'] += [lr.eval(session=sess)] eval_helper.plot_training_progress(visualize_dir, plot_data) if valid_iou >= max(plot_data['valid_iou']): print('Saving model...') t = time.time() checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt') saver.save(sess, checkpoint_path) print('Model is saved! t={}'.format( train_helper.get_expired_time(t))) coord.request_stop() coord.join(threads) sess.close()
def plot_results(train_data, valid_data): eval_helper.plot_training_progress(os.path.join(FLAGS.train_dir, 'stats'), train_data, valid_data)