def extract(self, ws, input_record, extract_record): def fetch(b): data = ws.fetch_blob(str(b())) return torch.tensor(data) def fetch_action(b): if self.sorted_action_features is None: return fetch(b) else: return mt.FeatureVector(float_features=fetch(b)) state = mt.FeatureVector(float_features=fetch(extract_record.state)) action = fetch_action(extract_record.action) reward = fetch(input_record.reward).reshape(-1, 1) # is_terminal should be filled by preprocessor if self.max_q_learning: if self.sorted_action_features is not None: next_state = None tiled_next_state = mt.FeatureVector( float_features=fetch(extract_record.tiled_next_state)) else: next_state = mt.FeatureVector( float_features=fetch(extract_record.next_state)) tiled_next_state = None possible_next_actions = mt.PossibleActions( lengths=fetch(extract_record.possible_next_actions["lengths"]), actions=fetch_action( extract_record.possible_next_actions["values"]), ) training_input = mt.MaxQLearningInput( state=state, action=action, next_state=next_state, tiled_next_state=tiled_next_state, possible_next_actions=possible_next_actions, reward=reward, not_terminal=(possible_next_actions.lengths > 0).float().reshape(-1, 1), ) else: next_state = mt.FeatureVector( float_features=fetch(extract_record.next_state)) next_action = fetch_action(extract_record.next_action) training_input = mt.SARSAInput( state=state, action=action, next_state=next_state, next_action=next_action, reward=reward, # HACK: Need a better way to check this not_terminal=torch.ones_like(reward), ) # TODO: stuff other fields in here extras = mt.ExtraData(action_probability=fetch( input_record.action_probability).reshape(-1, 1)) return mt.TrainingBatch(training_input=training_input, extras=extras)
def as_parametric_maxq_training_batch(self): state_dim = self.states.shape[1] return rlt.PreprocessedTrainingBatch( training_input=rlt.PreprocessedParametricDqnInput( state=rlt.PreprocessedFeatureVector(float_features=self.states), action=rlt.PreprocessedFeatureVector(float_features=self.actions), next_state=rlt.PreprocessedFeatureVector( float_features=self.next_states ), next_action=rlt.PreprocessedFeatureVector( float_features=self.next_actions ), tiled_next_state=rlt.PreprocessedFeatureVector( float_features=self.possible_next_actions_state_concat[ :, :state_dim ] ), possible_actions=None, possible_actions_mask=self.possible_actions_mask, possible_next_actions=rlt.PreprocessedFeatureVector( float_features=self.possible_next_actions_state_concat[ :, state_dim: ] ), possible_next_actions_mask=self.possible_next_actions_mask, reward=self.rewards, not_terminal=self.not_terminal, step=self.step, time_diff=self.time_diffs, ), extras=rlt.ExtraData(), )
def as_discrete_maxq_training_batch(self): return rlt.TrainingBatch( training_input=rlt.MaxQLearningInput( state=rlt.FeatureVector(float_features=self.states), action=self.actions, next_state=rlt.FeatureVector(float_features=self.next_states), next_action=self.next_actions, tiled_next_state=None, possible_actions=None, possible_actions_mask=self.possible_actions_mask, possible_next_actions=None, possible_next_actions_mask=self.possible_next_actions_mask, reward=self.rewards, not_terminal=self.not_terminal, step=self.step, time_diff=self.time_diffs, ), extras=rlt.ExtraData( mdp_id=self.mdp_ids, sequence_number=self.sequence_numbers, action_probability=self.propensities, max_num_actions=self.max_num_actions, metrics=self.metrics, ), )
def preprocess_batch(train_batch: Any) -> rlt.PreprocessedTrainingBatch: obs, action, reward, next_obs, next_action, next_reward, terminal, idxs, possible_actions_mask, log_prob = ( train_batch) obs = torch.tensor(obs).squeeze(2) action = torch.tensor(action).float() reward = torch.tensor(reward).unsqueeze(1) next_obs = torch.tensor(next_obs).squeeze(2) next_action = torch.tensor(next_action) not_terinal = 1.0 - torch.tensor(terminal).unsqueeze(1).float() idxs = torch.tensor(idxs) possible_actions_mask = torch.tensor(possible_actions_mask).float() log_prob = torch.tensor(log_prob) return rlt.PreprocessedTrainingBatch( training_input=rlt.PreprocessedPolicyNetworkInput( state=rlt.PreprocessedFeatureVector(float_features=obs), action=rlt.PreprocessedFeatureVector(float_features=action), next_state=rlt.PreprocessedFeatureVector( float_features=next_obs), next_action=rlt.PreprocessedFeatureVector( float_features=next_action), reward=reward, not_terminal=not_terinal, step=None, time_diff=None, ), extras=rlt.ExtraData(), )
def setup_extra_data(self, ws, input_record): extra_data = rlt.ExtraData( action_probability=np.array([0.11, 0.21, 0.13], dtype=np.float32) ) ws.feed_blob( str(input_record.action_probability()), extra_data.action_probability ) return extra_data
def as_parametric_sarsa_training_batch(self): return rlt.TrainingBatch( training_input=rlt.SARSAInput( state=rlt.FeatureVector(float_features=self.states), action=rlt.FeatureVector(float_features=self.actions), next_state=rlt.FeatureVector(float_features=self.next_states), next_action=rlt.FeatureVector(float_features=self.next_actions), reward=self.rewards, not_terminal=self.not_terminals, ), extras=rlt.ExtraData(), )
def as_discrete_sarsa_training_batch(self): return rlt.TrainingBatch( training_input=rlt.SARSAInput( state=rlt.FeatureVector(float_features=self.states), action=self.actions, next_state=rlt.FeatureVector(float_features=self.next_states), next_action=self.next_actions, reward=self.rewards, not_terminal=self.not_terminal, step=self.step, time_diff=self.time_diffs, ), extras=rlt.ExtraData(), )
def as_policy_network_training_batch(self): return rlt.TrainingBatch( training_input=rlt.PolicyNetworkInput( state=rlt.FeatureVector(float_features=self.states), action=rlt.FeatureVector(float_features=self.actions), next_state=rlt.FeatureVector(float_features=self.next_states), next_action=rlt.FeatureVector( float_features=self.next_actions), reward=self.rewards, not_terminal=self.not_terminal, step=self.step, time_diff=self.time_diffs, ), extras=rlt.ExtraData(), )
def preprocess(self, batch) -> rlt.RawTrainingBatch: state_features_dense, state_features_dense_presence = self.sparse_to_dense_processor( batch["state_features"] ) next_state_features_dense, next_state_features_dense_presence = self.sparse_to_dense_processor( batch["next_state_features"] ) mdp_ids = np.array(batch["mdp_id"]).reshape(-1, 1) sequence_numbers = torch.tensor( batch["sequence_number"], dtype=torch.int32 ).reshape(-1, 1) rewards = torch.tensor(batch["reward"], dtype=torch.float32).reshape(-1, 1) time_diffs = torch.tensor(batch["time_diff"], dtype=torch.int32).reshape(-1, 1) if "action_probability" in batch: propensities = torch.tensor( batch["action_probability"], dtype=torch.float32 ).reshape(-1, 1) else: propensities = torch.ones(rewards.shape, dtype=torch.float32) return rlt.RawTrainingBatch( training_input=rlt.RawBaseInput( # type: ignore state=rlt.FeatureVector( float_features=rlt.ValuePresence( value=state_features_dense, presence=state_features_dense_presence, ) ), next_state=rlt.FeatureVector( float_features=rlt.ValuePresence( value=next_state_features_dense, presence=next_state_features_dense_presence, ) ), reward=rewards, time_diff=time_diffs, step=None, not_terminal=None, ), extras=rlt.ExtraData( mdp_id=mdp_ids, sequence_number=sequence_numbers, action_probability=propensities, ), )
def as_slate_q_training_batch(self): batch_size, state_dim = self.states.shape action_dim = self.actions.shape[1] return rlt.PreprocessedTrainingBatch( training_input=rlt.PreprocessedSlateQInput( state=rlt.PreprocessedFeatureVector( float_features=self.states), next_state=rlt.PreprocessedFeatureVector( float_features=self.next_states), tiled_state=rlt.PreprocessedTiledFeatureVector( float_features=self. possible_actions_state_concat[:, :state_dim].view( batch_size, -1, state_dim)), tiled_next_state=rlt.PreprocessedTiledFeatureVector( float_features=self. possible_next_actions_state_concat[:, :state_dim].view( batch_size, -1, state_dim)), action=rlt.PreprocessedSlateFeatureVector( float_features=self. possible_actions_state_concat[:, state_dim:].view( batch_size, -1, action_dim), item_mask=self.possible_actions_mask, item_probability=self.propensities, ), next_action=rlt.PreprocessedSlateFeatureVector( float_features=self. possible_next_actions_state_concat[:, state_dim:].view( batch_size, -1, action_dim), item_mask=self.possible_next_actions_mask, item_probability=self.next_propensities, ), reward=self.rewards, reward_mask=self.rewards_mask, time_diff=self.time_diffs, step=self.step, not_terminal=self.not_terminal, ), extras=rlt.ExtraData( mdp_id=self.mdp_ids, sequence_number=self.sequence_numbers, action_probability=self.propensities, max_num_actions=self.max_num_actions, metrics=self.metrics, ), )
def as_discrete_sarsa_training_batch(self): return rlt.TrainingBatch( training_input=rlt.SARSAInput( state=rlt.FeatureVector(float_features=self.states), reward=self.rewards, time_diff=self.time_diffs, action=self.actions, next_action=self.next_actions, not_terminal=self.not_terminal, next_state=rlt.FeatureVector(float_features=self.next_states), step=self.step, ), extras=rlt.ExtraData( mdp_id=self.mdp_ids, sequence_number=self.sequence_numbers, action_probability=self.propensities, max_num_actions=self.max_num_actions, metrics=self.metrics, ), )
def as_cem_training_batch(self, batch_first=False): """ Generate one-step samples needed by CEM trainer. The samples will be used to train an ensemble of world models used by CEM. If batch_first = True: state/next state shape: batch_size x 1 x state_dim action shape: batch_size x 1 x action_dim reward/terminal shape: batch_size x 1 else (default): state/next state shape: 1 x batch_size x state_dim action shape: 1 x batch_size x action_dim reward/terminal shape: 1 x batch_size """ if batch_first: seq_len_dim = 1 reward, not_terminal = self.rewards, self.not_terminal else: seq_len_dim = 0 reward, not_terminal = transpose(self.rewards, self.not_terminal) training_input = rlt.PreprocessedMemoryNetworkInput( state=rlt.PreprocessedFeatureVector( float_features=self.states.unsqueeze(seq_len_dim)), action=self.actions.unsqueeze(seq_len_dim), next_state=rlt.PreprocessedFeatureVector( float_features=self.next_states.unsqueeze(seq_len_dim)), reward=reward, not_terminal=not_terminal, step=self.step, time_diff=self.time_diffs, ) return rlt.PreprocessedTrainingBatch( training_input=training_input, extras=rlt.ExtraData( mdp_id=self.mdp_ids, sequence_number=self.sequence_numbers, action_probability=self.propensities, max_num_actions=self.max_num_actions, metrics=self.metrics, ), )
def preprocess_batch(train_batch: Any) -> rlt.PreprocessedTrainingBatch: obs, action, reward, next_obs, next_action, next_reward, terminal, idxs, possible_actions_mask, log_prob = ( train_batch) batch_size = obs.shape[0] obs = torch.tensor(obs).squeeze(2) action = torch.tensor(action).float() next_obs = torch.tensor(next_obs).squeeze(2) next_action = torch.tensor(next_action).to(torch.float32) reward = torch.tensor(reward).unsqueeze(1) not_terminal = 1 - torch.tensor(terminal).unsqueeze(1).to(torch.uint8) possible_actions_mask = torch.ones_like(action).to(torch.bool) tiled_next_state = torch.repeat_interleave(next_obs, repeats=num_actions, axis=0) possible_next_actions = torch.eye(num_actions).repeat(batch_size, 1) possible_next_actions_mask = not_terminal.repeat(1, num_actions).to( torch.bool) return rlt.PreprocessedTrainingBatch( rlt.PreprocessedParametricDqnInput( state=rlt.PreprocessedFeatureVector(float_features=obs), action=rlt.PreprocessedFeatureVector(float_features=action), next_state=rlt.PreprocessedFeatureVector( float_features=next_obs), next_action=rlt.PreprocessedFeatureVector( float_features=next_action), possible_actions=None, possible_actions_mask=possible_actions_mask, possible_next_actions=rlt.PreprocessedFeatureVector( float_features=possible_next_actions), possible_next_actions_mask=possible_next_actions_mask, tiled_next_state=rlt.PreprocessedFeatureVector( float_features=tiled_next_state), reward=reward, not_terminal=not_terminal, step=None, time_diff=None, ), extras=rlt.ExtraData(), )
def preprocess_batch(train_batch: Any) -> rlt.PreprocessedTrainingBatch: obs, action, reward, next_obs, next_action, next_reward, terminal, idxs, possible_actions_mask, log_prob = ( train_batch) obs = torch.tensor(obs).squeeze(2) action = torch.tensor(action) reward = torch.tensor(reward).unsqueeze(1) next_obs = torch.tensor(next_obs).squeeze(2) next_action = torch.tensor(next_action) not_terminal = 1.0 - torch.tensor(terminal).unsqueeze(1).float() possible_actions_mask = torch.tensor(possible_actions_mask) next_possible_actions_mask = not_terminal.repeat(1, num_actions) log_prob = torch.tensor(log_prob) assert ( action.size(1) == num_actions ), f"action size(1) is {action.size(1)} while num_actions is {num_actions}" return rlt.PreprocessedTrainingBatch( training_input=rlt.PreprocessedDiscreteDqnInput( state=rlt.PreprocessedFeatureVector(float_features=obs), action=action, next_state=rlt.PreprocessedFeatureVector( float_features=next_obs), next_action=next_action, possible_actions_mask=possible_actions_mask, possible_next_actions_mask=next_possible_actions_mask, reward=reward, not_terminal=not_terminal, step=None, time_diff=None, ), extras=rlt.ExtraData( mdp_id=None, sequence_number=None, action_probability=log_prob.exp(), max_num_actions=None, metrics=None, ), )
def test_seq2slate_eval_data_page(self): """ Create 3 slate ranking logs and evaluate using Direct Method, Inverse Propensity Scores, and Doubly Robust. The logs are as follows: state: [1, 0, 0], [0, 1, 0], [0, 0, 1] indices in logged slates: [3, 2], [3, 2], [3, 2] model output indices: [2, 3], [3, 2], [2, 3] logged reward: 4, 5, 7 logged propensities: 0.2, 0.5, 0.4 predicted rewards on logged slates: 2, 4, 6 predicted rewards on model outputted slates: 1, 4, 5 Direct Method uses the predicted rewards on model outputted slates. Thus the result is expected to be (1 + 4 + 5) / 3 Inverse Propensity Scores would scale the reward by 1.0 / logged propensities whenever the model output slate matches with the logged slate. Since only the second log matches with the model output, the IPS result is expected to be 5 / 0.5 / 3 Doubly Robust is the sum of the direct method result and propensity-scaled reward difference; the latter is defined as: 1.0 / logged_propensities * (logged reward - predicted reward on logged slate) * Indicator(model slate == logged slate) Since only the second logged slate matches with the model outputted slate, the DR result is expected to be (1 + 4 + 5) / 3 + 1.0 / 0.5 * (5 - 4) / 3 """ batch_size = 3 state_dim = 3 src_seq_len = 2 tgt_seq_len = 2 candidate_dim = 2 reward_net = FakeSeq2SlateRewardNetwork() seq2slate_net = FakeSeq2SlateTransformerNet() baseline_net = nn.Linear(1, 1) trainer = Seq2SlateTrainer( seq2slate_net, baseline_net, parameters=None, minibatch_size=3, use_gpu=False, ) src_seq = torch.eye(candidate_dim).repeat(batch_size, 1, 1) tgt_out_idx = torch.LongTensor([[3, 2], [3, 2], [3, 2]]) tgt_out_seq = src_seq[torch.arange(batch_size). repeat_interleave(tgt_seq_len), # type: ignore tgt_out_idx.flatten() - 2, ].reshape( batch_size, tgt_seq_len, candidate_dim) ptb = rlt.PreprocessedTrainingBatch( training_input=rlt.PreprocessedRankingInput( state=rlt.PreprocessedFeatureVector( float_features=torch.eye(state_dim)), src_seq=rlt.PreprocessedFeatureVector(float_features=src_seq), tgt_out_seq=rlt.PreprocessedFeatureVector( float_features=tgt_out_seq), src_src_mask=torch.ones(batch_size, src_seq_len, src_seq_len), tgt_out_idx=tgt_out_idx, tgt_out_probs=torch.tensor([0.2, 0.5, 0.4]), slate_reward=torch.tensor([4.0, 5.0, 7.0]), ), extras=rlt.ExtraData( sequence_number=torch.tensor([0, 0, 0]), mdp_id=np.array(["0", "1", "2"]), ), ) edp = EvaluationDataPage.create_from_training_batch( ptb, trainer, reward_net) doubly_robust_estimator = DoublyRobustEstimator() direct_method, inverse_propensity, doubly_robust = doubly_robust_estimator.estimate( edp) logger.info(f"{direct_method}, {inverse_propensity}, {doubly_robust}") avg_logged_reward = (4 + 5 + 7) / 3 self.assertAlmostEqual(direct_method.raw, (1 + 4 + 5) / 3, delta=1e-6) self.assertAlmostEqual(direct_method.normalized, direct_method.raw / avg_logged_reward, delta=1e-6) self.assertAlmostEqual(inverse_propensity.raw, 5 / 0.5 / 3, delta=1e-6) self.assertAlmostEqual( inverse_propensity.normalized, inverse_propensity.raw / avg_logged_reward, delta=1e-6, ) self.assertAlmostEqual(doubly_robust.raw, direct_method.raw + 1 / 0.5 * (5 - 4) / 3, delta=1e-6) self.assertAlmostEqual(doubly_robust.normalized, doubly_robust.raw / avg_logged_reward, delta=1e-6)
def extract(self, ws, input_record, extract_record): def fetch(b): data = ws.fetch_blob(str(b())) return torch.tensor(data) def fetch_action(b): if self.sorted_action_features is None: return fetch(b) else: return mt.FeatureVector(float_features=fetch(b)) def fetch_possible_actions(b): if self.sorted_action_features is not None: return mt.FeatureVector(float_features=fetch(b)) else: return None state = mt.FeatureVector( float_features=fetch(extract_record.state_features)) next_state = mt.FeatureVector( float_features=fetch(extract_record.next_state_features)) action = fetch_action(extract_record.action) next_action = fetch_action(extract_record.next_action) if self.multi_steps is not None: step = fetch(input_record.step).reshape(-1, 1) else: step = None reward = fetch(input_record.reward).reshape(-1, 1) # is_terminal should be filled by preprocessor not_terminal = fetch(input_record.not_terminal).reshape(-1, 1) time_diff = fetch(input_record.time_diff).reshape(-1, 1) if self.include_possible_actions: # TODO: this will need to be more complicated to support sparse features assert self.max_num_actions is not None, "Missing max_num_actions" possible_actions_mask = fetch( extract_record.possible_actions_mask).reshape( -1, self.max_num_actions) possible_next_actions_mask = fetch( extract_record.possible_next_actions_mask).reshape( -1, self.max_num_actions) if self.sorted_action_features is not None: possible_actions = fetch_possible_actions( extract_record.possible_actions) possible_next_actions = fetch_possible_actions( extract_record.possible_next_actions) tiled_next_state = mt.FeatureVector( float_features=next_state.float_features.repeat( 1, self.max_num_actions).reshape( -1, next_state.float_features.shape[1])) else: possible_actions = None possible_next_actions = None tiled_next_state = None training_input = mt.MaxQLearningInput( state=state, action=action, next_state=next_state, tiled_next_state=tiled_next_state, possible_actions=possible_actions, possible_actions_mask=possible_actions_mask, possible_next_actions=possible_next_actions, possible_next_actions_mask=possible_next_actions_mask, next_action=next_action, reward=reward, not_terminal=not_terminal, step=step, time_diff=time_diff, ) else: training_input = mt.SARSAInput( state=state, action=action, next_state=next_state, next_action=next_action, reward=reward, not_terminal=not_terminal, step=step, time_diff=time_diff, ) # TODO: stuff other fields in here extras = mt.ExtraData(action_probability=fetch( input_record.action_probability).reshape(-1, 1)) return mt.TrainingBatch(training_input=training_input, extras=extras)
def extract(self, ws, input_record, extract_record): def fetch(b, to_torch=True): data = ws.fetch_blob(str(b())) if not isinstance(data, np.ndarray): # Blob uninitialized, return None and handle downstream return None if to_torch: return torch.tensor(data) return data def fetch_action(b): if self.sorted_action_features is None: return fetch(b) else: return mt.FeatureVector(float_features=fetch(b)) def fetch_possible_actions(b): if self.sorted_action_features is not None: return mt.FeatureVector(float_features=fetch(b)) else: return None state = mt.FeatureVector( float_features=fetch(extract_record.state_features)) next_state = mt.FeatureVector( float_features=fetch(extract_record.next_state_features)) action = fetch_action(extract_record.action) next_action = fetch_action(extract_record.next_action) max_num_actions = None step = None if self.multi_steps is not None: step = fetch(input_record.step).reshape(-1, 1) reward = fetch(input_record.reward).reshape(-1, 1) # is_terminal should be filled by preprocessor not_terminal = fetch(input_record.not_terminal).reshape(-1, 1) time_diff = fetch(input_record.time_diff).reshape(-1, 1) if self.include_possible_actions: # TODO: this will need to be more complicated to support sparse features assert self.max_num_actions is not None, "Missing max_num_actions" possible_actions_mask = (fetch( extract_record.possible_actions_mask).reshape( -1, self.max_num_actions).type(torch.FloatTensor)) possible_next_actions_mask = fetch( extract_record.possible_next_actions_mask).reshape( -1, self.max_num_actions) if self.sorted_action_features is not None: possible_actions = fetch_possible_actions( extract_record.possible_actions) possible_next_actions = fetch_possible_actions( extract_record.possible_next_actions) tiled_next_state = mt.FeatureVector( float_features=next_state.float_features.repeat( 1, self.max_num_actions).reshape( -1, next_state.float_features.shape[1])) max_num_actions = self.max_num_actions else: possible_actions = None possible_next_actions = None tiled_next_state = None training_input = mt.MaxQLearningInput( state=state, action=action, next_state=next_state, tiled_next_state=tiled_next_state, possible_actions=possible_actions, possible_actions_mask=possible_actions_mask, possible_next_actions=possible_next_actions, possible_next_actions_mask=possible_next_actions_mask, next_action=next_action, reward=reward, not_terminal=not_terminal, step=step, time_diff=time_diff, ) else: training_input = mt.SARSAInput( state=state, action=action, next_state=next_state, next_action=next_action, reward=reward, not_terminal=not_terminal, step=step, time_diff=time_diff, ) mdp_id = fetch(input_record.mdp_id, to_torch=False) sequence_number = fetch(input_record.sequence_number) metrics = fetch( extract_record.metrics) if self.metrics_to_score else None # TODO: stuff other fields in here extras = mt.ExtraData( action_probability=fetch(input_record.action_probability).reshape( -1, 1), sequence_number=sequence_number.reshape(-1, 1) if sequence_number is not None else None, mdp_id=mdp_id.reshape(-1, 1) if mdp_id is not None else None, max_num_actions=max_num_actions, metrics=metrics, ) return mt.TrainingBatch(training_input=training_input, extras=extras)