Example #1
0
def train(args):

    batch_size = args.batch_size
    seq_length = args.seq_length
    num_epochs = args.num_epochs
    save_every = args.save_every
    save_dir = args.save_dir
    data_dir = args.data_dir
    augment_data = args.augment_data
    checkpoint = args.checkpoint

    try:
        os.makedirs(data_dir)
    except:
        pass

    data_loader = DataLoader(data_dir=data_dir, augment_data=augment_data)

    model = Model(args)

    writer = tf.summary.FileWriter(save_dir, graph=tf.get_default_graph())

    config = tf.ConfigProto(allow_soft_placement=True,
                            gpu_options={'allow_growth': True})

    with tf.Session(config=config) as sess:

        saver = tf.train.Saver()
        if not checkpoint:
            checkpoint = tf.train.latest_checkpoint(save_dir)
        if checkpoint:
            saver.restore(sess, checkpoint)
            print(('=== graph restored ===', checkpoint), file=sys.stderr)
        else:
            tf.global_variables_initializer().run()

        start, global_step = sess.run([model.epoch, model.global_step])

        for e in range(start, num_epochs + 1):

            num_batches, train_batch = data_loader.batch_data(
                batch_size, seq_length)

            for x, y, ids in train_batch:

                t_start = time.time()

                feed = {
                    model.input_data: x,
                    model.target_data: y,
                    model.motion_id: one_hot(ids, ID_SIZE),
                    model.seq_length: [len(t) for t in x],
                }
                _, loss, = sess.run([
                    model.train_op,
                    model.loss,
                ], feed)

                global_step += 1

                if global_step % 200 == 0:
                    summary = tf.Summary(value=[
                        tf.Summary.Value(tag='day3/loss', simple_value=loss),
                    ])
                    writer.add_summary(summary, global_step=global_step)

                    t_elapsed = time.time() - t_start
                    print("epoch {}, step {}, loss = {:.5f}, elapsed = {:.3f}".
                          format(e, global_step, loss, t_elapsed))
                    t_start = time.time()

                if global_step % save_every == 0 and (global_step > 0):
                    checkpoint_path = save_dir + '/' + 'model.ckpt'
                    cp = saver.save(sess,
                                    checkpoint_path,
                                    global_step=global_step)
                    with open(save_dir + '/' + 'config.pkl', 'wb') as f:
                        pickle.dump(args, f, protocol=2)
                    print("model saved to {}".format(cp))

                sess.run([model.update_op], {
                    model.epoch_update: e,
                    model.step_update: global_step
                })

        checkpoint_path = save_dir + '/' + 'model.ckpt'
        cp = saver.save(sess, checkpoint_path, global_step=global_step)
        with open(save_dir + '/' + 'config.pkl', 'wb') as f:
            pickle.dump(args, f, protocol=2)
        print("model saved to {}".format(cp))