def over_add(x, pad_len): bs, n, k, hidden = x.shape x = Rearrange("b n k h->(b h) n k")(x) x = tf.signal.overlap_and_add(x, frame_step=k // 2) x = x[:, k // 2 - k // 4:-(pad_len + k // 4)] x = Rearrange("(b h) l->b l h", h=hidden)(x) return x
def segmentation(x, k): bs, length, hidden = x.shape target_length = int(np.floor(length * 2 / k)) gap = (target_length * k // 2) - length x = tf.pad(x, ((0, 0), (k // 2 - k // 4, gap + k // 4), (0, 0))) x = Rearrange("b l f->(b f) l")(x) x = tf.signal.frame(x, frame_length=k, frame_step=k // 2) x = Rearrange("(b a) t f -> b t f a", b=bs)(x) return x, gap
def __init__(self, dim, heads=8): super().__init__() self.heads = heads self.scale = dim**-0.5 self.to_qkv = tf.keras.layers.Dense(dim * 3, use_bias=False) self.to_out = tf.keras.layers.Dense(dim) self.rearrange_qkv = Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=self.heads) self.rearrange_out = Rearrange('b h n d -> b n (h d)')
def __init__(self, dimension, heads=8, dropout_rate=0.0): super(SelfAttention, self).__init__() self.heads = heads self.scale = dimension ** -0.5 self.qkv = Dense(dimension * 3, use_bias=False) self.rearrange_attention = Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=self.heads) self.attn_dropout = Dropout(dropout_rate) self.rearrange_output = Rearrange('b h n d -> b n (h d)') self.proj = Dense(dimension) self.proj_dropout = Dropout(dropout_rate)
def __init__(self, dim=512, levels=6, image_size=224, patch_size=14, consensus_self=False, local_consensus_radius=0): super(Glom, self).__init__() num_patches_side = (image_size // patch_size) num_patches = num_patches_side**2 self.levels = levels self.image_to_tokens = tf.keras.Sequential([ Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size), tf.keras.layers.Dense(dim, input_dim=patch_size**2 * 3) ]) self.pos_emb = tf.keras.layers.Embedding(num_patches, dim) self.init_levels = tf.Variable(tf.random.normal([levels, dim])) # bottom-up and top-down unit self.bottom_up = GroupedFeedForward(dim=dim, groups=levels) self.top_down = GroupedFeedForward(dim=dim, groups=levels - 1) # consensus attention unit self.attention = ConsensusAttention( num_patches_side, attend_self=consensus_self, local_consensus_radius=local_consensus_radius)
def VisionTransformer(input_shape, n_classes, patch_size, patch_dim, n_encoder_layers, n_heads, ff_dim, dropout_rate=0.0): inputs = tf.keras.layers.Input(input_shape) x = Rearrange('b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size)(inputs) x = tf.keras.layers.Dense(patch_dim)(x) x = ConcatEmbedding(1, patch_dim, side="left", axis=1, initializer=tf.keras.initializers.RandomNormal(), name="add_cls_token")(x) x = LearnedEmbedding1D(x.shape[1], patch_dim, initializer=tf.keras.initializers.RandomNormal(), name="pos_embedding")(x) x = Encoder(embed_dim=patch_dim, num_heads=n_heads, ff_dim=ff_dim, num_layers=n_encoder_layers, dropout_rate=dropout_rate)(x) x = tf.keras.layers.Cropping1D((0, x.shape[1] - 1))(x) x = tf.keras.layers.Reshape([-1])(x) x = tf.keras.Sequential([ tf.keras.layers.Dense(ff_dim, activation=tfa.activations.gelu), tf.keras.layers.Dense(n_classes)], name="mlp_head")(x) model = tf.keras.models.Model(inputs, x) return model
def VisionTransformerOS(input_shape, patch_size, patch_dim, n_encoder_layers, n_heads, ff_dim, dropout_rate=0.0): inputs1 = tf.keras.layers.Input(input_shape, name="x1") inputs2 = tf.keras.layers.Input(input_shape, name="x2") x1 = Rearrange('b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size)(inputs1) x2 = Rearrange('b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size)(inputs2) x1 = tf.keras.layers.Dense(patch_dim)(x1) x2 = tf.keras.layers.Dense(patch_dim)(x2) # grid_size = (input_shape[0] // patch_size, input_shape[1] // patch_size) # x1 = PatchEmbedding2D(grid_size=grid_size, embedding_dim=patch_dim)(x1) # x2 = PatchEmbedding2D(grid_size=grid_size, embedding_dim=patch_dim)(x2) x = ConcatEmbedding(n_embeddings=1, embedding_dim=patch_dim, side="left", axis=1, initializer=tf.keras.initializers.RandomNormal(), name="add_cls_token")(x1) x = ConcatEmbedding(n_embeddings=1, embedding_dim=patch_dim, side="right", axis=1, initializer=tf.keras.initializers.RandomNormal(), name="add_sep_token")(x) x = tf.keras.layers.Concatenate(axis=1)([x, x2]) x = LearnedEmbedding1D(patch_dim, initializer=tf.keras.initializers.RandomNormal(), name="pos_embedding")(x) # x = PositionalEmbedding1D(patch_dim, # name="pos_embedding")(x) x = Encoder(embed_dim=patch_dim, num_heads=n_heads, ff_dim=ff_dim, num_layers=n_encoder_layers, dropout_rate=dropout_rate)(x) x = tf.keras.layers.Cropping1D((0, x.shape[1] - 1))(x) x = tf.keras.layers.Reshape([-1])(x) # MLP x = tf.keras.layers.Dense(ff_dim, activation=tfa.activations.gelu)(x) x = tf.keras.layers.Dense(1, activation="sigmoid")(x) model = tf.keras.models.Model([inputs1, inputs2], x) return model
def VisionTransformerOSv2(input_shape, patch_size, patch_dim, n_encoder_layers, n_decoder_layers, n_heads, ff_dim, dropout_rate=0.0): inputs1 = tf.keras.layers.Input(input_shape, name="x1") inputs2 = tf.keras.layers.Input(input_shape, name="x2") x_enc = Rearrange('b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size)(inputs1) x_dec = Rearrange('b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size)(inputs2) x_enc = tf.keras.layers.Dense(patch_dim)(x_enc) x_dec = tf.keras.layers.Dense(patch_dim)(x_dec) x_enc = PositionalEmbedding1D(patch_dim)(x_enc) x_dec = PositionalEmbedding1D(patch_dim)(x_dec) x_enc = Encoder(embed_dim=patch_dim, num_heads=n_heads, ff_dim=ff_dim, num_layers=n_encoder_layers, dropout_rate=dropout_rate)(x_enc) x_dec = ConcatEmbedding(n_embeddings=1, embedding_dim=patch_dim, side="left", axis=1, initializer=tf.keras.initializers.RandomNormal(), name="add_cls_token")(x_dec) x_dec = Decoder(embed_dim=patch_dim, num_heads=n_heads, ff_dim=ff_dim, num_layers=n_decoder_layers, dropout_rate=dropout_rate, norm=False, causal=False)([x_dec, x_enc]) x = tf.keras.layers.Lambda(lambda x: x[:, 0, :], name="cls_token_out")(x_dec) # MLP x = tf.keras.layers.Dense(ff_dim, activation=tfa.activations.gelu)(x) x = tf.keras.layers.Dense(1, activation="sigmoid")(x) model = tf.keras.models.Model([inputs1, inputs2], x) return model
def call(self, x, **kwargs): b, hh, ww, c, u, h = *x.get_shape().as_list(), self.u, self.heads q = self.to_q(x) k = self.to_k(x) v = self.to_v(x) q = self.norm_q(q) v = self.norm_v(v) q = Rearrange('b hh ww (h k) -> b h k (hh ww)', h=h)(q) k = Rearrange('b hh ww (u k) -> b u k (hh ww)', u=u)(k) v = Rearrange('b hh ww (u v) -> b u v (hh ww)', u=u)(v) k = nn.softmax(k) Lc = einsum('b u k m, b u v m -> b k v', k, v) Yc = einsum('b h k n, b k v -> b n h v', q, Lc) if self.local_contexts: v = Rearrange('b u v (hh ww) -> b v hh ww u', hh=hh, ww=ww)(v) Lp = self.pos_conv(v) Lp = Rearrange('b v h w k -> b v k (h w)')(Lp) Yp = einsum('b h k n, b v k n -> b n h v', q, Lp) else: rel_pos_emb = tf.gather_nd(self.rel_pos_emb, self.rel_pos) Lp = einsum('n m k u, b u v m -> b n k v', rel_pos_emb, v) Yp = einsum('b h k n, b n k v -> b n h v', q, Lp) Y = Yc + Yp out = Rearrange('b (hh ww) h v -> b hh ww (h v)', hh = hh, ww = ww)(Y) return out
def __init__(self, dim, groups, mult=4): super(GroupedFeedForward, self).__init__() total_dim = dim * groups self.net = tf.keras.Sequential( [ Rearrange('b n l d -> b (l d) n'), tf.keras.layers.Conv1D( total_dim, total_dim * mult, 1, groups=groups, activation='gelu'), tf.keras.layers.Conv1D( total_dim, total_dim * mult, 1, groups=groups), Rearrange( 'b (l d) n -> b n l d', l=groups)]) def call(self, inputs): return self.net(inputs)
def __init__( self, num_classes, num_blocks, patch_size, hidden_dim, tokens_mlp_dim, channels_mlp_dim, ): super().__init__() s = (patch_size, patch_size) self.make_tokens = layers.Conv2D(hidden_dim, s, s) self.rearrange = Rearrange("n h w c -> n (h w) c") self.mixers = [ NdMixerBlock([tokens_mlp_dim, channels_mlp_dim]) for _ in range(num_blocks) ] self.batchnorm = layers.BatchNormalization() self.clf = layers.Dense(num_classes, kernel_initializer="zeros")
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels=3): super().__init__() assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size' num_patches = (image_size // patch_size)**2 patch_dim = channels * patch_size**2 self.patch_size = patch_size self.dim = dim self.pos_embedding = self.add_weight( "position_embeddings", shape=[num_patches + 1, dim], initializer=tf.keras.initializers.RandomNormal(), dtype=tf.float32) self.patch_to_embedding = tf.keras.layers.Dense(dim) self.cls_token = self.add_weight( "cls_token", shape=[1, 1, dim], initializer=tf.keras.initializers.RandomNormal(), dtype=tf.float32) self.rearrange = Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=self.patch_size, p2=self.patch_size) self.transformer = Transformer(dim, depth, heads, mlp_dim) self.to_cls_token = tf.identity self.mlp_head = tf.keras.Sequential([ tf.keras.layers.Dense(mlp_dim, activation=get_activation('gelu')), tf.keras.layers.Dense(num_classes) ])
def __init__(self, image_size, patch_size, n_classes, batch_size, dimension, depth, heads, mlp_dimension, channels=3): super(ImageTransformer, self).__init__() assert image_size % patch_size == 0, 'invalid patch size for image size' num_patches = (image_size // patch_size)**2 self.patch_size = patch_size self.dimension = dimension self.batch_size = batch_size self.positional_embedding = self.add_weight( "position_embeddings", shape=[num_patches + 1, dimension], initializer=tf.keras.initializers.RandomNormal(), dtype=tf.float32) self.embedding_mlp = tf.keras.layers.Dense(dimension) self.classification_token = self.add_weight( "classification_token", shape=[1, 1, dimension], initializer=tf.keras.initializers.RandomNormal(), dtype=tf.float32) self.rearrange = Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=self.patch_size, p2=self.patch_size) self.transformer = TransformerEncoder(dimension, depth, heads, mlp_dimension) self.classification_identity = tf.identity self.mlp_1 = tf.keras.layers.Dense(mlp_dimension) self.gelu = GELU() self.output = tf.keras.layers.Dense(n_classes)
def calc_rel_pos(n): pos = tf.stack(meshgrid(tf.range(n), tf.range(n), indexing = 'ij')) pos = Rearrange('n i j -> (i j) n')(pos) # [n*n, 2] pos[n] = (i, j) rel_pos = pos[None, :] - pos[:, None] # [n*n, n*n, 2] rel_pos[n, m] = (rel_i, rel_j) rel_pos += n - 1 # shift value range from [-n+1, n-1] to [0, 2n-2] return rel_pos
def __init__(self, p1, p2): super(Rearrange3d, self).__init__() self.rearrange = Rearrange('b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1 = p1, p2 = p2)
def __init__(self, num_patches=20, embed_dim=20): super(RearrangeCh, self).__init__() self.rearrange = Rearrange('b c n w -> b n (c w)')