Example #1
0
 def _absolute_position_to_relative_position(self, x):
     """
     x: [b, h, l, l]
     ret: [b, h, l, 2*l-1]
     """
     batch, heads, length, _ = x.size()
     # padd along column
     x = F.pad(
         x,
         commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0,
                                                             length - 1]]))
     x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
     # add 0's in the beginning that will skew the elements after reshape
     x_flat = F.pad(
         x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
     x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
     return x_final
Example #2
0
 def _same_padding(self, x):
     if self.kernel_size == 1:
         return x
     pad_l = (self.kernel_size - 1) // 2
     pad_r = self.kernel_size // 2
     padding = [[0, 0], [0, 0], [pad_l, pad_r]]
     x = F.pad(x, commons.convert_pad_shape(padding))
     return x
Example #3
0
 def _causal_padding(self, x):
     if self.kernel_size == 1:
         return x
     pad_l = self.kernel_size - 1
     pad_r = 0
     padding = [[0, 0], [0, 0], [pad_l, pad_r]]
     x = F.pad(x, commons.convert_pad_shape(padding))
     return x
Example #4
0
    def _relative_position_to_absolute_position(self, x):
        """
        x: [b, h, l, 2*l-1]
        ret: [b, h, l, l]
        """
        batch, heads, length, _ = x.size()
        # Concat columns of pad to shift from relative to absolute indexing.
        x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0,
                                                                         1]]))

        # Concat extra elements so to add up to shape (len+1, 2*len-1).
        x_flat = x.view([batch, heads, length * 2 * length])
        x_flat = F.pad(
            x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0,
                                                                length - 1]]))

        # Reshape and slice out the padded elements.
        x_final = x_flat.view([batch, heads, length + 1,
                               2 * length - 1])[:, :, :length, length - 1:]
        return x_final
Example #5
0
 def _get_relative_embeddings(self, relative_embeddings, length):
     max_relative_position = 2 * self.window_size + 1
     # Pad first before slice to avoid using cond ops.
     pad_length = max(length - (self.window_size + 1), 0)
     slice_start_position = max((self.window_size + 1) - length, 0)
     slice_end_position = slice_start_position + 2 * length - 1
     if pad_length > 0:
         padded_relative_embeddings = F.pad(
             relative_embeddings,
             commons.convert_pad_shape([[0, 0], [pad_length, pad_length],
                                        [0, 0]]))
     else:
         padded_relative_embeddings = relative_embeddings
     used_relative_embeddings = padded_relative_embeddings[:,
                                                           slice_start_position:
                                                           slice_end_position]
     return used_relative_embeddings