def evaluate(net, pose_loss_op, test_iterator, summary_writer, tag='test/pose_loss'): test_it = copy.copy(test_iterator) total_loss = 0.0 cnt = 0 num_batches = int(math.ceil(len(test_it.dataset) / test_it.batch_size)) print(len(test_it.dataset)) for batch in tqdm(test_it, total=num_batches): feed_dict = regressionnet.fill_joint_feed_dict( net, regressionnet.batch2feeds(batch)[:3], conv_lr=0.0, fc_lr=0.0, phase='test') global_step, loss_value = net.sess.run( [net.global_iter_counter, pose_loss_op], feed_dict=feed_dict) total_loss += loss_value * len(batch) cnt += len(batch) avg_loss = total_loss / len(test_it.dataset) print('Step {} {} = {:.3f}'.format(global_step, tag, avg_loss)) summary_writer.add_summary(create_sumamry(tag, avg_loss), global_step=global_step) assert cnt == 1000, 'cnt = {}'.format(cnt)
def train_loop(net, saver, loss_op, pose_loss_op, train_op, dataset_name, train_iterator, test_iterator, val_iterator=None, max_iter=None, test_step=None, snapshot_step=None, log_step=1, batch_size=None, conv_lr=None, fc_lr=None, fix_conv_iter=None, output_dir='results', ): summary_step = 50 with net.graph.as_default(): summary_writer = tf.train.SummaryWriter(output_dir, net.sess.graph) summary_op = tf.merge_all_summaries() fc_train_op = net.graph.get_operation_by_name('fc_train_op') global_step = None for step in xrange(max_iter + 1): # test, snapshot if step % test_step == 0 or step + 1 == max_iter or step == fix_conv_iter: global_step = net.sess.run(net.global_iter_counter) evaluate_pcp(net, pose_loss_op, test_iterator, summary_writer, dataset_name=dataset_name, tag_prefix='test') if val_iterator is not None: evaluate_pcp(net, pose_loss_op, val_iterator, summary_writer, dataset_name=dataset_name, tag_prefix='val') if step % snapshot_step == 0 and step > 1: checkpoint_prefix = os.path.join(output_dir, 'checkpoint') assert global_step is not None saver.save(net.sess, checkpoint_prefix, global_step=global_step) if step == max_iter: break # training start_time = time.time() feed_dict = regressionnet.fill_joint_feed_dict(net, regressionnet.batch2feeds(train_iterator.next())[:3], conv_lr=conv_lr, fc_lr=fc_lr, phase='train') if step < fix_conv_iter: feed_dict['lr/conv_lr:0'] = 0.0 if step < fix_conv_iter: cur_train_op = fc_train_op else: cur_train_op = train_op if step % summary_step == 0: global_step, summary_str, _, loss_value = net.sess.run( [net.global_iter_counter, summary_op, cur_train_op, pose_loss_op], feed_dict=feed_dict) summary_writer.add_summary(summary_str, global_step=global_step) else: global_step, _, loss_value = net.sess.run( [net.global_iter_counter, cur_train_op, pose_loss_op], feed_dict=feed_dict) duration = time.time() - start_time if step % log_step == 0 or step + 1 == max_iter: print('Step %d: train/pose_loss = %.2f (%.3f s, %.2f im/s)' % (global_step, loss_value, duration, batch_size // duration))