def over_add(x, pad_len):
    bs, n, k, hidden = x.shape
    x = Rearrange("b n k h->(b h) n k")(x)
    x = tf.signal.overlap_and_add(x, frame_step=k // 2)
    x = x[:, k // 2 - k // 4:-(pad_len + k // 4)]
    x = Rearrange("(b h) l->b l h", h=hidden)(x)
    return x
def segmentation(x, k):
    bs, length, hidden = x.shape
    target_length = int(np.floor(length * 2 / k))
    gap = (target_length * k // 2) - length
    x = tf.pad(x, ((0, 0), (k // 2 - k // 4, gap + k // 4), (0, 0)))
    x = Rearrange("b l f->(b f) l")(x)
    x = tf.signal.frame(x, frame_length=k, frame_step=k // 2)
    x = Rearrange("(b a) t f -> b t f a", b=bs)(x)
    return x, gap
Example #3
0
    def __init__(self, dim, heads=8):
        super().__init__()
        self.heads = heads
        self.scale = dim**-0.5

        self.to_qkv = tf.keras.layers.Dense(dim * 3, use_bias=False)
        self.to_out = tf.keras.layers.Dense(dim)

        self.rearrange_qkv = Rearrange('b n (qkv h d) -> qkv b h n d',
                                       qkv=3,
                                       h=self.heads)
        self.rearrange_out = Rearrange('b h n d -> b n (h d)')
    def __init__(self, dimension, heads=8, dropout_rate=0.0):
        super(SelfAttention, self).__init__()

        self.heads = heads
        self.scale = dimension ** -0.5

        self.qkv = Dense(dimension * 3, use_bias=False)
        self.rearrange_attention = Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=self.heads)
        self.attn_dropout = Dropout(dropout_rate)
        
        self.rearrange_output = Rearrange('b h n d -> b n (h d)')
        self.proj = Dense(dimension)
        self.proj_dropout = Dropout(dropout_rate)
Example #5
0
    def __init__(self,
                 dim=512,
                 levels=6,
                 image_size=224,
                 patch_size=14,
                 consensus_self=False,
                 local_consensus_radius=0):

        super(Glom, self).__init__()
        num_patches_side = (image_size // patch_size)
        num_patches = num_patches_side**2
        self.levels = levels

        self.image_to_tokens = tf.keras.Sequential([
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)',
                      p1=patch_size,
                      p2=patch_size),
            tf.keras.layers.Dense(dim, input_dim=patch_size**2 * 3)
        ])
        self.pos_emb = tf.keras.layers.Embedding(num_patches, dim)
        self.init_levels = tf.Variable(tf.random.normal([levels, dim]))

        # bottom-up and top-down unit
        self.bottom_up = GroupedFeedForward(dim=dim, groups=levels)
        self.top_down = GroupedFeedForward(dim=dim, groups=levels - 1)

        # consensus attention unit
        self.attention = ConsensusAttention(
            num_patches_side,
            attend_self=consensus_self,
            local_consensus_radius=local_consensus_radius)
Example #6
0
def VisionTransformer(input_shape, n_classes, patch_size, patch_dim, n_encoder_layers, n_heads, ff_dim,
                      dropout_rate=0.0):
    inputs = tf.keras.layers.Input(input_shape)
    x = Rearrange('b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size)(inputs)
    x = tf.keras.layers.Dense(patch_dim)(x)
    x = ConcatEmbedding(1, patch_dim,
                        side="left",
                        axis=1,
                        initializer=tf.keras.initializers.RandomNormal(),
                        name="add_cls_token")(x)
    x = LearnedEmbedding1D(x.shape[1], patch_dim,
                           initializer=tf.keras.initializers.RandomNormal(),
                           name="pos_embedding")(x)
    x = Encoder(embed_dim=patch_dim,
                num_heads=n_heads,
                ff_dim=ff_dim,
                num_layers=n_encoder_layers,
                dropout_rate=dropout_rate)(x)
    x = tf.keras.layers.Cropping1D((0, x.shape[1] - 1))(x)
    x = tf.keras.layers.Reshape([-1])(x)

    x = tf.keras.Sequential([
        tf.keras.layers.Dense(ff_dim, activation=tfa.activations.gelu),
        tf.keras.layers.Dense(n_classes)],
        name="mlp_head")(x)

    model = tf.keras.models.Model(inputs, x)
    return model
Example #7
0
def VisionTransformerOS(input_shape, patch_size, patch_dim, n_encoder_layers, n_heads, ff_dim, dropout_rate=0.0):
    inputs1 = tf.keras.layers.Input(input_shape, name="x1")
    inputs2 = tf.keras.layers.Input(input_shape, name="x2")

    x1 = Rearrange('b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size)(inputs1)
    x2 = Rearrange('b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size)(inputs2)

    x1 = tf.keras.layers.Dense(patch_dim)(x1)
    x2 = tf.keras.layers.Dense(patch_dim)(x2)

    # grid_size = (input_shape[0] // patch_size, input_shape[1] // patch_size)
    # x1 = PatchEmbedding2D(grid_size=grid_size, embedding_dim=patch_dim)(x1)
    # x2 = PatchEmbedding2D(grid_size=grid_size, embedding_dim=patch_dim)(x2)

    x = ConcatEmbedding(n_embeddings=1,
                        embedding_dim=patch_dim,
                        side="left",
                        axis=1,
                        initializer=tf.keras.initializers.RandomNormal(),
                        name="add_cls_token")(x1)
    x = ConcatEmbedding(n_embeddings=1,
                        embedding_dim=patch_dim,
                        side="right",
                        axis=1,
                        initializer=tf.keras.initializers.RandomNormal(),
                        name="add_sep_token")(x)
    x = tf.keras.layers.Concatenate(axis=1)([x, x2])

    x = LearnedEmbedding1D(patch_dim,
                           initializer=tf.keras.initializers.RandomNormal(),
                           name="pos_embedding")(x)
    # x = PositionalEmbedding1D(patch_dim,
    #                           name="pos_embedding")(x)
    x = Encoder(embed_dim=patch_dim,
                num_heads=n_heads,
                ff_dim=ff_dim,
                num_layers=n_encoder_layers,
                dropout_rate=dropout_rate)(x)
    x = tf.keras.layers.Cropping1D((0, x.shape[1] - 1))(x)
    x = tf.keras.layers.Reshape([-1])(x)

    # MLP
    x = tf.keras.layers.Dense(ff_dim, activation=tfa.activations.gelu)(x)
    x = tf.keras.layers.Dense(1, activation="sigmoid")(x)

    model = tf.keras.models.Model([inputs1, inputs2], x)
    return model
Example #8
0
def VisionTransformerOSv2(input_shape, patch_size, patch_dim, n_encoder_layers, n_decoder_layers, n_heads, ff_dim,
                          dropout_rate=0.0):
    inputs1 = tf.keras.layers.Input(input_shape, name="x1")
    inputs2 = tf.keras.layers.Input(input_shape, name="x2")

    x_enc = Rearrange('b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size)(inputs1)
    x_dec = Rearrange('b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size)(inputs2)

    x_enc = tf.keras.layers.Dense(patch_dim)(x_enc)
    x_dec = tf.keras.layers.Dense(patch_dim)(x_dec)

    x_enc = PositionalEmbedding1D(patch_dim)(x_enc)
    x_dec = PositionalEmbedding1D(patch_dim)(x_dec)

    x_enc = Encoder(embed_dim=patch_dim,
                    num_heads=n_heads,
                    ff_dim=ff_dim,
                    num_layers=n_encoder_layers,
                    dropout_rate=dropout_rate)(x_enc)

    x_dec = ConcatEmbedding(n_embeddings=1,
                            embedding_dim=patch_dim,
                            side="left",
                            axis=1,
                            initializer=tf.keras.initializers.RandomNormal(),
                            name="add_cls_token")(x_dec)
    x_dec = Decoder(embed_dim=patch_dim,
                    num_heads=n_heads,
                    ff_dim=ff_dim,
                    num_layers=n_decoder_layers,
                    dropout_rate=dropout_rate,
                    norm=False,
                    causal=False)([x_dec, x_enc])

    x = tf.keras.layers.Lambda(lambda x: x[:, 0, :], name="cls_token_out")(x_dec)

    # MLP
    x = tf.keras.layers.Dense(ff_dim, activation=tfa.activations.gelu)(x)
    x = tf.keras.layers.Dense(1, activation="sigmoid")(x)

    model = tf.keras.models.Model([inputs1, inputs2], x)
    return model
    def call(self, x, **kwargs):
        b, hh, ww, c, u, h = *x.get_shape().as_list(), self.u, self.heads

        q = self.to_q(x)
        k = self.to_k(x)
        v = self.to_v(x)

        q = self.norm_q(q)
        v = self.norm_v(v)

        q = Rearrange('b hh ww (h k) -> b h k (hh ww)', h=h)(q)
        k = Rearrange('b hh ww (u k) -> b u k (hh ww)', u=u)(k)
        v = Rearrange('b hh ww (u v) -> b u v (hh ww)', u=u)(v)

        k = nn.softmax(k)

        Lc = einsum('b u k m, b u v m -> b k v', k, v)
        Yc = einsum('b h k n, b k v -> b n h v', q, Lc)

        if self.local_contexts:
            v = Rearrange('b u v (hh ww) -> b v hh ww u', hh=hh, ww=ww)(v)
            Lp = self.pos_conv(v)
            Lp = Rearrange('b v h w k -> b v k (h w)')(Lp)
            Yp = einsum('b h k n, b v k n -> b n h v', q, Lp)
        else:
            rel_pos_emb = tf.gather_nd(self.rel_pos_emb, self.rel_pos)
            Lp = einsum('n m k u, b u v m -> b n k v', rel_pos_emb, v)
            Yp = einsum('b h k n, b n k v -> b n h v', q, Lp)

        Y = Yc + Yp
        out = Rearrange('b (hh ww) h v -> b hh ww (h v)', hh = hh, ww = ww)(Y)
        return out
Example #10
0
    def __init__(self, dim, groups, mult=4):
        super(GroupedFeedForward, self).__init__()
        total_dim = dim * groups
        self.net = tf.keras.Sequential(
            [
                Rearrange('b n l d -> b (l d) n'),
                tf.keras.layers.Conv1D(
                    total_dim,
                    total_dim * mult,
                    1,
                    groups=groups,
                    activation='gelu'),
                tf.keras.layers.Conv1D(
                    total_dim,
                    total_dim * mult,
                    1,
                    groups=groups),
                Rearrange(
                    'b (l d) n -> b n l d',
                    l=groups)])

        def call(self, inputs):
            return self.net(inputs)
Example #11
0
 def __init__(
     self,
     num_classes,
     num_blocks,
     patch_size,
     hidden_dim,
     tokens_mlp_dim,
     channels_mlp_dim,
 ):
     super().__init__()
     s = (patch_size, patch_size)
     self.make_tokens = layers.Conv2D(hidden_dim, s, s)
     self.rearrange = Rearrange("n h w c -> n (h w) c")
     self.mixers = [
         NdMixerBlock([tokens_mlp_dim, channels_mlp_dim])
         for _ in range(num_blocks)
     ]
     self.batchnorm = layers.BatchNormalization()
     self.clf = layers.Dense(num_classes, kernel_initializer="zeros")
Example #12
0
    def __init__(self,
                 *,
                 image_size,
                 patch_size,
                 num_classes,
                 dim,
                 depth,
                 heads,
                 mlp_dim,
                 channels=3):
        super().__init__()
        assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
        num_patches = (image_size // patch_size)**2
        patch_dim = channels * patch_size**2

        self.patch_size = patch_size
        self.dim = dim
        self.pos_embedding = self.add_weight(
            "position_embeddings",
            shape=[num_patches + 1, dim],
            initializer=tf.keras.initializers.RandomNormal(),
            dtype=tf.float32)
        self.patch_to_embedding = tf.keras.layers.Dense(dim)
        self.cls_token = self.add_weight(
            "cls_token",
            shape=[1, 1, dim],
            initializer=tf.keras.initializers.RandomNormal(),
            dtype=tf.float32)

        self.rearrange = Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)',
                                   p1=self.patch_size,
                                   p2=self.patch_size)

        self.transformer = Transformer(dim, depth, heads, mlp_dim)

        self.to_cls_token = tf.identity

        self.mlp_head = tf.keras.Sequential([
            tf.keras.layers.Dense(mlp_dim, activation=get_activation('gelu')),
            tf.keras.layers.Dense(num_classes)
        ])
    def __init__(self,
                 image_size,
                 patch_size,
                 n_classes,
                 batch_size,
                 dimension,
                 depth,
                 heads,
                 mlp_dimension,
                 channels=3):
        super(ImageTransformer, self).__init__()
        assert image_size % patch_size == 0, 'invalid patch size for image size'

        num_patches = (image_size // patch_size)**2
        self.patch_size = patch_size
        self.dimension = dimension
        self.batch_size = batch_size

        self.positional_embedding = self.add_weight(
            "position_embeddings",
            shape=[num_patches + 1, dimension],
            initializer=tf.keras.initializers.RandomNormal(),
            dtype=tf.float32)
        self.embedding_mlp = tf.keras.layers.Dense(dimension)
        self.classification_token = self.add_weight(
            "classification_token",
            shape=[1, 1, dimension],
            initializer=tf.keras.initializers.RandomNormal(),
            dtype=tf.float32)

        self.rearrange = Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)',
                                   p1=self.patch_size,
                                   p2=self.patch_size)
        self.transformer = TransformerEncoder(dimension, depth, heads,
                                              mlp_dimension)
        self.classification_identity = tf.identity
        self.mlp_1 = tf.keras.layers.Dense(mlp_dimension)
        self.gelu = GELU()
        self.output = tf.keras.layers.Dense(n_classes)
Example #14
0
def calc_rel_pos(n):
    pos = tf.stack(meshgrid(tf.range(n), tf.range(n), indexing = 'ij'))
    pos = Rearrange('n i j -> (i j) n')(pos)             # [n*n, 2] pos[n] = (i, j)
    rel_pos = pos[None, :] - pos[:, None]                # [n*n, n*n, 2] rel_pos[n, m] = (rel_i, rel_j)
    rel_pos += n - 1                                     # shift value range from [-n+1, n-1] to [0, 2n-2]
    return rel_pos
 def __init__(self, p1, p2):
     super(Rearrange3d, self).__init__()
     self.rearrange = Rearrange('b (h p1) (w p2) c -> b (h w) (p1 p2 c)',
                                p1 = p1, p2 = p2)
 def __init__(self, num_patches=20, embed_dim=20):
     super(RearrangeCh, self).__init__()
     self.rearrange = Rearrange('b c n w -> b n (c w)')