Beispiel #1
0
    def __init__(self, config):
        super(TransformerDecoder, self).__init__()
        self.config = config

        self.d_model = config["d_model"]
        self.nhead = config["nhead"]
        self.num_layers = config["num_layers"]
        self.encoder_dim = config["encoder_dim"]
        self.dim_feedforward = config["dim_feedforward"]
        self.vocab_size = config["vocab_size"]
        self.dropout_rate = config["dropout_rate"]
        self.activation = config["activation"]

        self.emb = nn.Embedding(self.vocab_size, self.d_model)
        self.emb_scale = self.d_model**0.5
        self.pe = modules.PositionalEncoding(self.d_model)
        self.dropout = nn.Dropout(self.dropout_rate)

        transformer_decoder_layer = transformer.TransformerDecoderLayer(
            d_model=self.d_model,
            nhead=self.nhead,
            dim_feedforward=self.dim_feedforward,
            dropout=self.dropout_rate,
            activation=self.activation)
        self.transformer_block = transformer.TransformerDecoder(
            transformer_decoder_layer, self.num_layers)

        self.output_affine = nn.Linear(self.d_model, self.vocab_size)
        nn.init.xavier_normal_(self.output_affine.weight)
        self.emb.weight = self.output_affine.weight  # tying weight
Beispiel #2
0
    def __init__(self, config):
        super(Transformer, self).__init__()
        self.config = config     

        self.input_dim = config["input_dim"]
        self.d_model = config["d_model"]
        self.nhead = config["nhead"]
        self.dim_feedforward = config["dim_feedforward"]
        self.num_layers = config["num_layers"]
        self.dropout_rate = config["dropout_rate"]
        self.activation = config["activation"]
        self.subconf = config["sub"]
        if self.subconf["type"] == "ConvV1":
            self.sub = Conv2dSubsample(self.input_dim, self.d_model) 
        elif self.subconf["type"] == "ConvV2":
            self.sub = Conv2dSubsampleV2(self.input_dim, self.d_model, self.subconf["layer_num"]) 
        elif self.subconf["type"] == "Stack":
            self.context_width = config["context_width"]
            self.subsample = config["subsample"]
            self.sub = Conv1dSubsample(self.input_dim, self.d_model, self.context_width, self.subsample)
        
        self.scale = self.d_model ** 0.5

        self.pe = modules.PositionalEncoding(self.d_model)
        self.dropout = nn.Dropout(self.dropout_rate)
        encoder_norm = LayerNorm(self.d_model)
        encoder_layer = transformer.TransformerEncoderLayer(d_model=self.d_model, 
                nhead=self.nhead, dim_feedforward=self.dim_feedforward, 
                dropout=self.dropout_rate, activation=self.activation)
        self.transformer_encoder = transformer.TransformerEncoder(encoder_layer, self.num_layers, encoder_norm)
Beispiel #3
0
    def __init__(self, config):
        super(TransformerLM, self).__init__()
        self.config = config

        self.vocab_size = config["vocab_size"]
        self.d_model = config["d_model"]
        self.nhead = config["nhead"]
        self.num_layers = config["num_layers"]
        self.dim_feedforward = config["dim_feedforward"]
        self.activation = config["activation"]
        self.dropout_rate = config["dropout_rate"]

        self.dropout = nn.Dropout(self.dropout_rate)
        self.scale = self.d_model**0.5
        self.pe = modules.PositionalEncoding(self.d_model)
        self.emb = nn.Embedding(self.vocab_size, self.d_model)
        encoder_layer = transformer.TransformerEncoderLayer(
            d_model=self.d_model,
            nhead=self.nhead,
            dim_feedforward=self.dim_feedforward,
            dropout=self.dropout_rate,
            activation=self.activation)
        self.transformer_encoder = transformer.TransformerEncoder(
            encoder_layer, self.num_layers)
        self.output_affine = nn.Linear(self.d_model,
                                       self.vocab_size,
                                       bias=False)
        self.emb.weight = self.output_affine.weight
Beispiel #4
0
    def __init__(self, in_features, out_features, num_conv_layers):
        super().__init__()

        self.conv = Conv2dEncoder(in_features, out_features, num_conv_layers)
        self.encoding = modules.PositionalEncoding()
        self.dropout = nn.Dropout(0.1)
        self.self_attention = attention.QKVDotProductAttention(out_features)
Beispiel #5
0
    def __init__(self, features, vocab_size):
        super().__init__()

        self.embedding = nn.Embedding(vocab_size, features, padding_idx=0)
        self.encoding = modules.PositionalEncoding()
        self.dropout = nn.Dropout(0.1)
        self.self_attention = attention.QKVDotProductAttention(features)
        self.attention = attention.QKVDotProductAttention(features)
        self.output = nn.Linear(features, vocab_size)
Beispiel #6
0
 def __init__(
     self,
     *,
     # Embeddings
     capacity=128,
     dropout=0.1,
     num_health_history_features=13,
     health_history_embedding_dim=64,
     num_health_profile_features=14,
     health_profile_embedding_dim=32,
     time_embedding_dim=32,
     encounter_duration_embedding_dim=32,
     encounter_duration_embedding_mode="thermo",
     encounter_duration_thermo_range=(0.0, 6.0),
     encounter_duration_num_thermo_bins=32,
     num_encounter_partner_id_bits=16,
     encounter_partner_id_embedding_dim=32,
     message_dim=8,
     message_embedding_dim=128,
     # Attention
     num_heads=4,
     sab_capacity=128,
     num_sabs=2,
     # Meta config
     pool_latent_entities=False,
     use_logit_sink=False,
     # Output
     encounter_output_features=1,
     latent_variable_output_features=1,
 ):
     # ------- Embeddings -------
     health_history_embedding = mods.HealthHistoryEmbedding(
         in_features=num_health_history_features,
         embedding_size=health_history_embedding_dim,
         capacity=capacity,
         dropout=dropout,
     )
     health_profile_embedding = mods.HealthProfileEmbedding(
         in_features=num_health_profile_features,
         embedding_size=health_profile_embedding_dim,
         capacity=capacity,
         dropout=dropout,
     )
     time_embedding = mods.PositionalEncoding(encoding_dim=time_embedding_dim)
     if encounter_duration_embedding_mode == "thermo":
         duration_embedding = mods.DurationEmbedding(
             num_thermo_bins=encounter_duration_num_thermo_bins,
             embedding_size=encounter_duration_embedding_dim,
             thermo_range=encounter_duration_thermo_range,
             capacity=capacity,
             dropout=dropout,
         )
     elif encounter_duration_embedding_mode == "sines":
         duration_embedding = mods.PositionalEncoding(
             encoding_dim=encounter_duration_embedding_dim
         )
     else:
         raise ValueError
     partner_id_embedding = mods.PartnerIdEmbedding(
         num_id_bits=num_encounter_partner_id_bits,
         embedding_size=encounter_partner_id_embedding_dim,
     )
     message_embedding = mods.MessageEmbedding(
         message_dim=message_dim,
         embedding_size=message_embedding_dim,
         capacity=capacity,
         dropout=dropout,
     )
     # ------- Attention -------
     sab_in_dim = (
         time_embedding_dim
         + encounter_partner_id_embedding_dim
         + encounter_duration_embedding_dim
         + health_history_embedding_dim
         + message_embedding_dim
         + health_profile_embedding_dim
     )
     sab_metadata_dim = (
         time_embedding_dim
         + encounter_partner_id_embedding_dim
         + encounter_duration_embedding_dim
     )
     sab_intermediate_in_dim = sab_capacity + sab_metadata_dim
     # Build the SABs
     if num_sabs >= 1:
         self_attention_blocks = [
             attn.SAB(dim_in=sab_in_dim, dim_out=sab_capacity, num_heads=num_heads)
         ]
     else:
         # This is a special code-path where we don't use any self-attention,
         # but just a plain-old MLP (as a baseline).
         self_attention_blocks = [
             nn.Sequential(
                 nn.Linear(sab_in_dim, sab_capacity),
                 nn.ReLU(),
                 nn.Linear(sab_capacity, sab_capacity),
                 nn.ReLU(),
                 nn.Linear(sab_capacity, sab_capacity),
                 nn.ReLU(),
             )
         ]
     for sab_idx in range(num_sabs - 1):
         self_attention_blocks.append(
             attn.SAB(
                 dim_in=sab_intermediate_in_dim,
                 dim_out=sab_capacity,
                 num_heads=num_heads,
             )
         )
     self_attention_blocks = nn.ModuleList(self_attention_blocks)
     # Build the entity poolers
     if pool_latent_entities:
         self_latent_variable_pooler = attn.PMA(
             dim=sab_capacity + sab_metadata_dim, num_seeds=1, num_heads=num_heads
         )
     else:
         self_latent_variable_pooler = None
     if use_logit_sink:
         encounter_logit_sink_pooler = attn.PMA(
             dim=sab_capacity + sab_metadata_dim, num_seeds=1, num_heads=num_heads
         )
     else:
         encounter_logit_sink_pooler = None
     # ------- Output processors -------
     # Encounter
     if use_logit_sink:
         logit_sink_mlp = nn.Linear(sab_capacity + sab_metadata_dim, sab_capacity)
     else:
         logit_sink_mlp = None
     encounter_mlp = nn.Sequential(
         nn.Linear(sab_capacity, capacity),
         nn.ReLU(),
         nn.Linear(capacity, encounter_output_features),
     )
     # Latent variables
     latent_variable_mlp = nn.Sequential(
         nn.Linear(sab_capacity + sab_metadata_dim, capacity),
         nn.ReLU(),
         nn.Linear(capacity, latent_variable_output_features),
     )
     # ------- Output placeholders -------
     # noinspection PyArgumentList
     message_placeholder = nn.Parameter(torch.randn((message_embedding_dim,)))
     # noinspection PyArgumentList
     partner_id_placeholder = nn.Parameter(
         torch.randn((encounter_partner_id_embedding_dim,))
     )
     # noinspection PyArgumentList
     duration_placeholder = nn.Parameter(
         torch.randn((encounter_duration_embedding_dim,))
     )
     # ------- Masking -------
     entity_masker = mods.EntityMasker()
     # Done; init the super
     super(ContactTracingTransformer, self).__init__(
         health_history_embedding=health_history_embedding,
         health_profile_embedding=health_profile_embedding,
         time_embedding=time_embedding,
         duration_embedding=duration_embedding,
         partner_id_embedding=partner_id_embedding,
         message_embedding=message_embedding,
         self_attention_blocks=self_attention_blocks,
         self_latent_variable_pooler=self_latent_variable_pooler,
         latent_variable_mlp=latent_variable_mlp,
         encounter_logit_sink_pooler=encounter_logit_sink_pooler,
         logit_sink_mlp=logit_sink_mlp,
         encounter_mlp=encounter_mlp,
         entity_masker=entity_masker,
         message_placeholder=message_placeholder,
         partner_id_placeholder=partner_id_placeholder,
         duration_placeholder=duration_placeholder,
     )