Exemplo n.º 1
0
def main():
    parser = argparse.ArgumentParser(description='Train Network')
    parser.add_argument('--data-dir', default='data', help='data directory')
    args = parser.parse_args()
    
    td = TrainingData(args.data_dir)
    
    with tf.Session() as sess:
        net = SSD(sess)
        net.create_from_vgg(args.vgg_dir, td.num_classes, td.conf)
        
        labels = tf.placeholder(tf.float32, shape=[None, None, td.num_classes+5])
        optimizer, loss = net.get_optimizer(labels)
        summary_writer = tf.summary.FileWriter(args.tensorboard_dir, sess.graph)
        saver = tf.train.Saver(max_to_keep=10)
        n_batches = int(math.ceil(td.num_train/args.batch_size))
        init_vars(sess)
        
        validation_loss = tf.placeholder(tf.float32)
        validation_loss_summary_op = tf.summary.scalar('validation_loss', validation_loss)
        training_loss = tf.placeholder(tf.float32)
        training_loss_summary_op = tf.summary.scalar('training_loss', training_loss)
        
        for e in range(args.epochs):
            generator = td.train_generator(args.batch_size)
            description = 'Epoch {}/{}'.format(e+1, args.epochs)
            training_loss_total = 0
            for x, y in tqdm(generator, total=n_train_batches, desc=description, unit='batches'):
                feed = {net.image_input: x,
                       labels: y, net.keep_prob: 1}
                loss_batch, _ = sess.run([loss, optimizer], feed_dict=feed)
                training_loss_total += loss_batch * x.shape[0]
            training_loss_total /= td.num_train
            
            generator = tf.valid_generator(args.batch_size)
            validation_loss_total = 0
            for x, y in generator:
                feed = {net.image_input: x,
                       labels: y, net.keep_prob: 1}
                loss_batch, _ = sess.run([loss], feed_dict=feed)
                validation_loss_total += loss_batch * x.shape[0]
            validation_loss_total /= td.num_valid
                
            feed = {validation_loss: validation_loss_total,
                    training_loss:   training_loss_total}
            loss_summary = sess.run([validation_loss_summary_op,
                                     training_loss_summary_op],
                                    feed_dict=feed)
            summary_writer.add_summary(loss_summary[0], e)
            summary_writer.add_summary(loss_summary[1], e)    
            
            if (e+1) % args.checkpoint_interval == 0:
                checkpoint = '{}/e{}.ckpt'.format(args.name, e+1)
                saver.save(sess, checkpoint)

        checkpoint = '{}/final.ckpt'.format(args.name)
        saver.save(sess, checkpoint)
        
    return 0