Ejemplo n.º 1
0
    def __init__(self,
                 input_dim,
                 output_dim,
                 virtual_batch_size=128,
                 momentum=0.02,
                 device='cpu'):
        """
        Initialize an attention transformer.

        Parameters
        ----------
        - input_dim : int
            Input size
        - output_dim : int
            Outpu_size
        - momentum : float
            Float value between 0 and 1 which will be used for momentum in batch norm
        """
        super(AttentiveTransformer, self).__init__()
        self.fc = Linear(input_dim, output_dim, bias=False)
        initialize_non_glu(self.fc, input_dim, output_dim)
        self.bn = GBN(output_dim,
                      virtual_batch_size=virtual_batch_size,
                      momentum=momentum,
                      device=device)

        # Sparsemax
        self.sp_max = sparsemax.Sparsemax(dim=-1)
Ejemplo n.º 2
0
    def __init__(self, input_dim, output_dim,
                 virtual_batch_size=128,
                 momentum=0.02,
                 mask_type="sparsemax"):
        """
        Initialize an attention transformer.
        Parameters
        ----------
        - input_dim : int
            Input size
        - output_dim : int
            Outpu_size
        - momentum : float
            Float value between 0 and 1 which will be used for momentum in batch norm
        - mask_type: str
            Either "sparsemax" or "entmax" : this is the masking function to use
        """
        super(AttentiveTransformer, self).__init__()
        self.fc = Linear(input_dim, output_dim, bias=False)
        initialize_non_glu(self.fc, input_dim, output_dim)
        self.bn = GBN(output_dim, virtual_batch_size=virtual_batch_size,
                      momentum=momentum)

        if mask_type == "sparsemax":
            # Sparsemax
            self.selector = sparsemax.Sparsemax(dim=-1)
        elif mask_type == "entmax":
            # Entmax
            self.selector = sparsemax.Entmax15(dim=-1)
        else:
            raise NotImplementedError("Please choose either sparsemax" +
                                      "or entmax as masktype")