def sample_memories(self, batch_size, use_gpu=False) -> rlt.PreprocessedMemoryNetworkInput: """ :param batch_size: number of samples to return :param use_gpu: whether to put samples on gpu State's shape is SEQ_LEN x BATCH_SIZE x STATE_DIM, for example. By default, MDN-RNN consumes data with SEQ_LEN as the first dimension. """ sample_indices = np.random.randint(self.memory_size, size=batch_size) device = torch.device("cuda") if use_gpu else torch.device("cpu") # state/next state shape: batch_size x seq_len x state_dim # action shape: batch_size x seq_len x action_dim # reward/not_terminal shape: batch_size x seq_len state, action, next_state, reward, not_terminal = map( lambda x: stack(x).float().to(device), zip(*self.deque_sample(sample_indices)), ) # make shapes seq_len x batch_size x feature_dim state, action, next_state, reward, not_terminal = transpose( state, action, next_state, reward, not_terminal) training_input = rlt.PreprocessedMemoryNetworkInput( state=rlt.FeatureData(float_features=state), reward=reward, time_diff=torch.ones_like(reward).float(), action=action, next_state=rlt.FeatureData(float_features=next_state), not_terminal=not_terminal, step=None, ) return training_input
def embed_state(self, state): """ Embed state after either reset() or step() """ assert len(self.recent_states) == len(self.recent_actions) old_mdnrnn_mode = self.mdnrnn.mdnrnn.training self.mdnrnn.mdnrnn.eval() # Embed the state as the hidden layer's output # until the previous step + current state if len(self.recent_states) == 0: mdnrnn_state = np.zeros((1, self.raw_state_dim)) mdnrnn_action = np.zeros((1, self.action_dim)) else: mdnrnn_state = np.array(list(self.recent_states)) mdnrnn_action = np.array(list(self.recent_actions)) mdnrnn_state = torch.tensor(mdnrnn_state, dtype=torch.float).unsqueeze(1) mdnrnn_action = torch.tensor(mdnrnn_action, dtype=torch.float).unsqueeze(1) mdnrnn_output = self.mdnrnn( rlt.FeatureData(mdnrnn_state), rlt.FeatureData(mdnrnn_action) ) hidden_embed = ( mdnrnn_output.all_steps_lstm_hidden[-1].squeeze().detach().cpu().numpy() ) state_embed = np.hstack((hidden_embed, state)) self.mdnrnn.mdnrnn.train(old_mdnrnn_mode) logger.debug( f"Embed_state\nrecent states: {np.array(self.recent_states)}\n" f"recent actions: {np.array(self.recent_actions)}\n" f"state_embed{state_embed}\n" ) return state_embed
def top_k_policy(q_network, obs: Tuple[torch.Tensor, torch.Tensor, DocumentFeature], recsim: RecSim): active_user_idxs, user_features, candidate_features = obs slate_with_null = recsim.select( candidate_features, torch.repeat_interleave(torch.arange(recsim.m).unsqueeze(dim=0), active_user_idxs.shape[0], dim=0), add_null=True, ) _user_choice, interest = recsim.compute_user_choice(slate_with_null) propensity = F.softmax(interest, dim=1)[:, :recsim.m] tiled_user_features = torch.repeat_interleave(user_features, recsim.m, dim=0) candidate_feature_vector = candidate_features.as_vector() action_dim = candidate_feature_vector.shape[2] flatten_candidate_features = candidate_feature_vector.view(-1, action_dim) q_network_input = ( rlt.FeatureData(tiled_user_features), rlt.FeatureData(flatten_candidate_features), ) q_values = q_network(*q_network_input).view(-1, recsim.m) values = q_values * propensity top_values, item_idxs = torch.topk(values, recsim.k, dim=1) return item_idxs
def test_get_Q(self): NUM_ACTION = 2 MULTI_STEPS = 3 BATCH_SIZE = 2 STATE_DIM = 4 all_permut = gen_permutations(MULTI_STEPS, NUM_ACTION) seq2reward_network = FakeSeq2RewardNetwork() batch = rlt.MemoryNetworkInput( state=rlt.FeatureData( float_features=torch.zeros(MULTI_STEPS, BATCH_SIZE, STATE_DIM) ), next_state=rlt.FeatureData( float_features=torch.zeros(MULTI_STEPS, BATCH_SIZE, STATE_DIM) ), action=rlt.FeatureData( float_features=torch.zeros(MULTI_STEPS, BATCH_SIZE, NUM_ACTION) ), reward=torch.zeros(1), time_diff=torch.zeros(1), step=torch.zeros(1), not_terminal=torch.zeros(1), ) q_values = get_Q(seq2reward_network, batch, all_permut) expected_q_values = torch.tensor([[11.0, 111.0], [11.0, 111.0]]) logger.info(f"q_values: {q_values}") assert torch.all(expected_q_values == q_values)
def __call__(self, batch: Dict[str, torch.Tensor]) -> rlt.PolicyNetworkInput: batch = batch_to_device(batch, self.device) preprocessed_state = self.state_preprocessor( batch["state_features"], batch["state_features_presence"] ) preprocessed_next_state = self.state_preprocessor( batch["next_state_features"], batch["next_state_features_presence"] ) preprocessed_action = self.action_preprocessor( batch["action"], batch["action_presence"] ) preprocessed_next_action = self.action_preprocessor( batch["next_action"], batch["next_action_presence"] ) return rlt.PolicyNetworkInput( state=rlt.FeatureData(preprocessed_state), next_state=rlt.FeatureData(preprocessed_next_state), action=rlt.FeatureData(preprocessed_action), next_action=rlt.FeatureData(preprocessed_next_action), reward=batch["reward"].unsqueeze(1), time_diff=batch["time_diff"].unsqueeze(1), step=batch["step"].unsqueeze(1), not_terminal=batch["not_terminal"].unsqueeze(1), extras=rlt.ExtraData( mdp_id=batch["mdp_id"].unsqueeze(1), sequence_number=batch["sequence_number"].unsqueeze(1), action_probability=batch["action_probability"].unsqueeze(1), ), )
def __call__(self, batch): not_terminal = 1.0 - batch.terminal.float() assert (len(batch.state.shape) == 2 ), f"{batch.state.shape} is not (batch_size, state_dim)." batch_size, _ = batch.state.shape action, next_action = one_hot_actions(self.num_actions, batch.action, batch.next_action, batch.terminal) possible_actions = get_possible_actions_for_gym( batch_size, self.num_actions) possible_next_actions = possible_actions.clone() possible_actions_mask = torch.ones((batch_size, self.num_actions)) possible_next_actions_mask = possible_actions_mask.clone() return rlt.ParametricDqnInput( state=rlt.FeatureData(float_features=batch.state), action=rlt.FeatureData(float_features=action), next_state=rlt.FeatureData(float_features=batch.next_state), next_action=rlt.FeatureData(float_features=next_action), possible_actions=possible_actions, possible_actions_mask=possible_actions_mask, possible_next_actions=possible_next_actions, possible_next_actions_mask=possible_next_actions_mask, reward=batch.reward, not_terminal=not_terminal, step=None, time_diff=None, extras=rlt.ExtraData( mdp_id=None, sequence_number=None, action_probability=batch.log_prob.exp(), max_num_actions=None, metrics=None, ), )
def test_parametric_wrapper(self): state_normalization_parameters = {i: _cont_norm() for i in range(1, 5)} action_normalization_parameters = { i: _cont_norm() for i in range(5, 9) } state_preprocessor = Preprocessor(state_normalization_parameters, False) action_preprocessor = Preprocessor(action_normalization_parameters, False) dqn = models.FullyConnectedCritic( state_dim=len(state_normalization_parameters), action_dim=len(action_normalization_parameters), sizes=[16], activations=["relu"], ) dqn_with_preprocessor = ParametricDqnWithPreprocessor( dqn, state_preprocessor=state_preprocessor, action_preprocessor=action_preprocessor, ) wrapper = ParametricDqnPredictorWrapper(dqn_with_preprocessor) input_prototype = dqn_with_preprocessor.input_prototype() output_action_names, q_value = wrapper(*input_prototype) self.assertEqual(output_action_names, ["Q"]) self.assertEqual(q_value.shape, (1, 1)) expected_output = dqn( rlt.FeatureData(state_preprocessor(*input_prototype[0])), rlt.FeatureData(action_preprocessor(*input_prototype[1])), ) self.assertTrue((expected_output == q_value).all())
def __call__(self, batch): # RB's state is (batch_size, state_dim, stack_size) whereas # we want (stack_size, batch_size, state_dim) # for scalar fields like reward and terminal, # RB returns (batch_size, stack_size), where as # we want (stack_size, batch_size) # Also convert action to float if len(batch.state.shape) == 2: # this is stack_size = 1 (i.e. we squeezed in RB) state = batch.state.unsqueeze(2) next_state = batch.next_state.unsqueeze(2) else: # shapes should be state = batch.state next_state = batch.next_state # Now shapes should be (batch_size, state_dim, stack_size) # Turn shapes into (stack_size, batch_size, feature_dim) where # feature \in {state, action}; also, make action a float permutation = [2, 0, 1] not_terminal = 1.0 - batch.terminal.transpose(0, 1).float() batch_action = batch.action if batch_action.ndim == 3: batch_action = batch_action.squeeze(1) action = F.one_hot(batch_action, self.num_actions).transpose(1, 2).float() return rlt.PreprocessedMemoryNetworkInput( state=rlt.FeatureData(state.permute(permutation)), next_state=rlt.FeatureData(next_state.permute(permutation)), action=action.permute(permutation).float(), reward=batch.reward.transpose(0, 1), not_terminal=not_terminal, step=None, time_diff=None, )
def forward(self, batch: Dict[str, torch.Tensor]) -> rlt.ParametricDqnInput: batch = batch_to_device(batch, self.device) # first preprocess state and action preprocessed_state = self.state_preprocessor( batch["state_features"], batch["state_features_presence"]) preprocessed_next_state = self.state_preprocessor( batch["next_state_features"], batch["next_state_features_presence"]) preprocessed_action = self.action_preprocessor( batch["action"], batch["action_presence"]) preprocessed_next_action = self.action_preprocessor( batch["next_action"], batch["next_action_presence"]) return rlt.ParametricDqnInput( state=rlt.FeatureData(preprocessed_state), next_state=rlt.FeatureData(preprocessed_next_state), action=rlt.FeatureData(preprocessed_action), next_action=rlt.FeatureData(preprocessed_next_action), reward=batch["reward"].unsqueeze(1), time_diff=batch["time_diff"].unsqueeze(1), step=batch["step"].unsqueeze(1), not_terminal=batch["not_terminal"].unsqueeze(1), possible_actions=batch["possible_actions"], possible_actions_mask=batch["possible_actions_mask"], possible_next_actions=batch["possible_next_actions"], possible_next_actions_mask=batch["possible_next_actions_mask"], extras=rlt.ExtraData( mdp_id=batch["mdp_id"].unsqueeze(1), sequence_number=batch["sequence_number"].unsqueeze(1), action_probability=batch["action_probability"].unsqueeze(1), ), )
def __call__(self, batch: Dict[str, torch.Tensor]) -> rlt.DiscreteDqnInput: batch = batch_to_device(batch, self.device) preprocessed_state = self.state_preprocessor( batch["state_features"], batch["state_features_presence"] ) preprocessed_next_state = self.state_preprocessor( batch["next_state_features"], batch["next_state_features_presence"] ) # not terminal iff at least one possible for next action not_terminal = batch["possible_next_actions_mask"].max(dim=1)[0].float() action = F.one_hot(batch["action"].to(torch.int64), self.num_actions) # next action can potentially have value self.num_action if not available next_action = F.one_hot( batch["next_action"].to(torch.int64), self.num_actions + 1 )[:, : self.num_actions] return rlt.DiscreteDqnInput( state=rlt.FeatureData(preprocessed_state), next_state=rlt.FeatureData(preprocessed_next_state), action=action, next_action=next_action, reward=batch["reward"].unsqueeze(1), time_diff=batch["time_diff"].unsqueeze(1), step=batch["step"].unsqueeze(1), not_terminal=not_terminal.unsqueeze(1), possible_actions_mask=batch["possible_actions_mask"], possible_next_actions_mask=batch["possible_next_actions_mask"], extras=rlt.ExtraData( mdp_id=batch["mdp_id"].unsqueeze(1), sequence_number=batch["sequence_number"].unsqueeze(1), action_probability=batch["action_probability"].unsqueeze(1), ), )
def as_slate_q_training_batch(self): batch_size, state_dim = self.states.shape action_dim = self.actions.shape[1] return rlt.PreprocessedSlateQInput( state=rlt.FeatureData(float_features=self.states), next_state=rlt.FeatureData(float_features=self.next_states), action=rlt.PreprocessedSlateFeatureVector( float_features=self.possible_actions_state_concat[:, state_dim:].view( batch_size, -1, action_dim ), item_mask=self.possible_actions_mask, item_probability=self.propensities, ), next_action=rlt.PreprocessedSlateFeatureVector( float_features=self.possible_next_actions_state_concat[ :, state_dim: ].view(batch_size, -1, action_dim), item_mask=self.possible_next_actions_mask, item_probability=self.next_propensities, ), reward=self.rewards, reward_mask=self.rewards_mask, time_diff=self.time_diffs, step=self.step, not_terminal=self.not_terminal, extras=rlt.ExtraData( mdp_id=self.mdp_ids, sequence_number=self.sequence_numbers, action_probability=self.propensities, max_num_actions=self.max_num_actions, metrics=self.metrics, ), )
def __call__(self, batch): action = batch.action if self.num_actions is not None: assert len(action.shape) == 2, f"{action.shape}" # one hot makes shape (batch_size, stack_size, feature_dim) action = F.one_hot(batch.action, self.num_actions).float() # make shape to (batch_size, feature_dim, stack_size) action = action.transpose(1, 2) # For (1-dimensional) vector fields, RB returns (batch_size, state_dim) # or (batch_size, state_dim, stack_size). # We want these to all be (stack_size, batch_size, state_dim), so # unsqueeze the former case; Note this only happens for stack_size = 1. # Then, permute. permutation = [2, 0, 1] vector_fields = { "state": batch.state, "action": action, "next_state": batch.next_state, } for name, tensor in vector_fields.items(): if len(tensor.shape) == 2: tensor = tensor.unsqueeze(2) assert len(tensor.shape) == 3, f"{name} has shape {tensor.shape}" vector_fields[name] = tensor.permute(permutation) # For scalar fields, RB returns (batch_size), or (batch_size, stack_size) # Do same as above, except transpose instead. scalar_fields = { "reward": batch.reward, "not_terminal": 1.0 - batch.terminal.float(), } for name, tensor in scalar_fields.items(): if len(tensor.shape) == 1: tensor = tensor.unsqueeze(1) assert len(tensor.shape) == 2, f"{name} has shape {tensor.shape}" scalar_fields[name] = tensor.transpose(0, 1) # stack_size > 1, so let's pad not_terminal with 1's, since # previous states couldn't have been terminal.. if scalar_fields["reward"].shape[0] > 1: batch_size = scalar_fields["reward"].shape[1] assert scalar_fields["not_terminal"].shape == ( 1, batch_size, ), f"{scalar_fields['not_terminal'].shape}" stacked_not_terminal = torch.ones_like(scalar_fields["reward"]) stacked_not_terminal[-1] = scalar_fields["not_terminal"] scalar_fields["not_terminal"] = stacked_not_terminal return rlt.MemoryNetworkInput( state=rlt.FeatureData(float_features=vector_fields["state"]), next_state=rlt.FeatureData( float_features=vector_fields["next_state"]), action=vector_fields["action"], reward=scalar_fields["reward"], not_terminal=scalar_fields["not_terminal"], step=None, time_diff=None, )
def __call__(self, batch): not_terminal = 1.0 - batch.terminal.float() action, next_action = one_hot_actions(self.num_actions, batch.action, batch.next_action, batch.terminal) if self.trainer_preprocessor is not None: state = self.trainer_preprocessor(batch.state) next_state = self.trainer_preprocessor(batch.next_state) else: state = rlt.FeatureData(float_features=batch.state) next_state = rlt.FeatureData(float_features=batch.next_state) return rlt.DiscreteDqnInput( state=state, action=action, next_state=next_state, next_action=next_action, possible_actions_mask=torch.ones_like(action).float(), possible_next_actions_mask=torch.ones_like(next_action).float(), reward=batch.reward, not_terminal=not_terminal, step=None, time_diff=None, extras=rlt.ExtraData( mdp_id=None, sequence_number=None, action_probability=batch.log_prob.exp(), max_num_actions=None, metrics=None, ), )
def get_Q(seq2reward_network, batch: rlt.MemoryNetworkInput, all_permut: torch.Tensor) -> torch.Tensor: batch_size = batch.state.float_features.shape[1] _, num_permut, num_action = all_permut.shape num_permut_per_action = int(num_permut / num_action) preprocessed_state = ( batch.state.float_features[0].unsqueeze(0).repeat_interleave( num_permut, dim=1)) state_feature_vector = rlt.FeatureData(preprocessed_state) # expand action to match the expanded state sequence action = rlt.FeatureData(all_permut.repeat(1, batch_size, 1)) acc_reward = seq2reward_network(state_feature_vector, action).acc_reward.reshape( batch_size, num_action, num_permut_per_action) # The permuations are generated with lexical order # the output has shape [num_perm, num_action,1] # that means we can aggregate on the max reward # then reshape it to (BATCH_SIZE, ACT_DIM) max_acc_reward = ( # pyre-fixme[16]: `Tuple` has no attribute `values`. torch.max(acc_reward, dim=2).values.detach().reshape(batch_size, num_action)) return max_acc_reward
def __call__(self, batch): not_terminal = 1.0 - batch.terminal.float() action = F.one_hot(batch.action, self.num_actions).squeeze(1).float() # next action is garbage for terminal transitions (so just zero them) next_action = torch.zeros_like(action) non_terminal_indices = (batch.terminal == 0).squeeze(1) next_action[non_terminal_indices] = (F.one_hot( batch.next_action[non_terminal_indices], self.num_actions).squeeze(1).float()) return rlt.DiscreteDqnInput( state=rlt.FeatureData(float_features=batch.state), action=action, next_state=rlt.FeatureData(float_features=batch.next_state), next_action=next_action, possible_actions_mask=torch.ones_like(action).float(), possible_next_actions_mask=torch.ones_like(next_action).float(), reward=batch.reward, not_terminal=not_terminal, step=None, time_diff=None, extras=rlt.ExtraData( mdp_id=None, sequence_number=None, action_probability=batch.log_prob.exp(), max_num_actions=None, metrics=None, ), )
def internal_reward_estimation(self, state, action): """ Only used by Gym """ self.reward_network.eval() reward_estimates = self.reward_network(rlt.FeatureData(state), rlt.FeatureData(action)) self.reward_network.train() return reward_estimates.cpu()
def internal_prediction(self, state, action): """ Only used by Gym """ self.q_network.eval() q_values = self.q_network(rlt.FeatureData(state), rlt.FeatureData(action)) self.q_network.train() return q_values.cpu()
def as_policy_network_training_batch(self): return rlt.PolicyNetworkInput( state=rlt.FeatureData(float_features=self.states), action=rlt.FeatureData(float_features=self.actions), next_state=rlt.FeatureData(float_features=self.next_states), next_action=rlt.FeatureData(float_features=self.next_actions), reward=self.rewards, not_terminal=self.not_terminal, step=self.step, time_diff=self.time_diffs, extras=rlt.ExtraData(), )
def forward( self, state_with_presence: Tuple[torch.Tensor, torch.Tensor], action_with_presence: Tuple[torch.Tensor, torch.Tensor], ): preprocessed_state = self.state_preprocessor(state_with_presence[0], state_with_presence[1]) preprocessed_action = self.action_preprocessor(action_with_presence[0], action_with_presence[1]) state = rlt.FeatureData(preprocessed_state) action = rlt.FeatureData(preprocessed_action) q_value = self.model(state, action) return q_value
def get_Q( self, batch: rlt.MemoryNetworkInput, batch_size: int, seq_len: int, num_action: int, ) -> torch.Tensor: if not self.view_q_value: return torch.zeros(batch_size, num_action) try: # pyre-fixme[16]: `Seq2RewardTrainer` has no attribute `all_permut`. self.all_permut except AttributeError: def gen_permutations(seq_len: int, num_action: int) -> torch.Tensor: """ generate all seq_len permutations for a given action set the return shape is (SEQ_LEN, PERM_NUM, ACTION_DIM) """ all_permut = torch.cartesian_prod(*[torch.arange(num_action)] * seq_len) all_permut = F.one_hot(all_permut, num_action).transpose(0, 1) return all_permut.float() self.all_permut = gen_permutations(seq_len, num_action) # pyre-fixme[16]: `Seq2RewardTrainer` has no attribute `num_permut`. self.num_permut = self.all_permut.size(1) preprocessed_state = batch.state.float_features.repeat( 1, self.num_permut, 1) state_feature_vector = rlt.FeatureData(preprocessed_state) # expand action to match the expanded state sequence action = self.all_permut.repeat(1, batch_size, 1) reward = self.seq2reward_network( state_feature_vector, rlt.FeatureData(action)).acc_reward.reshape( batch_size, num_action, self.num_permut // num_action) # The permuations are generated with lexical order # the output has shape [num_perm, num_action,1] # that means we can aggregate on the max reward # then reshape it to (BATCH_SIZE, ACT_DIM) max_reward = ( # pyre-fixme[16]: `Tuple` has no attribute `values`. torch.max(reward, 2).values.cpu().detach().reshape(batch_size, num_action)) return max_reward
def acc_rewards_of_one_solution( self, init_state: torch.Tensor, solution: torch.Tensor, solution_idx: int ): """ ensemble_pop_size trajectories will be sampled to evaluate a CEM solution. Each trajectory is generated by one world model :param init_state: its shape is (state_dim, ) :param solution: its shape is (plan_horizon_length, action_dim) :param solution_idx: the index of the solution :return reward: Reward of each of ensemble_pop_size trajectories """ reward_matrix = np.zeros((self.ensemble_pop_size, self.plan_horizon_length)) for i in range(self.ensemble_pop_size): state = init_state mem_net_idx = np.random.randint(0, len(self.mem_net_list)) for j in range(self.plan_horizon_length): # state shape: # (1, 1, state_dim) # action shape: # (1, 1, action_dim) ( reward, next_state, not_terminal, not_terminal_prob, ) = self.sample_reward_next_state_terminal( state=rlt.FeatureData(state.reshape((1, 1, self.state_dim))), action=rlt.FeatureData( solution[j, :].reshape((1, 1, self.action_dim)) ), mem_net=self.mem_net_list[mem_net_idx], ) reward_matrix[i, j] = reward * (self.gamma ** j) if not not_terminal: logger.debug( f"Solution {solution_idx}: predict terminal at step {j}" f" with prob. {1.0 - not_terminal_prob}" ) if not not_terminal: break state = next_state return np.sum(reward_matrix, axis=1)
def __call__(self, obs): user = torch.tensor(obs["user"]).float().unsqueeze(0) doc_obs = obs["doc"] if self.discrete_keys or self.box_keys: # Dict space discrete_features: List[torch.Tensor] = [] for k, n in self.discrete_keys: vals = torch.tensor([v[k] for v in doc_obs.values()]) assert vals.shape == (self.num_docs, ) discrete_features.append(F.one_hot(vals, n).float()) box_features: List[torch.Tensor] = [] for k, d in self.box_keys: vals = np.vstack([v[k] for v in doc_obs.values()]) assert vals.shape == (self.num_docs, d) box_features.append(torch.tensor(vals).float()) doc_features = torch.cat(discrete_features + box_features, dim=1).unsqueeze(0) else: # Simply a Box space vals = np.vstack(list(doc_obs.values())) doc_features = torch.tensor(vals).float().unsqueeze(0) # This comes from ValueWrapper value = (torch.tensor([ v["value"] for v in obs["augmentation"].values() ]).float().unsqueeze(0)) candidate_docs = rlt.DocList(float_features=doc_features, value=value) return rlt.FeatureData(float_features=user, candidate_docs=candidate_docs)
def obs_preprocessor(self, obs: np.ndarray) -> rlt.FeatureData: # pyre-fixme[16]: `Gym` has no attribute `observation_space`. obs_space = self.observation_space if isinstance(obs_space, spaces.Box): return rlt.FeatureData(torch.tensor(obs).float().unsqueeze(0)) else: raise NotImplementedError(f"{obs_space} obs space not supported for Gym.")
def get_loss(self, training_batch: rlt.MemoryNetworkInput): """ Compute losses: MSE(predicted_acc_reward, target_acc_reward) :param training_batch: training_batch has these fields: - state: (SEQ_LEN, BATCH_SIZE, STATE_DIM) torch tensor - action: (SEQ_LEN, BATCH_SIZE, ACTION_DIM) torch tensor - reward: (SEQ_LEN, BATCH_SIZE) torch tensor :returns: mse loss on reward """ seq2reward_output = self.seq2reward_network( training_batch.state, rlt.FeatureData(training_batch.action)) predicted_acc_reward = seq2reward_output.acc_reward target_rewards = training_batch.reward target_acc_reward = torch.sum(target_rewards, 0).unsqueeze(1) # make sure the prediction and target tensors have the same size # the size should both be (BATCH_SIZE, 1) in this case. assert predicted_acc_reward.size() == target_acc_reward.size() mse = F.mse_loss(predicted_acc_reward, target_acc_reward) return mse
def forward( self, state: torch.Tensor, src_seq: torch.Tensor, tgt_out_seq: torch.Tensor, src_src_mask: torch.Tensor, tgt_out_idx: torch.Tensor, ) -> torch.Tensor: return self.model( rlt.PreprocessedRankingInput( state=rlt.FeatureData(float_features=state), src_seq=rlt.FeatureData(float_features=src_seq), tgt_out_seq=rlt.FeatureData(float_features=tgt_out_seq), src_src_mask=src_src_mask, tgt_out_idx=tgt_out_idx, )).predicted_reward
def forward(self, state_with_presence: Tuple[torch.Tensor, torch.Tensor]): preprocessed_state = self.state_preprocessor( state_with_presence[0], state_with_presence[1] ) state_feature_vector = rlt.FeatureData(preprocessed_state) q_values = self.model(state_feature_vector) return q_values
def get_loss(self, training_batch: rlt.MemoryNetworkInput): """ Compute losses: MSE(predicted_acc_reward, target_acc_reward) :param training_batch: training_batch has these fields: - state: (SEQ_LEN, BATCH_SIZE, STATE_DIM) torch tensor - action: (SEQ_LEN, BATCH_SIZE, ACTION_DIM) torch tensor - reward: (SEQ_LEN, BATCH_SIZE) torch tensor :returns: mse loss on reward """ seq2reward_output = self.seq2reward_network( training_batch.state, rlt.FeatureData(training_batch.action)) predicted_acc_reward = seq2reward_output.acc_reward target_rewards = training_batch.reward seq_len, batch_size = target_rewards.size() gamma = self.params.gamma gamma_mask = torch.Tensor([[gamma**i for i in range(seq_len)] for _ in range(batch_size) ]).transpose(0, 1) target_acc_reward = torch.sum(target_rewards * gamma_mask, 0).unsqueeze(1) # make sure the prediction and target tensors have the same size # the size should both be (BATCH_SIZE, 1) in this case. assert (predicted_acc_reward.size() == target_acc_reward.size() ), f"{predicted_acc_reward.size()}!={target_acc_reward.size()}" mse = F.mse_loss(predicted_acc_reward, target_acc_reward) return mse
def test_discrete_wrapper(self): ids = range(1, 5) state_normalization_parameters = {i: _cont_norm() for i in ids} state_preprocessor = Preprocessor(state_normalization_parameters, False) action_dim = 2 dqn = models.FullyConnectedDQN( state_dim=len(state_normalization_parameters), action_dim=action_dim, sizes=[16], activations=["relu"], ) state_feature_config = rlt.ModelFeatureConfig(float_feature_infos=[ rlt.FloatFeatureInfo(feature_id=i, name=f"feat_{i}") for i in ids ]) dqn_with_preprocessor = DiscreteDqnWithPreprocessor( dqn, state_preprocessor, state_feature_config) action_names = ["L", "R"] wrapper = DiscreteDqnPredictorWrapper(dqn_with_preprocessor, action_names, state_feature_config) input_prototype = dqn_with_preprocessor.input_prototype()[0] output_action_names, q_values = wrapper(input_prototype) self.assertEqual(action_names, output_action_names) self.assertEqual(q_values.shape, (1, 2)) state_with_presence = input_prototype.float_features_with_presence expected_output = dqn( rlt.FeatureData(state_preprocessor(*state_with_presence))) self.assertTrue((expected_output == q_values).all())
def test_actor_wrapper(self): state_normalization_parameters = {i: _cont_norm() for i in range(1, 5)} action_normalization_parameters = { i: _cont_action_norm() for i in range(101, 105) } state_preprocessor = Preprocessor(state_normalization_parameters, False) postprocessor = Postprocessor(action_normalization_parameters, False) # Test with FullyConnectedActor to make behavior deterministic actor = models.FullyConnectedActor( state_dim=len(state_normalization_parameters), action_dim=len(action_normalization_parameters), sizes=[16], activations=["relu"], ) actor_with_preprocessor = ActorWithPreprocessor( actor, state_preprocessor, postprocessor) wrapper = ActorPredictorWrapper(actor_with_preprocessor) input_prototype = actor_with_preprocessor.input_prototype() action = wrapper(*input_prototype) self.assertEqual(action.shape, (1, len(action_normalization_parameters))) expected_output = postprocessor( actor(rlt.FeatureData( state_preprocessor(*input_prototype[0]))).action) self.assertTrue((expected_output == action).all())
def forward( self, state_with_presence: Tuple[torch.Tensor, torch.Tensor], candidate_with_presence_list: List[Tuple[torch.Tensor, torch.Tensor]], ): assert ( len(candidate_with_presence_list) == self.num_candidates ), f"{len(candidate_with_presence_list)} != {self.num_candidates}" preprocessed_state = self.state_preprocessor(*state_with_presence) # each is batch_size x candidate_dim, result is batch_size x num_candidates x candidate_dim preprocessed_candidates = torch.stack( [ self.candidate_preprocessor(*x) for x in candidate_with_presence_list ], dim=1, ) input = rlt.FeatureData( float_features=preprocessed_state, candidate_docs=rlt.DocList( float_features=preprocessed_candidates, mask=torch.tensor(-1), value=torch.tensor(-1), ), ) input = rlt._embed_states(input) action = self.model(input).action if self.action_postprocessor is not None: # pyre-fixme[29]: `Optional[Postprocessor]` is not a function. action = self.action_postprocessor(action) return action