コード例 #1
0
def training(train_data, valid_data, data_factory, config, default_graph,
             baseline_graph):
    """ training
    """
    tfconfig = tf.ConfigProto()
    tfconfig.gpu_options.allow_growth = True
    default_sess = tf.Session(config=tfconfig, graph=default_graph)

    if baseline_graph is not None:
        baseline_sess = tf.Session(config=tfconfig, graph=baseline_graph)
        with baseline_graph.as_default() as graph:
            baseline_C = C_MODEL_BASE(config, graph, if_training=False)
            saver = tf.train.Saver()
            saver.restore(baseline_sess, FLAGS.baseline_checkpoint)
            print('successfully restore baseline critic from checkpoint: %s' %
                  (FLAGS.baseline_checkpoint))
    with default_graph.as_default() as graph:
        # number of batches
        num_batches = train_data['A'].shape[0] // FLAGS.batch_size
        num_valid_batches = valid_data['A'].shape[0] // FLAGS.batch_size
        print('num_batches', num_batches)
        print('num_valid_batches', num_valid_batches)
        # model
        C = C_MODEL(config, graph)
        G = G_MODEL(config, C.inference, graph)
        init = tf.global_variables_initializer()
        # saver for later restore
        saver = tf.train.Saver(max_to_keep=0)  # 0 -> keep them all

        default_sess.run(init)
        # restore model if exist
        if FLAGS.restore_path is not None:
            saver.restore(default_sess, FLAGS.restore_path)
            print('successfully restore model from checkpoint: %s' %
                  (FLAGS.restore_path))
        D_loss_mean = 0.0
        D_valid_loss_mean = 0.0
        G_loss_mean = 0.0
        log_counter = 0
        # to evaluate time cost
        start_time = time.time()
        for epoch_id in range(FLAGS.total_epoches):
            # shuffle the data
            train_data, valid_data = data_factory.shuffle()

            batch_id = 0
            while batch_id < num_batches - FLAGS.num_train_D:
                real_data_batch = None
                if epoch_id < FLAGS.num_pretrain_D or (
                        epoch_id + 1) % FLAGS.freq_train_D == 0:
                    num_train_D = num_batches
                else:
                    num_train_D = FLAGS.num_train_D
                for id_ in range(num_train_D):
                    # make sure not exceed the boundary
                    data_idx = batch_id * \
                        FLAGS.batch_size % (
                            train_data['B'].shape[0] - FLAGS.batch_size)
                    # data
                    real_samples = train_data['B'][data_idx:data_idx +
                                                   FLAGS.batch_size]
                    real_conds = train_data['A'][data_idx:data_idx +
                                                 FLAGS.batch_size]
                    # samples
                    fake_samples = G.generate(default_sess, z_samples(),
                                              real_conds)
                    # train Critic
                    D_loss_mean, global_steps = C.step(default_sess,
                                                       fake_samples,
                                                       real_samples,
                                                       real_conds)
                    batch_id += 1
                    log_counter += 1

                    # log validation loss
                    data_idx = global_steps * \
                        FLAGS.batch_size % (
                            valid_data['B'].shape[0] - FLAGS.batch_size)
                    valid_real_samples = valid_data['B'][data_idx:data_idx +
                                                         FLAGS.batch_size]
                    valid_real_conds = valid_data['A'][data_idx:data_idx +
                                                       FLAGS.batch_size]
                    fake_samples = G.generate(default_sess, z_samples(),
                                              valid_real_conds)
                    D_valid_loss_mean = C.log_valid_loss(
                        default_sess, fake_samples, valid_real_samples,
                        valid_real_conds)

                    if baseline_graph is not None:
                        # baseline critic eval
                        baseline_C.eval_EM_distance(baseline_sess,
                                                    fake_samples,
                                                    valid_real_samples,
                                                    valid_real_conds,
                                                    global_steps)

                # train G
                G_loss_mean, global_steps = G.step(default_sess, z_samples(),
                                                   real_conds)
                log_counter += 1

                # logging
                if log_counter >= FLAGS.log_freq:
                    end_time = time.time()
                    log_counter = 0
                    print(
                        "%d, epoches, %d steps, mean C_loss: %f, mean C_valid_loss: %f, mean G_loss: %f, time cost: %f(sec)"
                        % (epoch_id, global_steps, D_loss_mean,
                           D_valid_loss_mean, G_loss_mean,
                           (end_time - start_time)))
                    start_time = time.time()  # save checkpoints
            # save model
            if (epoch_id % FLAGS.save_model_freq
                ) == 0 or epoch_id == FLAGS.total_epoches - 1:
                save_path = saver.save(default_sess,
                                       CHECKPOINTS_PATH + "model.ckpt",
                                       global_step=global_steps)
                print("Model saved in file: %s" % save_path)
            # plot generated sample
            if (epoch_id % FLAGS.save_result_freq
                ) == 0 or epoch_id == FLAGS.total_epoches - 1:
                # fake
                samples = G.generate(default_sess, z_samples(), real_conds)
                # print(samples)
                real_samples = train_data['B'][data_idx:data_idx +
                                               FLAGS.batch_size]
                concat_ = np.concatenate([real_conds, samples], axis=-1)
                # print(concat_)
                fake_result = data_factory.recover_data(concat_)
                game_visualizer.plot_data(fake_result[0],
                                          FLAGS.seq_length,
                                          file_path=SAMPLE_PATH +
                                          str(global_steps) + '_fake.mp4',
                                          if_save=True)
                # real
                concat_ = np.concatenate([real_conds, real_samples], axis=-1)
                real_result = data_factory.recover_data(concat_)
                game_visualizer.plot_data(real_result[0],
                                          FLAGS.seq_length,
                                          file_path=SAMPLE_PATH +
                                          str(global_steps) + '_real.mp4',
                                          if_save=True)
コード例 #2
0
def training(real_data, normer, config, graph):
    """ training
    """
    # number of batches
    num_batches = real_data.shape[0] // FLAGS.batch_size
    shuffled_indexes = np.random.permutation(real_data.shape[0])
    real_data = real_data[shuffled_indexes]
    real_data, valid_data = np.split(real_data, [real_data.shape[0] // 10 * 9])
    print(real_data.shape)
    print(valid_data.shape)
    exit()
    num_batches = num_batches // 10 * 9
    num_valid_batches = num_batches // 10 * 1
    # model
    C = C_MODEL(config, graph)
    G = G_MODEL(config, C.inference, graph)
    init = tf.global_variables_initializer()
    # saver for later restore
    saver = tf.train.Saver()
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        sess.run(init)
        # restore model if exist
        if FLAGS.restore_path is not None:
            saver.restore(sess, FLAGS.restore_path)
            print('successfully restore model from checkpoint: %s' %
                  (FLAGS.restore_path))

        D_loss_mean = 0.0
        D_valid_loss_mean = 0.0
        G_loss_mean = 0.0
        log_counter = 0
        # to evaluate time cost
        start_time = time.time()
        for epoch_id in range(FLAGS.total_epoches):
            # shuffle the data
            shuffled_indexes = np.random.permutation(real_data.shape[0])
            real_data = real_data[shuffled_indexes]
            shuffled_indexes = np.random.permutation(valid_data.shape[0])
            valid_data = valid_data[shuffled_indexes]

            batch_id = 0
            while batch_id < num_batches - FLAGS.num_train_D:
                real_data_batch = None
                if epoch_id < FLAGS.num_pretrain_D or (
                        epoch_id + 1) % FLAGS.freq_train_D == 0:
                    num_train_D = num_batches * 5  # TODO
                else:
                    num_train_D = FLAGS.num_train_D
                for id_ in range(num_train_D):
                    # make sure not exceed the boundary
                    data_idx = batch_id * \
                        FLAGS.batch_size % (
                            real_data.shape[0] - FLAGS.batch_size)
                    # data
                    real_samples = real_data[data_idx:data_idx +
                                             FLAGS.batch_size]
                    # samples
                    fake_samples = G.generate(sess, z_samples())
                    # train Critic
                    D_loss_mean, global_steps = C.step(sess, fake_samples,
                                                       real_samples)
                    batch_id += 1
                    log_counter += 1

                    # log validation loss
                    data_idx = global_steps * \
                        FLAGS.batch_size % (
                            valid_data.shape[0] - FLAGS.batch_size)
                    valid_real_samples = valid_data[data_idx:data_idx +
                                                    FLAGS.batch_size]
                    D_valid_loss_mean = C.log_valid_loss(
                        sess, fake_samples, valid_real_samples)

                # train G
                G_loss_mean, global_steps = G.step(sess, z_samples())
                log_counter += 1

                # logging
                if log_counter >= FLAGS.log_freq:
                    end_time = time.time()
                    log_counter = 0
                    print(
                        "%d, epoches, %d steps, mean D_loss: %f, mean D_valid_loss: %f, mean G_loss: %f, time cost: %f(sec)"
                        % (epoch_id, global_steps, D_loss_mean,
                           D_valid_loss_mean, G_loss_mean,
                           (end_time - start_time)))
                    start_time = time.time()  # save checkpoints
            # save model
            if (epoch_id % FLAGS.save_model_freq
                ) == 0 or epoch_id == FLAGS.total_epoches - 1:
                save_path = saver.save(sess,
                                       CHECKPOINTS_PATH + "model.ckpt",
                                       global_step=global_steps)
                print("Model saved in file: %s" % save_path)
            # plot generated sample
            if (epoch_id % FLAGS.save_result_freq
                ) == 0 or epoch_id == FLAGS.total_epoches - 1:
                samples = G.generate(sess, z_samples())
                # scale recovering
                samples = normer.recover_data(samples)
                # plot
                game_visualizer.plot_data(samples[0:],
                                          FLAGS.seq_length,
                                          file_path=SAMPLE_PATH +
                                          str(global_steps) + '.gif',
                                          if_save=True)
コード例 #3
0
def rnn():
    """ to collect results vary in length
    Saved Result
    ------------
    results_A_fake_B : float, numpy ndarray, shape=[n_latents=100, n_conditions=100, length=100, features=23]
        Real A + Fake B
    results_A_real_B : float, numpy ndarray, shape=[n_conditions=100, length=100, features=23]
        Real A + Real B
    results_critic_scores : float, numpy ndarray, shape=[n_latents=100, n_conditions=100]
        critic scores for each input data
    """

    save_path = os.path.join(COLLECT_PATH, 'rnn')
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    real_data = np.load(FLAGS.data_path)[:, :FLAGS.seq_length, :, :]
    print('real_data.shape', real_data.shape)
    # DataFactory
    data_factory = DataFactory(real_data)
    # target data
    target_data = np.load('../../data/FixedFPS5.npy')[-100:]
    target_length = np.load('../../data/FixedFPS5Length.npy')[-100:]
    print('target_data.shape', target_data.shape)
    team_AB = np.concatenate(
        [
            # ball
            target_data[:, :, 0, :3].reshape(
                [target_data.shape[0], target_data.shape[1], 1 * 3]),
            # team A players
            target_data[:, :, 1:6, :2].reshape(
                [target_data.shape[0], target_data.shape[1], 5 * 2]),
            # team B players
            target_data[:, :, 6:11, :2].reshape(
                [target_data.shape[0], target_data.shape[1], 5 * 2])
        ], axis=-1
    )
    team_AB = data_factory.normalize(team_AB)
    team_A = team_AB[:, :, :13]
    team_B = team_AB[:, :, 13:]
    # result collector
    results_A_fake_B = []
    results_A_real_B = []
    config = TrainingConfig(235)
    with tf.get_default_graph().as_default() as graph:
        # model
        C = C_MODEL(config, graph)
        G = G_MODEL(config, C.inference, graph)
        tfconfig = tf.ConfigProto()
        tfconfig.gpu_options.allow_growth = True
        default_sess = tf.Session(config=tfconfig, graph=graph)
        # saver for later restore
        saver = tf.train.Saver(max_to_keep=0)  # 0 -> keep them all
        # restore model if exist
        saver.restore(default_sess, FLAGS.restore_path)
        print('successfully restore model from checkpoint: %s' %
              (FLAGS.restore_path))
        for idx in range(team_AB.shape[0]):
            # given 100(FLAGS.n_latents) latents generate 100 results on same condition at once
            real_samples = team_B[idx:idx + 1, :]
            real_samples = np.concatenate(
                [real_samples for _ in range(FLAGS.n_latents)], axis=0)
            real_conds = team_A[idx:idx + 1, :]
            real_conds = np.concatenate(
                [real_conds for _ in range(FLAGS.n_latents)], axis=0)
            # generate result
            latents = z_samples(FLAGS.n_latents)
            result = G.generate(default_sess, latents, real_conds)
            # calculate em distance
            recoverd_A_fake_B = data_factory.recover_data(
                np.concatenate([real_conds, result], axis=-1))
            # padding to length=200
            dummy = np.zeros(
                shape=[FLAGS.n_latents, team_AB.shape[1] - target_length[idx], team_AB.shape[2]])
            temp_A_fake_B_concat = np.concatenate(
                [recoverd_A_fake_B[:, :target_length[idx]], dummy], axis=1)
            results_A_fake_B.append(temp_A_fake_B_concat)
    print(np.array(results_A_fake_B).shape)
    # concat along with conditions dimension (axis=1)
    results_A_fake_B = np.stack(results_A_fake_B, axis=1)
    # real data
    results_A = data_factory.recover_BALL_and_A(team_A)
    results_real_B = data_factory.recover_B(team_B)
    results_A_real_B = data_factory.recover_data(team_AB)
    # saved as numpy
    print(np.array(results_A_fake_B).shape)
    print(np.array(results_A_real_B).shape)
    np.save(os.path.join(save_path, 'results_A_fake_B.npy'),
            np.array(results_A_fake_B).astype(np.float32).reshape([FLAGS.n_latents, team_AB.shape[0], team_AB.shape[1], 23]))
    np.save(os.path.join(save_path, 'results_A_real_B.npy'),
            np.array(results_A_real_B).astype(np.float32).reshape([team_AB.shape[0], team_AB.shape[1], 23]))
    print('!!Completely Saved!!')