Exemplo n.º 1
0
def evaluate(net,
             pose_loss_op,
             test_iterator,
             summary_writer,
             tag='test/pose_loss'):
    test_it = copy.copy(test_iterator)
    total_loss = 0.0
    cnt = 0
    num_batches = int(math.ceil(len(test_it.dataset) / test_it.batch_size))
    print(len(test_it.dataset))
    for batch in tqdm(test_it, total=num_batches):
        feed_dict = regressionnet.fill_joint_feed_dict(
            net,
            regressionnet.batch2feeds(batch)[:3],
            conv_lr=0.0,
            fc_lr=0.0,
            phase='test')
        global_step, loss_value = net.sess.run(
            [net.global_iter_counter, pose_loss_op], feed_dict=feed_dict)
        total_loss += loss_value * len(batch)
        cnt += len(batch)
    avg_loss = total_loss / len(test_it.dataset)
    print('Step {} {} = {:.3f}'.format(global_step, tag, avg_loss))
    summary_writer.add_summary(create_sumamry(tag, avg_loss),
                               global_step=global_step)
    assert cnt == 1000, 'cnt = {}'.format(cnt)
Exemplo n.º 2
0
def train_loop(net, saver, loss_op, pose_loss_op, train_op, dataset_name, train_iterator, test_iterator,
               val_iterator=None,
               max_iter=None,
               test_step=None,
               snapshot_step=None,
               log_step=1,
               batch_size=None,
               conv_lr=None,
               fc_lr=None,
               fix_conv_iter=None,
               output_dir='results',
               ):

    summary_step = 50

    with net.graph.as_default():
        summary_writer = tf.train.SummaryWriter(output_dir, net.sess.graph)
        summary_op = tf.merge_all_summaries()
        fc_train_op = net.graph.get_operation_by_name('fc_train_op')
    global_step = None

    for step in xrange(max_iter + 1):

        # test, snapshot
        if step % test_step == 0 or step + 1 == max_iter or step == fix_conv_iter:
            global_step = net.sess.run(net.global_iter_counter)
            evaluate_pcp(net, pose_loss_op, test_iterator, summary_writer,
                         dataset_name=dataset_name,
                         tag_prefix='test')
            if val_iterator is not None:
                evaluate_pcp(net, pose_loss_op, val_iterator, summary_writer,
                             dataset_name=dataset_name,
                             tag_prefix='val')

        if step % snapshot_step == 0 and step > 1:
            checkpoint_prefix = os.path.join(output_dir, 'checkpoint')
            assert global_step is not None
            saver.save(net.sess, checkpoint_prefix, global_step=global_step)
        if step == max_iter:
            break

        # training
        start_time = time.time()
        feed_dict = regressionnet.fill_joint_feed_dict(net,
                                                       regressionnet.batch2feeds(train_iterator.next())[:3],
                                                       conv_lr=conv_lr,
                                                       fc_lr=fc_lr,
                                                       phase='train')
        if step < fix_conv_iter:
            feed_dict['lr/conv_lr:0'] = 0.0

        if step < fix_conv_iter:
            cur_train_op = fc_train_op
        else:
            cur_train_op = train_op

        if step % summary_step == 0:
            global_step, summary_str, _, loss_value = net.sess.run(
                [net.global_iter_counter,
                 summary_op,
                 cur_train_op,
                 pose_loss_op],
                feed_dict=feed_dict)
            summary_writer.add_summary(summary_str, global_step=global_step)
        else:
            global_step, _, loss_value = net.sess.run(
                [net.global_iter_counter, cur_train_op, pose_loss_op],
                feed_dict=feed_dict)

        duration = time.time() - start_time

        if step % log_step == 0 or step + 1 == max_iter:
            print('Step %d: train/pose_loss = %.2f (%.3f s, %.2f im/s)'
                  % (global_step, loss_value, duration,
                     batch_size // duration))