def extract(self, ws, input_record, extract_record): def fetch(b): data = ws.fetch_blob(str(b())) return torch.tensor(data) def fetch_action(b): if self.sorted_action_features is None: return fetch(b) else: return mt.FeatureVector(float_features=fetch(b)) state = mt.FeatureVector(float_features=fetch(extract_record.state)) action = fetch_action(extract_record.action) reward = fetch(input_record.reward).reshape(-1, 1) # is_terminal should be filled by preprocessor if self.max_q_learning: if self.sorted_action_features is not None: next_state = None tiled_next_state = mt.FeatureVector( float_features=fetch(extract_record.tiled_next_state)) else: next_state = mt.FeatureVector( float_features=fetch(extract_record.next_state)) tiled_next_state = None possible_next_actions = mt.PossibleActions( lengths=fetch(extract_record.possible_next_actions["lengths"]), actions=fetch_action( extract_record.possible_next_actions["values"]), ) training_input = mt.MaxQLearningInput( state=state, action=action, next_state=next_state, tiled_next_state=tiled_next_state, possible_next_actions=possible_next_actions, reward=reward, not_terminal=(possible_next_actions.lengths > 0).float().reshape(-1, 1), ) else: next_state = mt.FeatureVector( float_features=fetch(extract_record.next_state)) next_action = fetch_action(extract_record.next_action) training_input = mt.SARSAInput( state=state, action=action, next_state=next_state, next_action=next_action, reward=reward, # HACK: Need a better way to check this not_terminal=torch.ones_like(reward), ) # TODO: stuff other fields in here extras = mt.ExtraData(action_probability=fetch( input_record.action_probability).reshape(-1, 1)) return mt.TrainingBatch(training_input=training_input, extras=extras)
def as_discrete_maxq_training_batch(self): return rlt.TrainingBatch( training_input=rlt.MaxQLearningInput( state=rlt.FeatureVector(float_features=self.states), action=self.actions, next_state=rlt.FeatureVector(float_features=self.next_states), next_action=self.next_actions, tiled_next_state=None, possible_actions=None, possible_actions_mask=self.possible_actions_mask, possible_next_actions=None, possible_next_actions_mask=self.possible_next_actions_mask, reward=self.rewards, not_terminal=self.not_terminal, step=self.step, time_diff=self.time_diffs, ), extras=rlt.ExtraData( mdp_id=self.mdp_ids, sequence_number=self.sequence_numbers, action_probability=self.propensities, max_num_actions=self.max_num_actions, metrics=self.metrics, ), )
def extract(self, ws, input_record, extract_record): def fetch(b): data = ws.fetch_blob(str(b())) return torch.tensor(data) def fetch_action(b): if self.sorted_action_features is None: return fetch(b) else: return mt.FeatureVector(float_features=fetch(b)) state = mt.FeatureVector(float_features=fetch(extract_record.state)) next_state = mt.FeatureVector( float_features=fetch(extract_record.next_state)) action = fetch_action(extract_record.action) reward = fetch(input_record.reward) # is_terminal should be filled by preprocessor if self.max_q_learning: possible_next_actions = mt.PossibleActions( lengths=fetch(extract_record.possible_next_actions["lengths"]), actions=fetch_action( extract_record.possible_next_actions["values"]), ) training_input = mt.MaxQLearningInput( state=state, action=action, next_state=next_state, possible_next_actions=possible_next_actions, reward=reward, is_terminal=None, ) else: next_action = fetch_action(extract_record.next_action) training_input = mt.SARSAInput( state=state, action=action, next_state=next_state, next_action=next_action, reward=reward, is_terminal=None, ) # TODO: stuff other fields in here extras = None return mt.TrainingBatch(training_input=training_input, extras=extras)
def as_parametric_maxq_training_batch(self): state_dim = self.states.shape[1] return rlt.TrainingBatch( training_input=rlt.MaxQLearningInput( state=rlt.FeatureVector(float_features=self.states), action=rlt.FeatureVector(float_features=self.actions), next_state=None, next_action=None, tiled_next_state=rlt.FeatureVector( float_features=self. possible_next_actions_state_concat[:, :state_dim]), possible_actions=None, possible_actions_mask=self.possible_actions_mask, possible_next_actions=rlt.FeatureVector( float_features=self. possible_next_actions_state_concat[:, state_dim:]), possible_next_actions_mask=self.possible_next_actions_mask, reward=self.rewards, not_terminal=self.not_terminal, step=self.step, time_diff=self.time_diffs, ), extras=rlt.ExtraData(), )
def extract(self, ws, input_record, extract_record): def fetch(b): data = ws.fetch_blob(str(b())) return torch.tensor(data) def fetch_action(b): if self.sorted_action_features is None: return fetch(b) else: return mt.FeatureVector(float_features=fetch(b)) def fetch_possible_actions(b): if self.sorted_action_features is not None: return mt.FeatureVector(float_features=fetch(b)) else: return None state = mt.FeatureVector( float_features=fetch(extract_record.state_features)) next_state = mt.FeatureVector( float_features=fetch(extract_record.next_state_features)) action = fetch_action(extract_record.action) next_action = fetch_action(extract_record.next_action) if self.multi_steps is not None: step = fetch(input_record.step).reshape(-1, 1) else: step = None reward = fetch(input_record.reward).reshape(-1, 1) # is_terminal should be filled by preprocessor not_terminal = fetch(input_record.not_terminal).reshape(-1, 1) time_diff = fetch(input_record.time_diff).reshape(-1, 1) if self.include_possible_actions: # TODO: this will need to be more complicated to support sparse features assert self.max_num_actions is not None, "Missing max_num_actions" possible_actions_mask = fetch( extract_record.possible_actions_mask).reshape( -1, self.max_num_actions) possible_next_actions_mask = fetch( extract_record.possible_next_actions_mask).reshape( -1, self.max_num_actions) if self.sorted_action_features is not None: possible_actions = fetch_possible_actions( extract_record.possible_actions) possible_next_actions = fetch_possible_actions( extract_record.possible_next_actions) tiled_next_state = mt.FeatureVector( float_features=next_state.float_features.repeat( 1, self.max_num_actions).reshape( -1, next_state.float_features.shape[1])) else: possible_actions = None possible_next_actions = None tiled_next_state = None training_input = mt.MaxQLearningInput( state=state, action=action, next_state=next_state, tiled_next_state=tiled_next_state, possible_actions=possible_actions, possible_actions_mask=possible_actions_mask, possible_next_actions=possible_next_actions, possible_next_actions_mask=possible_next_actions_mask, next_action=next_action, reward=reward, not_terminal=not_terminal, step=step, time_diff=time_diff, ) else: training_input = mt.SARSAInput( state=state, action=action, next_state=next_state, next_action=next_action, reward=reward, not_terminal=not_terminal, step=step, time_diff=time_diff, ) # TODO: stuff other fields in here extras = mt.ExtraData(action_probability=fetch( input_record.action_probability).reshape(-1, 1)) return mt.TrainingBatch(training_input=training_input, extras=extras)
def extract(self, ws, input_record, extract_record): def fetch(b, to_torch=True): data = ws.fetch_blob(str(b())) if not isinstance(data, np.ndarray): # Blob uninitialized, return None and handle downstream return None if to_torch: return torch.tensor(data) return data def fetch_action(b): if self.sorted_action_features is None: return fetch(b) else: return mt.FeatureVector(float_features=fetch(b)) def fetch_possible_actions(b): if self.sorted_action_features is not None: return mt.FeatureVector(float_features=fetch(b)) else: return None state = mt.FeatureVector( float_features=fetch(extract_record.state_features)) next_state = mt.FeatureVector( float_features=fetch(extract_record.next_state_features)) action = fetch_action(extract_record.action) next_action = fetch_action(extract_record.next_action) max_num_actions = None step = None if self.multi_steps is not None: step = fetch(input_record.step).reshape(-1, 1) reward = fetch(input_record.reward).reshape(-1, 1) # is_terminal should be filled by preprocessor not_terminal = fetch(input_record.not_terminal).reshape(-1, 1) time_diff = fetch(input_record.time_diff).reshape(-1, 1) if self.include_possible_actions: # TODO: this will need to be more complicated to support sparse features assert self.max_num_actions is not None, "Missing max_num_actions" possible_actions_mask = (fetch( extract_record.possible_actions_mask).reshape( -1, self.max_num_actions).type(torch.FloatTensor)) possible_next_actions_mask = fetch( extract_record.possible_next_actions_mask).reshape( -1, self.max_num_actions) if self.sorted_action_features is not None: possible_actions = fetch_possible_actions( extract_record.possible_actions) possible_next_actions = fetch_possible_actions( extract_record.possible_next_actions) tiled_next_state = mt.FeatureVector( float_features=next_state.float_features.repeat( 1, self.max_num_actions).reshape( -1, next_state.float_features.shape[1])) max_num_actions = self.max_num_actions else: possible_actions = None possible_next_actions = None tiled_next_state = None training_input = mt.MaxQLearningInput( state=state, action=action, next_state=next_state, tiled_next_state=tiled_next_state, possible_actions=possible_actions, possible_actions_mask=possible_actions_mask, possible_next_actions=possible_next_actions, possible_next_actions_mask=possible_next_actions_mask, next_action=next_action, reward=reward, not_terminal=not_terminal, step=step, time_diff=time_diff, ) else: training_input = mt.SARSAInput( state=state, action=action, next_state=next_state, next_action=next_action, reward=reward, not_terminal=not_terminal, step=step, time_diff=time_diff, ) mdp_id = fetch(input_record.mdp_id, to_torch=False) sequence_number = fetch(input_record.sequence_number) metrics = fetch( extract_record.metrics) if self.metrics_to_score else None # TODO: stuff other fields in here extras = mt.ExtraData( action_probability=fetch(input_record.action_probability).reshape( -1, 1), sequence_number=sequence_number.reshape(-1, 1) if sequence_number is not None else None, mdp_id=mdp_id.reshape(-1, 1) if mdp_id is not None else None, max_num_actions=max_num_actions, metrics=metrics, ) return mt.TrainingBatch(training_input=training_input, extras=extras)