def update_batch_id_and_shuffle(self): self.batch_id = self.batch_id + 1 if self.batch_id >= self.num_batch: self.epoch_id = self.epoch_id + 1 self.batch_id = 0 self.data_factory.shuffle_train() # save model if self.epoch_id % FLAGS.checkpoint_step == 0: checkpoint_ = os.path.join(CHECKPOINT_PATH, 'model.ckpt') self.model.save_model(checkpoint_) print("Saved model:", checkpoint_) # save generated sample if self.epoch_id % FLAGS.vis_freq == 0: print('epoch_id:', self.epoch_id) data_idx = self.batch_id * FLAGS.batch_size f_train = self.data_factory.f_train seq_train = self.data_factory.seq_train seq_feat = f_train[data_idx:data_idx + FLAGS.batch_size] seq_ = seq_train[data_idx:data_idx + FLAGS.batch_size] recon = reconstruct_(self.model, seq_, z_samples(), seq_feat) sample = recon[:, :, :22] samples = self.data_factory.recover_BALL_and_A(sample) samples = self.data_factory.recover_B(samples) game_visualizer.plot_data( samples[0], FLAGS.seq_length, file_path=SAMPLE_PATH + 'reconstruct{}.mp4'.format(self.epoch_id), if_save=True)
def mode_8(sess, graph, save_path): """ to find high-openshot-penalty data in 1000 real data """ real_data = np.load(FLAGS.data_path)[:, :FLAGS.seq_length, :, :] print('real_data.shape', real_data.shape) data_factory = DataFactory(real_data) train_data, valid_data = data_factory.fetch_data() # placeholder tensor real_data_t = graph.get_tensor_by_name('real_data:0') matched_cond_t = graph.get_tensor_by_name('matched_cond:0') # result tensor heuristic_penalty_pframe = graph.get_tensor_by_name( 'Critic/C_inference/heuristic_penalty/Min:0') # 'Generator/G_loss/C_inference/linear_result/Reshape:0') if not os.path.exists(save_path): os.makedirs(save_path) real_hp_pframe_all = [] for batch_id in range(train_data['A'].shape[0] // FLAGS.batch_size): index_id = batch_id * FLAGS.batch_size real_data = train_data['B'][index_id:index_id + FLAGS.batch_size] cond_data = train_data['A'][index_id:index_id + FLAGS.batch_size] # real feed_dict = {real_data_t: real_data, matched_cond_t: cond_data} real_hp_pframe = sess.run(heuristic_penalty_pframe, feed_dict=feed_dict) real_hp_pframe_all.append(real_hp_pframe) real_hp_pframe_all = np.concatenate(real_hp_pframe_all, axis=0) print(real_hp_pframe_all.shape) real_hp_pdata = np.mean(real_hp_pframe_all, axis=1) mean_ = np.mean(real_hp_pdata) std_ = np.std(real_hp_pdata) print(mean_) print(std_) concat_AB = np.concatenate([train_data['A'], train_data['B']], axis=-1) recoverd = data_factory.recover_data(concat_AB) for i, v in enumerate(real_hp_pdata): if v > (mean_ + 2 * std_): print('bad', i, v) game_visualizer.plot_data(recoverd[i], recoverd.shape[1], file_path=save_path + 'bad_' + str(i) + '_' + str(v) + '.mp4', if_save=True) if v < 0.0025: print('good', i, v) game_visualizer.plot_data(recoverd[i], recoverd.shape[1], file_path=save_path + 'good_' + str(i) + '_' + str(v) + '.mp4', if_save=True) print('!!Completely Saved!!')
def mode_6(sess, graph, save_path): """ to draw different length result """ # normalize real_data = np.load(FLAGS.data_path) print('real_data.shape', real_data.shape) data_factory = DataFactory(real_data) target_data = np.load('FEATURES-7.npy')[:, :] team_AB = np.concatenate( [ # ball target_data[:, :, 0, :3].reshape( [target_data.shape[0], target_data.shape[1], 1 * 3]), # team A players target_data[:, :, 1:6, :2].reshape( [target_data.shape[0], target_data.shape[1], 5 * 2]), # team B players target_data[:, :, 6:11, :2].reshape( [target_data.shape[0], target_data.shape[1], 5 * 2]) ], axis=-1 ) team_AB = data_factory.normalize(team_AB) team_A = team_AB[:, :, :13] team_B = team_AB[:, :, 13:] # placeholder tensor latent_input_t = graph.get_tensor_by_name('latent_input:0') team_a_t = graph.get_tensor_by_name('team_a:0') # result tensor result_t = graph.get_tensor_by_name( 'Generator/G_inference/conv_result/conv1d/Maximum:0') if not os.path.exists(save_path): os.makedirs(save_path) # result collector latents = z_samples(team_AB.shape[0]) feed_dict = { latent_input_t: latents, team_a_t: team_A } result_fake_B = sess.run(result_t, feed_dict=feed_dict) results_A_fake_B = np.concatenate([team_A, result_fake_B], axis=-1) results_A_fake_B = data_factory.recover_data(results_A_fake_B) for i in range(results_A_fake_B.shape[0]): game_visualizer.plot_data( results_A_fake_B[i], target_data.shape[1], file_path=save_path + str(i) + '.mp4', if_save=True) print('!!Completely Saved!!')
def training(real_data, normer, config, graph): """ training """ # number of batches num_batches = real_data.shape[0] // FLAGS.batch_size shuffled_indexes = np.random.permutation(real_data.shape[0]) real_data = real_data[shuffled_indexes] real_data, valid_data = np.split(real_data, [real_data.shape[0] // 10 * 9]) print(real_data.shape) print(valid_data.shape) exit() num_batches = num_batches // 10 * 9 num_valid_batches = num_batches // 10 * 1 # model C = C_MODEL(config, graph) G = G_MODEL(config, C.inference, graph) init = tf.global_variables_initializer() # saver for later restore saver = tf.train.Saver() config = tf.ConfigProto() config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: sess.run(init) # restore model if exist if FLAGS.restore_path is not None: saver.restore(sess, FLAGS.restore_path) print('successfully restore model from checkpoint: %s' % (FLAGS.restore_path)) D_loss_mean = 0.0 D_valid_loss_mean = 0.0 G_loss_mean = 0.0 log_counter = 0 # to evaluate time cost start_time = time.time() for epoch_id in range(FLAGS.total_epoches): # shuffle the data shuffled_indexes = np.random.permutation(real_data.shape[0]) real_data = real_data[shuffled_indexes] shuffled_indexes = np.random.permutation(valid_data.shape[0]) valid_data = valid_data[shuffled_indexes] batch_id = 0 while batch_id < num_batches - FLAGS.num_train_D: real_data_batch = None if epoch_id < FLAGS.num_pretrain_D or ( epoch_id + 1) % FLAGS.freq_train_D == 0: num_train_D = num_batches * 5 # TODO else: num_train_D = FLAGS.num_train_D for id_ in range(num_train_D): # make sure not exceed the boundary data_idx = batch_id * \ FLAGS.batch_size % ( real_data.shape[0] - FLAGS.batch_size) # data real_samples = real_data[data_idx:data_idx + FLAGS.batch_size] # samples fake_samples = G.generate(sess, z_samples()) # train Critic D_loss_mean, global_steps = C.step(sess, fake_samples, real_samples) batch_id += 1 log_counter += 1 # log validation loss data_idx = global_steps * \ FLAGS.batch_size % ( valid_data.shape[0] - FLAGS.batch_size) valid_real_samples = valid_data[data_idx:data_idx + FLAGS.batch_size] D_valid_loss_mean = C.log_valid_loss( sess, fake_samples, valid_real_samples) # train G G_loss_mean, global_steps = G.step(sess, z_samples()) log_counter += 1 # logging if log_counter >= FLAGS.log_freq: end_time = time.time() log_counter = 0 print( "%d, epoches, %d steps, mean D_loss: %f, mean D_valid_loss: %f, mean G_loss: %f, time cost: %f(sec)" % (epoch_id, global_steps, D_loss_mean, D_valid_loss_mean, G_loss_mean, (end_time - start_time))) start_time = time.time() # save checkpoints # save model if (epoch_id % FLAGS.save_model_freq ) == 0 or epoch_id == FLAGS.total_epoches - 1: save_path = saver.save(sess, CHECKPOINTS_PATH + "model.ckpt", global_step=global_steps) print("Model saved in file: %s" % save_path) # plot generated sample if (epoch_id % FLAGS.save_result_freq ) == 0 or epoch_id == FLAGS.total_epoches - 1: samples = G.generate(sess, z_samples()) # scale recovering samples = normer.recover_data(samples) # plot game_visualizer.plot_data(samples[0:], FLAGS.seq_length, file_path=SAMPLE_PATH + str(global_steps) + '.gif', if_save=True)
def training(train_data, valid_data, data_factory, config, default_graph, baseline_graph): """ training """ tfconfig = tf.ConfigProto() tfconfig.gpu_options.allow_growth = True default_sess = tf.Session(config=tfconfig, graph=default_graph) if baseline_graph is not None: baseline_sess = tf.Session(config=tfconfig, graph=baseline_graph) with baseline_graph.as_default() as graph: baseline_C = C_MODEL_BASE(config, graph, if_training=False) saver = tf.train.Saver() saver.restore(baseline_sess, FLAGS.baseline_checkpoint) print('successfully restore baseline critic from checkpoint: %s' % (FLAGS.baseline_checkpoint)) with default_graph.as_default() as graph: # number of batches num_batches = train_data['A'].shape[0] // FLAGS.batch_size num_valid_batches = valid_data['A'].shape[0] // FLAGS.batch_size print('num_batches', num_batches) print('num_valid_batches', num_valid_batches) # model C = C_MODEL(config, graph) G = G_MODEL(config, C.inference, graph) init = tf.global_variables_initializer() # saver for later restore saver = tf.train.Saver(max_to_keep=0) # 0 -> keep them all default_sess.run(init) # restore model if exist if FLAGS.restore_path is not None: saver.restore(default_sess, FLAGS.restore_path) print('successfully restore model from checkpoint: %s' % (FLAGS.restore_path)) D_loss_mean = 0.0 D_valid_loss_mean = 0.0 G_loss_mean = 0.0 log_counter = 0 # to evaluate time cost start_time = time.time() for epoch_id in range(FLAGS.total_epoches): # shuffle the data train_data, valid_data = data_factory.shuffle() batch_id = 0 while batch_id < num_batches - FLAGS.num_train_D: real_data_batch = None if epoch_id < FLAGS.num_pretrain_D or ( epoch_id + 1) % FLAGS.freq_train_D == 0: num_train_D = num_batches else: num_train_D = FLAGS.num_train_D for id_ in range(num_train_D): # make sure not exceed the boundary data_idx = batch_id * \ FLAGS.batch_size % ( train_data['B'].shape[0] - FLAGS.batch_size) # data real_samples = train_data['B'][data_idx:data_idx + FLAGS.batch_size] real_conds = train_data['A'][data_idx:data_idx + FLAGS.batch_size] # samples fake_samples = G.generate(default_sess, z_samples(), real_conds) # train Critic D_loss_mean, global_steps = C.step(default_sess, fake_samples, real_samples, real_conds) batch_id += 1 log_counter += 1 # log validation loss data_idx = global_steps * \ FLAGS.batch_size % ( valid_data['B'].shape[0] - FLAGS.batch_size) valid_real_samples = valid_data['B'][data_idx:data_idx + FLAGS.batch_size] valid_real_conds = valid_data['A'][data_idx:data_idx + FLAGS.batch_size] fake_samples = G.generate(default_sess, z_samples(), valid_real_conds) D_valid_loss_mean = C.log_valid_loss( default_sess, fake_samples, valid_real_samples, valid_real_conds) if baseline_graph is not None: # baseline critic eval baseline_C.eval_EM_distance(baseline_sess, fake_samples, valid_real_samples, valid_real_conds, global_steps) # train G G_loss_mean, global_steps = G.step(default_sess, z_samples(), real_conds) log_counter += 1 # logging if log_counter >= FLAGS.log_freq: end_time = time.time() log_counter = 0 print( "%d, epoches, %d steps, mean C_loss: %f, mean C_valid_loss: %f, mean G_loss: %f, time cost: %f(sec)" % (epoch_id, global_steps, D_loss_mean, D_valid_loss_mean, G_loss_mean, (end_time - start_time))) start_time = time.time() # save checkpoints # save model if (epoch_id % FLAGS.save_model_freq ) == 0 or epoch_id == FLAGS.total_epoches - 1: save_path = saver.save(default_sess, CHECKPOINTS_PATH + "model.ckpt", global_step=global_steps) print("Model saved in file: %s" % save_path) # plot generated sample if (epoch_id % FLAGS.save_result_freq ) == 0 or epoch_id == FLAGS.total_epoches - 1: # fake samples = G.generate(default_sess, z_samples(), real_conds) # print(samples) real_samples = train_data['B'][data_idx:data_idx + FLAGS.batch_size] concat_ = np.concatenate([real_conds, samples], axis=-1) # print(concat_) fake_result = data_factory.recover_data(concat_) game_visualizer.plot_data(fake_result[0], FLAGS.seq_length, file_path=SAMPLE_PATH + str(global_steps) + '_fake.mp4', if_save=True) # real concat_ = np.concatenate([real_conds, real_samples], axis=-1) real_result = data_factory.recover_data(concat_) game_visualizer.plot_data(real_result[0], FLAGS.seq_length, file_path=SAMPLE_PATH + str(global_steps) + '_real.mp4', if_save=True)
def training(sess, model, real_data, num_batches, saver, normer, is_pretrain=False): """ """ shuffled_indexes = np.random.permutation(real_data.shape[0]) real_data = real_data[shuffled_indexes] real_data, valid_data = np.split(real_data, [real_data.shape[0] //10 * 9]) print(real_data.shape) print(valid_data.shape) num_batches = num_batches // 10 * 9 num_valid_batches = num_batches // 10 * 1 # fixed sampled result input noise sampled_noise = z_samples(real_data) if is_pretrain: G_loss_mean = 0.0 # to evaluate time cost start_time = time.time() for epoch_id in range(FLAGS.pretrain_epoches): # shuffle the data shuffled_indexes = np.random.permutation(real_data.shape[0]) real_data = real_data[shuffled_indexes] shuffled_indexes = np.random.permutation(valid_data.shape[0]) valid_data = valid_data[shuffled_indexes] batch_id = 0 for batch_id in range(num_batches): # make sure not exceed the boundary data_idx = batch_id * FLAGS.batch_size % ( real_data.shape[0] - FLAGS.batch_size) # data real_data_batch = real_data[data_idx:data_idx + FLAGS.batch_size] # pretrain G G_loss_mean, global_steps = model.G_pretrain_step( sess, real_data_batch) # log validation loss data_idx = global_steps * FLAGS.batch_size % ( valid_data.shape[0] - FLAGS.batch_size) valid_data_batch = valid_data[data_idx:data_idx + FLAGS.batch_size] G_valid_loss_mean = model.G_pretrain_log_valid_loss( sess, valid_data_batch) # logging if batch_id % FLAGS.log_freq == 0: end_time = time.time() print("%d, epoches, %d steps, mean G_loss: %f, mean G_valid_loss: %f, time cost: %f(sec)" % (epoch_id, global_steps, G_loss_mean, G_valid_loss_mean, (end_time - start_time))) start_time = time.time() # save checkpoints # save model if (epoch_id % FLAGS.save_model_freq) == 0 or epoch_id == FLAGS.total_epoches - 1: save_path = saver.save( sess, FLAGS.checkpoints_dir + "model.ckpt", global_step=global_steps) print("Model saved in file: %s" % save_path) # plot generated sample if (epoch_id % FLAGS.save_result_freq) == 0 or epoch_id == FLAGS.total_epoches - 1: # training result samples = model.generate_pretrain( sess, real_data_batch) samples = normer.recover_data(samples) game_visualizer.plot_data( samples[0:], FLAGS.seq_length, file_path=FLAGS.sample_dir + str(global_steps) + '_pretrain_train.gif', if_save=True) # testing result samples = model.generate( sess, real_data_batch[:, 0, :]) samples = normer.recover_data(samples) game_visualizer.plot_data( samples[0:], FLAGS.seq_length, file_path=FLAGS.sample_dir + str(global_steps) + '_pretrain_test.gif', if_save=True) else: D_loss_mean = 0.0 G_loss_mean = 0.0 log_counter = 0 # to evaluate time cost start_time = time.time() for epoch_id in range(FLAGS.total_epoches): # shuffle the data shuffled_indexes = np.random.permutation(real_data.shape[0]) real_data = real_data[shuffled_indexes] shuffled_indexes = np.random.permutation(valid_data.shape[0]) valid_data = valid_data[shuffled_indexes] batch_id = 0 while batch_id < num_batches - FLAGS.num_train_D: real_data_batch = None if epoch_id < FLAGS.num_pretrain_D or (epoch_id + 1) % FLAGS.freq_train_D == 0: num_train_D = num_batches * 5 # TODO else: num_train_D = FLAGS.num_train_D for id_ in range(num_train_D): # make sure not exceed the boundary data_idx = batch_id * \ FLAGS.batch_size % ( real_data.shape[0] - FLAGS.batch_size) # data real_data_batch = real_data[data_idx:data_idx + FLAGS.batch_size] # train D D_loss_mean, global_steps = model.D_step( sess, z_samples(real_data), real_data_batch) batch_id += 1 log_counter += 1 # log validation loss data_idx = global_steps * \ FLAGS.batch_size % ( valid_data.shape[0] - FLAGS.batch_size) valid_data_batch = valid_data[data_idx:data_idx + FLAGS.batch_size] D_valid_loss_mean = model.D_log_valid_loss( sess, z_samples(real_data), valid_data_batch) # train G G_loss_mean, global_steps = model.G_step( sess, z_samples(real_data)) log_counter += 1 # logging if log_counter >= FLAGS.log_freq: end_time = time.time() log_counter = 0 print("%d, epoches, %d steps, mean D_loss: %f, mean D_valid_loss: %f, mean G_loss: %f, time cost: %f(sec)" % (epoch_id, global_steps, D_loss_mean, D_valid_loss_mean, G_loss_mean, (end_time - start_time))) start_time = time.time() # save checkpoints # save model if (epoch_id % FLAGS.save_model_freq) == 0 or epoch_id == FLAGS.total_epoches - 1: save_path = saver.save( sess, FLAGS.checkpoints_dir + "model.ckpt", global_step=global_steps) print("Model saved in file: %s" % save_path) # plot generated sample if (epoch_id % FLAGS.save_result_freq) == 0 or epoch_id == FLAGS.total_epoches - 1: samples = model.generate( sess, z_samples(real_data)) # scale recovering samples = normer.recover_data(samples) # plot game_visualizer.plot_data( samples[0:], FLAGS.seq_length, file_path=FLAGS.sample_dir + str(global_steps) + '.gif', if_save=True)