def __init__(self, input_dim, output_dim, virtual_batch_size=128, momentum=0.02, device='cpu'): """ Initialize an attention transformer. Parameters ---------- - input_dim : int Input size - output_dim : int Outpu_size - momentum : float Float value between 0 and 1 which will be used for momentum in batch norm """ super(AttentiveTransformer, self).__init__() self.fc = Linear(input_dim, output_dim, bias=False) initialize_non_glu(self.fc, input_dim, output_dim) self.bn = GBN(output_dim, virtual_batch_size=virtual_batch_size, momentum=momentum, device=device) # Sparsemax self.sp_max = sparsemax.Sparsemax(dim=-1)
def __init__(self, input_dim, output_dim, virtual_batch_size=128, momentum=0.02, mask_type="sparsemax"): """ Initialize an attention transformer. Parameters ---------- - input_dim : int Input size - output_dim : int Outpu_size - momentum : float Float value between 0 and 1 which will be used for momentum in batch norm - mask_type: str Either "sparsemax" or "entmax" : this is the masking function to use """ super(AttentiveTransformer, self).__init__() self.fc = Linear(input_dim, output_dim, bias=False) initialize_non_glu(self.fc, input_dim, output_dim) self.bn = GBN(output_dim, virtual_batch_size=virtual_batch_size, momentum=momentum) if mask_type == "sparsemax": # Sparsemax self.selector = sparsemax.Sparsemax(dim=-1) elif mask_type == "entmax": # Entmax self.selector = sparsemax.Entmax15(dim=-1) else: raise NotImplementedError("Please choose either sparsemax" + "or entmax as masktype")