Exemplo n.º 1
0
def _create_replay_buffer_and_insert(env: EnvWrapper):
    env.seed(1)
    replay_buffer = ReplayBuffer(replay_capacity=6, batch_size=1)
    replay_buffer_inserter = make_replay_buffer_inserter(env)
    obs = env.reset()
    inserted = []
    terminal = False
    i = 0
    while not terminal and i < 5:
        logger.info(f"Iteration: {i}")
        action = env.action_space.sample()
        next_obs, reward, terminal, _ = env.step(action)
        inserted.append(
            {
                "observation": obs,
                "action": action,
                "reward": reward,
                "terminal": terminal,
            }
        )
        transition = Transition(
            mdp_id=0,
            sequence_number=i,
            observation=obs,
            action=action,
            reward=reward,
            terminal=terminal,
            log_prob=0.0,
        )
        replay_buffer_inserter(replay_buffer, transition)
        obs = next_obs
        i += 1

    return replay_buffer, inserted
def _create_replay_buffer_and_insert(env: gym.Env):
    env.seed(1)
    replay_buffer = ReplayBuffer.create_from_env(
        env, replay_memory_size=10, batch_size=1
    )
    replay_buffer_inserter = make_replay_buffer_inserter(env)
    obs = env.reset()
    inserted = []
    terminal = False
    i = 0
    while not terminal and i < 5:
        logger.info(f"Iteration: {i}")
        action = env.action_space.sample()
        next_obs, reward, terminal, _ = env.step(action)
        inserted.append(
            {
                "observation": obs,
                "action": action,
                "reward": reward,
                "terminal": terminal,
            }
        )
        log_prob = 0.0
        replay_buffer_inserter(replay_buffer, obs, action, reward, terminal, log_prob)
        obs = next_obs
        i += 1

    return replay_buffer, inserted
Exemplo n.º 3
0
 def test_cartpole(self):
     env = gym.make("CartPole-v0")
     replay_buffer = ReplayBuffer.create_from_env(
         env, replay_memory_size=10, batch_size=5
     )
     replay_buffer_inserter = make_replay_buffer_inserter(env)
     obs = env.reset()
     terminal = False
     i = 0
     while not terminal and i < 5:
         action = env.action_space.sample()
         next_obs, reward, terminal, _ = env.step(action)
         replay_buffer_inserter(
             replay_buffer, obs, action, reward, terminal, log_prob=0.0
         )
         obs = next_obs
Exemplo n.º 4
0
def create_df_from_replay_buffer(
    env: gym.Env,
    problem_domain: ProblemDomain,
    desired_size: int,
    multi_steps: Optional[int],
    ds: str,
) -> pd.DataFrame:
    # fill the replay buffer
    set_seed(env, SEED)
    if multi_steps is None:
        update_horizon = 1
        return_as_timeline_format = False
    else:
        update_horizon = multi_steps
        return_as_timeline_format = True
    is_multi_steps = multi_steps is not None

    replay_buffer = ReplayBuffer(
        replay_capacity=desired_size,
        batch_size=1,
        update_horizon=update_horizon,
        return_as_timeline_format=return_as_timeline_format,
    )
    fill_replay_buffer(env, replay_buffer, desired_size)

    batch = replay_buffer.sample_all_valid_transitions()
    n = batch.state.shape[0]
    logger.info(f"Creating df of size {n}.")

    def discrete_feat_transform(elem) -> str:
        """ query data expects str format """
        return str(elem.item())

    def continuous_feat_transform(elem: List[float]) -> Dict[int, float]:
        """ query data expects sparse format """
        assert isinstance(elem, torch.Tensor), f"{type(elem)} isn't tensor"
        assert len(elem.shape) == 1, f"{elem.shape} isn't 1-dimensional"
        return {i: s.item() for i, s in enumerate(elem)}

    def make_parametric_feat_transform(one_hot_dim: int):
        """ one-hot and then continuous_feat_transform """
        def transform(elem) -> Dict[int, float]:
            elem_tensor = torch.tensor(elem.item())
            one_hot_feat = F.one_hot(elem_tensor, one_hot_dim).float()
            return continuous_feat_transform(one_hot_feat)

        return transform

    state_features = feature_transform(batch.state, continuous_feat_transform)
    next_state_features = feature_transform(
        batch.next_state,
        continuous_feat_transform,
        is_next_with_multi_steps=is_multi_steps,
    )

    if problem_domain == ProblemDomain.DISCRETE_ACTION:
        # discrete action is str
        action = feature_transform(batch.action, discrete_feat_transform)
        next_action = feature_transform(
            batch.next_action,
            discrete_feat_transform,
            is_next_with_multi_steps=is_multi_steps,
            replace_when_terminal="",
            terminal=batch.terminal,
        )
    elif problem_domain == ProblemDomain.PARAMETRIC_ACTION:
        # continuous action is Dict[int, double]
        assert isinstance(env.action_space, gym.spaces.Discrete)
        parametric_feat_transform = make_parametric_feat_transform(
            env.action_space.n)
        action = feature_transform(batch.action, parametric_feat_transform)
        next_action = feature_transform(
            batch.next_action,
            parametric_feat_transform,
            is_next_with_multi_steps=is_multi_steps,
            replace_when_terminal={},
            terminal=batch.terminal,
        )
    elif problem_domain == ProblemDomain.CONTINUOUS_ACTION:
        action = feature_transform(batch.action, continuous_feat_transform)
        next_action = feature_transform(
            batch.next_action,
            continuous_feat_transform,
            is_next_with_multi_steps=is_multi_steps,
            replace_when_terminal={},
            terminal=batch.terminal,
        )
    elif problem_domain == ProblemDomain.MDN_RNN:
        action = feature_transform(batch.action, discrete_feat_transform)
        assert multi_steps is not None
        next_action = feature_transform(
            batch.next_action,
            discrete_feat_transform,
            is_next_with_multi_steps=True,
            replace_when_terminal="",
            terminal=batch.terminal,
        )
    else:
        raise NotImplementedError(f"model type: {problem_domain}.")

    if multi_steps is None:
        time_diff = [1] * n
        reward = batch.reward.squeeze(1).tolist()
        metrics = [{"reward": r} for r in reward]
    else:
        time_diff = [[1] * len(ns) for ns in next_state_features]
        reward = [reward_list.tolist() for reward_list in batch.reward]
        metrics = [[{
            "reward": r.item()
        } for r in reward_list] for reward_list in batch.reward]

    # TODO(T67265031): change this to int
    mdp_id = [str(i.item()) for i in batch.mdp_id]
    sequence_number = batch.sequence_number.squeeze(1).tolist()
    # in the product data, all sequence_number_ordinal start from 1.
    # So to be consistent with the product data.

    sequence_number_ordinal = (batch.sequence_number.squeeze(1) + 1).tolist()
    action_probability = batch.log_prob.exp().squeeze(1).tolist()
    df_dict = {
        "state_features": state_features,
        "next_state_features": next_state_features,
        "action": action,
        "next_action": next_action,
        "reward": reward,
        "action_probability": action_probability,
        "metrics": metrics,
        "time_diff": time_diff,
        "mdp_id": mdp_id,
        "sequence_number": sequence_number,
        "sequence_number_ordinal": sequence_number_ordinal,
        "ds": [ds] * n,
    }

    if problem_domain == ProblemDomain.PARAMETRIC_ACTION:
        # Possible actions are List[Dict[int, float]]
        assert isinstance(env.action_space, gym.spaces.Discrete)
        possible_actions = [{i: 1.0} for i in range(env.action_space.n)]

    elif problem_domain == ProblemDomain.DISCRETE_ACTION:
        # Possible actions are List[str]
        assert isinstance(env.action_space, gym.spaces.Discrete)
        possible_actions = [str(i) for i in range(env.action_space.n)]

    elif problem_domain == ProblemDomain.MDN_RNN:
        # Possible actions are List[str]
        assert isinstance(env.action_space, gym.spaces.Discrete)
        possible_actions = [str(i) for i in range(env.action_space.n)]

    # these are fillers, which should have correct shape
    pa_features = range(n)
    pna_features = time_diff
    if problem_domain in (
            ProblemDomain.DISCRETE_ACTION,
            ProblemDomain.PARAMETRIC_ACTION,
            ProblemDomain.MDN_RNN,
    ):

        def pa_transform(x):
            return possible_actions

        df_dict["possible_actions"] = feature_transform(
            pa_features, pa_transform)
        df_dict["possible_next_actions"] = feature_transform(
            pna_features,
            pa_transform,
            is_next_with_multi_steps=is_multi_steps,
            replace_when_terminal=[],
            terminal=batch.terminal,
        )

    df = pd.DataFrame(df_dict)
    # validate df
    validate_mdp_ids_seq_nums(df)
    # shuffling (sample the whole batch)
    df = df.reindex(np.random.permutation(df.index))
    return df