def embed_state(self, state): """Embed state after either reset() or step()""" assert len(self.recent_states) == len(self.recent_actions) old_mdnrnn_mode = self.mdnrnn.mdnrnn.training self.mdnrnn.mdnrnn.eval() # Embed the state as the hidden layer's output # until the previous step + current state if len(self.recent_states) == 0: mdnrnn_state = np.zeros((1, self.raw_state_dim)) mdnrnn_action = np.zeros((1, self.action_dim)) else: mdnrnn_state = np.array(list(self.recent_states)) mdnrnn_action = np.array(list(self.recent_actions)) mdnrnn_state = torch.tensor(mdnrnn_state, dtype=torch.float).unsqueeze(1) mdnrnn_action = torch.tensor(mdnrnn_action, dtype=torch.float).unsqueeze(1) mdnrnn_output = self.mdnrnn(rlt.FeatureData(mdnrnn_state), rlt.FeatureData(mdnrnn_action)) hidden_embed = (mdnrnn_output.all_steps_lstm_hidden[-1].squeeze(). detach().cpu().numpy()) state_embed = np.hstack((hidden_embed, state)) self.mdnrnn.mdnrnn.train(old_mdnrnn_mode) logger.debug( f"Embed_state\nrecent states: {np.array(self.recent_states)}\n" f"recent actions: {np.array(self.recent_actions)}\n" f"state_embed{state_embed}\n") return state_embed
def sample_memories(self, batch_size, use_gpu=False) -> rlt.MemoryNetworkInput: """ :param batch_size: number of samples to return :param use_gpu: whether to put samples on gpu State's shape is SEQ_LEN x BATCH_SIZE x STATE_DIM, for example. By default, MDN-RNN consumes data with SEQ_LEN as the first dimension. """ sample_indices = np.random.randint(self.memory_size, size=batch_size) device = torch.device("cuda") if use_gpu else torch.device("cpu") # state/next state shape: batch_size x seq_len x state_dim # action shape: batch_size x seq_len x action_dim # reward/not_terminal shape: batch_size x seq_len state, action, next_state, reward, not_terminal = map( lambda x: stack(x).float().to(device), zip(*self.deque_sample(sample_indices)), ) # make shapes seq_len x batch_size x feature_dim state, action, next_state, reward, not_terminal = transpose( state, action, next_state, reward, not_terminal ) return rlt.MemoryNetworkInput( state=rlt.FeatureData(float_features=state), reward=reward, time_diff=torch.ones_like(reward).float(), action=action, next_state=rlt.FeatureData(float_features=next_state), not_terminal=not_terminal, step=None, )
def forward(self, batch: Dict[str, torch.Tensor]) -> rlt.ParametricDqnInput: batch = batch_to_device(batch, self.device) # first preprocess state and action preprocessed_state = self.state_preprocessor( batch["state_features"], batch["state_features_presence"]) preprocessed_next_state = self.state_preprocessor( batch["next_state_features"], batch["next_state_features_presence"]) preprocessed_action = self.action_preprocessor( batch["action"], batch["action_presence"]) preprocessed_next_action = self.action_preprocessor( batch["next_action"], batch["next_action_presence"]) return rlt.ParametricDqnInput( state=rlt.FeatureData(preprocessed_state), next_state=rlt.FeatureData(preprocessed_next_state), action=rlt.FeatureData(preprocessed_action), next_action=rlt.FeatureData(preprocessed_next_action), reward=batch["reward"].unsqueeze(1), time_diff=batch["time_diff"].unsqueeze(1), step=batch["step"].unsqueeze(1), not_terminal=batch["not_terminal"].unsqueeze(1), possible_actions=batch["possible_actions"], possible_actions_mask=batch["possible_actions_mask"], possible_next_actions=batch["possible_next_actions"], possible_next_actions_mask=batch["possible_next_actions_mask"], extras=rlt.ExtraData( mdp_id=batch["mdp_id"].unsqueeze(1), sequence_number=batch["sequence_number"].unsqueeze(1), action_probability=batch["action_probability"].unsqueeze(1), ), )
def test_parametric_wrapper(self): state_normalization_parameters = {i: _cont_norm() for i in range(1, 5)} action_normalization_parameters = { i: _cont_norm() for i in range(5, 9) } state_preprocessor = Preprocessor(state_normalization_parameters, False) action_preprocessor = Preprocessor(action_normalization_parameters, False) dqn = models.FullyConnectedCritic( state_dim=len(state_normalization_parameters), action_dim=len(action_normalization_parameters), sizes=[16], activations=["relu"], ) dqn_with_preprocessor = ParametricDqnWithPreprocessor( dqn, state_preprocessor=state_preprocessor, action_preprocessor=action_preprocessor, ) wrapper = ParametricDqnPredictorWrapper(dqn_with_preprocessor) input_prototype = dqn_with_preprocessor.input_prototype() output_action_names, q_value = wrapper(*input_prototype) self.assertEqual(output_action_names, ["Q"]) self.assertEqual(q_value.shape, (1, 1)) expected_output = dqn( rlt.FeatureData(state_preprocessor(*input_prototype[0])), rlt.FeatureData(action_preprocessor(*input_prototype[1])), ) self.assertTrue((expected_output == q_value).all())
def forward(self, batch: Dict[str, torch.Tensor]) -> rlt.DiscreteDqnInput: batch = batch_to_device(batch, self.device) preprocessed_state = self.state_preprocessor( batch["state_features"], batch["state_features_presence"]) preprocessed_next_state = self.state_preprocessor( batch["next_state_features"], batch["next_state_features_presence"]) # not terminal iff at least one possible for next action not_terminal = batch["possible_next_actions_mask"].max( dim=1)[0].float() action = F.one_hot(batch["action"].to(torch.int64), self.num_actions) # next action can potentially have value self.num_action if not available next_action = F.one_hot(batch["next_action"].to(torch.int64), self.num_actions + 1)[:, :self.num_actions] return rlt.DiscreteDqnInput( state=rlt.FeatureData(preprocessed_state), next_state=rlt.FeatureData(preprocessed_next_state), action=action, next_action=next_action, reward=batch["reward"].unsqueeze(1), time_diff=batch["time_diff"].unsqueeze(1), step=batch["step"].unsqueeze(1), not_terminal=not_terminal.unsqueeze(1), possible_actions_mask=batch["possible_actions_mask"], possible_next_actions_mask=batch["possible_next_actions_mask"], extras=rlt.ExtraData( mdp_id=batch["mdp_id"].unsqueeze(1), sequence_number=batch["sequence_number"].unsqueeze(1), action_probability=batch["action_probability"].unsqueeze(1), ), )
def get_Q( seq2reward_network: Seq2RewardNetwork, cur_state: torch.Tensor, all_permut: torch.Tensor, ) -> torch.Tensor: """ Input: cur_state: the current state from where we start planning. shape: batch_size x state_dim all_permut: all action sequences (sorted in lexical order) for enumeration shape: seq_len x num_perm x action_dim """ batch_size = cur_state.shape[0] _, num_permut, num_action = all_permut.shape num_permut_per_action = int(num_permut / num_action) preprocessed_state = cur_state.unsqueeze(0).repeat_interleave(num_permut, dim=1) state_feature_vector = rlt.FeatureData(preprocessed_state) # expand action to match the expanded state sequence action = rlt.FeatureData(all_permut.repeat(1, batch_size, 1)) acc_reward = seq2reward_network(state_feature_vector, action).acc_reward.reshape( batch_size, num_action, num_permut_per_action) # The permuations are generated with lexical order # the output has shape [num_perm, num_action,1] # that means we can aggregate on the max reward # then reshape it to (BATCH_SIZE, ACT_DIM) max_acc_reward = (torch.max(acc_reward, dim=2).values.detach().reshape( batch_size, num_action)) return max_acc_reward
def __call__(self, batch): not_terminal = 1.0 - batch.terminal.float() assert (len(batch.state.shape) == 2 ), f"{batch.state.shape} is not (batch_size, state_dim)." batch_size, _ = batch.state.shape action, next_action = one_hot_actions(self.num_actions, batch.action, batch.next_action, batch.terminal) possible_actions = get_possible_actions_for_gym( batch_size, self.num_actions) possible_next_actions = possible_actions.clone() possible_actions_mask = torch.ones((batch_size, self.num_actions)) possible_next_actions_mask = possible_actions_mask.clone() return rlt.ParametricDqnInput( state=rlt.FeatureData(float_features=batch.state), action=rlt.FeatureData(float_features=action), next_state=rlt.FeatureData(float_features=batch.next_state), next_action=rlt.FeatureData(float_features=next_action), possible_actions=possible_actions, possible_actions_mask=possible_actions_mask, possible_next_actions=possible_next_actions, possible_next_actions_mask=possible_next_actions_mask, reward=batch.reward, not_terminal=not_terminal, step=None, time_diff=None, extras=rlt.ExtraData( mdp_id=None, sequence_number=None, action_probability=batch.log_prob.exp(), max_num_actions=None, metrics=None, ), )
def __call__(self, batch): action = batch.action if self.num_actions is not None: assert len(action.shape) == 2, f"{action.shape}" # one hot makes shape (batch_size, stack_size, feature_dim) action = F.one_hot(batch.action, self.num_actions).float() # make shape to (batch_size, feature_dim, stack_size) action = action.transpose(1, 2) # For (1-dimensional) vector fields, RB returns (batch_size, state_dim) # or (batch_size, state_dim, stack_size). # We want these to all be (stack_size, batch_size, state_dim), so # unsqueeze the former case; Note this only happens for stack_size = 1. # Then, permute. permutation = [2, 0, 1] vector_fields = { "state": batch.state, "action": action, "next_state": batch.next_state, } for name, tensor in vector_fields.items(): if len(tensor.shape) == 2: tensor = tensor.unsqueeze(2) assert len(tensor.shape) == 3, f"{name} has shape {tensor.shape}" vector_fields[name] = tensor.permute(permutation) # For scalar fields, RB returns (batch_size), or (batch_size, stack_size) # Do same as above, except transpose instead. scalar_fields = { "reward": batch.reward, "not_terminal": 1.0 - batch.terminal.float(), } for name, tensor in scalar_fields.items(): if len(tensor.shape) == 1: tensor = tensor.unsqueeze(1) assert len(tensor.shape) == 2, f"{name} has shape {tensor.shape}" scalar_fields[name] = tensor.transpose(0, 1) # stack_size > 1, so let's pad not_terminal with 1's, since # previous states couldn't have been terminal.. if scalar_fields["reward"].shape[0] > 1: batch_size = scalar_fields["reward"].shape[1] assert scalar_fields["not_terminal"].shape == ( 1, batch_size, ), f"{scalar_fields['not_terminal'].shape}" stacked_not_terminal = torch.ones_like(scalar_fields["reward"]) stacked_not_terminal[-1] = scalar_fields["not_terminal"] scalar_fields["not_terminal"] = stacked_not_terminal return rlt.MemoryNetworkInput( state=rlt.FeatureData(float_features=vector_fields["state"]), next_state=rlt.FeatureData(float_features=vector_fields["next_state"]), action=vector_fields["action"], reward=scalar_fields["reward"], not_terminal=scalar_fields["not_terminal"], step=None, time_diff=None, )
def forward( self, state_with_presence: Tuple[torch.Tensor, torch.Tensor], action_with_presence: Tuple[torch.Tensor, torch.Tensor], ): preprocessed_state = self.state_preprocessor(state_with_presence[0], state_with_presence[1]) preprocessed_action = self.action_preprocessor(action_with_presence[0], action_with_presence[1]) state = rlt.FeatureData(preprocessed_state) action = rlt.FeatureData(preprocessed_action) q_value = self.model(state, action) return q_value
def acc_rewards_of_one_solution( self, init_state: torch.Tensor, solution: torch.Tensor, solution_idx: int ): """ ensemble_pop_size trajectories will be sampled to evaluate a CEM solution. Each trajectory is generated by one world model :param init_state: its shape is (state_dim, ) :param solution: its shape is (plan_horizon_length, action_dim) :param solution_idx: the index of the solution :return reward: Reward of each of ensemble_pop_size trajectories """ reward_matrix = np.zeros((self.ensemble_pop_size, self.plan_horizon_length)) for i in range(self.ensemble_pop_size): state = init_state mem_net_idx = np.random.randint(0, len(self.mem_net_list)) for j in range(self.plan_horizon_length): # state shape: # (1, 1, state_dim) # action shape: # (1, 1, action_dim) ( reward, next_state, not_terminal, not_terminal_prob, ) = self.sample_reward_next_state_terminal( state=rlt.FeatureData(state.reshape((1, 1, self.state_dim))), action=rlt.FeatureData( solution[j, :].reshape((1, 1, self.action_dim)) ), mem_net=self.mem_net_list[mem_net_idx], ) reward_matrix[i, j] = reward * (self.gamma ** j) if not not_terminal: logger.debug( f"Solution {solution_idx}: predict terminal at step {j}" f" with prob. {1.0 - not_terminal_prob}" ) if not not_terminal: break state = next_state return np.sum(reward_matrix, axis=1)
def forward( self, state_with_presence: Tuple[torch.Tensor, torch.Tensor], candidate_with_presence_list: List[Tuple[torch.Tensor, torch.Tensor]], ): assert ( len(candidate_with_presence_list) == self.num_candidates ), f"{len(candidate_with_presence_list)} != {self.num_candidates}" preprocessed_state = self.state_preprocessor(*state_with_presence) # each is batch_size x candidate_dim, result is batch_size x num_candidates x candidate_dim preprocessed_candidates = torch.stack( [ self.candidate_preprocessor(*x) for x in candidate_with_presence_list ], dim=1, ) input = rlt.FeatureData( float_features=preprocessed_state, candidate_docs=rlt.DocList( float_features=preprocessed_candidates, mask=torch.tensor(-1), value=torch.tensor(-1), ), ) input = rlt._embed_states(input) action = self.model(input).action if self.action_postprocessor is not None: # pyre-fixme[29]: `Optional[Postprocessor]` is not a function. action = self.action_postprocessor(action) return action
def create_sequence_data(state_dim, action_dim, seq_len, batch_size, num_batches): SCALE = 2 weight = SCALE * torch.randn(state_dim + action_dim) data = [None for _ in range(num_batches)] for i in range(num_batches): state = SCALE * torch.randn(seq_len, batch_size, state_dim) action = SCALE * torch.randn(seq_len, batch_size, action_dim) # random valid step valid_step = torch.randint(1, seq_len + 1, (batch_size, 1)) feature_mask = torch.arange(seq_len).repeat(batch_size, 1) feature_mask = (feature_mask >= (seq_len - valid_step)).float() assert feature_mask.shape == (batch_size, seq_len), feature_mask.shape feature_mask = feature_mask.transpose(0, 1).unsqueeze(-1) assert feature_mask.shape == (seq_len, batch_size, 1), feature_mask.shape feature = torch.cat((state, action), dim=2) masked_feature = feature * feature_mask # seq_len, batch_size, state_dim + action_dim left_shifted = torch.cat( ( masked_feature.narrow(0, 1, seq_len - 1), torch.zeros(1, batch_size, state_dim + action_dim), ), dim=0, ) # seq_len, batch_size, state_dim + action_dim right_shifted = torch.cat( ( torch.zeros(1, batch_size, state_dim + action_dim), masked_feature.narrow(0, 0, seq_len - 1), ), dim=0, ) # reward_matrix shape: batch_size x seq_len reward_matrix = torch.matmul(left_shifted + right_shifted, weight).transpose(0, 1) mask = torch.arange(seq_len).repeat(batch_size, 1) mask = (mask >= (seq_len - valid_step)).float() reward = (reward_matrix * mask).sum(dim=1).reshape(-1, 1) data[i] = rlt.MemoryNetworkInput( state=rlt.FeatureData(state), action=action, valid_step=valid_step, reward=reward, # the rest fields will not be used next_state=torch.tensor([]), step=torch.tensor([]), not_terminal=torch.tensor([]), time_diff=torch.tensor([]), ) return weight, data
def __call__(self, obs): user = torch.tensor(obs["user"]).float().unsqueeze(0) doc_obs = obs["doc"] if self.discrete_keys or self.box_keys: # Dict space discrete_features: List[torch.Tensor] = [] for k, n in self.discrete_keys: vals = torch.tensor([v[k] for v in doc_obs.values()]) assert vals.shape == (self.num_docs, ) discrete_features.append(F.one_hot(vals, n).float()) box_features: List[torch.Tensor] = [] for k, d in self.box_keys: vals = np.vstack([v[k] for v in doc_obs.values()]) assert vals.shape == (self.num_docs, d) box_features.append(torch.tensor(vals).float()) doc_features = torch.cat(discrete_features + box_features, dim=1).unsqueeze(0) else: # Simply a Box space vals = np.vstack(list(doc_obs.values())) doc_features = torch.tensor(vals).float().unsqueeze(0) # This comes from ValueWrapper value = (torch.tensor([ v["value"] for v in obs["augmentation"].values() ]).float().unsqueeze(0)) candidate_docs = rlt.DocList(float_features=doc_features, value=value) return rlt.FeatureData(float_features=user, candidate_docs=candidate_docs)
def forward( self, state: torch.Tensor, src_seq: torch.Tensor, tgt_out_seq: torch.Tensor, src_src_mask: torch.Tensor, tgt_out_idx: torch.Tensor, ) -> torch.Tensor: return self.model( rlt.PreprocessedRankingInput( state=rlt.FeatureData(float_features=state), src_seq=rlt.FeatureData(float_features=src_seq), tgt_out_seq=rlt.FeatureData(float_features=tgt_out_seq), src_src_mask=src_src_mask, tgt_out_idx=tgt_out_idx, )).predicted_reward
def test_actor_wrapper(self): state_normalization_parameters = {i: _cont_norm() for i in range(1, 5)} action_normalization_parameters = { i: _cont_action_norm() for i in range(101, 105) } state_preprocessor = Preprocessor(state_normalization_parameters, False) postprocessor = Postprocessor(action_normalization_parameters, False) # Test with FullyConnectedActor to make behavior deterministic actor = models.FullyConnectedActor( state_dim=len(state_normalization_parameters), action_dim=len(action_normalization_parameters), sizes=[16], activations=["relu"], ) actor_with_preprocessor = ActorWithPreprocessor( actor, state_preprocessor, postprocessor) wrapper = ActorPredictorWrapper(actor_with_preprocessor) input_prototype = actor_with_preprocessor.input_prototype() action, _log_prob = wrapper(*input_prototype) self.assertEqual(action.shape, (1, len(action_normalization_parameters))) expected_output = postprocessor( actor(rlt.FeatureData( state_preprocessor(*input_prototype[0]))).action) self.assertTrue((expected_output == action).all())
def test_discrete_wrapper(self): ids = range(1, 5) state_normalization_parameters = {i: _cont_norm() for i in ids} state_preprocessor = Preprocessor(state_normalization_parameters, False) action_dim = 2 dqn = models.FullyConnectedDQN( state_dim=len(state_normalization_parameters), action_dim=action_dim, sizes=[16], activations=["relu"], ) state_feature_config = rlt.ModelFeatureConfig(float_feature_infos=[ rlt.FloatFeatureInfo(feature_id=i, name=f"feat_{i}") for i in ids ]) dqn_with_preprocessor = DiscreteDqnWithPreprocessor( dqn, state_preprocessor, state_feature_config) action_names = ["L", "R"] wrapper = DiscreteDqnPredictorWrapper(dqn_with_preprocessor, action_names, state_feature_config) input_prototype = dqn_with_preprocessor.input_prototype()[0] output_action_names, q_values = wrapper(input_prototype) self.assertEqual(action_names, output_action_names) self.assertEqual(q_values.shape, (1, 2)) state_with_presence = input_prototype.float_features_with_presence expected_output = dqn( rlt.FeatureData(state_preprocessor(*state_with_presence))) self.assertTrue((expected_output == q_values).all())
def forward(self, training_batch: rlt.MemoryNetworkInput): # state shape: seq_len, batch_size, state_dim state = training_batch.state # action shape: seq_len, batch_size, action_dim action = rlt.FeatureData(float_features=training_batch.action) # shape: seq_len, batch_size, state_dim + action_dim cat_input = torch.cat((state.float_features, action.float_features), dim=-1) # shape: seq_len, batch_size, (state_dim + action_dim) * context_size ngram = self._ngram(cat_input) # shape: batch_size, 1 valid_step = training_batch.valid_step seq_len, batch_size, _ = training_batch.action.shape # output shape: batch_size, seq_len output = self.fc(ngram).squeeze(2).transpose(0, 1) assert valid_step is not None mask = _gen_mask(valid_step, batch_size, seq_len) output_masked = output * mask pred_reward = output_masked.sum(dim=1, keepdim=True) return rlt.RewardNetworkOutput(predicted_reward=pred_reward)
def create_data(state_dim, action_dim, seq_len, batch_size, num_batches): SCALE = 2 weight = SCALE * torch.randn(state_dim + action_dim) data = [None for _ in range(num_batches)] for i in range(num_batches): state = SCALE * torch.randn(seq_len, batch_size, state_dim) action = SCALE * torch.randn(seq_len, batch_size, action_dim) # random valid step valid_step = torch.randint(1, seq_len + 1, (batch_size, 1)) # reward_matrix shape: batch_size x seq_len reward_matrix = torch.matmul( torch.cat((state, action), dim=2), weight ).transpose(0, 1) mask = torch.arange(seq_len).repeat(batch_size, 1) mask = (mask >= (seq_len - valid_step)).float() reward = (reward_matrix * mask).sum(dim=1).reshape(-1, 1) data[i] = rlt.MemoryNetworkInput( state=rlt.FeatureData(state), action=action, valid_step=valid_step, reward=reward, # the rest fields will not be used next_state=torch.tensor([]), step=torch.tensor([]), not_terminal=torch.tensor([]), time_diff=torch.tensor([]), ) return weight, data
def test_forward_pass(self): torch.manual_seed(123) state_dim = 1 action_dim = 2 state = rlt.FeatureData(torch.tensor([[2.0]])) bcq_drop_threshold = 0.20 q_network = FullyConnectedDQN(state_dim, action_dim, sizes=[2], activations=["relu"]) init.constant_(q_network.fc.dnn[-2].bias, 3.0) imitator_network = FullyConnectedNetwork( layers=[state_dim, 2, action_dim], activations=["relu", "linear"]) imitator_probs = torch.nn.functional.softmax(imitator_network( state.float_features), dim=1) bcq_mask = imitator_probs < bcq_drop_threshold npt.assert_array_equal(bcq_mask.detach(), [[True, False]]) model = BatchConstrainedDQN( state_dim=state_dim, q_network=q_network, imitator_network=imitator_network, bcq_drop_threshold=bcq_drop_threshold, ) final_q_values = model(state) npt.assert_array_equal(final_q_values.detach(), [[-1e10, 3.0]])
def test_discrete_wrapper_with_id_list(self): state_normalization_parameters = {i: _cont_norm() for i in range(1, 5)} state_preprocessor = Preprocessor(state_normalization_parameters, False) action_dim = 2 state_feature_config = rlt.ModelFeatureConfig( float_feature_infos=[ rlt.FloatFeatureInfo(name=str(i), feature_id=i) for i in range(1, 5) ], id_list_feature_configs=[ rlt.IdListFeatureConfig(name="A", feature_id=10, id_mapping_name="A_mapping") ], id_mapping_config={"A_mapping": rlt.IdMapping(ids=[0, 1, 2])}, ) embedding_concat = models.EmbeddingBagConcat( state_dim=len(state_normalization_parameters), model_feature_config=state_feature_config, embedding_dim=8, ) dqn = models.Sequential( embedding_concat, rlt.TensorFeatureData(), models.FullyConnectedDQN( embedding_concat.output_dim, action_dim=action_dim, sizes=[16], activations=["relu"], ), ) dqn_with_preprocessor = DiscreteDqnWithPreprocessor( dqn, state_preprocessor, state_feature_config) action_names = ["L", "R"] wrapper = DiscreteDqnPredictorWrapper(dqn_with_preprocessor, action_names, state_feature_config) input_prototype = dqn_with_preprocessor.input_prototype()[0] output_action_names, q_values = wrapper(input_prototype) self.assertEqual(action_names, output_action_names) self.assertEqual(q_values.shape, (1, 2)) feature_id_to_name = { config.feature_id: config.name for config in state_feature_config.id_list_feature_configs } state_id_list_features = { feature_id_to_name[k]: v for k, v in input_prototype.id_list_features.items() } state_with_presence = input_prototype.float_features_with_presence expected_output = dqn( rlt.FeatureData( float_features=state_preprocessor(*state_with_presence), id_list_features=state_id_list_features, )) self.assertTrue((expected_output == q_values).all())
def _stack(slates): obs = rlt.FeatureData( float_features=torch.from_numpy( np.stack(np.array([slate["user"] for slate in slates]))), candidate_docs=rlt.DocList(float_features=torch.from_numpy( np.stack(np.array([slate["doc"] for slate in slates])))), ) return obs
def _get_values( self, state_action: Tuple[rlt.FeatureData, rlt.FeatureData] ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: state, action = state_action shared_state = rlt.FeatureData(self.shared_network(state)) value = self.value_network(shared_state) advantage = self.advantage_network(shared_state, action) q_value = value + advantage return advantage, value, q_value
def __call__(self, batch): not_terminal = 1.0 - batch.terminal.float() action, next_action = one_hot_actions(self.num_actions, batch.action, batch.next_action, batch.terminal) if self.trainer_preprocessor is not None: state = self.trainer_preprocessor(batch.state) next_state = self.trainer_preprocessor(batch.next_state) else: state = rlt.FeatureData(float_features=batch.state) next_state = rlt.FeatureData(float_features=batch.next_state) try: possible_actions_mask = batch.possible_actions_mask.float() except AttributeError: possible_actions_mask = torch.ones_like(action).float() try: possible_next_actions_mask = batch.next_possible_actions_mask.float( ) except AttributeError: possible_next_actions_mask = torch.ones_like(next_action).float() return rlt.DiscreteDqnInput( state=state, action=action, next_state=next_state, next_action=next_action, possible_actions_mask=possible_actions_mask, possible_next_actions_mask=possible_next_actions_mask, reward=batch.reward, not_terminal=not_terminal, step=None, time_diff=None, extras=rlt.ExtraData( mdp_id=None, sequence_number=None, action_probability=batch.log_prob.exp(), max_num_actions=None, metrics=None, ), )
def get_possible_actions_for_gym(batch_size: int, num_actions: int) -> rlt.FeatureData: """ tiled_actions should be (batch_size * num_actions, num_actions) forall i in [batch_size], tiled_actions[i*num_actions:(i+1)*num_actions] should be I[num_actions] where I[n] is the n-dimensional identity matrix. NOTE: this is only the case for when we convert discrete action to parametric action via one-hot encoding. """ possible_actions = torch.eye(num_actions).repeat(repeats=(batch_size, 1)) return rlt.FeatureData(float_features=possible_actions)
def forward(self, obs): if self.log_transform: obs = rlt.FeatureData( float_features=obs.float_features.clip(EPS).log(), candidate_docs=rlt.DocList( float_features=obs.candidate_docs.float_features.clip(EPS).log(), ), ) mlp_input = self._concat_features(obs) scores = self.mlp(mlp_input) return scores.squeeze(-1)
def input_prototype(self): # Sample config for input batch_size = 2 state_dim = 5 num_docs = 3 candidate_dim = 4 return rlt.FeatureData( float_features=torch.randn((batch_size, state_dim)), candidate_docs=rlt.DocList(float_features=torch.randn( batch_size, num_docs, candidate_dim)), )
def _get_values( self, state: rlt.FeatureData ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: shared_state = rlt.FeatureData(self.shared_network(state)) value = self.value_network(shared_state) raw_advantage = self.advantage_network(shared_state) reduce_over = tuple(range(1, raw_advantage.dim())) advantage = raw_advantage - raw_advantage.mean(dim=reduce_over, keepdim=True) q_value = value + advantage return value, raw_advantage, advantage, q_value
def serving_to_feature_data( serving: rlt.ServingFeatureData, dense_preprocessor: Preprocessor, sparse_preprocessor: SparsePreprocessor, ) -> rlt.FeatureData: float_features_with_presence, id_list_features, id_score_list_features = serving return rlt.FeatureData( float_features=dense_preprocessor(*float_features_with_presence), id_list_features=sparse_preprocessor.preprocess_id_list( id_list_features), id_score_list_features=sparse_preprocessor.preprocess_id_score_list( id_score_list_features), )
def __call__(self, trajectory: Trajectory): action = torch.from_numpy(np.stack(trajectory.action).squeeze()) if self.num_actions is not None: action = F.one_hot(action, self.num_actions).float() assert len(action.shape) == 2, f"{action.shape}" # one hot makes shape (batch_size, num_actions) state = (self._get_recsim_state( trajectory.observation) if self.recsim_obs else rlt.FeatureData( torch.from_numpy(np.stack(trajectory.observation)).float())) return rlt.PolicyGradientInput( state=state, action=action, reward=torch.tensor(trajectory.reward), log_prob=torch.tensor(trajectory.log_prob), )
def forward(self, state_vp, candidate_vp): batch_size, num_candidates, candidate_dim = candidate_vp[0].shape state_feats = self.state_preprocessor(*state_vp) candidate_feats = self.candidate_preprocessor( candidate_vp[0].view( batch_size * num_candidates, len(self.candidate_preprocessor.sorted_features), ), candidate_vp[1].view( batch_size * num_candidates, len(self.candidate_preprocessor.sorted_features), ), ).view(batch_size, num_candidates, -1) input = rlt.FeatureData(float_features=state_feats, candidate_docs=rlt.DocList(candidate_feats)) scores = self.mlp(input).view(batch_size, num_candidates) return scores