예제 #1
0
    def update_batch_id_and_shuffle(self):
        self.batch_id = self.batch_id + 1
        if self.batch_id >= self.num_batch:
            self.epoch_id = self.epoch_id + 1
            self.batch_id = 0
            self.data_factory.shuffle_train()
            # save model
            if self.epoch_id % FLAGS.checkpoint_step == 0:
                checkpoint_ = os.path.join(CHECKPOINT_PATH, 'model.ckpt')
                self.model.save_model(checkpoint_)
                print("Saved model:", checkpoint_)
            # save generated sample
            if self.epoch_id % FLAGS.vis_freq == 0:
                print('epoch_id:', self.epoch_id)
                data_idx = self.batch_id * FLAGS.batch_size
                f_train = self.data_factory.f_train
                seq_train = self.data_factory.seq_train
                seq_feat = f_train[data_idx:data_idx + FLAGS.batch_size]
                seq_ = seq_train[data_idx:data_idx + FLAGS.batch_size]

                recon = reconstruct_(self.model, seq_, z_samples(), seq_feat)
                sample = recon[:, :, :22]
                samples = self.data_factory.recover_BALL_and_A(sample)
                samples = self.data_factory.recover_B(samples)
                game_visualizer.plot_data(
                    samples[0],
                    FLAGS.seq_length,
                    file_path=SAMPLE_PATH +
                    'reconstruct{}.mp4'.format(self.epoch_id),
                    if_save=True)
예제 #2
0
def mode_8(sess, graph, save_path):
    """ to find high-openshot-penalty data in 1000 real data
    """
    real_data = np.load(FLAGS.data_path)[:, :FLAGS.seq_length, :, :]
    print('real_data.shape', real_data.shape)
    data_factory = DataFactory(real_data)
    train_data, valid_data = data_factory.fetch_data()
    # placeholder tensor
    real_data_t = graph.get_tensor_by_name('real_data:0')
    matched_cond_t = graph.get_tensor_by_name('matched_cond:0')
    # result tensor
    heuristic_penalty_pframe = graph.get_tensor_by_name(
        'Critic/C_inference/heuristic_penalty/Min:0')
    # 'Generator/G_loss/C_inference/linear_result/Reshape:0')

    if not os.path.exists(save_path):
        os.makedirs(save_path)
    real_hp_pframe_all = []
    for batch_id in range(train_data['A'].shape[0] // FLAGS.batch_size):
        index_id = batch_id * FLAGS.batch_size
        real_data = train_data['B'][index_id:index_id + FLAGS.batch_size]
        cond_data = train_data['A'][index_id:index_id + FLAGS.batch_size]
        # real
        feed_dict = {real_data_t: real_data, matched_cond_t: cond_data}
        real_hp_pframe = sess.run(heuristic_penalty_pframe,
                                  feed_dict=feed_dict)
        real_hp_pframe_all.append(real_hp_pframe)
    real_hp_pframe_all = np.concatenate(real_hp_pframe_all, axis=0)
    print(real_hp_pframe_all.shape)
    real_hp_pdata = np.mean(real_hp_pframe_all, axis=1)
    mean_ = np.mean(real_hp_pdata)
    std_ = np.std(real_hp_pdata)
    print(mean_)
    print(std_)

    concat_AB = np.concatenate([train_data['A'], train_data['B']], axis=-1)
    recoverd = data_factory.recover_data(concat_AB)
    for i, v in enumerate(real_hp_pdata):
        if v > (mean_ + 2 * std_):
            print('bad', i, v)
            game_visualizer.plot_data(recoverd[i],
                                      recoverd.shape[1],
                                      file_path=save_path + 'bad_' + str(i) +
                                      '_' + str(v) + '.mp4',
                                      if_save=True)
        if v < 0.0025:
            print('good', i, v)
            game_visualizer.plot_data(recoverd[i],
                                      recoverd.shape[1],
                                      file_path=save_path + 'good_' + str(i) +
                                      '_' + str(v) + '.mp4',
                                      if_save=True)

    print('!!Completely Saved!!')
def mode_6(sess, graph, save_path):
    """ to draw different length result
    """
    # normalize
    real_data = np.load(FLAGS.data_path)
    print('real_data.shape', real_data.shape)
    data_factory = DataFactory(real_data)
    target_data = np.load('FEATURES-7.npy')[:, :]
    team_AB = np.concatenate(
        [
            # ball
            target_data[:, :, 0, :3].reshape(
                [target_data.shape[0], target_data.shape[1], 1 * 3]),
            # team A players
            target_data[:, :, 1:6, :2].reshape(
                [target_data.shape[0], target_data.shape[1], 5 * 2]),
            # team B players
            target_data[:, :, 6:11, :2].reshape(
                [target_data.shape[0], target_data.shape[1], 5 * 2])
        ], axis=-1
    )
    team_AB = data_factory.normalize(team_AB)
    team_A = team_AB[:, :, :13]
    team_B = team_AB[:, :, 13:]
    # placeholder tensor
    latent_input_t = graph.get_tensor_by_name('latent_input:0')
    team_a_t = graph.get_tensor_by_name('team_a:0')
    # result tensor
    result_t = graph.get_tensor_by_name(
        'Generator/G_inference/conv_result/conv1d/Maximum:0')
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    # result collector
    latents = z_samples(team_AB.shape[0])
    feed_dict = {
        latent_input_t: latents,
        team_a_t: team_A
    }
    result_fake_B = sess.run(result_t, feed_dict=feed_dict)
    results_A_fake_B = np.concatenate([team_A, result_fake_B], axis=-1)
    results_A_fake_B = data_factory.recover_data(results_A_fake_B)
    for i in range(results_A_fake_B.shape[0]):
        game_visualizer.plot_data(
            results_A_fake_B[i], target_data.shape[1], file_path=save_path + str(i) + '.mp4', if_save=True)

    print('!!Completely Saved!!')
def training(real_data, normer, config, graph):
    """ training
    """
    # number of batches
    num_batches = real_data.shape[0] // FLAGS.batch_size
    shuffled_indexes = np.random.permutation(real_data.shape[0])
    real_data = real_data[shuffled_indexes]
    real_data, valid_data = np.split(real_data, [real_data.shape[0] // 10 * 9])
    print(real_data.shape)
    print(valid_data.shape)
    exit()
    num_batches = num_batches // 10 * 9
    num_valid_batches = num_batches // 10 * 1
    # model
    C = C_MODEL(config, graph)
    G = G_MODEL(config, C.inference, graph)
    init = tf.global_variables_initializer()
    # saver for later restore
    saver = tf.train.Saver()
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        sess.run(init)
        # restore model if exist
        if FLAGS.restore_path is not None:
            saver.restore(sess, FLAGS.restore_path)
            print('successfully restore model from checkpoint: %s' %
                  (FLAGS.restore_path))

        D_loss_mean = 0.0
        D_valid_loss_mean = 0.0
        G_loss_mean = 0.0
        log_counter = 0
        # to evaluate time cost
        start_time = time.time()
        for epoch_id in range(FLAGS.total_epoches):
            # shuffle the data
            shuffled_indexes = np.random.permutation(real_data.shape[0])
            real_data = real_data[shuffled_indexes]
            shuffled_indexes = np.random.permutation(valid_data.shape[0])
            valid_data = valid_data[shuffled_indexes]

            batch_id = 0
            while batch_id < num_batches - FLAGS.num_train_D:
                real_data_batch = None
                if epoch_id < FLAGS.num_pretrain_D or (
                        epoch_id + 1) % FLAGS.freq_train_D == 0:
                    num_train_D = num_batches * 5  # TODO
                else:
                    num_train_D = FLAGS.num_train_D
                for id_ in range(num_train_D):
                    # make sure not exceed the boundary
                    data_idx = batch_id * \
                        FLAGS.batch_size % (
                            real_data.shape[0] - FLAGS.batch_size)
                    # data
                    real_samples = real_data[data_idx:data_idx +
                                             FLAGS.batch_size]
                    # samples
                    fake_samples = G.generate(sess, z_samples())
                    # train Critic
                    D_loss_mean, global_steps = C.step(sess, fake_samples,
                                                       real_samples)
                    batch_id += 1
                    log_counter += 1

                    # log validation loss
                    data_idx = global_steps * \
                        FLAGS.batch_size % (
                            valid_data.shape[0] - FLAGS.batch_size)
                    valid_real_samples = valid_data[data_idx:data_idx +
                                                    FLAGS.batch_size]
                    D_valid_loss_mean = C.log_valid_loss(
                        sess, fake_samples, valid_real_samples)

                # train G
                G_loss_mean, global_steps = G.step(sess, z_samples())
                log_counter += 1

                # logging
                if log_counter >= FLAGS.log_freq:
                    end_time = time.time()
                    log_counter = 0
                    print(
                        "%d, epoches, %d steps, mean D_loss: %f, mean D_valid_loss: %f, mean G_loss: %f, time cost: %f(sec)"
                        % (epoch_id, global_steps, D_loss_mean,
                           D_valid_loss_mean, G_loss_mean,
                           (end_time - start_time)))
                    start_time = time.time()  # save checkpoints
            # save model
            if (epoch_id % FLAGS.save_model_freq
                ) == 0 or epoch_id == FLAGS.total_epoches - 1:
                save_path = saver.save(sess,
                                       CHECKPOINTS_PATH + "model.ckpt",
                                       global_step=global_steps)
                print("Model saved in file: %s" % save_path)
            # plot generated sample
            if (epoch_id % FLAGS.save_result_freq
                ) == 0 or epoch_id == FLAGS.total_epoches - 1:
                samples = G.generate(sess, z_samples())
                # scale recovering
                samples = normer.recover_data(samples)
                # plot
                game_visualizer.plot_data(samples[0:],
                                          FLAGS.seq_length,
                                          file_path=SAMPLE_PATH +
                                          str(global_steps) + '.gif',
                                          if_save=True)
def training(train_data, valid_data, data_factory, config, default_graph,
             baseline_graph):
    """ training
    """
    tfconfig = tf.ConfigProto()
    tfconfig.gpu_options.allow_growth = True
    default_sess = tf.Session(config=tfconfig, graph=default_graph)

    if baseline_graph is not None:
        baseline_sess = tf.Session(config=tfconfig, graph=baseline_graph)
        with baseline_graph.as_default() as graph:
            baseline_C = C_MODEL_BASE(config, graph, if_training=False)
            saver = tf.train.Saver()
            saver.restore(baseline_sess, FLAGS.baseline_checkpoint)
            print('successfully restore baseline critic from checkpoint: %s' %
                  (FLAGS.baseline_checkpoint))
    with default_graph.as_default() as graph:
        # number of batches
        num_batches = train_data['A'].shape[0] // FLAGS.batch_size
        num_valid_batches = valid_data['A'].shape[0] // FLAGS.batch_size
        print('num_batches', num_batches)
        print('num_valid_batches', num_valid_batches)
        # model
        C = C_MODEL(config, graph)
        G = G_MODEL(config, C.inference, graph)
        init = tf.global_variables_initializer()
        # saver for later restore
        saver = tf.train.Saver(max_to_keep=0)  # 0 -> keep them all

        default_sess.run(init)
        # restore model if exist
        if FLAGS.restore_path is not None:
            saver.restore(default_sess, FLAGS.restore_path)
            print('successfully restore model from checkpoint: %s' %
                  (FLAGS.restore_path))
        D_loss_mean = 0.0
        D_valid_loss_mean = 0.0
        G_loss_mean = 0.0
        log_counter = 0
        # to evaluate time cost
        start_time = time.time()
        for epoch_id in range(FLAGS.total_epoches):
            # shuffle the data
            train_data, valid_data = data_factory.shuffle()

            batch_id = 0
            while batch_id < num_batches - FLAGS.num_train_D:
                real_data_batch = None
                if epoch_id < FLAGS.num_pretrain_D or (
                        epoch_id + 1) % FLAGS.freq_train_D == 0:
                    num_train_D = num_batches
                else:
                    num_train_D = FLAGS.num_train_D
                for id_ in range(num_train_D):
                    # make sure not exceed the boundary
                    data_idx = batch_id * \
                        FLAGS.batch_size % (
                            train_data['B'].shape[0] - FLAGS.batch_size)
                    # data
                    real_samples = train_data['B'][data_idx:data_idx +
                                                   FLAGS.batch_size]
                    real_conds = train_data['A'][data_idx:data_idx +
                                                 FLAGS.batch_size]
                    # samples
                    fake_samples = G.generate(default_sess, z_samples(),
                                              real_conds)
                    # train Critic
                    D_loss_mean, global_steps = C.step(default_sess,
                                                       fake_samples,
                                                       real_samples,
                                                       real_conds)
                    batch_id += 1
                    log_counter += 1

                    # log validation loss
                    data_idx = global_steps * \
                        FLAGS.batch_size % (
                            valid_data['B'].shape[0] - FLAGS.batch_size)
                    valid_real_samples = valid_data['B'][data_idx:data_idx +
                                                         FLAGS.batch_size]
                    valid_real_conds = valid_data['A'][data_idx:data_idx +
                                                       FLAGS.batch_size]
                    fake_samples = G.generate(default_sess, z_samples(),
                                              valid_real_conds)
                    D_valid_loss_mean = C.log_valid_loss(
                        default_sess, fake_samples, valid_real_samples,
                        valid_real_conds)

                    if baseline_graph is not None:
                        # baseline critic eval
                        baseline_C.eval_EM_distance(baseline_sess,
                                                    fake_samples,
                                                    valid_real_samples,
                                                    valid_real_conds,
                                                    global_steps)

                # train G
                G_loss_mean, global_steps = G.step(default_sess, z_samples(),
                                                   real_conds)
                log_counter += 1

                # logging
                if log_counter >= FLAGS.log_freq:
                    end_time = time.time()
                    log_counter = 0
                    print(
                        "%d, epoches, %d steps, mean C_loss: %f, mean C_valid_loss: %f, mean G_loss: %f, time cost: %f(sec)"
                        % (epoch_id, global_steps, D_loss_mean,
                           D_valid_loss_mean, G_loss_mean,
                           (end_time - start_time)))
                    start_time = time.time()  # save checkpoints
            # save model
            if (epoch_id % FLAGS.save_model_freq
                ) == 0 or epoch_id == FLAGS.total_epoches - 1:
                save_path = saver.save(default_sess,
                                       CHECKPOINTS_PATH + "model.ckpt",
                                       global_step=global_steps)
                print("Model saved in file: %s" % save_path)
            # plot generated sample
            if (epoch_id % FLAGS.save_result_freq
                ) == 0 or epoch_id == FLAGS.total_epoches - 1:
                # fake
                samples = G.generate(default_sess, z_samples(), real_conds)
                # print(samples)
                real_samples = train_data['B'][data_idx:data_idx +
                                               FLAGS.batch_size]
                concat_ = np.concatenate([real_conds, samples], axis=-1)
                # print(concat_)
                fake_result = data_factory.recover_data(concat_)
                game_visualizer.plot_data(fake_result[0],
                                          FLAGS.seq_length,
                                          file_path=SAMPLE_PATH +
                                          str(global_steps) + '_fake.mp4',
                                          if_save=True)
                # real
                concat_ = np.concatenate([real_conds, real_samples], axis=-1)
                real_result = data_factory.recover_data(concat_)
                game_visualizer.plot_data(real_result[0],
                                          FLAGS.seq_length,
                                          file_path=SAMPLE_PATH +
                                          str(global_steps) + '_real.mp4',
                                          if_save=True)
def training(sess, model, real_data, num_batches, saver, normer, is_pretrain=False):
    """
    """
    shuffled_indexes = np.random.permutation(real_data.shape[0])
    real_data = real_data[shuffled_indexes]
    real_data, valid_data = np.split(real_data, [real_data.shape[0] //10 * 9])
    print(real_data.shape)
    print(valid_data.shape)
    num_batches = num_batches // 10 * 9
    num_valid_batches = num_batches // 10 * 1

    # fixed sampled result input noise
    sampled_noise = z_samples(real_data)

    if is_pretrain:
        G_loss_mean = 0.0
        # to evaluate time cost
        start_time = time.time()
        for epoch_id in range(FLAGS.pretrain_epoches):
            # shuffle the data
            shuffled_indexes = np.random.permutation(real_data.shape[0])
            real_data = real_data[shuffled_indexes]
            shuffled_indexes = np.random.permutation(valid_data.shape[0])
            valid_data = valid_data[shuffled_indexes]

            batch_id = 0
            for batch_id in range(num_batches):
                # make sure not exceed the boundary
                data_idx = batch_id * FLAGS.batch_size % (
                    real_data.shape[0] - FLAGS.batch_size)
                # data
                real_data_batch = real_data[data_idx:data_idx +
                                            FLAGS.batch_size]
                # pretrain G
                G_loss_mean, global_steps = model.G_pretrain_step(
                    sess, real_data_batch)
                
                # log validation loss
                data_idx = global_steps * FLAGS.batch_size % (
                    valid_data.shape[0] - FLAGS.batch_size)
                valid_data_batch = valid_data[data_idx:data_idx +
                                                FLAGS.batch_size]
                G_valid_loss_mean = model.G_pretrain_log_valid_loss(
                    sess, valid_data_batch)
                # logging
                if batch_id % FLAGS.log_freq == 0:
                    end_time = time.time()
                    print("%d, epoches, %d steps, mean G_loss: %f, mean G_valid_loss: %f, time cost: %f(sec)" %
                          (epoch_id,
                           global_steps,
                           G_loss_mean,
                           G_valid_loss_mean,
                           (end_time - start_time)))
                    start_time = time.time()  # save checkpoints
            # save model
            if (epoch_id % FLAGS.save_model_freq) == 0 or epoch_id == FLAGS.total_epoches - 1:
                save_path = saver.save(
                    sess, FLAGS.checkpoints_dir + "model.ckpt",
                    global_step=global_steps)
                print("Model saved in file: %s" % save_path)
            # plot generated sample
            if (epoch_id % FLAGS.save_result_freq) == 0 or epoch_id == FLAGS.total_epoches - 1:
                # training result
                samples = model.generate_pretrain(
                    sess, real_data_batch)
                samples = normer.recover_data(samples)
                game_visualizer.plot_data(
                    samples[0:], FLAGS.seq_length, file_path=FLAGS.sample_dir + str(global_steps) + '_pretrain_train.gif', if_save=True)
                # testing result
                samples = model.generate(
                    sess, real_data_batch[:, 0, :])
                samples = normer.recover_data(samples)
                game_visualizer.plot_data(
                    samples[0:], FLAGS.seq_length, file_path=FLAGS.sample_dir + str(global_steps) + '_pretrain_test.gif', if_save=True)
    else:
        D_loss_mean = 0.0
        G_loss_mean = 0.0
        log_counter = 0
        # to evaluate time cost
        start_time = time.time()
        for epoch_id in range(FLAGS.total_epoches):
            # shuffle the data
            shuffled_indexes = np.random.permutation(real_data.shape[0])
            real_data = real_data[shuffled_indexes]
            shuffled_indexes = np.random.permutation(valid_data.shape[0])
            valid_data = valid_data[shuffled_indexes]

            batch_id = 0
            while batch_id < num_batches - FLAGS.num_train_D:
                real_data_batch = None
                if epoch_id < FLAGS.num_pretrain_D or (epoch_id + 1) % FLAGS.freq_train_D == 0:
                    num_train_D = num_batches * 5  # TODO
                else:
                    num_train_D = FLAGS.num_train_D
                for id_ in range(num_train_D):
                    # make sure not exceed the boundary
                    data_idx = batch_id * \
                        FLAGS.batch_size % (
                            real_data.shape[0] - FLAGS.batch_size)
                    # data
                    real_data_batch = real_data[data_idx:data_idx +
                                                FLAGS.batch_size]
                    # train D
                    D_loss_mean, global_steps = model.D_step(
                        sess, z_samples(real_data), real_data_batch)
                    batch_id += 1
                    log_counter += 1

                    # log validation loss
                    data_idx = global_steps * \
                        FLAGS.batch_size % (
                            valid_data.shape[0] - FLAGS.batch_size)
                    valid_data_batch = valid_data[data_idx:data_idx +
                                                  FLAGS.batch_size]
                    D_valid_loss_mean = model.D_log_valid_loss(
                        sess, z_samples(real_data), valid_data_batch)

                # train G
                G_loss_mean, global_steps = model.G_step(
                    sess, z_samples(real_data))
                log_counter += 1

                # logging
                if log_counter >= FLAGS.log_freq:
                    end_time = time.time()
                    log_counter = 0
                    print("%d, epoches, %d steps, mean D_loss: %f, mean D_valid_loss: %f, mean G_loss: %f, time cost: %f(sec)" %
                          (epoch_id,
                           global_steps,
                           D_loss_mean,
                           D_valid_loss_mean,
                           G_loss_mean,
                           (end_time - start_time)))
                    start_time = time.time()  # save checkpoints
            # save model
            if (epoch_id % FLAGS.save_model_freq) == 0 or epoch_id == FLAGS.total_epoches - 1:
                save_path = saver.save(
                    sess, FLAGS.checkpoints_dir + "model.ckpt",
                    global_step=global_steps)
                print("Model saved in file: %s" % save_path)
            # plot generated sample
            if (epoch_id % FLAGS.save_result_freq) == 0 or epoch_id == FLAGS.total_epoches - 1:
                samples = model.generate(
                    sess, z_samples(real_data))
                # scale recovering
                samples = normer.recover_data(samples)
                # plot
                game_visualizer.plot_data(
                    samples[0:], FLAGS.seq_length, file_path=FLAGS.sample_dir + str(global_steps) + '.gif', if_save=True)