def train(args): batch_size = args.batch_size seq_length = args.seq_length num_epochs = args.num_epochs save_every = args.save_every save_dir = args.save_dir data_dir = args.data_dir augment_data = args.augment_data checkpoint = args.checkpoint try: os.makedirs(data_dir) except: pass data_loader = DataLoader(data_dir=data_dir, augment_data=augment_data) model = Model(args) writer = tf.summary.FileWriter(save_dir, graph=tf.get_default_graph()) config = tf.ConfigProto(allow_soft_placement=True, gpu_options={'allow_growth': True}) with tf.Session(config=config) as sess: saver = tf.train.Saver() if not checkpoint: checkpoint = tf.train.latest_checkpoint(save_dir) if checkpoint: saver.restore(sess, checkpoint) print(('=== graph restored ===', checkpoint), file=sys.stderr) else: tf.global_variables_initializer().run() start, global_step = sess.run([model.epoch, model.global_step]) for e in range(start, num_epochs + 1): num_batches, train_batch = data_loader.batch_data( batch_size, seq_length) for x, y, ids in train_batch: t_start = time.time() feed = { model.input_data: x, model.target_data: y, model.motion_id: one_hot(ids, ID_SIZE), model.seq_length: [len(t) for t in x], } _, loss, = sess.run([ model.train_op, model.loss, ], feed) global_step += 1 if global_step % 200 == 0: summary = tf.Summary(value=[ tf.Summary.Value(tag='day3/loss', simple_value=loss), ]) writer.add_summary(summary, global_step=global_step) t_elapsed = time.time() - t_start print("epoch {}, step {}, loss = {:.5f}, elapsed = {:.3f}". format(e, global_step, loss, t_elapsed)) t_start = time.time() if global_step % save_every == 0 and (global_step > 0): checkpoint_path = save_dir + '/' + 'model.ckpt' cp = saver.save(sess, checkpoint_path, global_step=global_step) with open(save_dir + '/' + 'config.pkl', 'wb') as f: pickle.dump(args, f, protocol=2) print("model saved to {}".format(cp)) sess.run([model.update_op], { model.epoch_update: e, model.step_update: global_step }) checkpoint_path = save_dir + '/' + 'model.ckpt' cp = saver.save(sess, checkpoint_path, global_step=global_step) with open(save_dir + '/' + 'config.pkl', 'wb') as f: pickle.dump(args, f, protocol=2) print("model saved to {}".format(cp))