def __init__(self, embedding_dim, values_dim, linear_transform=False): super(DotProdAttention, self).__init__() self.embedding_dim = embedding_dim self.values_dim = values_dim self.linear_transform = linear_transform if self.linear_transform: self.key_transform = BatchLinear(self.embedding_dim, self.embedding_dim, bias=False) self.query_transform = BatchLinear(self.embedding_dim, self.embedding_dim, bias=False) self.value_transform = BatchLinear(self.values_dim, self.values_dim, bias=False)
def __init__(self, input_dim, latent_dim, output_dim): super(StandardDecoder, self).__init__() self.input_dim = input_dim self.latent_dim = latent_dim self.output_dim = output_dim post_pooling_fn = nn.Sequential( BatchLinear(self.latent_dim + self.input_dim, self.latent_dim), nn.ReLU(), BatchLinear(self.latent_dim, self.latent_dim), nn.ReLU(), BatchLinear(self.latent_dim, 2 * self.output_dim), ) self.post_pooling_fn = init_sequential_weights(post_pooling_fn) self.sigma_fn = nn.functional.softplus
def __init__(self, input_dim, latent_dim, use_attention=False): super(StandardEncoder, self).__init__() self.latent_dim = latent_dim self.input_dim = input_dim self.use_attention = use_attention pre_pooling_fn = nn.Sequential( BatchLinear(self.input_dim, self.latent_dim), nn.ReLU(), BatchLinear(self.latent_dim, self.latent_dim), nn.ReLU(), BatchLinear(self.latent_dim, self.latent_dim), ) self.pre_pooling_fn = init_sequential_weights(pre_pooling_fn) if self.use_attention: self.pooling_fn = CrossAttention() else: self.pooling_fn = MeanPooling(pooling_dim=1)
def __init__(self, embedding_dim, value_dim, num_heads): super(MultiHeadAttention, self).__init__() self.embedding_dim = embedding_dim self.num_heads = num_heads self.value_dim = value_dim self.head_size = self.embedding_dim // self.num_heads self.key_transform = BatchLinear(self.embedding_dim, self.embedding_dim, bias=False) self.query_transform = BatchLinear(self.embedding_dim, self.embedding_dim, bias=False) self.value_transform = BatchLinear(self.embedding_dim, self.embedding_dim, bias=False) self.attention = DotProdAttention(embedding_dim=self.embedding_dim, values_dim=self.embedding_dim, linear_transform=False) self.head_combine = BatchLinear(self.embedding_dim, self.embedding_dim)
def __init__(self, input_dim=1, embedding_dim=128, values_dim=128, num_heads=8): super(CrossAttention, self).__init__() self.input_dim = input_dim self.embedding_dim = embedding_dim self.values_dim = values_dim self.num_heads = num_heads self._attention = MultiHeadAttention(embedding_dim=self.embedding_dim, value_dim=self.values_dim, num_heads=self.num_heads) self.embedding = BatchMLP(in_features=self.input_dim, out_features=self.embedding_dim) # Additional modules for transformer-style computations: self.ln1 = nn.LayerNorm(self.embedding_dim) self.ln2 = nn.LayerNorm(self.embedding_dim) self.ff = BatchLinear(self.embedding_dim, self.embedding_dim)