예제 #1
0
 def _content_focus(self, memory_vb):
     """
     variables needed:
         key_vb:    [batch_size x num_heads x mem_wid]
                 -> similarity key vector, to compare to each row in memory
                 -> by cosine similarity
         beta_vb:   [batch_size x num_heads x 1]
                 -> NOTE: refer here: https://github.com/deepmind/dnc/issues/9
                 -> \in (1, +inf) after oneplus(); similarity key strength
                 -> amplify or attenuate the pecision of the focus
         memory_vb: [batch_size x mem_hei   x mem_wid]
     returns:
         wc_vb:     [batch_size x num_heads x mem_hei]
                 -> the attention weight by content focus
     """
     # Key Similarity
     s = time.time()
     K_vb = batch_cosine_sim(
         self.key_vb, memory_vb)  # [batch_size x num_heads x mem_hei]
     e = time.time()
     self.key_similarity_time += e - s
     # Content-based Weighting
     s = time.time()
     self.wc_vb = K_vb * self.beta_vb.expand_as(
         K_vb)  # [batch_size x num_heads x mem_hei]
     self.wc_vb = F.softmax(self.wc_vb.transpose(0, 2)).transpose(0, 2)
     e = time.time()
     self.content_weighting_time += e - s
예제 #2
0
 def _content_focus(self, memory_vb):
     """
     variables needed:
         key_vb:    [batch_size x num_heads x mem_wid]
                 -> similarity key vector, to compare to each row in memory
                 -> by cosine similarity
         beta_vb:   [batch_size x num_heads x 1]
                 -> NOTE: refer here: https://github.com/deepmind/dnc/issues/9
                 -> \in (1, +inf) after oneplus(); similarity key strength
                 -> amplify or attenuate the pecision of the focus
         memory_vb: [batch_size x mem_hei   x mem_wid]
     returns:
         wc_vb:     [batch_size x num_heads x mem_hei]
                 -> the attention weight by content focus
     """
     self.similarities = batch_cosine_sim_pre_norm(self.key_vb, memory_vb)
     K_vb = batch_cosine_sim(
         self.key_vb, memory_vb)  # [batch_size x num_heads x mem_hei]
     self.wc_vb = K_vb * self.beta_vb.expand_as(
         K_vb)  # [batch_size x num_heads x mem_hei]
     # NOTE: modified the old version self.wc_vb = F.softmax(self.wc_vb.transpose(0, 2)).transpose(0, 2)
     self.wc_vb = F.softmax(self.wc_vb, 2)