예제 #1
0
    def __init__(
            self,
            # transformer parameters
            d_model=512,
            d_kv=64,
            d_ff=2048,
            num_layers=6,
            num_heads=8,
            pre_norm=False,
            use_bias=True,
            activation="gelu",
            dropout_rate=0.1,
            layer_norm_epsilon=1e-6,
            # masking parameters
            use_masking=False,
            mask_rate=0.2,
            # positional embedding parameters
            max_temporal_buckets=16,
            name="bert",
            **kwargs):
        super(BertEncoder, self).__init__(name=name)
        self.d_model = d_model
        # masking parameters
        self.use_masking = use_masking
        self.mask_rate = mask_rate

        self.pos_embedding_lookup = transformers.TemporalEmbeddings(
            hidden_size=self.d_model,
            max_temporal_buckets=max_temporal_buckets,
        )

        # define transformer head
        self.tx = transformers.TransformerEncoder(
            d_model=d_model,
            d_kv=d_kv,
            d_ff=d_ff,
            num_layers=num_layers,
            num_heads=num_heads,
            pre_norm=pre_norm,
            use_bias=use_bias,
            activation=activation,
            dropout_rate=dropout_rate,
            layer_norm_epsilon=layer_norm_epsilon,
            name="transformer",
        )
예제 #2
0
    def __init__(
            self,
            # patching parameters
            temporal_patch_size=512,
            temporal_patch_stride=512,
            random_patch_sampling=False,
            patch_sampling_rate=0.5,
            # pre-projection parameters
            pre_projection=False,
            # transformer parameters
            d_model=768,
            d_kv=64,
            d_ff=3072,
            num_layers=2,
            num_heads=12,
            pre_norm=False,
            use_bias=True,
            activation="gelu",
            dropout_rate=0.1,
            layer_norm_epsilon=1e-6,
            # masking parameters
            use_masking=True,
            mask_rate=0.2,
            # positional embedding parameters
            max_temporal_buckets=300,
            # post-projection parameters
            post_projection=False,
            d_post_proj=1024,
            post_proj_activation="gelu",
            name="au_tx_1d",
            **kwargs):
        super(AuTx1D, self).__init__(name=name)
        self.pre_projection = pre_projection
        self.post_projection = post_projection
        self.d_model = d_model
        # masking parameters
        self.use_masking = use_masking
        self.mask_rate = mask_rate

        # define waveform to patch module
        self.wave_to_patch = tf.keras.layers.Conv1D(
            filters=d_model,
            kernel_size=temporal_patch_size,
            strides=temporal_patch_stride,
            padding="valid",
            name="waveform_to_patch",
        )

        if self.pre_projection:
            self.pre_proj = tf.keras.layers.Dense(
                d_model,
                activation=activation,
                name="pre_tx_projection",
            )
        else:
            self.pre_proj = tf.identity

        self.use_random_patches = random_patch_sampling
        # define positional embedding module
        self.max_num_patches = int(patch_sampling_rate * max_temporal_buckets)
        assert max_temporal_buckets > self.max_num_patches, (
            "Max number of positional buckets should be bigger than max"
            " number of input patches")
        self.pos_embedding_lookup = transformers.TemporalEmbeddings(
            hidden_size=self.d_model,
            max_temporal_buckets=max_temporal_buckets,
        )

        # define transformer head
        self.tx = transformers.TransformerEncoder(
            d_model=d_model,
            d_kv=d_kv,
            d_ff=d_ff,
            num_layers=num_layers,
            num_heads=num_heads,
            pre_norm=pre_norm,
            use_bias=use_bias,
            activation=activation,
            dropout_rate=dropout_rate,
            layer_norm_epsilon=layer_norm_epsilon,
            name="transformer",
        )

        if self.post_projection:
            self.post_proj = tf.keras.layers.Dense(
                d_post_proj,
                activation=post_proj_activation,
                name="post_tx_projection",
            )
        else:
            self.post_proj = tf.identity
예제 #3
0
    def __init__(
            self,
            # pre-transformer parameters
            vid_temporal_patch_size=4,
            vid_spatial_patch_size=16,
            aud_temporal_patch_size=128,
            txt_vocab_size=2**16,
            txt_embedding_dim=300,
            txt_embedding_trainable=False,
            # video & audio input sampling
            random_patch_sampling=False,
            patch_sampling_rate=0.5,
            # transformer head parameters
            d_model=1024,
            d_kv=64,
            d_ff=4096,
            num_layers=24,
            num_heads=16,
            pre_norm=True,
            use_bias=True,
            activation="gelu",
            dropout_rate=0.1,
            layer_norm_epsilon=1e-6,
            # positional embedding parameters
            max_vid_temporal_buckets=8,
            max_vid_spatial_buckets=14,
            max_aud_temporal_buckets=1200,
            max_txt_temporal_buckets=16,
            # final head parameters
            d_post_proj=1024,
            post_proj_activation="gelu",
            name="unified_vat_transformer",
            **kwargs):
        super(UniversalVATT, self).__init__(name=name)
        self.d_model = d_model
        # define pre-tx projection
        self.raw_to_embeddings = {
            "video":
            tf.keras.layers.Conv3D(
                filters=d_model,
                kernel_size=(vid_temporal_patch_size, vid_spatial_patch_size,
                             vid_spatial_patch_size),
                strides=(vid_temporal_patch_size, vid_spatial_patch_size,
                         vid_spatial_patch_size),
                padding="valid",
                name="voxel_to_patch",
            ),
            "audio":
            tf.keras.layers.Conv1D(
                filters=d_model,
                kernel_size=aud_temporal_patch_size,
                strides=aud_temporal_patch_size,
                padding="valid",
                name="waveform_to_patch",
            ),
            "text":
            tf.keras.layers.Embedding(txt_vocab_size,
                                      txt_embedding_dim,
                                      trainable=txt_embedding_trainable,
                                      name="text_embedding")
        }
        self.pre_proj = {
            "video":
            tf.keras.layers.Dense(d_model,
                                  activation=activation,
                                  name="video_pre_tx_projection"),
            "audio":
            tf.keras.layers.Dense(d_model,
                                  activation=activation,
                                  name="audio_pre_tx_projection"),
            "text":
            tf.keras.layers.Dense(d_model,
                                  activation=activation,
                                  name="text_pre_tx_projection"),
        }

        # define sampling-related params
        self.use_random_patches = random_patch_sampling
        self.patch_sampling_rate = patch_sampling_rate
        self.max_buckets = {
            "video": max_vid_temporal_buckets * (max_vid_spatial_buckets**2),
            "audio": max_aud_temporal_buckets,
        }
        self.max_num_patches = {
            "video": int(self.patch_sampling_rate * self.max_buckets["video"]),
            "audio": int(self.patch_sampling_rate * self.max_buckets["audio"]),
        }
        assert self.max_buckets["video"] > self.max_num_patches["video"], (
            "Max number of video positional buckets should be bigger than max"
            " number of video input patches")
        assert self.max_buckets["audio"] > self.max_num_patches["audio"], (
            "Max number of audio positional buckets should be bigger than max"
            " number of audio input patches")

        # define positional embedding module
        self.pos_embedding_lookup = {
            "video":
            transformers.SpatioTemporalEmbeddings(
                hidden_size=self.d_model,
                max_temporal_buckets=max_vid_temporal_buckets,
                max_vertical_buckets=max_vid_spatial_buckets,
                max_horizontal_buckets=max_vid_spatial_buckets,
                name="video_spatio_temporal_embeddings",
            ),
            "audio":
            transformers.TemporalEmbeddings(
                hidden_size=self.d_model,
                max_temporal_buckets=max_aud_temporal_buckets,
                name="audio_temporal_embeddings",
            ),
            "text":
            transformers.TemporalEmbeddings(
                hidden_size=self.d_model,
                max_temporal_buckets=max_txt_temporal_buckets,
                name="text_temporal_embeddings",
            ),
        }

        # define transformer head
        self.tx = transformers.TransformerEncoder(
            d_model=d_model,
            d_kv=d_kv,
            d_ff=d_ff,
            num_layers=num_layers,
            num_heads=num_heads,
            pre_norm=pre_norm,
            use_bias=use_bias,
            activation=activation,
            dropout_rate=dropout_rate,
            layer_norm_epsilon=layer_norm_epsilon,
            name="transformer",
        )

        # define post-tx projection head - it could be logits or embd space
        self.post_proj = {
            "video":
            tf.keras.layers.Dense(d_post_proj,
                                  activation=post_proj_activation,
                                  name="video_post_tx_projection"),
            "audio":
            tf.keras.layers.Dense(d_post_proj,
                                  activation=post_proj_activation,
                                  name="audio_post_tx_projection"),
            "text":
            tf.keras.layers.Dense(d_post_proj,
                                  activation=post_proj_activation,
                                  name="text_post_tx_projection"),
        }