def generate_messages(self, params: ExpertHiddenStateParams) -> AttentionMessages:
     k = self._m(torch.cat([params.hidden_state, params.expert_id], 2))
     k_expanded = expand_to_batch_size(k, self.batch_size)
     # note expert_id size must be value_size - 1
     fixed_value = torch.empty((*params.expert_id.shape[:-1], 1), device=self.device).fill_(-1.0)
     values = torch.cat([fixed_value, params.expert_id], -1)
     assert values.shape[-1] == self.value_size
     return AttentionMessages(
         k_expanded,
         values,
         k_expanded
     )
Exemple #2
0
 def generate_messages(self,
                       params: SearchAgentParams) -> AttentionMessages:
     return AttentionMessages(self._keys, self._values, self._keys)
Exemple #3
0
 def generate_messages(self, params: T_P) -> AttentionMessages:
     return AttentionMessages(None, None,
                              self.module(self._param_extractor(params)))
Exemple #4
0
 def generate_messages(self, params: T_P) -> AttentionMessages:
     return AttentionMessages(
         self.module.key(self._param_extractor(params)),
         self.module.value(self._param_extractor(params)), None)
Exemple #5
0
 def generate_messages(self, params: T_P) -> AttentionMessages:
     return AttentionMessages(
         None, None, expand_to_batch_size(self._query, self.batch_size))
Exemple #6
0
 def generate_messages(self, params: T_P) -> AttentionMessages:
     return AttentionMessages(
         expand_to_batch_size(self._key, self.batch_size),
         self.module(self._param_extractor(params)), None)
Exemple #7
0
 def generate_messages(self, params: T_P) -> AttentionMessages:
     result = self.unit.generate_messages(params)
     return AttentionMessages(result.key, result.value, None)