Exemple #1
0
def encoder_block(x,
                  att_dim=512,
                  num_heads=12,
                  mlp_dim=2048,
                  attn_drop=0.,
                  ffn_drop=0.1,
                  residual_scale=2.,
                  residual_drop=0.):
    # MSA
    inpt = x
    x = LayerNormalization()(x)
    x = MultiHeadAttention(att_dim, num_heads, attn_drop,
                           ffn_drop)([x, x, x])  # self-attention
    x = Dropout(residual_drop,
                noise_shape=(None, 1, 1))(x)  # stochastic-depth by sample
    x = Lambda(lambda x: x[0] + x[1] / residual_scale)([inpt, x])

    # FFN
    inpt = x
    x = LayerNormalization()(x)
    x = FeedForwardNetwork(mlp_dim,
                           att_dim,
                           activation=gelu,
                           drop_rate=ffn_drop)(x)
    x = Dropout(residual_drop, noise_shape=(None, 1, 1))(x)
    x = Lambda(lambda x: x[0] + x[1] / residual_scale)([inpt, x])

    return x
Exemple #2
0
def encoder_block(x,
                  hidden_dim=768,
                  att_drop_rate=0.,
                  num_heads=12,
                  mlp_dim=3072,
                  drop_rate=0.1):
    # MSA
    inpt = x
    x = LayerNormalization()(x)
    x = MultiHeadAttention(hidden_dim, num_heads)([x, x, x])  # self-attention
    x = Dropout(drop_rate)(x)
    x = add([inpt, x])
    # layer norm

    # FFN
    inpt = x
    out_dim = K.int_shape(x)[-1]
    x = LayerNormalization()(x)
    x = FeedForwardNetwork(mlp_dim,
                           out_dim,
                           activation=gelu,
                           drop_rate=drop_rate)(x)
    x = add([inpt, x])

    return x
class SentenceEncoderBlock(Layer):
    def __init__(self,
                 output_dim,
                 attention_dim,
                 n_heads,
                 dropout=0.3,
                 **kwargs):
        self.output_dim = output_dim  # Es la dimensión de salida del encoder después de las fc
        self.n_heads = n_heads
        self.attention_dim = attention_dim  # Es la dimensión para dq/dk/dv de multihead attention
        self.activation = "relu"
        self.dropout = dropout
        super(SentenceEncoderBlock, self).__init__(**kwargs)

    def build(self, input_shape):

        # "Two linear transformations with a ReLU activation in between" #
        self.dense_1 = Dense(self.output_dim, activation=self.activation)
        self.dense_1.build(input_shape)
        self._trainable_weights += self.dense_1.trainable_weights

        self.dense_2 = Dense(self.output_dim)
        self.dense_2.build(input_shape)
        self._trainable_weights += self.dense_2.trainable_weights

        # MultiHeadAttention #
        self.multihead_attention = MultiHeadAttention(self.attention_dim,
                                                      self.n_heads)
        self.multihead_attention.build(input_shape)
        self._trainable_weights += self.multihead_attention.trainable_weights

        # LayerNorm #
        self.layer_normalization = LayerNormalization()
        self.layer_normalization.build(input_shape)
        self._trainable_weights += self.layer_normalization.trainable_weights

        super(SentenceEncoderBlock, self).build(input_shape)

    def compute_mask(self, inputs, mask=None):
        # Just pass the received mask from previous layer, to the next layer
        return mask

    def call(self, x, mask=None):

        z, all_attns = self.multihead_attention(x)
        z = K.dropout(z, self.dropout)
        xz = self.layer_normalization(x + z)
        h_xz = self.dense_1(xz)
        h_xz = self.dense_2(h_xz)
        h_xz = K.dropout(h_xz, self.dropout)
        h_xz = self.layer_normalization(h_xz + xz)
        return [h_xz, all_attns]

    def compute_output_shape(self, input_shape):
        return [(input_shape[0], input_shape[1], self.output_dim),
                (input_shape[0], self.n_heads, input_shape[1], input_shape[1])]
def transformer_encoder(inputs, num_heads=4, dropout_rate=0.1):
    in_dim = K.int_shape(inputs)[-1]
    x = MultiHeadAttention(num_heads, in_dim)([inputs, inputs])
    x = Dropout(dropout_rate)(x)
    x = add([inputs, x])
    x1 = LayerNormalization()(x)
    x = Dense(in_dim * 2, activation='relu')(x1)
    x = Dense(in_dim)(x)
    x = Dropout(dropout_rate)(x)
    x = add([x1, x])
    x = LayerNormalization()(x)
    return x
Exemple #5
0
    def __init__(self,
                 d_model,
                 heads,
                 dim_q,
                 dim_v,
                 hidden_units,
                 dropout_rate,
                 name,
                 activation='relu',
                 **kwargs):
        self.dim_v = dim_v
        self.dim_q = dim_q
        self.hidden_units = hidden_units
        self.heads = heads

        self.attention_layer = MultiHeadedAttention(d_model=d_model,
                                                    heads=self.heads,
                                                    dim_q=self.dim_q,
                                                    dim_v=self.dim_v,
                                                    dropout_rate=dropout_rate,
                                                    name=name)
        self.normalization_layer = LayerNormalization()
        self.feedforward = PositionWiseFeedForward(d_model=d_model,
                                                   inner_dim=self.hidden_units,
                                                   dropout_rate=dropout_rate,
                                                   name=name)
Exemple #6
0
    def __init__(self,
                 emb_dim,
                 feature_shape,
                 n_heads,
                 window_size,
                 mlp_ratio=4,
                 qkv_bias=True,
                 attn_drop=0.,
                 ffn_drop=0.,
                 residual_drop=0.,
                 idx=None,
                 **kwargs):
        super(SwinTransformerBlock, self).__init__(name='STB_%d' % idx,
                                                   **kwargs)
        self.emb_dim = emb_dim
        self.feature_shape = feature_shape
        self.window_size = window_size
        self.shift_size = window_size // 2

        # W-MSA
        self.ln1 = LayerNormalization()
        self.wmsa = WindowMultiHeadAttention(emb_dim, n_heads, window_size,
                                             qkv_bias, attn_drop, ffn_drop)
        self.res_drop1 = Dropout(residual_drop, noise_shape=(None, 1, 1))

        self.ln2 = LayerNormalization()
        self.ffn = FeedForwardNetwork(emb_dim * mlp_ratio,
                                      emb_dim,
                                      activation=gelu,
                                      drop_rate=ffn_drop)
        self.res_drop2 = Dropout(residual_drop, noise_shape=(None, 1, 1))

        # SW-MSA
        self.ln3 = LayerNormalization()
        self.wmsa_s = WindowMultiHeadAttention(emb_dim, n_heads, window_size,
                                               qkv_bias, attn_drop, ffn_drop)
        self.res_drop3 = Dropout(residual_drop, noise_shape=(None, 1, 1))

        self.ln4 = LayerNormalization()
        self.ffn_s = FeedForwardNetwork(emb_dim * mlp_ratio,
                                        emb_dim,
                                        activation=gelu,
                                        drop_rate=ffn_drop)
        self.res_drop4 = Dropout(residual_drop, noise_shape=(None, 1, 1))
Exemple #7
0
    def build(self, input_shape):

        # "Two linear transformations with a ReLU activation in between" #
        self.dense_1 = Dense(self.output_dim, activation=self.activation)
        self.dense_1.build(input_shape)
        self._trainable_weights += self.dense_1.trainable_weights

        self.dense_2 = Dense(self.output_dim)
        self.dense_2.build(input_shape)
        self._trainable_weights += self.dense_2.trainable_weights

        # MultiHeadAttention #
        self.multihead_attention = MultiHeadAttention(self.attention_dim,
                                                      self.n_heads)
        self.multihead_attention.build(input_shape)
        self._trainable_weights += self.multihead_attention.trainable_weights

        # LayerNorm #
        self.layer_normalization = LayerNormalization()
        self.layer_normalization.build(input_shape)
        self._trainable_weights += self.layer_normalization.trainable_weights

        super(WordEncoderBlock, self).build(input_shape)
Exemple #8
0
    def __init__(self, feature_shape, emb_dim, stage_idx=None, **kwargs):
        super(PatchMerging, self).__init__(name='PatchMerging_%d' % stage_idx,
                                           **kwargs)
        h, w = feature_shape
        pad_h, pad_w = int(h % 2 == 1), int(w % 2 == 1)

        self.use_pad = False
        if pad_h or pad_w:
            self.use_pad = True
            self.pad = Pad_HW(0, pad_h, 0, pad_w)
        self.ln = LayerNormalization()
        self.dense = Dense(2 * emb_dim,
                           use_bias=False,
                           kernel_initializer=bias_init)

        self.feature_shape = feature_shape
        self.emb_dim = emb_dim
Exemple #9
0
def visionTransformer(input_size=224,
                      patch_size=16,
                      drop_rate=0.1,
                      num_layers=12,
                      hidden_dim=768,
                      att_drop_rate=0.,
                      num_heads=12,
                      mlp_dim=3072,
                      out_dim=None):

    inpt = Input((input_size, input_size, 3))

    # linear project patches
    x = Conv2D(hidden_dim, patch_size, strides=patch_size,
               padding='valid')(inpt)  # [b,Np,Np,D]

    # reshape
    b, h, w, c = K.int_shape(x)
    x = Reshape((h * w, c))(x)  # [b,N,D]

    # prepend class token
    x0 = Lambda(lambda x: K.placeholder((None, 1, hidden_dim)))(x)
    x = Concatenate(axis=1)([x0, x])  # [b,N+1,D]

    b, sq_len, input_dim = K.int_shape(x)
    # add fixed/learnable positional embeddings
    pe = Lambda(lambda x: positional_embedding(sq_len, input_dim))(x)
    x = add([x, pe])  # [b,N+1,D]
    x = Dropout(drop_rate)(x)

    # transformer encoder
    for i in range(num_layers):
        x = encoder_block(x, hidden_dim, att_drop_rate, num_heads, mlp_dim,
                          drop_rate)
    x = LayerNormalization()(x)

    # take cls token
    x = Lambda(lambda x: x[:, 0, :])(x)  # [b,D]
    if out_dim:
        x = Dense(out_dim, activation='tanh')(x)

    model = Model(inpt, x)

    return model
    def build(self):

        self.input_document = Input(shape=(self.max_words,
                                           self.embedding_dims))
        self.mask = Input(shape=(self.max_words, ))
        self.pos_encoding = Input(shape=(self.max_words, self.embedding_dims))

        # Padding #
        self.z_input = MyMasking()(self.input_document, mask=self.mask)

        # Positional Encoding#
        if self.pe:
            self.z_input = PositionalEncoding()(self.z_input,
                                                mask=self.pos_encoding)

        # Dropout at input (sentence level)
        self.z_input = SpatialDropout1D(self.dropout_input)(self.z_input)

        self.all_attns = []
        ant_layer = self.z_input
        for i in range(self.n_encoders):
            self.sentence_encoder = SentenceEncoderBlock(
                self.output_encoder_dims[i],
                self.attention_dims[i],
                self.n_heads[i],
                dropout=self.dropout_output)
            self.document_encoder = self.sentence_encoder(ant_layer)
            self.z_encoder = Lambda(lambda x: x[0])(self.document_encoder)
            self.attn_encoder = Lambda(lambda x: x[1])(self.document_encoder)

            self.all_attns.append(self.attn_encoder)

            # Masking entre cada capa #
            self.z_encoder = MyMasking()(self.z_encoder, mask=self.mask)

            ant_layers = (self.z_encoder)

        # Prepare all attentions #
        if self.n_encoders > 1:
            self.all_attns = [
                Lambda(lambda a: K.expand_dims(a, 1))(x)
                for x in self.all_attns
            ]
            self.all_attns = Concatenate(axis=1)(self.all_attns)

        ##########################

        if self.pool_mode == "max":
            self.z_encoder = GlobalMaxPooling1D()(self.z_encoder)

        else:
            self.z_encoder = GlobalAveragePooling1D()(self.z_encoder)

        #self.z_encoder = Dropout(0.3)(self.z_encoder) # En el mejor, no estaba
        self.z_encoder = LayerNormalization()(
            self.z_encoder)  # En el mejor, esto activado!

        self.h = self.z_encoder

        if self.final_h:
            self.h = Dense(self.dim_h, activation="relu")(self.h)
            #self.h = Dropout(0.3)(self.h) # Este no está en el mejor
            self.h = LayerNormalization()(
                self.h)  # En el mejor, esto activado!

        self.output = Dense(4, activation="softmax")(self.h)

        self.model = Model(
            inputs=[self.input_document, self.mask, self.pos_encoding],
            outputs=[self.output])

        self.attn_model = Model(
            inputs=[self.input_document, self.mask, self.pos_encoding],
            outputs=self.all_attns)
    def build(self):

        self.input_article = Input(shape=(self.document_max_sents,
                                          self.document_max_words_per_sent))

        self.input_summary = Input(shape=(self.summary_max_sents,
                                          self.summary_max_words_per_sent))

        self.mask_word_article = Input(shape=(self.document_max_sents,
                                              self.document_max_words_per_sent))

        self.mask_word_summary = Input(shape=(self.summary_max_sents,
                                              self.summary_max_words_per_sent))

        self.mask_sent_article = Input(shape=(self.document_max_sents,))
        self.mask_sent_summary = Input(shape=(self.summary_max_sents,))


        self.pos_encoding_word_article = Input(shape=(self.document_max_sents,
                                                      self.document_max_words_per_sent, self.embedding_dims))

        self.pos_encoding_word_summary = Input(shape=(self.summary_max_sents,
                                                      self.summary_max_words_per_sent, self.embedding_dims))

        self.pos_encoding_sent_article = Input(shape=(self.document_max_sents, self.embedding_dims))
        self.pos_encoding_sent_summary = Input(shape=(self.summary_max_sents, self.embedding_dims))


        self.embedding = Embedding(self.max_vocabulary, self.embedding_dims, mask_zero=False)

        # Get Word Embeddings (shared between branches) #
        self.e_article = self.embedding(self.input_article)
        self.e_summary = self.embedding(self.input_summary)

        # Masking de palabras #
        self.ep_article = TimeDistributed(MyMasking())(self.e_article, mask = self.mask_word_article)
        self.ep_summary = TimeDistributed(MyMasking())(self.e_summary, mask = self.mask_word_summary)

        # Adding Word embeddings and Positional embeddings #
        if self.pe_words:
            self.ep_article = TimeDistributed(PositionalEncoding())(self.ep_article, mask = self.pos_encoding_word_article)
            self.ep_summary = TimeDistributed(PositionalEncoding())(self.ep_summary, mask = self.pos_encoding_word_summary)

        # Dropout at input (word level)#
        #self.ep_article = TimeDistributed(SpatialDropout1D(self.dropout_word_input))(self.ep_article)
        #self.ep_summary = TimeDistributed(SpatialDropout1D(self.dropout_word_input))(self.ep_summary)

        # Word Encoders #

        ant_layers = (self.ep_article, self.ep_summary)
        for i in range(self.n_word_encoders):
            self.word_encoder = WordEncoderBlock(self.output_word_encoder_dims[i],
                                                 self.word_attention_dims[i],
                                                 self.n_word_heads[i], dropout = self.dropout_word_output)

            self.z_article_word_encoder = TimeDistributed(self.word_encoder)(ant_layers[0])
            self.z_summary_word_encoder = TimeDistributed(self.word_encoder)(ant_layers[1])

            self.z_article_word_encoder = TimeDistributed(MyMasking())(self.z_article_word_encoder, mask = self.mask_word_article) # Padding entre cada capa
            self.z_summary_word_encoder = TimeDistributed(MyMasking())(self.z_summary_word_encoder, mask = self.mask_word_summary)

            ant_layers = (self.z_article_word_encoder, self.z_summary_word_encoder)

        self.z_article_word_encoder = TimeDistributed(GlobalMaxPooling1D())(self.z_article_word_encoder)
        self.z_summary_word_encoder = TimeDistributed(GlobalMaxPooling1D())(self.z_summary_word_encoder)

        # Sentence Encoders #

        # Padding de frases #
        self.z_article_word_encoder = MyMasking()(self.z_article_word_encoder, mask = self.mask_sent_article)
        self.z_summary_word_encoder = MyMasking()(self.z_summary_word_encoder, mask = self.mask_sent_summary)

        # Positional Encodings para orden sobre frases #
        if self.pe_sentences:
            self.z_article_word_encoder = PositionalEncoding()(self.z_article_word_encoder, mask = self.pos_encoding_sent_article)
            self.z_summary_word_encoder = PositionalEncoding()(self.z_summary_word_encoder, mask = self.pos_encoding_sent_summary)

        # Dropout at input (sentence level)
        #self.z_article_word_encoder = SpatialDropout1D(self.dropout_sent_input)(self.z_article_word_encoder)
        #self.z_summary_word_encoder = SpatialDropout1D(self.dropout_sent_input)(self.z_summary_word_encoder)

        self.all_article_attns = []
        ant_layers = (self.z_article_word_encoder, self.z_summary_word_encoder)

        for i in range(self.n_sentence_encoders):
            self.sentence_encoder = SentenceEncoderBlock(self.output_sentence_encoder_dims[i],
                                                         self.sentence_attention_dims[i],
                                                         self.n_sentence_heads[i], dropout = self.dropout_sent_output)

            self.article_sentence_encoder = self.sentence_encoder(ant_layers[0])
            self.summary_sentence_encoder = self.sentence_encoder(ant_layers[1])

            self.z_article_sentence_encoder = Lambda(lambda x: x[0])(self.article_sentence_encoder)
            self.z_summary_sentence_encoder = Lambda(lambda x: x[0])(self.summary_sentence_encoder)

            self.attn_article_sentence_encoder = Lambda(lambda x: x[1])(self.article_sentence_encoder)
            self.all_article_attns.append(self.attn_article_sentence_encoder)

            # Masking entre cada capa #
            self.z_article_sentence_encoder = MyMasking()(self.z_article_sentence_encoder, mask = self.mask_sent_article)
            self.z_summary_sentence_encoder = MyMasking()(self.z_summary_sentence_encoder, mask = self.mask_sent_summary)

            ant_layers = (self.z_article_sentence_encoder, self.z_summary_sentence_encoder)

        # Prepare all attentions #
        if self.n_sentence_encoders > 1:
            self.all_article_attns = [Lambda(lambda a: K.expand_dims(a, 1))(x) for x in self.all_article_attns]
            self.all_article_attns = Concatenate(axis=1)(self.all_article_attns)

        ##########################


        self.z_article_sentence_encoder = GlobalMaxPooling1D()(self.z_article_sentence_encoder)
        self.z_summary_sentence_encoder = GlobalMaxPooling1D()(self.z_summary_sentence_encoder)

        self.difference = Lambda(lambda x: K.abs(x[0] - x[1]))([self.z_article_sentence_encoder,
                                                                self.z_summary_sentence_encoder])

        self.collapsed = Concatenate(axis=-1)([self.z_article_sentence_encoder,
                                               self.z_summary_sentence_encoder,
                                               self.difference])

        self.collapsed = LayerNormalization()(self.collapsed)

        self.h = Dense(self.dim_h, activation="relu")(self.collapsed)
        self.h = LayerNormalization()(self.h)
        self.output = Dense(2, activation="softmax")(self.h)

        self.model = Model(inputs = [self.input_article, self.mask_word_article, self.mask_sent_article, self.pos_encoding_word_article, self.pos_encoding_sent_article,
                                     self.input_summary, self.mask_word_summary, self.mask_sent_summary, self.pos_encoding_word_summary, self.pos_encoding_sent_summary],
                           outputs = [self.output])

        self.attn_model = Model(inputs = [self.input_article, self.mask_word_article, self.mask_sent_article, self.pos_encoding_word_article, self.pos_encoding_sent_article],
                                outputs = self.all_article_attns)
Exemple #12
0
def LV_ViT(input_size=224,
           patch_size=16,
           emb_dim=384,
           mlp_dim=1152,
           out_dim=9,
           num_layers=16,
           num_heads=6,
           attn_drop=0.,
           ffn_drop=0.,
           residual_drop=0.1,
           residual_scale=2.,
           mix_token=False,
           aux_loss=False):

    inpt = Input((input_size, input_size, 3))

    # patch embedding
    x = ConvBN(inpt, 64, kernel_size=7, strides=2)
    x = ConvBN(x, 64, kernel_size=3, strides=1)
    x = ConvBN(x, 64, kernel_size=3, strides=1)
    x = Conv2D(emb_dim, 8, strides=8, padding='same')(x)
    N = input_size // patch_size

    # mixtoken
    if mix_token:
        mix_mask = Input((N, N, 1))  # random crop based on beta distribution
        x = Lambda(mix_tokens)([x, mix_mask])

    # cls token
    x = Reshape((N * N, emb_dim))(x)
    tmp = Lambda(lambda x: x[:, 0:1, :])(x)  # [b,1,D]
    x0 = Lambda(lambda x: K.zeros_like(x))(tmp)
    x = Concatenate(axis=1)([x0, x])  # [b,N+1,D]

    # positional embeddings
    pe = Lambda(lambda x: tf.tile(positional_embedding(N * N + 1, emb_dim),
                                  [tf.shape(x)[0], 1, 1]))(x)
    x = add([x, pe])  # [b,N+1,D]
    x = Dropout(ffn_drop)(x)

    # transformer blocks
    for i in range(num_layers):
        x = encoder_block(x, emb_dim, num_heads, mlp_dim, attn_drop, ffn_drop,
                          residual_scale, residual_drop)
    x = LayerNormalization()(x)

    # take cls token
    x_cls = Lambda(lambda x: x[:, 0, :])(x)  # [b,D]
    if out_dim:
        x_cls = Dense(out_dim, activation='softmax')(x_cls)

    # take aux tokens
    if aux_loss:
        x_aux = Lambda(lambda x: x[:, 1:, :])(x)  # [b,N,D]
        if mix_token:
            x_aux = Reshape((N, N, emb_dim))(x_aux)
            x_aux = Lambda(mix_tokens)([x_aux, mix_mask])
            x_aux = Reshape((N * N, emb_dim))(x_aux)
        x_aux = Dense(out_dim, activation='softmax')(x_aux)

    model_inputs = [inpt, mix_mask] if mix_token else inpt
    model_outputs = [x_cls, x_aux] if aux_loss else x_cls
    model = Model(model_inputs, model_outputs)

    return model
Exemple #13
0
def SwinTransformer(
        input_shape=(224, 224, 3),
        patch_size=4,
        emb_dim=96,
        ape=False,
        n_classes=1000,  # in/out hypers
        num_layers=[2, 2, 6, 2],
        num_heads=[3, 6, 12, 24],  # structual hypers
        window_size=7,
        qkv_bias=True,
        qk_scale=None,
        mlp_ratio=4,  # swin-block hypers
        attn_drop=0.,
        ffn_drop=0.,
        residual_drop=0.2):

    inpt = Input(input_shape)
    # assert input_size%7==0 and input_size%16==0, 'input_size can not be divided clean'

    # patch embedding
    x = Conv2D(emb_dim, patch_size, strides=patch_size,
               padding='same')(inpt)  # (b,h/4,w/4,C), autopad
    H, W = (math.ceil(input_shape[0] / patch_size),
            math.ceil(input_shape[1] / patch_size))
    x = LayerNormalization()(x)  # [b,H,W,D]

    # absolute positional embeddings
    if ape:
        pe = Lambda(lambda x: tf.truncated_normal(
            (H, W, emb_dim), mean=0.0, stddev=.02))(x)  # (1,H,W,D)
        x = add([x, pe])  # (b,H,W,D)

    x = Dropout(ffn_drop)(x)

    # transformer stages
    n_stages = len(num_layers)
    dbr = np.linspace(0, residual_drop, num=sum(num_layers))  # drop block rate
    block_idx = 0
    for i in range(n_stages):
        merging = True if i < n_stages - 1 else False
        # pad on the top
        pad_b, pad_r = (window_size - H % window_size) % window_size, (
            window_size - W % window_size) % window_size
        if pad_b or pad_r:
            pad_l = pad_t = 0
            # x = Lambda(lambda x: tf.pad(x, [[0,0],[pad_t, pad_b], [pad_l, pad_r], [0,0]]))(x)
            x = Pad_HW(pad_t, pad_b, pad_l, pad_r)(x)
            H += pad_b
            W += pad_r
        # WSA+SWSA: 2 blocks
        x = basicStage(
            x,
            emb_dim, (H, W),
            num_layers[i] // 2,
            num_heads[i],
            window_size,
            mlp_ratio,
            qkv_bias,
            attn_drop=0.,
            ffn_drop=0.,
            residual_drop=dbr[sum(num_layers[:i]):sum(num_layers[:i + 1])],
            patch_merging=merging,
            idx=block_idx,
            stage=i)
        emb_dim *= 2
        block_idx += num_layers[i] // 2
        if merging:
            H = (H + 1) // 2
            W = (W + 1) // 2
    x = LayerNormalization()(x)  # [b,H/32,W/32,8C]

    # head
    x = GlobalAveragePooling2D(data_format='channels_last')(x)  # (b,8C)

    if n_classes:
        x = Dense(n_classes,
                  activation='softmax',
                  kernel_initializer=bias_init,
                  bias_initializer='zeros')(x)

    model = Model(inpt, x)

    return model
Exemple #14
0
def SwinTransformer(input_size=224,
                    patch_size=4,
                    emb_dim=96,
                    mlp_ratio=4,
                    out_dim=1000,
                    ape=False,
                    num_layers=[2, 2, 6, 2],
                    num_heads=[3, 6, 12, 24],
                    window_size=7,
                    qkv_bias=True,
                    attn_drop=0.,
                    ffn_drop=0.,
                    residual_drop=0.2):

    inpt = Input((input_size, input_size, 3))
    assert input_size % 7 == 0 and input_size % 16 == 0, 'input_size can not be divided clean'

    # patch embedding
    x = Conv2D(emb_dim, patch_size, strides=patch_size, padding='same')(inpt)
    N = input_size // patch_size  # grid_size, s4
    x = Reshape((N * N, emb_dim))(x)
    x = LayerNormalization()(x)  # [b,T,D]

    # absolute positional embeddings
    if ape:
        pe = Lambda(lambda x: tf.truncated_normal(
            (1, tf.shape(x)[1], tf.shape(x)[2]), mean=0.0, stddev=.02))(
                x)  # [1,T,D]
        x = add([x, pe])  # [b,N+1,D]

    x = Dropout(ffn_drop)(x)

    # transformer stages
    n_stages = len(num_layers)
    dbr = np.linspace(0, residual_drop, num=sum(num_layers))
    feature_size = N  # start from s4
    block_idx = 0
    for i in range(n_stages):
        merging = True if i < n_stages - 1 else False
        x = swin_block(x,
                       emb_dim,
                       feature_size,
                       num_layers[i],
                       num_heads[i],
                       window_size,
                       mlp_ratio,
                       qkv_bias,
                       attn_drop=0.,
                       ffn_drop=0.,
                       residual_drop=dbr[i],
                       patch_merging=merging,
                       idx=block_idx,
                       stage=i)
        emb_dim *= 2
        block_idx += num_layers[i]
        if merging:
            feature_size //= 2
    x = LayerNormalization()(x)  # [H/32*W/32,8C]

    # head
    x = GlobalAveragePooling1D(data_format='channels_last')(x)
    if out_dim:
        x = Dense(out_dim, activation='softmax')(x)

    model = Model(inpt, x)

    return model