def generate_messages(
         self, params: ExpertHiddenStateParams) -> AttentionMessages:
     # [batch_size, n_keys, key_size]
     z = torch.cat([params.hidden_state, params.expert_id], 2)
     keys = self._m(z)
     # note expert_id size must be value_size - 1
     fixed_value = torch.empty((*params.expert_id.shape[:-1], 1),
                               device=self.device).fill_(-1.0)
     values = torch.cat([fixed_value, self._value_model(z)], -1)
     # values = torch.cat([fixed_value, params.expert_id], -1)
     queries = self._enhance_queries(keys)
     assert values.shape[-1] == self.value_size
     return AttentionMessages(keys, values, queries)
Example #2
0
 def generate_messages(self, params: SearchAgentParams) -> AttentionMessages:
     # fixed_value = torch.empty((*self._last_result.shape[:-1], 1), device=self.device).fill_(1.0)
     # values = torch.cat([fixed_value, self._value_model(self._last_result)], -1)
     # return AttentionMessages(self._keys, values, self._keys)
     return AttentionMessages(self._keys, self._values, self._keys)
Example #3
0
 def generate_messages(self, params: T_P) -> AttentionMessages:
     return AttentionMessages(None, None, self.module(self._param_extractor(params)))
Example #4
0
 def generate_messages(self, params: T_P) -> AttentionMessages:
     return AttentionMessages(
         self.module.key(self._param_extractor(params)),
         self.module.value(self._param_extractor(params)),
         None)
Example #5
0
 def generate_messages(self, params: T_P) -> AttentionMessages:
     return AttentionMessages(None, None, expand_to_batch_size(self._query, self.batch_size))
Example #6
0
 def generate_messages(self, params: T_P) -> AttentionMessages:
     return AttentionMessages(
         expand_to_batch_size(self._key, self.batch_size),
         self.module(self._param_extractor(params)),
         None)
Example #7
0
 def generate_messages(self, params: T_P) -> AttentionMessages:
     result = self.unit.generate_messages(params)
     return AttentionMessages(result.key, result.value, None)