コード例 #1
0
def vis_data(off_data, off_label, def_data, def_label, outdir, start_idx=0):
    """ vis the dataset of transitions

    Args
    ----
    """

    idx_ = start_idx

    init_pos = [
        np.array(off_data[idx_, 0, -1, 0, :]),
        np.array(off_data[idx_, 0, -1, 1:6, :], dtype=np.float),
        np.array(off_data[idx_, 0, -1, 6:11, :], dtype=np.float)
    ]
    env = gym.make('bball-pretrain-v0')
    env = BBallWrapper(env,
                       if_clip=False,
                       if_norm_obs=False,
                       if_norm_act=False,
                       init_mode=2,
                       if_vis_visual_aid=True,
                       if_vis_trajectory=False,
                       init_positions=init_pos)
    env = gym.wrappers.Monitor(env,
                               outdir,
                               lambda unused_episode_number: True,
                               force=False,
                               resume=True)
    obs = env.reset()

    while True:
        # prevent from modification
        temp_off_label = np.array(off_label[idx_, 0])
        temp_def_label = np.array(def_label[idx_, 0])
        if idx_ == start_idx:
            # the env's velocity is zero, so we add the last velocity after env reset.
            last_vel = off_data[idx_, 0, -1, 1:6, :] - \
                off_data[idx_, 0, -2, 1:6, :]
            temp_off_label[5:] += last_vel.reshape([
                10,
            ])
            last_vel = def_data[idx_, 0, -1, 6:11, :] - \
                def_data[idx_, 0, -2, 6:11, :]
            temp_def_label += last_vel.reshape([
                10,
            ])
        # offense
        action = pack_action([temp_off_label[:3], temp_off_label[3:]],
                             team='offense')
        obs, _, done, _ = env.step(action)
        if done:
            env.close()
            break
        # deffense
        action = pack_action(temp_def_label, team='defense')
        obs, _, done, _ = env.step(action)
        if done:
            env.close()
            break
        idx_ += 1
コード例 #2
0
def f(ppo_policy, config, conditions):
    # env to generate fake state
    env = gym.make(config.env)
    env = BBallWrapper(env,
                       init_mode=3,
                       fps=config.FPS,
                       if_back_real=config.if_back_real,
                       time_limit=config.max_length)
    env = MonitorWrapper(
        env,
        directory=os.path.join(config.logdir, 'gail_training/'),
        if_back_real=config.if_back_real,
        # init from dataset in order
        init_mode=3)
    # align the conditions with env
    # -1 : newest state
    conditions = conditions[None]
    env.data = conditions[:, :, -1]
    obs_state = env.reset()
    one_epi_fake = []
    one_epi_fake_act = []
    for len_idx in range(config.max_length):
        if config.if_back_real:
            act = ppo_policy.act(np.array(conditions[:,
                                                     len_idx:len_idx + 1, :]),
                                 stochastic=True)
        else:
            act = ppo_policy.act(np.array(obs_state)[None, None],
                                 stochastic=True)
        transformed_act = [
            # Discrete(3) must be int
            int(0),
            # Box(2,)
            np.array([0.0, 0.0], dtype=np.float32),
            # Box(5, 2)
            np.zeros(shape=[5, 2], dtype=np.float32),
            # Box(5, 2)
            np.reshape(act, [5, 2])
        ]
        obs_state, _, _, _ = env.step(transformed_act)
        one_epi_fake.append(obs_state[-1])
        one_epi_fake_act.append(act.reshape([5, 2]))
    return one_epi_fake, one_epi_fake_act
コード例 #3
0
def collect_results(config, steps, ppo_policy, D, denormalize_observ, generated_amount=100):
    """ test policy
    - draw episode into mpeg video
    - collect episode with scores on each frame into .npz file (for out customized player)

    Args
    -----
    config : object, providing configurations via attributes.
    vanilla_env : object, env
    steps : int, to name the file with number of iterations of Discriminator
    ppo_policy : object, policy to generate actions
    D : object, discriminator to judge realistic
    denormalize_observ : function, denorm the returned observation
    """
    timer = time.time()
    # read condition length
    data_len = np.load('bball_strategies/data/FixedFPS5Length.npy')
    # data_len = np.load('bball_strategies/data/WGAN/all_model_results/length.npy')
    # env to testing
    vanilla_env = gym.make(config.env)
    vanilla_env = BBallWrapper(vanilla_env, data=h5py.File(
        'bball_strategies/data/OrderedGAILTransitionData_Testing.hdf5', 'r'), init_mode=1, fps=config.FPS, time_limit=np.max(data_len)-2)
    vanilla_env = MonitorWrapper(vanilla_env, directory=os.path.join(config.logdir, 'collect_result/video/'), video_callable=lambda _: True,
                                 # init from dataset
                                 init_mode=1)
    total_output = []
    index_list = []
    for i in range(generated_amount):
        print('generating # {} episode'.format(i))
        numpy_collector = []
        act_collector = []
        vanilla_obs = vanilla_env.reset()
        for _ in range(vanilla_env.time_limit):
            vanilla_act = ppo_policy.act(
                np.array(vanilla_obs)[None, None], stochastic=False)
            act_collector.append(vanilla_act.reshape([5, 2]))
            vanilla_trans_act = [
                # Discrete(3) must be int
                int(0),
                # Box(2,)
                np.array([0.0, 0.0], dtype=np.float32),
                # Box(5, 2)
                np.zeros(shape=[5, 2], dtype=np.float32),
                # Box(5, 2)
                np.reshape(vanilla_act, [5, 2])
            ]
            vanilla_obs, _, _, info = vanilla_env.step(
                vanilla_trans_act)
            numpy_collector.append(vanilla_obs)
        index_list.append(info['data_idx'])
        numpy_collector = np.array(numpy_collector)
        act_collector = np.array(act_collector)
        numpy_collector = denormalize_observ(numpy_collector)
        total_output.append(numpy_collector)
    total_output = np.array(total_output)
    # save numpy
    np.save(os.path.join(config.logdir,
                         'collect_result/total_output.npy'), total_output)
    np.save(os.path.join(config.logdir,
                         'collect_result/total_output_length.npy'), data_len[index_list]-2)

    print('collect_results time cost: {} per episode'.format(
        (time.time() - timer)/generated_amount))
    vanilla_env.close()
コード例 #4
0
def vis_result(sess, model, off_data, off_label, def_data, def_label, outdir,
               num_video):
    """ vis the results by using the pretrain output interacting with env 

    Args
    ----
    """

    data_len = np.load('bball_strategies/data/FixedFPS5Length.npy')
    accumulator = 0
    for i, v in enumerate(data_len):
        data_len[i] += accumulator
        accumulator += v
    for i in range(num_video):
        start_idx = data_len[i]
        idx_ = start_idx
        init_pos = [
            np.array(off_data[idx_, 0, -1, 0, :]),
            np.array(off_data[idx_, 0, -1, 1:6, :], dtype=np.float),
            np.array(off_data[idx_, 0, -1, 6:11, :], dtype=np.float)
        ]
        env = gym.make('bball-pretrain-v0')
        env = BBallWrapper(env,
                           if_clip=False,
                           if_norm_obs=False,
                           if_norm_act=False,
                           init_mode=2,
                           if_vis_visual_aid=True,
                           if_vis_trajectory=False,
                           init_positions=init_pos)
        env = gym.wrappers.Monitor(env,
                                   outdir,
                                   lambda unused_episode_number: True,
                                   force=False,
                                   resume=True)
        obs = env.reset()

        while True:
            # prevent from modification
            temp_off_label = np.array(off_label[idx_, 0])
            temp_def_label = np.array(def_label[idx_, 0])
            if idx_ == start_idx:
                # the env's velocity is zero, so we add the last velocity after env reset.
                last_vel = off_data[idx_, 0, -1, 1:6, :] - \
                    off_data[idx_, 0, -2, 1:6, :]
                temp_off_label[5:] += last_vel.reshape([
                    10,
                ])
                last_vel = def_data[idx_, 0, -1, 6:11, :] - \
                    def_data[idx_, 0, -2, 6:11, :]
                temp_def_label += last_vel.reshape([
                    10,
                ])
            if FLAGS.config == 'offense':
                # offense turn
                obs = norm_obs(env, obs)
                logits, actions = model.perform(sess, obs[None, None])
                actions = pack_action([logits[0, 0], actions[0, 0]],
                                      FLAGS.config)
                obs, _, done, _ = env.step(actions)
                if done:
                    env.close()
                    break
                # defense turn
                actions = pack_action(temp_def_label, team='defense')
                obs, _, done, _ = env.step(actions)
                if done:
                    env.close()
                    break
            elif FLAGS.config == 'defense':
                # offense turn
                actions = pack_action([temp_off_label[:3], temp_off_label[3:]],
                                      team='offense')
                obs, _, done, _ = env.step(actions)
                if done:
                    env.close()
                    break
                # defense turn
                obs = norm_obs(env, obs)
                actions = model.perform(sess, obs[None, None])
                actions = pack_action(actions, FLAGS.config)
                obs, _, done, _ = env.step(actions)
                if done:
                    env.close()
                    break
            idx_ += 1
コード例 #5
0
def tally_reward_line_chart(config, steps, ppo_policy, D, denormalize_observ, normalize_observ, normalize_action):
    """ tally 100 episodes as line chart to show how well the discriminator judge on each state of real and fake episode
    """
    if config.is_gail:
        episode_amount = 100
        # real data
        all_data = h5py.File(
            'bball_strategies/data/GAILTransitionData_51.hdf5', 'r')
        expert_data, _ = np.split(
            all_data['OBS'].value, [all_data['OBS'].value.shape[0]*9//10])
        expert_action, _ = np.split(
            all_data['DEF_ACT'].value, [all_data['DEF_ACT'].value.shape[0]*9//10])
        # env
        vanilla_env = gym.make(config.env)
        vanilla_env = BBallWrapper(vanilla_env, init_mode=1, fps=config.FPS, if_back_real=False,
                                   time_limit=50)
        vanilla_env.data = np.load('bball_strategies/data/GAILEnvData_51.npy')
        # real
        selected_idx = np.random.choice(expert_data.shape[0], episode_amount)
        # frame 0 is condition
        batch_real_states = expert_data[selected_idx, 1:]
        real_action = expert_action[selected_idx, :-1]
        batch_real_states = np.concatenate(
            batch_real_states, axis=0)
        real_action = np.concatenate(real_action[:, None], axis=0)
        batch_real_states = normalize_observ(batch_real_states)
        real_rewards = D.get_rewards_value(
            batch_real_states, normalize_action(real_action)).reshape([100, -1])
        # fake
        numpy_collector = []
        act_collector = []
        for _ in range(episode_amount):
            vanilla_obs = vanilla_env.reset()
            for _ in range(vanilla_env.time_limit):
                vanilla_act = ppo_policy.act(
                    np.array(vanilla_obs)[None, None], stochastic=False)
                act_collector.append(vanilla_act.reshape([1, 5, 2]))
                vanilla_trans_act = [
                    # Discrete(3) must be int
                    int(0),
                    # Box(2,)
                    np.array([0.0, 0.0], dtype=np.float32),
                    # Box(5, 2)
                    np.zeros(shape=[5, 2], dtype=np.float32),
                    # Box(5, 2)
                    np.reshape(vanilla_act, [5, 2])
                ]
                vanilla_obs, _, _, _ = vanilla_env.step(
                    vanilla_trans_act)
                numpy_collector.append(vanilla_obs)
        numpy_collector = np.array(numpy_collector)
        act_collector = np.array(act_collector)
        fake_rewards = D.get_rewards_value(
            numpy_collector, act_collector).reshape([100, -1])
        # vis
        vis_line_chart(real_rewards, fake_rewards, config.logdir, str(steps))
    else:
        episode_amount = 100
        # real data
        all_data = h5py.File(
            'bball_strategies/data/GAILTransitionData_51.hdf5', 'r')
        expert_data, _ = np.split(
            all_data['OBS'].value, [all_data['OBS'].value.shape[0]*9//10])
        expert_action, _ = np.split(
            all_data['DEF_ACT'].value, [all_data['DEF_ACT'].value.shape[0]*9//10])
        # env
        vanilla_env = gym.make(config.env)
        vanilla_env = BBallWrapper(vanilla_env, init_mode=1, fps=config.FPS, if_back_real=False,
                                   time_limit=config.max_length)
        vanilla_env.data = np.load('bball_strategies/data/GAILEnvData_51.npy')
        # real
        selected_idx = np.random.choice(expert_data.shape[0], episode_amount)
        # frame 0 is condition
        batch_real_states = expert_data[selected_idx, 1:config.max_length+1, -1]
        real_action = expert_action[selected_idx, :config.max_length]
        batch_real_states = normalize_observ(batch_real_states)
        real_rewards = D.get_rewards_value(
            batch_real_states, normalize_action(real_action)).reshape([-1, 1])
        real_rewards = np.tile(real_rewards, [1, config.max_length])
        # fake
        numpy_collector = []
        act_collector = []
        for _ in range(episode_amount):
            vanilla_obs = vanilla_env.reset()
            epi_obs = []
            epi_act = []
            for _ in range(config.max_length):
                vanilla_act = ppo_policy.act(
                    np.array(vanilla_obs)[None, None], stochastic=False)
                vanilla_trans_act = [
                    # Discrete(3) must be int
                    int(0),
                    # Box(2,)
                    np.array([0.0, 0.0], dtype=np.float32),
                    # Box(5, 2)
                    np.zeros(shape=[5, 2], dtype=np.float32),
                    # Box(5, 2)
                    np.reshape(vanilla_act, [5, 2])
                ]
                vanilla_obs, _, _, _ = vanilla_env.step(
                    vanilla_trans_act)
                epi_obs.append(vanilla_obs[-1])
                epi_act.append(vanilla_act.reshape([5, 2]))
            numpy_collector.append(epi_obs)
            act_collector.append(epi_act)
        numpy_collector = np.array(numpy_collector)
        act_collector = np.array(act_collector)
        fake_rewards = D.get_rewards_value(
            numpy_collector, act_collector).reshape([-1, 1])
        fake_rewards = np.tile(fake_rewards, [1, config.max_length])
        # vis
        vis_line_chart(real_rewards, fake_rewards, config.logdir, str(steps))
コード例 #6
0
def tally_reward_line_chart(config, steps, ppo_policy, D, normalize_observ, normalize_action, stochastic=False):
    """ tally a line chart by 100 episodes to show how well the discriminator judge on each state

    Args
    -----
    config : object, providing configurations via attributes.
    steps : int, to name the file with number of iterations of Discriminator
    ppo_policy : object, policy to generate actions
    D : object, discriminator to judge realistic
    normalize_observ : function, norm the states of real data
    normalize_action : function, norm the actions of real data
    stochastic : bool, decide the methods to generate actions, True->sampled, False->select mode
    """
    timer = time.time()
    episode_amount = 100
    # env
    vanilla_env = gym.make(config.env)
    vanilla_env = BBallWrapper(vanilla_env, data=h5py.File('bball_strategies/data/OrderedGAILTransitionData_522.hdf5', 'r'), init_mode=1, fps=config.FPS, time_limit=50)
    # fake
    numpy_collector = []
    act_collector = []
    real_numpy_collector = []
    real_act_collector = []
    for _ in range(episode_amount):
        vanilla_obs = vanilla_env.reset()
        epi_obs = []
        epi_act = []
        real_epi_obs = []
        real_epi_act = []
        for _ in range(vanilla_env.time_limit):
            vanilla_act = ppo_policy.act(
                np.array(vanilla_obs)[None, None], stochastic=stochastic)
            vanilla_trans_act = [
                # Discrete(3) must be int
                int(0),
                # Box(2,), ball
                np.array([0.0, 0.0], dtype=np.float32),
                # Box(5, 2), offense
                np.zeros(shape=[5, 2], dtype=np.float32),
                # Box(5, 2), defense
                np.reshape(vanilla_act, [5, 2])
            ]
            vanilla_obs, _, _, info = vanilla_env.step(
                vanilla_trans_act)
            epi_obs.append(vanilla_obs[-1])
            epi_act.append(vanilla_act.reshape([5, 2]))
            real_epi_obs.append(np.concatenate(
                [vanilla_obs[-1, 0:6], normalize_observ(info['expert_s']), vanilla_obs[-1, 11:14]], axis=0))
            real_epi_act.append(info['expert_a'])
        numpy_collector.append(epi_obs)
        act_collector.append(epi_act)
        real_numpy_collector.append(real_epi_obs)
        real_act_collector.append(real_epi_act)
    numpy_collector = np.array(numpy_collector)
    act_collector = np.array(act_collector)
    real_numpy_collector = np.array(real_numpy_collector)
    real_act_collector = np.array(real_act_collector)
    fake_rewards = D.get_rewards_value(
        numpy_collector, act_collector)
    real_act_collector = normalize_action(real_act_collector)
    real_rewards = D.get_rewards_value(
        real_numpy_collector, real_act_collector)
    # vis
    if stochastic:
        save_path = os.path.join(config.logdir, 'gail_testing_G{}_D{}'.format(
            config.max_length, config.D_len), 'line_chart_stochastic/')
    else:
        save_path = os.path.join(config.logdir, 'gail_testing_G{}_D{}'.format(
            config.max_length, config.D_len), 'line_chart/')
    vis_line_chart(real_rewards, fake_rewards, save_path, str(steps))
    print('tally_reward_line_chart time cost: {}'.format(time.time() - timer))