Example #1
0
    def extract(self, ws, input_record, extract_record):
        def fetch(b):
            data = ws.fetch_blob(str(b()))
            return torch.tensor(data)

        def fetch_action(b):
            if self.sorted_action_features is None:
                return fetch(b)
            else:
                return mt.FeatureVector(float_features=fetch(b))

        state = mt.FeatureVector(float_features=fetch(extract_record.state))
        action = fetch_action(extract_record.action)
        reward = fetch(input_record.reward).reshape(-1, 1)

        # is_terminal should be filled by preprocessor
        if self.max_q_learning:
            if self.sorted_action_features is not None:
                next_state = None
                tiled_next_state = mt.FeatureVector(
                    float_features=fetch(extract_record.tiled_next_state))
            else:
                next_state = mt.FeatureVector(
                    float_features=fetch(extract_record.next_state))
                tiled_next_state = None
            possible_next_actions = mt.PossibleActions(
                lengths=fetch(extract_record.possible_next_actions["lengths"]),
                actions=fetch_action(
                    extract_record.possible_next_actions["values"]),
            )

            training_input = mt.MaxQLearningInput(
                state=state,
                action=action,
                next_state=next_state,
                tiled_next_state=tiled_next_state,
                possible_next_actions=possible_next_actions,
                reward=reward,
                not_terminal=(possible_next_actions.lengths >
                              0).float().reshape(-1, 1),
            )
        else:
            next_state = mt.FeatureVector(
                float_features=fetch(extract_record.next_state))
            next_action = fetch_action(extract_record.next_action)
            training_input = mt.SARSAInput(
                state=state,
                action=action,
                next_state=next_state,
                next_action=next_action,
                reward=reward,
                # HACK: Need a better way to check this
                not_terminal=torch.ones_like(reward),
            )

        # TODO: stuff other fields in here
        extras = mt.ExtraData(action_probability=fetch(
            input_record.action_probability).reshape(-1, 1))

        return mt.TrainingBatch(training_input=training_input, extras=extras)
Example #2
0
 def as_parametric_maxq_training_batch(self):
     state_dim = self.states.shape[1]
     return rlt.PreprocessedTrainingBatch(
         training_input=rlt.PreprocessedParametricDqnInput(
             state=rlt.PreprocessedFeatureVector(float_features=self.states),
             action=rlt.PreprocessedFeatureVector(float_features=self.actions),
             next_state=rlt.PreprocessedFeatureVector(
                 float_features=self.next_states
             ),
             next_action=rlt.PreprocessedFeatureVector(
                 float_features=self.next_actions
             ),
             tiled_next_state=rlt.PreprocessedFeatureVector(
                 float_features=self.possible_next_actions_state_concat[
                     :, :state_dim
                 ]
             ),
             possible_actions=None,
             possible_actions_mask=self.possible_actions_mask,
             possible_next_actions=rlt.PreprocessedFeatureVector(
                 float_features=self.possible_next_actions_state_concat[
                     :, state_dim:
                 ]
             ),
             possible_next_actions_mask=self.possible_next_actions_mask,
             reward=self.rewards,
             not_terminal=self.not_terminal,
             step=self.step,
             time_diff=self.time_diffs,
         ),
         extras=rlt.ExtraData(),
     )
Example #3
0
 def as_discrete_maxq_training_batch(self):
     return rlt.TrainingBatch(
         training_input=rlt.MaxQLearningInput(
             state=rlt.FeatureVector(float_features=self.states),
             action=self.actions,
             next_state=rlt.FeatureVector(float_features=self.next_states),
             next_action=self.next_actions,
             tiled_next_state=None,
             possible_actions=None,
             possible_actions_mask=self.possible_actions_mask,
             possible_next_actions=None,
             possible_next_actions_mask=self.possible_next_actions_mask,
             reward=self.rewards,
             not_terminal=self.not_terminal,
             step=self.step,
             time_diff=self.time_diffs,
         ),
         extras=rlt.ExtraData(
             mdp_id=self.mdp_ids,
             sequence_number=self.sequence_numbers,
             action_probability=self.propensities,
             max_num_actions=self.max_num_actions,
             metrics=self.metrics,
         ),
     )
Example #4
0
 def preprocess_batch(train_batch: Any) -> rlt.PreprocessedTrainingBatch:
     obs, action, reward, next_obs, next_action, next_reward, terminal, idxs, possible_actions_mask, log_prob = (
         train_batch)
     obs = torch.tensor(obs).squeeze(2)
     action = torch.tensor(action).float()
     reward = torch.tensor(reward).unsqueeze(1)
     next_obs = torch.tensor(next_obs).squeeze(2)
     next_action = torch.tensor(next_action)
     not_terinal = 1.0 - torch.tensor(terminal).unsqueeze(1).float()
     idxs = torch.tensor(idxs)
     possible_actions_mask = torch.tensor(possible_actions_mask).float()
     log_prob = torch.tensor(log_prob)
     return rlt.PreprocessedTrainingBatch(
         training_input=rlt.PreprocessedPolicyNetworkInput(
             state=rlt.PreprocessedFeatureVector(float_features=obs),
             action=rlt.PreprocessedFeatureVector(float_features=action),
             next_state=rlt.PreprocessedFeatureVector(
                 float_features=next_obs),
             next_action=rlt.PreprocessedFeatureVector(
                 float_features=next_action),
             reward=reward,
             not_terminal=not_terinal,
             step=None,
             time_diff=None,
         ),
         extras=rlt.ExtraData(),
     )
 def setup_extra_data(self, ws, input_record):
     extra_data = rlt.ExtraData(
         action_probability=np.array([0.11, 0.21, 0.13], dtype=np.float32)
     )
     ws.feed_blob(
         str(input_record.action_probability()), extra_data.action_probability
     )
     return extra_data
Example #6
0
 def as_parametric_sarsa_training_batch(self):
     return rlt.TrainingBatch(
         training_input=rlt.SARSAInput(
             state=rlt.FeatureVector(float_features=self.states),
             action=rlt.FeatureVector(float_features=self.actions),
             next_state=rlt.FeatureVector(float_features=self.next_states),
             next_action=rlt.FeatureVector(float_features=self.next_actions),
             reward=self.rewards,
             not_terminal=self.not_terminals,
         ),
         extras=rlt.ExtraData(),
     )
Example #7
0
 def as_discrete_sarsa_training_batch(self):
     return rlt.TrainingBatch(
         training_input=rlt.SARSAInput(
             state=rlt.FeatureVector(float_features=self.states),
             action=self.actions,
             next_state=rlt.FeatureVector(float_features=self.next_states),
             next_action=self.next_actions,
             reward=self.rewards,
             not_terminal=self.not_terminal,
             step=self.step,
             time_diff=self.time_diffs,
         ),
         extras=rlt.ExtraData(),
     )
Example #8
0
 def as_policy_network_training_batch(self):
     return rlt.TrainingBatch(
         training_input=rlt.PolicyNetworkInput(
             state=rlt.FeatureVector(float_features=self.states),
             action=rlt.FeatureVector(float_features=self.actions),
             next_state=rlt.FeatureVector(float_features=self.next_states),
             next_action=rlt.FeatureVector(
                 float_features=self.next_actions),
             reward=self.rewards,
             not_terminal=self.not_terminal,
             step=self.step,
             time_diff=self.time_diffs,
         ),
         extras=rlt.ExtraData(),
     )
Example #9
0
    def preprocess(self, batch) -> rlt.RawTrainingBatch:
        state_features_dense, state_features_dense_presence = self.sparse_to_dense_processor(
            batch["state_features"]
        )
        next_state_features_dense, next_state_features_dense_presence = self.sparse_to_dense_processor(
            batch["next_state_features"]
        )

        mdp_ids = np.array(batch["mdp_id"]).reshape(-1, 1)
        sequence_numbers = torch.tensor(
            batch["sequence_number"], dtype=torch.int32
        ).reshape(-1, 1)
        rewards = torch.tensor(batch["reward"], dtype=torch.float32).reshape(-1, 1)
        time_diffs = torch.tensor(batch["time_diff"], dtype=torch.int32).reshape(-1, 1)
        if "action_probability" in batch:
            propensities = torch.tensor(
                batch["action_probability"], dtype=torch.float32
            ).reshape(-1, 1)
        else:
            propensities = torch.ones(rewards.shape, dtype=torch.float32)

        return rlt.RawTrainingBatch(
            training_input=rlt.RawBaseInput(  # type: ignore
                state=rlt.FeatureVector(
                    float_features=rlt.ValuePresence(
                        value=state_features_dense,
                        presence=state_features_dense_presence,
                    )
                ),
                next_state=rlt.FeatureVector(
                    float_features=rlt.ValuePresence(
                        value=next_state_features_dense,
                        presence=next_state_features_dense_presence,
                    )
                ),
                reward=rewards,
                time_diff=time_diffs,
                step=None,
                not_terminal=None,
            ),
            extras=rlt.ExtraData(
                mdp_id=mdp_ids,
                sequence_number=sequence_numbers,
                action_probability=propensities,
            ),
        )
Example #10
0
 def as_slate_q_training_batch(self):
     batch_size, state_dim = self.states.shape
     action_dim = self.actions.shape[1]
     return rlt.PreprocessedTrainingBatch(
         training_input=rlt.PreprocessedSlateQInput(
             state=rlt.PreprocessedFeatureVector(
                 float_features=self.states),
             next_state=rlt.PreprocessedFeatureVector(
                 float_features=self.next_states),
             tiled_state=rlt.PreprocessedTiledFeatureVector(
                 float_features=self.
                 possible_actions_state_concat[:, :state_dim].view(
                     batch_size, -1, state_dim)),
             tiled_next_state=rlt.PreprocessedTiledFeatureVector(
                 float_features=self.
                 possible_next_actions_state_concat[:, :state_dim].view(
                     batch_size, -1, state_dim)),
             action=rlt.PreprocessedSlateFeatureVector(
                 float_features=self.
                 possible_actions_state_concat[:, state_dim:].view(
                     batch_size, -1, action_dim),
                 item_mask=self.possible_actions_mask,
                 item_probability=self.propensities,
             ),
             next_action=rlt.PreprocessedSlateFeatureVector(
                 float_features=self.
                 possible_next_actions_state_concat[:, state_dim:].view(
                     batch_size, -1, action_dim),
                 item_mask=self.possible_next_actions_mask,
                 item_probability=self.next_propensities,
             ),
             reward=self.rewards,
             reward_mask=self.rewards_mask,
             time_diff=self.time_diffs,
             step=self.step,
             not_terminal=self.not_terminal,
         ),
         extras=rlt.ExtraData(
             mdp_id=self.mdp_ids,
             sequence_number=self.sequence_numbers,
             action_probability=self.propensities,
             max_num_actions=self.max_num_actions,
             metrics=self.metrics,
         ),
     )
Example #11
0
 def as_discrete_sarsa_training_batch(self):
     return rlt.TrainingBatch(
         training_input=rlt.SARSAInput(
             state=rlt.FeatureVector(float_features=self.states),
             reward=self.rewards,
             time_diff=self.time_diffs,
             action=self.actions,
             next_action=self.next_actions,
             not_terminal=self.not_terminal,
             next_state=rlt.FeatureVector(float_features=self.next_states),
             step=self.step,
         ),
         extras=rlt.ExtraData(
             mdp_id=self.mdp_ids,
             sequence_number=self.sequence_numbers,
             action_probability=self.propensities,
             max_num_actions=self.max_num_actions,
             metrics=self.metrics,
         ),
     )
Example #12
0
    def as_cem_training_batch(self, batch_first=False):
        """
        Generate one-step samples needed by CEM trainer.
        The samples will be used to train an ensemble of world models used by CEM.

        If batch_first = True:
            state/next state shape: batch_size x 1 x state_dim
            action shape: batch_size x 1 x action_dim
            reward/terminal shape: batch_size x 1
        else (default):
             state/next state shape: 1 x batch_size x state_dim
             action shape: 1 x batch_size x action_dim
             reward/terminal shape: 1 x batch_size
        """
        if batch_first:
            seq_len_dim = 1
            reward, not_terminal = self.rewards, self.not_terminal
        else:
            seq_len_dim = 0
            reward, not_terminal = transpose(self.rewards, self.not_terminal)
        training_input = rlt.PreprocessedMemoryNetworkInput(
            state=rlt.PreprocessedFeatureVector(
                float_features=self.states.unsqueeze(seq_len_dim)),
            action=self.actions.unsqueeze(seq_len_dim),
            next_state=rlt.PreprocessedFeatureVector(
                float_features=self.next_states.unsqueeze(seq_len_dim)),
            reward=reward,
            not_terminal=not_terminal,
            step=self.step,
            time_diff=self.time_diffs,
        )
        return rlt.PreprocessedTrainingBatch(
            training_input=training_input,
            extras=rlt.ExtraData(
                mdp_id=self.mdp_ids,
                sequence_number=self.sequence_numbers,
                action_probability=self.propensities,
                max_num_actions=self.max_num_actions,
                metrics=self.metrics,
            ),
        )
Example #13
0
    def preprocess_batch(train_batch: Any) -> rlt.PreprocessedTrainingBatch:
        obs, action, reward, next_obs, next_action, next_reward, terminal, idxs, possible_actions_mask, log_prob = (
            train_batch)
        batch_size = obs.shape[0]

        obs = torch.tensor(obs).squeeze(2)
        action = torch.tensor(action).float()
        next_obs = torch.tensor(next_obs).squeeze(2)
        next_action = torch.tensor(next_action).to(torch.float32)
        reward = torch.tensor(reward).unsqueeze(1)
        not_terminal = 1 - torch.tensor(terminal).unsqueeze(1).to(torch.uint8)
        possible_actions_mask = torch.ones_like(action).to(torch.bool)

        tiled_next_state = torch.repeat_interleave(next_obs,
                                                   repeats=num_actions,
                                                   axis=0)
        possible_next_actions = torch.eye(num_actions).repeat(batch_size, 1)
        possible_next_actions_mask = not_terminal.repeat(1, num_actions).to(
            torch.bool)
        return rlt.PreprocessedTrainingBatch(
            rlt.PreprocessedParametricDqnInput(
                state=rlt.PreprocessedFeatureVector(float_features=obs),
                action=rlt.PreprocessedFeatureVector(float_features=action),
                next_state=rlt.PreprocessedFeatureVector(
                    float_features=next_obs),
                next_action=rlt.PreprocessedFeatureVector(
                    float_features=next_action),
                possible_actions=None,
                possible_actions_mask=possible_actions_mask,
                possible_next_actions=rlt.PreprocessedFeatureVector(
                    float_features=possible_next_actions),
                possible_next_actions_mask=possible_next_actions_mask,
                tiled_next_state=rlt.PreprocessedFeatureVector(
                    float_features=tiled_next_state),
                reward=reward,
                not_terminal=not_terminal,
                step=None,
                time_diff=None,
            ),
            extras=rlt.ExtraData(),
        )
Example #14
0
 def preprocess_batch(train_batch: Any) -> rlt.PreprocessedTrainingBatch:
     obs, action, reward, next_obs, next_action, next_reward, terminal, idxs, possible_actions_mask, log_prob = (
         train_batch)
     obs = torch.tensor(obs).squeeze(2)
     action = torch.tensor(action)
     reward = torch.tensor(reward).unsqueeze(1)
     next_obs = torch.tensor(next_obs).squeeze(2)
     next_action = torch.tensor(next_action)
     not_terminal = 1.0 - torch.tensor(terminal).unsqueeze(1).float()
     possible_actions_mask = torch.tensor(possible_actions_mask)
     next_possible_actions_mask = not_terminal.repeat(1, num_actions)
     log_prob = torch.tensor(log_prob)
     assert (
         action.size(1) == num_actions
     ), f"action size(1) is {action.size(1)} while num_actions is {num_actions}"
     return rlt.PreprocessedTrainingBatch(
         training_input=rlt.PreprocessedDiscreteDqnInput(
             state=rlt.PreprocessedFeatureVector(float_features=obs),
             action=action,
             next_state=rlt.PreprocessedFeatureVector(
                 float_features=next_obs),
             next_action=next_action,
             possible_actions_mask=possible_actions_mask,
             possible_next_actions_mask=next_possible_actions_mask,
             reward=reward,
             not_terminal=not_terminal,
             step=None,
             time_diff=None,
         ),
         extras=rlt.ExtraData(
             mdp_id=None,
             sequence_number=None,
             action_probability=log_prob.exp(),
             max_num_actions=None,
             metrics=None,
         ),
     )
    def test_seq2slate_eval_data_page(self):
        """
        Create 3 slate ranking logs and evaluate using Direct Method, Inverse
        Propensity Scores, and Doubly Robust.

        The logs are as follows:
        state: [1, 0, 0], [0, 1, 0], [0, 0, 1]
        indices in logged slates: [3, 2], [3, 2], [3, 2]
        model output indices: [2, 3], [3, 2], [2, 3]
        logged reward: 4, 5, 7
        logged propensities: 0.2, 0.5, 0.4
        predicted rewards on logged slates: 2, 4, 6
        predicted rewards on model outputted slates: 1, 4, 5

        Direct Method uses the predicted rewards on model outputted slates.
        Thus the result is expected to be (1 + 4 + 5) / 3

        Inverse Propensity Scores would scale the reward by 1.0 / logged propensities
        whenever the model output slate matches with the logged slate.
        Since only the second log matches with the model output, the IPS result
        is expected to be 5 / 0.5 / 3

        Doubly Robust is the sum of the direct method result and propensity-scaled
        reward difference; the latter is defined as:
        1.0 / logged_propensities * (logged reward - predicted reward on logged slate)
         * Indicator(model slate == logged slate)
        Since only the second logged slate matches with the model outputted slate,
        the DR result is expected to be (1 + 4 + 5) / 3 + 1.0 / 0.5 * (5 - 4) / 3
        """
        batch_size = 3
        state_dim = 3
        src_seq_len = 2
        tgt_seq_len = 2
        candidate_dim = 2

        reward_net = FakeSeq2SlateRewardNetwork()
        seq2slate_net = FakeSeq2SlateTransformerNet()
        baseline_net = nn.Linear(1, 1)
        trainer = Seq2SlateTrainer(
            seq2slate_net,
            baseline_net,
            parameters=None,
            minibatch_size=3,
            use_gpu=False,
        )

        src_seq = torch.eye(candidate_dim).repeat(batch_size, 1, 1)
        tgt_out_idx = torch.LongTensor([[3, 2], [3, 2], [3, 2]])
        tgt_out_seq = src_seq[torch.arange(batch_size).
                              repeat_interleave(tgt_seq_len),  # type: ignore
                              tgt_out_idx.flatten() - 2, ].reshape(
                                  batch_size, tgt_seq_len, candidate_dim)

        ptb = rlt.PreprocessedTrainingBatch(
            training_input=rlt.PreprocessedRankingInput(
                state=rlt.PreprocessedFeatureVector(
                    float_features=torch.eye(state_dim)),
                src_seq=rlt.PreprocessedFeatureVector(float_features=src_seq),
                tgt_out_seq=rlt.PreprocessedFeatureVector(
                    float_features=tgt_out_seq),
                src_src_mask=torch.ones(batch_size, src_seq_len, src_seq_len),
                tgt_out_idx=tgt_out_idx,
                tgt_out_probs=torch.tensor([0.2, 0.5, 0.4]),
                slate_reward=torch.tensor([4.0, 5.0, 7.0]),
            ),
            extras=rlt.ExtraData(
                sequence_number=torch.tensor([0, 0, 0]),
                mdp_id=np.array(["0", "1", "2"]),
            ),
        )

        edp = EvaluationDataPage.create_from_training_batch(
            ptb, trainer, reward_net)
        doubly_robust_estimator = DoublyRobustEstimator()
        direct_method, inverse_propensity, doubly_robust = doubly_robust_estimator.estimate(
            edp)
        logger.info(f"{direct_method}, {inverse_propensity}, {doubly_robust}")

        avg_logged_reward = (4 + 5 + 7) / 3
        self.assertAlmostEqual(direct_method.raw, (1 + 4 + 5) / 3, delta=1e-6)
        self.assertAlmostEqual(direct_method.normalized,
                               direct_method.raw / avg_logged_reward,
                               delta=1e-6)
        self.assertAlmostEqual(inverse_propensity.raw, 5 / 0.5 / 3, delta=1e-6)
        self.assertAlmostEqual(
            inverse_propensity.normalized,
            inverse_propensity.raw / avg_logged_reward,
            delta=1e-6,
        )
        self.assertAlmostEqual(doubly_robust.raw,
                               direct_method.raw + 1 / 0.5 * (5 - 4) / 3,
                               delta=1e-6)
        self.assertAlmostEqual(doubly_robust.normalized,
                               doubly_robust.raw / avg_logged_reward,
                               delta=1e-6)
Example #16
0
    def extract(self, ws, input_record, extract_record):
        def fetch(b):
            data = ws.fetch_blob(str(b()))
            return torch.tensor(data)

        def fetch_action(b):
            if self.sorted_action_features is None:
                return fetch(b)
            else:
                return mt.FeatureVector(float_features=fetch(b))

        def fetch_possible_actions(b):
            if self.sorted_action_features is not None:
                return mt.FeatureVector(float_features=fetch(b))
            else:
                return None

        state = mt.FeatureVector(
            float_features=fetch(extract_record.state_features))
        next_state = mt.FeatureVector(
            float_features=fetch(extract_record.next_state_features))

        action = fetch_action(extract_record.action)
        next_action = fetch_action(extract_record.next_action)
        if self.multi_steps is not None:
            step = fetch(input_record.step).reshape(-1, 1)
        else:
            step = None
        reward = fetch(input_record.reward).reshape(-1, 1)

        # is_terminal should be filled by preprocessor
        not_terminal = fetch(input_record.not_terminal).reshape(-1, 1)
        time_diff = fetch(input_record.time_diff).reshape(-1, 1)

        if self.include_possible_actions:
            # TODO: this will need to be more complicated to support sparse features
            assert self.max_num_actions is not None, "Missing max_num_actions"
            possible_actions_mask = fetch(
                extract_record.possible_actions_mask).reshape(
                    -1, self.max_num_actions)
            possible_next_actions_mask = fetch(
                extract_record.possible_next_actions_mask).reshape(
                    -1, self.max_num_actions)

            if self.sorted_action_features is not None:
                possible_actions = fetch_possible_actions(
                    extract_record.possible_actions)
                possible_next_actions = fetch_possible_actions(
                    extract_record.possible_next_actions)
                tiled_next_state = mt.FeatureVector(
                    float_features=next_state.float_features.repeat(
                        1, self.max_num_actions).reshape(
                            -1, next_state.float_features.shape[1]))
            else:
                possible_actions = None
                possible_next_actions = None
                tiled_next_state = None

            training_input = mt.MaxQLearningInput(
                state=state,
                action=action,
                next_state=next_state,
                tiled_next_state=tiled_next_state,
                possible_actions=possible_actions,
                possible_actions_mask=possible_actions_mask,
                possible_next_actions=possible_next_actions,
                possible_next_actions_mask=possible_next_actions_mask,
                next_action=next_action,
                reward=reward,
                not_terminal=not_terminal,
                step=step,
                time_diff=time_diff,
            )
        else:
            training_input = mt.SARSAInput(
                state=state,
                action=action,
                next_state=next_state,
                next_action=next_action,
                reward=reward,
                not_terminal=not_terminal,
                step=step,
                time_diff=time_diff,
            )

        # TODO: stuff other fields in here
        extras = mt.ExtraData(action_probability=fetch(
            input_record.action_probability).reshape(-1, 1))

        return mt.TrainingBatch(training_input=training_input, extras=extras)
Example #17
0
    def extract(self, ws, input_record, extract_record):
        def fetch(b, to_torch=True):
            data = ws.fetch_blob(str(b()))
            if not isinstance(data, np.ndarray):
                # Blob uninitialized, return None and handle downstream
                return None
            if to_torch:
                return torch.tensor(data)
            return data

        def fetch_action(b):
            if self.sorted_action_features is None:
                return fetch(b)
            else:
                return mt.FeatureVector(float_features=fetch(b))

        def fetch_possible_actions(b):
            if self.sorted_action_features is not None:
                return mt.FeatureVector(float_features=fetch(b))
            else:
                return None

        state = mt.FeatureVector(
            float_features=fetch(extract_record.state_features))
        next_state = mt.FeatureVector(
            float_features=fetch(extract_record.next_state_features))

        action = fetch_action(extract_record.action)
        next_action = fetch_action(extract_record.next_action)
        max_num_actions = None
        step = None
        if self.multi_steps is not None:
            step = fetch(input_record.step).reshape(-1, 1)
        reward = fetch(input_record.reward).reshape(-1, 1)

        # is_terminal should be filled by preprocessor
        not_terminal = fetch(input_record.not_terminal).reshape(-1, 1)
        time_diff = fetch(input_record.time_diff).reshape(-1, 1)

        if self.include_possible_actions:
            # TODO: this will need to be more complicated to support sparse features
            assert self.max_num_actions is not None, "Missing max_num_actions"
            possible_actions_mask = (fetch(
                extract_record.possible_actions_mask).reshape(
                    -1, self.max_num_actions).type(torch.FloatTensor))
            possible_next_actions_mask = fetch(
                extract_record.possible_next_actions_mask).reshape(
                    -1, self.max_num_actions)

            if self.sorted_action_features is not None:
                possible_actions = fetch_possible_actions(
                    extract_record.possible_actions)
                possible_next_actions = fetch_possible_actions(
                    extract_record.possible_next_actions)
                tiled_next_state = mt.FeatureVector(
                    float_features=next_state.float_features.repeat(
                        1, self.max_num_actions).reshape(
                            -1, next_state.float_features.shape[1]))
                max_num_actions = self.max_num_actions

            else:
                possible_actions = None
                possible_next_actions = None
                tiled_next_state = None

            training_input = mt.MaxQLearningInput(
                state=state,
                action=action,
                next_state=next_state,
                tiled_next_state=tiled_next_state,
                possible_actions=possible_actions,
                possible_actions_mask=possible_actions_mask,
                possible_next_actions=possible_next_actions,
                possible_next_actions_mask=possible_next_actions_mask,
                next_action=next_action,
                reward=reward,
                not_terminal=not_terminal,
                step=step,
                time_diff=time_diff,
            )
        else:
            training_input = mt.SARSAInput(
                state=state,
                action=action,
                next_state=next_state,
                next_action=next_action,
                reward=reward,
                not_terminal=not_terminal,
                step=step,
                time_diff=time_diff,
            )

        mdp_id = fetch(input_record.mdp_id, to_torch=False)
        sequence_number = fetch(input_record.sequence_number)

        metrics = fetch(
            extract_record.metrics) if self.metrics_to_score else None

        # TODO: stuff other fields in here
        extras = mt.ExtraData(
            action_probability=fetch(input_record.action_probability).reshape(
                -1, 1),
            sequence_number=sequence_number.reshape(-1, 1)
            if sequence_number is not None else None,
            mdp_id=mdp_id.reshape(-1, 1) if mdp_id is not None else None,
            max_num_actions=max_num_actions,
            metrics=metrics,
        )

        return mt.TrainingBatch(training_input=training_input, extras=extras)