def training(train_data, valid_data, data_factory, config, default_graph, baseline_graph): """ training """ tfconfig = tf.ConfigProto() tfconfig.gpu_options.allow_growth = True default_sess = tf.Session(config=tfconfig, graph=default_graph) if baseline_graph is not None: baseline_sess = tf.Session(config=tfconfig, graph=baseline_graph) with baseline_graph.as_default() as graph: baseline_C = C_MODEL_BASE(config, graph, if_training=False) saver = tf.train.Saver() saver.restore(baseline_sess, FLAGS.baseline_checkpoint) print('successfully restore baseline critic from checkpoint: %s' % (FLAGS.baseline_checkpoint)) with default_graph.as_default() as graph: # number of batches num_batches = train_data['A'].shape[0] // FLAGS.batch_size num_valid_batches = valid_data['A'].shape[0] // FLAGS.batch_size print('num_batches', num_batches) print('num_valid_batches', num_valid_batches) # model C = C_MODEL(config, graph) G = G_MODEL(config, C.inference, graph) init = tf.global_variables_initializer() # saver for later restore saver = tf.train.Saver(max_to_keep=0) # 0 -> keep them all default_sess.run(init) # restore model if exist if FLAGS.restore_path is not None: saver.restore(default_sess, FLAGS.restore_path) print('successfully restore model from checkpoint: %s' % (FLAGS.restore_path)) D_loss_mean = 0.0 D_valid_loss_mean = 0.0 G_loss_mean = 0.0 log_counter = 0 # to evaluate time cost start_time = time.time() for epoch_id in range(FLAGS.total_epoches): # shuffle the data train_data, valid_data = data_factory.shuffle() batch_id = 0 while batch_id < num_batches - FLAGS.num_train_D: real_data_batch = None if epoch_id < FLAGS.num_pretrain_D or ( epoch_id + 1) % FLAGS.freq_train_D == 0: num_train_D = num_batches else: num_train_D = FLAGS.num_train_D for id_ in range(num_train_D): # make sure not exceed the boundary data_idx = batch_id * \ FLAGS.batch_size % ( train_data['B'].shape[0] - FLAGS.batch_size) # data real_samples = train_data['B'][data_idx:data_idx + FLAGS.batch_size] real_conds = train_data['A'][data_idx:data_idx + FLAGS.batch_size] # samples fake_samples = G.generate(default_sess, z_samples(), real_conds) # train Critic D_loss_mean, global_steps = C.step(default_sess, fake_samples, real_samples, real_conds) batch_id += 1 log_counter += 1 # log validation loss data_idx = global_steps * \ FLAGS.batch_size % ( valid_data['B'].shape[0] - FLAGS.batch_size) valid_real_samples = valid_data['B'][data_idx:data_idx + FLAGS.batch_size] valid_real_conds = valid_data['A'][data_idx:data_idx + FLAGS.batch_size] fake_samples = G.generate(default_sess, z_samples(), valid_real_conds) D_valid_loss_mean = C.log_valid_loss( default_sess, fake_samples, valid_real_samples, valid_real_conds) if baseline_graph is not None: # baseline critic eval baseline_C.eval_EM_distance(baseline_sess, fake_samples, valid_real_samples, valid_real_conds, global_steps) # train G G_loss_mean, global_steps = G.step(default_sess, z_samples(), real_conds) log_counter += 1 # logging if log_counter >= FLAGS.log_freq: end_time = time.time() log_counter = 0 print( "%d, epoches, %d steps, mean C_loss: %f, mean C_valid_loss: %f, mean G_loss: %f, time cost: %f(sec)" % (epoch_id, global_steps, D_loss_mean, D_valid_loss_mean, G_loss_mean, (end_time - start_time))) start_time = time.time() # save checkpoints # save model if (epoch_id % FLAGS.save_model_freq ) == 0 or epoch_id == FLAGS.total_epoches - 1: save_path = saver.save(default_sess, CHECKPOINTS_PATH + "model.ckpt", global_step=global_steps) print("Model saved in file: %s" % save_path) # plot generated sample if (epoch_id % FLAGS.save_result_freq ) == 0 or epoch_id == FLAGS.total_epoches - 1: # fake samples = G.generate(default_sess, z_samples(), real_conds) # print(samples) real_samples = train_data['B'][data_idx:data_idx + FLAGS.batch_size] concat_ = np.concatenate([real_conds, samples], axis=-1) # print(concat_) fake_result = data_factory.recover_data(concat_) game_visualizer.plot_data(fake_result[0], FLAGS.seq_length, file_path=SAMPLE_PATH + str(global_steps) + '_fake.mp4', if_save=True) # real concat_ = np.concatenate([real_conds, real_samples], axis=-1) real_result = data_factory.recover_data(concat_) game_visualizer.plot_data(real_result[0], FLAGS.seq_length, file_path=SAMPLE_PATH + str(global_steps) + '_real.mp4', if_save=True)
def training(real_data, normer, config, graph): """ training """ # number of batches num_batches = real_data.shape[0] // FLAGS.batch_size shuffled_indexes = np.random.permutation(real_data.shape[0]) real_data = real_data[shuffled_indexes] real_data, valid_data = np.split(real_data, [real_data.shape[0] // 10 * 9]) print(real_data.shape) print(valid_data.shape) exit() num_batches = num_batches // 10 * 9 num_valid_batches = num_batches // 10 * 1 # model C = C_MODEL(config, graph) G = G_MODEL(config, C.inference, graph) init = tf.global_variables_initializer() # saver for later restore saver = tf.train.Saver() config = tf.ConfigProto() config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: sess.run(init) # restore model if exist if FLAGS.restore_path is not None: saver.restore(sess, FLAGS.restore_path) print('successfully restore model from checkpoint: %s' % (FLAGS.restore_path)) D_loss_mean = 0.0 D_valid_loss_mean = 0.0 G_loss_mean = 0.0 log_counter = 0 # to evaluate time cost start_time = time.time() for epoch_id in range(FLAGS.total_epoches): # shuffle the data shuffled_indexes = np.random.permutation(real_data.shape[0]) real_data = real_data[shuffled_indexes] shuffled_indexes = np.random.permutation(valid_data.shape[0]) valid_data = valid_data[shuffled_indexes] batch_id = 0 while batch_id < num_batches - FLAGS.num_train_D: real_data_batch = None if epoch_id < FLAGS.num_pretrain_D or ( epoch_id + 1) % FLAGS.freq_train_D == 0: num_train_D = num_batches * 5 # TODO else: num_train_D = FLAGS.num_train_D for id_ in range(num_train_D): # make sure not exceed the boundary data_idx = batch_id * \ FLAGS.batch_size % ( real_data.shape[0] - FLAGS.batch_size) # data real_samples = real_data[data_idx:data_idx + FLAGS.batch_size] # samples fake_samples = G.generate(sess, z_samples()) # train Critic D_loss_mean, global_steps = C.step(sess, fake_samples, real_samples) batch_id += 1 log_counter += 1 # log validation loss data_idx = global_steps * \ FLAGS.batch_size % ( valid_data.shape[0] - FLAGS.batch_size) valid_real_samples = valid_data[data_idx:data_idx + FLAGS.batch_size] D_valid_loss_mean = C.log_valid_loss( sess, fake_samples, valid_real_samples) # train G G_loss_mean, global_steps = G.step(sess, z_samples()) log_counter += 1 # logging if log_counter >= FLAGS.log_freq: end_time = time.time() log_counter = 0 print( "%d, epoches, %d steps, mean D_loss: %f, mean D_valid_loss: %f, mean G_loss: %f, time cost: %f(sec)" % (epoch_id, global_steps, D_loss_mean, D_valid_loss_mean, G_loss_mean, (end_time - start_time))) start_time = time.time() # save checkpoints # save model if (epoch_id % FLAGS.save_model_freq ) == 0 or epoch_id == FLAGS.total_epoches - 1: save_path = saver.save(sess, CHECKPOINTS_PATH + "model.ckpt", global_step=global_steps) print("Model saved in file: %s" % save_path) # plot generated sample if (epoch_id % FLAGS.save_result_freq ) == 0 or epoch_id == FLAGS.total_epoches - 1: samples = G.generate(sess, z_samples()) # scale recovering samples = normer.recover_data(samples) # plot game_visualizer.plot_data(samples[0:], FLAGS.seq_length, file_path=SAMPLE_PATH + str(global_steps) + '.gif', if_save=True)
def rnn(): """ to collect results vary in length Saved Result ------------ results_A_fake_B : float, numpy ndarray, shape=[n_latents=100, n_conditions=100, length=100, features=23] Real A + Fake B results_A_real_B : float, numpy ndarray, shape=[n_conditions=100, length=100, features=23] Real A + Real B results_critic_scores : float, numpy ndarray, shape=[n_latents=100, n_conditions=100] critic scores for each input data """ save_path = os.path.join(COLLECT_PATH, 'rnn') if not os.path.exists(save_path): os.makedirs(save_path) real_data = np.load(FLAGS.data_path)[:, :FLAGS.seq_length, :, :] print('real_data.shape', real_data.shape) # DataFactory data_factory = DataFactory(real_data) # target data target_data = np.load('../../data/FixedFPS5.npy')[-100:] target_length = np.load('../../data/FixedFPS5Length.npy')[-100:] print('target_data.shape', target_data.shape) team_AB = np.concatenate( [ # ball target_data[:, :, 0, :3].reshape( [target_data.shape[0], target_data.shape[1], 1 * 3]), # team A players target_data[:, :, 1:6, :2].reshape( [target_data.shape[0], target_data.shape[1], 5 * 2]), # team B players target_data[:, :, 6:11, :2].reshape( [target_data.shape[0], target_data.shape[1], 5 * 2]) ], axis=-1 ) team_AB = data_factory.normalize(team_AB) team_A = team_AB[:, :, :13] team_B = team_AB[:, :, 13:] # result collector results_A_fake_B = [] results_A_real_B = [] config = TrainingConfig(235) with tf.get_default_graph().as_default() as graph: # model C = C_MODEL(config, graph) G = G_MODEL(config, C.inference, graph) tfconfig = tf.ConfigProto() tfconfig.gpu_options.allow_growth = True default_sess = tf.Session(config=tfconfig, graph=graph) # saver for later restore saver = tf.train.Saver(max_to_keep=0) # 0 -> keep them all # restore model if exist saver.restore(default_sess, FLAGS.restore_path) print('successfully restore model from checkpoint: %s' % (FLAGS.restore_path)) for idx in range(team_AB.shape[0]): # given 100(FLAGS.n_latents) latents generate 100 results on same condition at once real_samples = team_B[idx:idx + 1, :] real_samples = np.concatenate( [real_samples for _ in range(FLAGS.n_latents)], axis=0) real_conds = team_A[idx:idx + 1, :] real_conds = np.concatenate( [real_conds for _ in range(FLAGS.n_latents)], axis=0) # generate result latents = z_samples(FLAGS.n_latents) result = G.generate(default_sess, latents, real_conds) # calculate em distance recoverd_A_fake_B = data_factory.recover_data( np.concatenate([real_conds, result], axis=-1)) # padding to length=200 dummy = np.zeros( shape=[FLAGS.n_latents, team_AB.shape[1] - target_length[idx], team_AB.shape[2]]) temp_A_fake_B_concat = np.concatenate( [recoverd_A_fake_B[:, :target_length[idx]], dummy], axis=1) results_A_fake_B.append(temp_A_fake_B_concat) print(np.array(results_A_fake_B).shape) # concat along with conditions dimension (axis=1) results_A_fake_B = np.stack(results_A_fake_B, axis=1) # real data results_A = data_factory.recover_BALL_and_A(team_A) results_real_B = data_factory.recover_B(team_B) results_A_real_B = data_factory.recover_data(team_AB) # saved as numpy print(np.array(results_A_fake_B).shape) print(np.array(results_A_real_B).shape) np.save(os.path.join(save_path, 'results_A_fake_B.npy'), np.array(results_A_fake_B).astype(np.float32).reshape([FLAGS.n_latents, team_AB.shape[0], team_AB.shape[1], 23])) np.save(os.path.join(save_path, 'results_A_real_B.npy'), np.array(results_A_real_B).astype(np.float32).reshape([team_AB.shape[0], team_AB.shape[1], 23])) print('!!Completely Saved!!')