예제 #1
0
    def display_last_attention(self, outputs, tag='', step=None, fname=None):
        if step is None:
            step = self.global_step

        if fname is None:
            image = tight_grid(
                norm_tensor(
                    outputs['decoder_attention']
                    ['Decoder_LastBlock_CrossAttention']
                    [0]))  # dim 0 of image_batch is now number of heads
            batch_plot_path = f'{tag}_Decoder_Final_Attention'
            self.add_image(str(batch_plot_path),
                           tf.expand_dims(tf.expand_dims(image, 0), -1),
                           step=step)
        else:
            for j, file in enumerate(fname):
                image = tight_grid(
                    norm_tensor(
                        outputs['decoder_attention']
                        ['Decoder_LastBlock_CrossAttention']
                        [j]))  # dim 0 of image_batch is now number of heads
                batch_plot_path = f'{tag}_Decoder_Final_Attention/{file.numpy().decode("utf-8")}'
                self.add_image(str(batch_plot_path),
                               tf.expand_dims(tf.expand_dims(image, 0), -1),
                               step=step)
예제 #2
0
 def display_attention_heads(self,
                             outputs: dict,
                             tag='',
                             step: int = None,
                             fname: list = None):
     if step is None:
         step = self.global_step
     for layer in ['encoder_attention']:
         for k in outputs[layer].keys():
             if fname is None:
                 image = tight_grid(norm_tensor(
                     outputs[layer][k]
                     [0]))  # dim 0 of image_batch is now number of heads
                 if k == 'Decoder_LastBlock_CrossAttention':
                     batch_plot_path = f'{tag}_Decoder_Final_Attention'
                 else:
                     batch_plot_path = f'{tag}_{layer}/{k}'
                 self.add_image(str(batch_plot_path),
                                tf.expand_dims(tf.expand_dims(image, 0),
                                               -1),
                                step=step)
             else:
                 for j, file in enumerate(fname):
                     image = tight_grid(
                         norm_tensor(outputs[layer][k][j])
                     )  # dim 0 of image_batch is now number of heads
                     if k == 'Decoder_LastBlock_CrossAttention':
                         batch_plot_path = f'{tag}_Decoder_Final_Attention/{file.numpy().decode("utf-8")}'
                     else:
                         batch_plot_path = f'{tag}_{layer}/{k}/{file.numpy().decode("utf-8")}'
                     self.add_image(str(batch_plot_path),
                                    tf.expand_dims(tf.expand_dims(image, 0),
                                                   -1),
                                    step=step)
예제 #3
0
 def display_attention_heads(self, outputs, tag=''):
     for layer in ['encoder_attention', 'decoder_attention']:
         for k in outputs[layer].keys():
             image = tight_grid(norm_tensor(outputs[layer][k][0]))
             # dim 0 of image_batch is now number of heads
             batch_plot_path = f'{tag}/{layer}/{k}'
             self.add_image(str(batch_plot_path),
                            tf.expand_dims(tf.expand_dims(image, 0), -1))