Exemple #1
0
    def sample_memories(self,
                        batch_size,
                        use_gpu=False) -> rlt.PreprocessedMemoryNetworkInput:
        """
        :param batch_size: number of samples to return
        :param use_gpu: whether to put samples on gpu
        State's shape is SEQ_LEN x BATCH_SIZE x STATE_DIM, for example.
        By default, MDN-RNN consumes data with SEQ_LEN as the first dimension.
        """
        sample_indices = np.random.randint(self.memory_size, size=batch_size)
        device = torch.device("cuda") if use_gpu else torch.device("cpu")
        # state/next state shape: batch_size x seq_len x state_dim
        # action shape: batch_size x seq_len x action_dim
        # reward/not_terminal shape: batch_size x seq_len
        state, action, next_state, reward, not_terminal = map(
            lambda x: stack(x).float().to(device),
            zip(*self.deque_sample(sample_indices)),
        )

        # make shapes seq_len x batch_size x feature_dim
        state, action, next_state, reward, not_terminal = transpose(
            state, action, next_state, reward, not_terminal)

        training_input = rlt.PreprocessedMemoryNetworkInput(
            state=rlt.FeatureData(float_features=state),
            reward=reward,
            time_diff=torch.ones_like(reward).float(),
            action=action,
            next_state=rlt.FeatureData(float_features=next_state),
            not_terminal=not_terminal,
            step=None,
        )
        return training_input
    def embed_state(self, state):
        """ Embed state after either reset() or step() """
        assert len(self.recent_states) == len(self.recent_actions)
        old_mdnrnn_mode = self.mdnrnn.mdnrnn.training
        self.mdnrnn.mdnrnn.eval()

        # Embed the state as the hidden layer's output
        # until the previous step + current state
        if len(self.recent_states) == 0:
            mdnrnn_state = np.zeros((1, self.raw_state_dim))
            mdnrnn_action = np.zeros((1, self.action_dim))
        else:
            mdnrnn_state = np.array(list(self.recent_states))
            mdnrnn_action = np.array(list(self.recent_actions))

        mdnrnn_state = torch.tensor(mdnrnn_state, dtype=torch.float).unsqueeze(1)
        mdnrnn_action = torch.tensor(mdnrnn_action, dtype=torch.float).unsqueeze(1)
        mdnrnn_output = self.mdnrnn(
            rlt.FeatureData(mdnrnn_state), rlt.FeatureData(mdnrnn_action)
        )
        hidden_embed = (
            mdnrnn_output.all_steps_lstm_hidden[-1].squeeze().detach().cpu().numpy()
        )
        state_embed = np.hstack((hidden_embed, state))
        self.mdnrnn.mdnrnn.train(old_mdnrnn_mode)
        logger.debug(
            f"Embed_state\nrecent states: {np.array(self.recent_states)}\n"
            f"recent actions: {np.array(self.recent_actions)}\n"
            f"state_embed{state_embed}\n"
        )
        return state_embed
Exemple #3
0
def top_k_policy(q_network, obs: Tuple[torch.Tensor, torch.Tensor,
                                       DocumentFeature], recsim: RecSim):
    active_user_idxs, user_features, candidate_features = obs

    slate_with_null = recsim.select(
        candidate_features,
        torch.repeat_interleave(torch.arange(recsim.m).unsqueeze(dim=0),
                                active_user_idxs.shape[0],
                                dim=0),
        add_null=True,
    )
    _user_choice, interest = recsim.compute_user_choice(slate_with_null)
    propensity = F.softmax(interest, dim=1)[:, :recsim.m]

    tiled_user_features = torch.repeat_interleave(user_features,
                                                  recsim.m,
                                                  dim=0)
    candidate_feature_vector = candidate_features.as_vector()
    action_dim = candidate_feature_vector.shape[2]
    flatten_candidate_features = candidate_feature_vector.view(-1, action_dim)

    q_network_input = (
        rlt.FeatureData(tiled_user_features),
        rlt.FeatureData(flatten_candidate_features),
    )
    q_values = q_network(*q_network_input).view(-1, recsim.m)

    values = q_values * propensity

    top_values, item_idxs = torch.topk(values, recsim.k, dim=1)
    return item_idxs
Exemple #4
0
 def test_get_Q(self):
     NUM_ACTION = 2
     MULTI_STEPS = 3
     BATCH_SIZE = 2
     STATE_DIM = 4
     all_permut = gen_permutations(MULTI_STEPS, NUM_ACTION)
     seq2reward_network = FakeSeq2RewardNetwork()
     batch = rlt.MemoryNetworkInput(
         state=rlt.FeatureData(
             float_features=torch.zeros(MULTI_STEPS, BATCH_SIZE, STATE_DIM)
         ),
         next_state=rlt.FeatureData(
             float_features=torch.zeros(MULTI_STEPS, BATCH_SIZE, STATE_DIM)
         ),
         action=rlt.FeatureData(
             float_features=torch.zeros(MULTI_STEPS, BATCH_SIZE, NUM_ACTION)
         ),
         reward=torch.zeros(1),
         time_diff=torch.zeros(1),
         step=torch.zeros(1),
         not_terminal=torch.zeros(1),
     )
     q_values = get_Q(seq2reward_network, batch, all_permut)
     expected_q_values = torch.tensor([[11.0, 111.0], [11.0, 111.0]])
     logger.info(f"q_values: {q_values}")
     assert torch.all(expected_q_values == q_values)
 def __call__(self, batch: Dict[str, torch.Tensor]) -> rlt.PolicyNetworkInput:
     batch = batch_to_device(batch, self.device)
     preprocessed_state = self.state_preprocessor(
         batch["state_features"], batch["state_features_presence"]
     )
     preprocessed_next_state = self.state_preprocessor(
         batch["next_state_features"], batch["next_state_features_presence"]
     )
     preprocessed_action = self.action_preprocessor(
         batch["action"], batch["action_presence"]
     )
     preprocessed_next_action = self.action_preprocessor(
         batch["next_action"], batch["next_action_presence"]
     )
     return rlt.PolicyNetworkInput(
         state=rlt.FeatureData(preprocessed_state),
         next_state=rlt.FeatureData(preprocessed_next_state),
         action=rlt.FeatureData(preprocessed_action),
         next_action=rlt.FeatureData(preprocessed_next_action),
         reward=batch["reward"].unsqueeze(1),
         time_diff=batch["time_diff"].unsqueeze(1),
         step=batch["step"].unsqueeze(1),
         not_terminal=batch["not_terminal"].unsqueeze(1),
         extras=rlt.ExtraData(
             mdp_id=batch["mdp_id"].unsqueeze(1),
             sequence_number=batch["sequence_number"].unsqueeze(1),
             action_probability=batch["action_probability"].unsqueeze(1),
         ),
     )
 def __call__(self, batch):
     not_terminal = 1.0 - batch.terminal.float()
     assert (len(batch.state.shape) == 2
             ), f"{batch.state.shape} is not (batch_size, state_dim)."
     batch_size, _ = batch.state.shape
     action, next_action = one_hot_actions(self.num_actions, batch.action,
                                           batch.next_action,
                                           batch.terminal)
     possible_actions = get_possible_actions_for_gym(
         batch_size, self.num_actions)
     possible_next_actions = possible_actions.clone()
     possible_actions_mask = torch.ones((batch_size, self.num_actions))
     possible_next_actions_mask = possible_actions_mask.clone()
     return rlt.ParametricDqnInput(
         state=rlt.FeatureData(float_features=batch.state),
         action=rlt.FeatureData(float_features=action),
         next_state=rlt.FeatureData(float_features=batch.next_state),
         next_action=rlt.FeatureData(float_features=next_action),
         possible_actions=possible_actions,
         possible_actions_mask=possible_actions_mask,
         possible_next_actions=possible_next_actions,
         possible_next_actions_mask=possible_next_actions_mask,
         reward=batch.reward,
         not_terminal=not_terminal,
         step=None,
         time_diff=None,
         extras=rlt.ExtraData(
             mdp_id=None,
             sequence_number=None,
             action_probability=batch.log_prob.exp(),
             max_num_actions=None,
             metrics=None,
         ),
     )
Exemple #7
0
    def test_parametric_wrapper(self):
        state_normalization_parameters = {i: _cont_norm() for i in range(1, 5)}
        action_normalization_parameters = {
            i: _cont_norm()
            for i in range(5, 9)
        }
        state_preprocessor = Preprocessor(state_normalization_parameters,
                                          False)
        action_preprocessor = Preprocessor(action_normalization_parameters,
                                           False)
        dqn = models.FullyConnectedCritic(
            state_dim=len(state_normalization_parameters),
            action_dim=len(action_normalization_parameters),
            sizes=[16],
            activations=["relu"],
        )
        dqn_with_preprocessor = ParametricDqnWithPreprocessor(
            dqn,
            state_preprocessor=state_preprocessor,
            action_preprocessor=action_preprocessor,
        )
        wrapper = ParametricDqnPredictorWrapper(dqn_with_preprocessor)

        input_prototype = dqn_with_preprocessor.input_prototype()
        output_action_names, q_value = wrapper(*input_prototype)
        self.assertEqual(output_action_names, ["Q"])
        self.assertEqual(q_value.shape, (1, 1))

        expected_output = dqn(
            rlt.FeatureData(state_preprocessor(*input_prototype[0])),
            rlt.FeatureData(action_preprocessor(*input_prototype[1])),
        )
        self.assertTrue((expected_output == q_value).all())
    def __call__(self, batch):
        # RB's state is (batch_size, state_dim, stack_size) whereas
        # we want (stack_size, batch_size, state_dim)
        # for scalar fields like reward and terminal,
        # RB returns (batch_size, stack_size), where as
        # we want (stack_size, batch_size)
        # Also convert action to float

        if len(batch.state.shape) == 2:
            # this is stack_size = 1 (i.e. we squeezed in RB)
            state = batch.state.unsqueeze(2)
            next_state = batch.next_state.unsqueeze(2)
        else:
            # shapes should be
            state = batch.state
            next_state = batch.next_state
        # Now shapes should be (batch_size, state_dim, stack_size)
        # Turn shapes into (stack_size, batch_size, feature_dim) where
        # feature \in {state, action}; also, make action a float
        permutation = [2, 0, 1]
        not_terminal = 1.0 - batch.terminal.transpose(0, 1).float()
        batch_action = batch.action
        if batch_action.ndim == 3:
            batch_action = batch_action.squeeze(1)
        action = F.one_hot(batch_action,
                           self.num_actions).transpose(1, 2).float()
        return rlt.PreprocessedMemoryNetworkInput(
            state=rlt.FeatureData(state.permute(permutation)),
            next_state=rlt.FeatureData(next_state.permute(permutation)),
            action=action.permute(permutation).float(),
            reward=batch.reward.transpose(0, 1),
            not_terminal=not_terminal,
            step=None,
            time_diff=None,
        )
 def forward(self, batch: Dict[str,
                               torch.Tensor]) -> rlt.ParametricDqnInput:
     batch = batch_to_device(batch, self.device)
     # first preprocess state and action
     preprocessed_state = self.state_preprocessor(
         batch["state_features"], batch["state_features_presence"])
     preprocessed_next_state = self.state_preprocessor(
         batch["next_state_features"],
         batch["next_state_features_presence"])
     preprocessed_action = self.action_preprocessor(
         batch["action"], batch["action_presence"])
     preprocessed_next_action = self.action_preprocessor(
         batch["next_action"], batch["next_action_presence"])
     return rlt.ParametricDqnInput(
         state=rlt.FeatureData(preprocessed_state),
         next_state=rlt.FeatureData(preprocessed_next_state),
         action=rlt.FeatureData(preprocessed_action),
         next_action=rlt.FeatureData(preprocessed_next_action),
         reward=batch["reward"].unsqueeze(1),
         time_diff=batch["time_diff"].unsqueeze(1),
         step=batch["step"].unsqueeze(1),
         not_terminal=batch["not_terminal"].unsqueeze(1),
         possible_actions=batch["possible_actions"],
         possible_actions_mask=batch["possible_actions_mask"],
         possible_next_actions=batch["possible_next_actions"],
         possible_next_actions_mask=batch["possible_next_actions_mask"],
         extras=rlt.ExtraData(
             mdp_id=batch["mdp_id"].unsqueeze(1),
             sequence_number=batch["sequence_number"].unsqueeze(1),
             action_probability=batch["action_probability"].unsqueeze(1),
         ),
     )
 def __call__(self, batch: Dict[str, torch.Tensor]) -> rlt.DiscreteDqnInput:
     batch = batch_to_device(batch, self.device)
     preprocessed_state = self.state_preprocessor(
         batch["state_features"], batch["state_features_presence"]
     )
     preprocessed_next_state = self.state_preprocessor(
         batch["next_state_features"], batch["next_state_features_presence"]
     )
     # not terminal iff at least one possible for next action
     not_terminal = batch["possible_next_actions_mask"].max(dim=1)[0].float()
     action = F.one_hot(batch["action"].to(torch.int64), self.num_actions)
     # next action can potentially have value self.num_action if not available
     next_action = F.one_hot(
         batch["next_action"].to(torch.int64), self.num_actions + 1
     )[:, : self.num_actions]
     return rlt.DiscreteDqnInput(
         state=rlt.FeatureData(preprocessed_state),
         next_state=rlt.FeatureData(preprocessed_next_state),
         action=action,
         next_action=next_action,
         reward=batch["reward"].unsqueeze(1),
         time_diff=batch["time_diff"].unsqueeze(1),
         step=batch["step"].unsqueeze(1),
         not_terminal=not_terminal.unsqueeze(1),
         possible_actions_mask=batch["possible_actions_mask"],
         possible_next_actions_mask=batch["possible_next_actions_mask"],
         extras=rlt.ExtraData(
             mdp_id=batch["mdp_id"].unsqueeze(1),
             sequence_number=batch["sequence_number"].unsqueeze(1),
             action_probability=batch["action_probability"].unsqueeze(1),
         ),
     )
 def as_slate_q_training_batch(self):
     batch_size, state_dim = self.states.shape
     action_dim = self.actions.shape[1]
     return rlt.PreprocessedSlateQInput(
         state=rlt.FeatureData(float_features=self.states),
         next_state=rlt.FeatureData(float_features=self.next_states),
         action=rlt.PreprocessedSlateFeatureVector(
             float_features=self.possible_actions_state_concat[:, state_dim:].view(
                 batch_size, -1, action_dim
             ),
             item_mask=self.possible_actions_mask,
             item_probability=self.propensities,
         ),
         next_action=rlt.PreprocessedSlateFeatureVector(
             float_features=self.possible_next_actions_state_concat[
                 :, state_dim:
             ].view(batch_size, -1, action_dim),
             item_mask=self.possible_next_actions_mask,
             item_probability=self.next_propensities,
         ),
         reward=self.rewards,
         reward_mask=self.rewards_mask,
         time_diff=self.time_diffs,
         step=self.step,
         not_terminal=self.not_terminal,
         extras=rlt.ExtraData(
             mdp_id=self.mdp_ids,
             sequence_number=self.sequence_numbers,
             action_probability=self.propensities,
             max_num_actions=self.max_num_actions,
             metrics=self.metrics,
         ),
     )
    def __call__(self, batch):
        action = batch.action
        if self.num_actions is not None:
            assert len(action.shape) == 2, f"{action.shape}"
            # one hot makes shape (batch_size, stack_size, feature_dim)
            action = F.one_hot(batch.action, self.num_actions).float()
            # make shape to (batch_size, feature_dim, stack_size)
            action = action.transpose(1, 2)

        # For (1-dimensional) vector fields, RB returns (batch_size, state_dim)
        # or (batch_size, state_dim, stack_size).
        # We want these to all be (stack_size, batch_size, state_dim), so
        # unsqueeze the former case; Note this only happens for stack_size = 1.
        # Then, permute.
        permutation = [2, 0, 1]
        vector_fields = {
            "state": batch.state,
            "action": action,
            "next_state": batch.next_state,
        }
        for name, tensor in vector_fields.items():
            if len(tensor.shape) == 2:
                tensor = tensor.unsqueeze(2)
            assert len(tensor.shape) == 3, f"{name} has shape {tensor.shape}"
            vector_fields[name] = tensor.permute(permutation)

        # For scalar fields, RB returns (batch_size), or (batch_size, stack_size)
        # Do same as above, except transpose instead.
        scalar_fields = {
            "reward": batch.reward,
            "not_terminal": 1.0 - batch.terminal.float(),
        }
        for name, tensor in scalar_fields.items():
            if len(tensor.shape) == 1:
                tensor = tensor.unsqueeze(1)
            assert len(tensor.shape) == 2, f"{name} has shape {tensor.shape}"
            scalar_fields[name] = tensor.transpose(0, 1)

        # stack_size > 1, so let's pad not_terminal with 1's, since
        # previous states couldn't have been terminal..
        if scalar_fields["reward"].shape[0] > 1:
            batch_size = scalar_fields["reward"].shape[1]
            assert scalar_fields["not_terminal"].shape == (
                1,
                batch_size,
            ), f"{scalar_fields['not_terminal'].shape}"
            stacked_not_terminal = torch.ones_like(scalar_fields["reward"])
            stacked_not_terminal[-1] = scalar_fields["not_terminal"]
            scalar_fields["not_terminal"] = stacked_not_terminal

        return rlt.MemoryNetworkInput(
            state=rlt.FeatureData(float_features=vector_fields["state"]),
            next_state=rlt.FeatureData(
                float_features=vector_fields["next_state"]),
            action=vector_fields["action"],
            reward=scalar_fields["reward"],
            not_terminal=scalar_fields["not_terminal"],
            step=None,
            time_diff=None,
        )
    def __call__(self, batch):
        not_terminal = 1.0 - batch.terminal.float()
        action, next_action = one_hot_actions(self.num_actions, batch.action,
                                              batch.next_action,
                                              batch.terminal)
        if self.trainer_preprocessor is not None:
            state = self.trainer_preprocessor(batch.state)
            next_state = self.trainer_preprocessor(batch.next_state)
        else:
            state = rlt.FeatureData(float_features=batch.state)
            next_state = rlt.FeatureData(float_features=batch.next_state)

        return rlt.DiscreteDqnInput(
            state=state,
            action=action,
            next_state=next_state,
            next_action=next_action,
            possible_actions_mask=torch.ones_like(action).float(),
            possible_next_actions_mask=torch.ones_like(next_action).float(),
            reward=batch.reward,
            not_terminal=not_terminal,
            step=None,
            time_diff=None,
            extras=rlt.ExtraData(
                mdp_id=None,
                sequence_number=None,
                action_probability=batch.log_prob.exp(),
                max_num_actions=None,
                metrics=None,
            ),
        )
Exemple #14
0
def get_Q(seq2reward_network, batch: rlt.MemoryNetworkInput,
          all_permut: torch.Tensor) -> torch.Tensor:
    batch_size = batch.state.float_features.shape[1]
    _, num_permut, num_action = all_permut.shape
    num_permut_per_action = int(num_permut / num_action)

    preprocessed_state = (
        batch.state.float_features[0].unsqueeze(0).repeat_interleave(
            num_permut, dim=1))
    state_feature_vector = rlt.FeatureData(preprocessed_state)

    # expand action to match the expanded state sequence
    action = rlt.FeatureData(all_permut.repeat(1, batch_size, 1))
    acc_reward = seq2reward_network(state_feature_vector,
                                    action).acc_reward.reshape(
                                        batch_size, num_action,
                                        num_permut_per_action)

    # The permuations are generated with lexical order
    # the output has shape [num_perm, num_action,1]
    # that means we can aggregate on the max reward
    # then reshape it to (BATCH_SIZE, ACT_DIM)
    max_acc_reward = (
        # pyre-fixme[16]: `Tuple` has no attribute `values`.
        torch.max(acc_reward,
                  dim=2).values.detach().reshape(batch_size, num_action))

    return max_acc_reward
 def __call__(self, batch):
     not_terminal = 1.0 - batch.terminal.float()
     action = F.one_hot(batch.action, self.num_actions).squeeze(1).float()
     # next action is garbage for terminal transitions (so just zero them)
     next_action = torch.zeros_like(action)
     non_terminal_indices = (batch.terminal == 0).squeeze(1)
     next_action[non_terminal_indices] = (F.one_hot(
         batch.next_action[non_terminal_indices],
         self.num_actions).squeeze(1).float())
     return rlt.DiscreteDqnInput(
         state=rlt.FeatureData(float_features=batch.state),
         action=action,
         next_state=rlt.FeatureData(float_features=batch.next_state),
         next_action=next_action,
         possible_actions_mask=torch.ones_like(action).float(),
         possible_next_actions_mask=torch.ones_like(next_action).float(),
         reward=batch.reward,
         not_terminal=not_terminal,
         step=None,
         time_diff=None,
         extras=rlt.ExtraData(
             mdp_id=None,
             sequence_number=None,
             action_probability=batch.log_prob.exp(),
             max_num_actions=None,
             metrics=None,
         ),
     )
Exemple #16
0
 def internal_reward_estimation(self, state, action):
     """
     Only used by Gym
     """
     self.reward_network.eval()
     reward_estimates = self.reward_network(rlt.FeatureData(state),
                                            rlt.FeatureData(action))
     self.reward_network.train()
     return reward_estimates.cpu()
Exemple #17
0
 def internal_prediction(self, state, action):
     """
     Only used by Gym
     """
     self.q_network.eval()
     q_values = self.q_network(rlt.FeatureData(state),
                               rlt.FeatureData(action))
     self.q_network.train()
     return q_values.cpu()
 def as_policy_network_training_batch(self):
     return rlt.PolicyNetworkInput(
         state=rlt.FeatureData(float_features=self.states),
         action=rlt.FeatureData(float_features=self.actions),
         next_state=rlt.FeatureData(float_features=self.next_states),
         next_action=rlt.FeatureData(float_features=self.next_actions),
         reward=self.rewards,
         not_terminal=self.not_terminal,
         step=self.step,
         time_diff=self.time_diffs,
         extras=rlt.ExtraData(),
     )
Exemple #19
0
 def forward(
     self,
     state_with_presence: Tuple[torch.Tensor, torch.Tensor],
     action_with_presence: Tuple[torch.Tensor, torch.Tensor],
 ):
     preprocessed_state = self.state_preprocessor(state_with_presence[0],
                                                  state_with_presence[1])
     preprocessed_action = self.action_preprocessor(action_with_presence[0],
                                                    action_with_presence[1])
     state = rlt.FeatureData(preprocessed_state)
     action = rlt.FeatureData(preprocessed_action)
     q_value = self.model(state, action)
     return q_value
    def get_Q(
        self,
        batch: rlt.MemoryNetworkInput,
        batch_size: int,
        seq_len: int,
        num_action: int,
    ) -> torch.Tensor:
        if not self.view_q_value:
            return torch.zeros(batch_size, num_action)
        try:
            # pyre-fixme[16]: `Seq2RewardTrainer` has no attribute `all_permut`.
            self.all_permut
        except AttributeError:

            def gen_permutations(seq_len: int,
                                 num_action: int) -> torch.Tensor:
                """
                generate all seq_len permutations for a given action set
                the return shape is (SEQ_LEN, PERM_NUM, ACTION_DIM)
                """
                all_permut = torch.cartesian_prod(*[torch.arange(num_action)] *
                                                  seq_len)
                all_permut = F.one_hot(all_permut, num_action).transpose(0, 1)
                return all_permut.float()

            self.all_permut = gen_permutations(seq_len, num_action)
            # pyre-fixme[16]: `Seq2RewardTrainer` has no attribute `num_permut`.
            self.num_permut = self.all_permut.size(1)

        preprocessed_state = batch.state.float_features.repeat(
            1, self.num_permut, 1)
        state_feature_vector = rlt.FeatureData(preprocessed_state)

        # expand action to match the expanded state sequence
        action = self.all_permut.repeat(1, batch_size, 1)
        reward = self.seq2reward_network(
            state_feature_vector, rlt.FeatureData(action)).acc_reward.reshape(
                batch_size, num_action, self.num_permut // num_action)

        # The permuations are generated with lexical order
        # the output has shape [num_perm, num_action,1]
        # that means we can aggregate on the max reward
        # then reshape it to (BATCH_SIZE, ACT_DIM)
        max_reward = (
            # pyre-fixme[16]: `Tuple` has no attribute `values`.
            torch.max(reward,
                      2).values.cpu().detach().reshape(batch_size, num_action))

        return max_reward
Exemple #21
0
    def acc_rewards_of_one_solution(
        self, init_state: torch.Tensor, solution: torch.Tensor, solution_idx: int
    ):
        """
        ensemble_pop_size trajectories will be sampled to evaluate a
        CEM solution. Each trajectory is generated by one world model

        :param init_state: its shape is (state_dim, )
        :param solution: its shape is (plan_horizon_length, action_dim)
        :param solution_idx: the index of the solution
        :return reward: Reward of each of ensemble_pop_size trajectories
        """
        reward_matrix = np.zeros((self.ensemble_pop_size, self.plan_horizon_length))

        for i in range(self.ensemble_pop_size):
            state = init_state
            mem_net_idx = np.random.randint(0, len(self.mem_net_list))
            for j in range(self.plan_horizon_length):
                # state shape:
                # (1, 1, state_dim)
                # action shape:
                # (1, 1, action_dim)
                (
                    reward,
                    next_state,
                    not_terminal,
                    not_terminal_prob,
                ) = self.sample_reward_next_state_terminal(
                    state=rlt.FeatureData(state.reshape((1, 1, self.state_dim))),
                    action=rlt.FeatureData(
                        solution[j, :].reshape((1, 1, self.action_dim))
                    ),
                    mem_net=self.mem_net_list[mem_net_idx],
                )
                reward_matrix[i, j] = reward * (self.gamma ** j)

                if not not_terminal:
                    logger.debug(
                        f"Solution {solution_idx}: predict terminal at step {j}"
                        f" with prob. {1.0 - not_terminal_prob}"
                    )

                if not not_terminal:
                    break

                state = next_state

        return np.sum(reward_matrix, axis=1)
    def __call__(self, obs):
        user = torch.tensor(obs["user"]).float().unsqueeze(0)

        doc_obs = obs["doc"]

        if self.discrete_keys or self.box_keys:
            # Dict space
            discrete_features: List[torch.Tensor] = []
            for k, n in self.discrete_keys:
                vals = torch.tensor([v[k] for v in doc_obs.values()])
                assert vals.shape == (self.num_docs, )
                discrete_features.append(F.one_hot(vals, n).float())

            box_features: List[torch.Tensor] = []
            for k, d in self.box_keys:
                vals = np.vstack([v[k] for v in doc_obs.values()])
                assert vals.shape == (self.num_docs, d)
                box_features.append(torch.tensor(vals).float())

            doc_features = torch.cat(discrete_features + box_features,
                                     dim=1).unsqueeze(0)
        else:
            # Simply a Box space
            vals = np.vstack(list(doc_obs.values()))
            doc_features = torch.tensor(vals).float().unsqueeze(0)

        # This comes from ValueWrapper
        value = (torch.tensor([
            v["value"] for v in obs["augmentation"].values()
        ]).float().unsqueeze(0))

        candidate_docs = rlt.DocList(float_features=doc_features, value=value)
        return rlt.FeatureData(float_features=user,
                               candidate_docs=candidate_docs)
Exemple #23
0
 def obs_preprocessor(self, obs: np.ndarray) -> rlt.FeatureData:
     # pyre-fixme[16]: `Gym` has no attribute `observation_space`.
     obs_space = self.observation_space
     if isinstance(obs_space, spaces.Box):
         return rlt.FeatureData(torch.tensor(obs).float().unsqueeze(0))
     else:
         raise NotImplementedError(f"{obs_space} obs space not supported for Gym.")
    def get_loss(self, training_batch: rlt.MemoryNetworkInput):
        """
        Compute losses:
            MSE(predicted_acc_reward, target_acc_reward)

        :param training_batch:
            training_batch has these fields:
            - state: (SEQ_LEN, BATCH_SIZE, STATE_DIM) torch tensor
            - action: (SEQ_LEN, BATCH_SIZE, ACTION_DIM) torch tensor
            - reward: (SEQ_LEN, BATCH_SIZE) torch tensor

        :returns: mse loss on reward
        """

        seq2reward_output = self.seq2reward_network(
            training_batch.state, rlt.FeatureData(training_batch.action))

        predicted_acc_reward = seq2reward_output.acc_reward
        target_rewards = training_batch.reward
        target_acc_reward = torch.sum(target_rewards, 0).unsqueeze(1)
        # make sure the prediction and target tensors have the same size
        # the size should both be (BATCH_SIZE, 1) in this case.
        assert predicted_acc_reward.size() == target_acc_reward.size()
        mse = F.mse_loss(predicted_acc_reward, target_acc_reward)
        return mse
Exemple #25
0
 def forward(
     self,
     state: torch.Tensor,
     src_seq: torch.Tensor,
     tgt_out_seq: torch.Tensor,
     src_src_mask: torch.Tensor,
     tgt_out_idx: torch.Tensor,
 ) -> torch.Tensor:
     return self.model(
         rlt.PreprocessedRankingInput(
             state=rlt.FeatureData(float_features=state),
             src_seq=rlt.FeatureData(float_features=src_seq),
             tgt_out_seq=rlt.FeatureData(float_features=tgt_out_seq),
             src_src_mask=src_src_mask,
             tgt_out_idx=tgt_out_idx,
         )).predicted_reward
Exemple #26
0
 def forward(self, state_with_presence: Tuple[torch.Tensor, torch.Tensor]):
     preprocessed_state = self.state_preprocessor(
         state_with_presence[0], state_with_presence[1]
     )
     state_feature_vector = rlt.FeatureData(preprocessed_state)
     q_values = self.model(state_feature_vector)
     return q_values
Exemple #27
0
    def get_loss(self, training_batch: rlt.MemoryNetworkInput):
        """
        Compute losses:
            MSE(predicted_acc_reward, target_acc_reward)

        :param training_batch:
            training_batch has these fields:
            - state: (SEQ_LEN, BATCH_SIZE, STATE_DIM) torch tensor
            - action: (SEQ_LEN, BATCH_SIZE, ACTION_DIM) torch tensor
            - reward: (SEQ_LEN, BATCH_SIZE) torch tensor

        :returns: mse loss on reward
        """

        seq2reward_output = self.seq2reward_network(
            training_batch.state, rlt.FeatureData(training_batch.action))

        predicted_acc_reward = seq2reward_output.acc_reward
        target_rewards = training_batch.reward
        seq_len, batch_size = target_rewards.size()
        gamma = self.params.gamma
        gamma_mask = torch.Tensor([[gamma**i for i in range(seq_len)]
                                   for _ in range(batch_size)
                                   ]).transpose(0, 1)
        target_acc_reward = torch.sum(target_rewards * gamma_mask,
                                      0).unsqueeze(1)
        # make sure the prediction and target tensors have the same size
        # the size should both be (BATCH_SIZE, 1) in this case.
        assert (predicted_acc_reward.size() == target_acc_reward.size()
                ), f"{predicted_acc_reward.size()}!={target_acc_reward.size()}"
        mse = F.mse_loss(predicted_acc_reward, target_acc_reward)
        return mse
Exemple #28
0
    def test_discrete_wrapper(self):
        ids = range(1, 5)
        state_normalization_parameters = {i: _cont_norm() for i in ids}
        state_preprocessor = Preprocessor(state_normalization_parameters,
                                          False)
        action_dim = 2
        dqn = models.FullyConnectedDQN(
            state_dim=len(state_normalization_parameters),
            action_dim=action_dim,
            sizes=[16],
            activations=["relu"],
        )
        state_feature_config = rlt.ModelFeatureConfig(float_feature_infos=[
            rlt.FloatFeatureInfo(feature_id=i, name=f"feat_{i}") for i in ids
        ])
        dqn_with_preprocessor = DiscreteDqnWithPreprocessor(
            dqn, state_preprocessor, state_feature_config)
        action_names = ["L", "R"]
        wrapper = DiscreteDqnPredictorWrapper(dqn_with_preprocessor,
                                              action_names,
                                              state_feature_config)
        input_prototype = dqn_with_preprocessor.input_prototype()[0]
        output_action_names, q_values = wrapper(input_prototype)
        self.assertEqual(action_names, output_action_names)
        self.assertEqual(q_values.shape, (1, 2))

        state_with_presence = input_prototype.float_features_with_presence
        expected_output = dqn(
            rlt.FeatureData(state_preprocessor(*state_with_presence)))
        self.assertTrue((expected_output == q_values).all())
Exemple #29
0
    def test_actor_wrapper(self):
        state_normalization_parameters = {i: _cont_norm() for i in range(1, 5)}
        action_normalization_parameters = {
            i: _cont_action_norm()
            for i in range(101, 105)
        }
        state_preprocessor = Preprocessor(state_normalization_parameters,
                                          False)
        postprocessor = Postprocessor(action_normalization_parameters, False)

        # Test with FullyConnectedActor to make behavior deterministic
        actor = models.FullyConnectedActor(
            state_dim=len(state_normalization_parameters),
            action_dim=len(action_normalization_parameters),
            sizes=[16],
            activations=["relu"],
        )
        actor_with_preprocessor = ActorWithPreprocessor(
            actor, state_preprocessor, postprocessor)
        wrapper = ActorPredictorWrapper(actor_with_preprocessor)
        input_prototype = actor_with_preprocessor.input_prototype()
        action = wrapper(*input_prototype)
        self.assertEqual(action.shape,
                         (1, len(action_normalization_parameters)))

        expected_output = postprocessor(
            actor(rlt.FeatureData(
                state_preprocessor(*input_prototype[0]))).action)
        self.assertTrue((expected_output == action).all())
Exemple #30
0
 def forward(
     self,
     state_with_presence: Tuple[torch.Tensor, torch.Tensor],
     candidate_with_presence_list: List[Tuple[torch.Tensor, torch.Tensor]],
 ):
     assert (
         len(candidate_with_presence_list) == self.num_candidates
     ), f"{len(candidate_with_presence_list)} != {self.num_candidates}"
     preprocessed_state = self.state_preprocessor(*state_with_presence)
     # each is batch_size x candidate_dim, result is batch_size x num_candidates x candidate_dim
     preprocessed_candidates = torch.stack(
         [
             self.candidate_preprocessor(*x)
             for x in candidate_with_presence_list
         ],
         dim=1,
     )
     input = rlt.FeatureData(
         float_features=preprocessed_state,
         candidate_docs=rlt.DocList(
             float_features=preprocessed_candidates,
             mask=torch.tensor(-1),
             value=torch.tensor(-1),
         ),
     )
     input = rlt._embed_states(input)
     action = self.model(input).action
     if self.action_postprocessor is not None:
         # pyre-fixme[29]: `Optional[Postprocessor]` is not a function.
         action = self.action_postprocessor(action)
     return action