def __init__( self, # transformer parameters d_model=512, d_kv=64, d_ff=2048, num_layers=6, num_heads=8, pre_norm=False, use_bias=True, activation="gelu", dropout_rate=0.1, layer_norm_epsilon=1e-6, # masking parameters use_masking=False, mask_rate=0.2, # positional embedding parameters max_temporal_buckets=16, name="bert", **kwargs): super(BertEncoder, self).__init__(name=name) self.d_model = d_model # masking parameters self.use_masking = use_masking self.mask_rate = mask_rate self.pos_embedding_lookup = transformers.TemporalEmbeddings( hidden_size=self.d_model, max_temporal_buckets=max_temporal_buckets, ) # define transformer head self.tx = transformers.TransformerEncoder( d_model=d_model, d_kv=d_kv, d_ff=d_ff, num_layers=num_layers, num_heads=num_heads, pre_norm=pre_norm, use_bias=use_bias, activation=activation, dropout_rate=dropout_rate, layer_norm_epsilon=layer_norm_epsilon, name="transformer", )
def __init__( self, # patching parameters temporal_patch_size=512, temporal_patch_stride=512, random_patch_sampling=False, patch_sampling_rate=0.5, # pre-projection parameters pre_projection=False, # transformer parameters d_model=768, d_kv=64, d_ff=3072, num_layers=2, num_heads=12, pre_norm=False, use_bias=True, activation="gelu", dropout_rate=0.1, layer_norm_epsilon=1e-6, # masking parameters use_masking=True, mask_rate=0.2, # positional embedding parameters max_temporal_buckets=300, # post-projection parameters post_projection=False, d_post_proj=1024, post_proj_activation="gelu", name="au_tx_1d", **kwargs): super(AuTx1D, self).__init__(name=name) self.pre_projection = pre_projection self.post_projection = post_projection self.d_model = d_model # masking parameters self.use_masking = use_masking self.mask_rate = mask_rate # define waveform to patch module self.wave_to_patch = tf.keras.layers.Conv1D( filters=d_model, kernel_size=temporal_patch_size, strides=temporal_patch_stride, padding="valid", name="waveform_to_patch", ) if self.pre_projection: self.pre_proj = tf.keras.layers.Dense( d_model, activation=activation, name="pre_tx_projection", ) else: self.pre_proj = tf.identity self.use_random_patches = random_patch_sampling # define positional embedding module self.max_num_patches = int(patch_sampling_rate * max_temporal_buckets) assert max_temporal_buckets > self.max_num_patches, ( "Max number of positional buckets should be bigger than max" " number of input patches") self.pos_embedding_lookup = transformers.TemporalEmbeddings( hidden_size=self.d_model, max_temporal_buckets=max_temporal_buckets, ) # define transformer head self.tx = transformers.TransformerEncoder( d_model=d_model, d_kv=d_kv, d_ff=d_ff, num_layers=num_layers, num_heads=num_heads, pre_norm=pre_norm, use_bias=use_bias, activation=activation, dropout_rate=dropout_rate, layer_norm_epsilon=layer_norm_epsilon, name="transformer", ) if self.post_projection: self.post_proj = tf.keras.layers.Dense( d_post_proj, activation=post_proj_activation, name="post_tx_projection", ) else: self.post_proj = tf.identity
def __init__( self, # pre-transformer parameters vid_temporal_patch_size=4, vid_spatial_patch_size=16, aud_temporal_patch_size=128, txt_vocab_size=2**16, txt_embedding_dim=300, txt_embedding_trainable=False, # video & audio input sampling random_patch_sampling=False, patch_sampling_rate=0.5, # transformer head parameters d_model=1024, d_kv=64, d_ff=4096, num_layers=24, num_heads=16, pre_norm=True, use_bias=True, activation="gelu", dropout_rate=0.1, layer_norm_epsilon=1e-6, # positional embedding parameters max_vid_temporal_buckets=8, max_vid_spatial_buckets=14, max_aud_temporal_buckets=1200, max_txt_temporal_buckets=16, # final head parameters d_post_proj=1024, post_proj_activation="gelu", name="unified_vat_transformer", **kwargs): super(UniversalVATT, self).__init__(name=name) self.d_model = d_model # define pre-tx projection self.raw_to_embeddings = { "video": tf.keras.layers.Conv3D( filters=d_model, kernel_size=(vid_temporal_patch_size, vid_spatial_patch_size, vid_spatial_patch_size), strides=(vid_temporal_patch_size, vid_spatial_patch_size, vid_spatial_patch_size), padding="valid", name="voxel_to_patch", ), "audio": tf.keras.layers.Conv1D( filters=d_model, kernel_size=aud_temporal_patch_size, strides=aud_temporal_patch_size, padding="valid", name="waveform_to_patch", ), "text": tf.keras.layers.Embedding(txt_vocab_size, txt_embedding_dim, trainable=txt_embedding_trainable, name="text_embedding") } self.pre_proj = { "video": tf.keras.layers.Dense(d_model, activation=activation, name="video_pre_tx_projection"), "audio": tf.keras.layers.Dense(d_model, activation=activation, name="audio_pre_tx_projection"), "text": tf.keras.layers.Dense(d_model, activation=activation, name="text_pre_tx_projection"), } # define sampling-related params self.use_random_patches = random_patch_sampling self.patch_sampling_rate = patch_sampling_rate self.max_buckets = { "video": max_vid_temporal_buckets * (max_vid_spatial_buckets**2), "audio": max_aud_temporal_buckets, } self.max_num_patches = { "video": int(self.patch_sampling_rate * self.max_buckets["video"]), "audio": int(self.patch_sampling_rate * self.max_buckets["audio"]), } assert self.max_buckets["video"] > self.max_num_patches["video"], ( "Max number of video positional buckets should be bigger than max" " number of video input patches") assert self.max_buckets["audio"] > self.max_num_patches["audio"], ( "Max number of audio positional buckets should be bigger than max" " number of audio input patches") # define positional embedding module self.pos_embedding_lookup = { "video": transformers.SpatioTemporalEmbeddings( hidden_size=self.d_model, max_temporal_buckets=max_vid_temporal_buckets, max_vertical_buckets=max_vid_spatial_buckets, max_horizontal_buckets=max_vid_spatial_buckets, name="video_spatio_temporal_embeddings", ), "audio": transformers.TemporalEmbeddings( hidden_size=self.d_model, max_temporal_buckets=max_aud_temporal_buckets, name="audio_temporal_embeddings", ), "text": transformers.TemporalEmbeddings( hidden_size=self.d_model, max_temporal_buckets=max_txt_temporal_buckets, name="text_temporal_embeddings", ), } # define transformer head self.tx = transformers.TransformerEncoder( d_model=d_model, d_kv=d_kv, d_ff=d_ff, num_layers=num_layers, num_heads=num_heads, pre_norm=pre_norm, use_bias=use_bias, activation=activation, dropout_rate=dropout_rate, layer_norm_epsilon=layer_norm_epsilon, name="transformer", ) # define post-tx projection head - it could be logits or embd space self.post_proj = { "video": tf.keras.layers.Dense(d_post_proj, activation=post_proj_activation, name="video_post_tx_projection"), "audio": tf.keras.layers.Dense(d_post_proj, activation=post_proj_activation, name="audio_post_tx_projection"), "text": tf.keras.layers.Dense(d_post_proj, activation=post_proj_activation, name="text_post_tx_projection"), }