def __init__( self, embed_dim, fixed_embed_dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, rpe=False, drop_rate=0., attn_drop=0., proj_drop=0., drop_path=0., pre_norm=True, rpe_length=14, ): super().__init__() self.normalize_before = pre_norm self.drop_path = DropPath( drop_path) if drop_path > 0. else nn.Identity() self.dropout = drop_rate self.attn = RelativePositionAttention( embed_dim=embed_dim, fixed_embed_dim=fixed_embed_dim, num_heads=num_heads, attn_drop=attn_drop, proj_drop=proj_drop, rpe=rpe, qkv_bias=qkv_bias, qk_scale=qk_scale, rpe_length=rpe_length) self.attn_layer_norm = nn.LayerNorm(embed_dim) self.ffn_layer_norm = nn.LayerNorm(embed_dim) self.activation_fn = nn.GELU() self.fc1 = nn.Linear( cast(int, embed_dim), cast(int, nn.ValueChoice.to_int( embed_dim * mlp_ratio))) self.fc2 = nn.Linear( cast(int, nn.ValueChoice.to_int( embed_dim * mlp_ratio)), cast(int, embed_dim))
def __init__(self): super().__init__() self.m = nn.LayerNorm([10, 10])
def __init__( self, search_embed_dim: Tuple[int, ...] = (192, 216, 240), search_mlp_ratio: Tuple[float, ...] = (3.5, 4.0), search_num_heads: Tuple[int, ...] = (3, 4), search_depth: Tuple[int, ...] = (12, 13, 14), img_size: int = 224, patch_size: int = 16, in_chans: int = 3, num_classes: int = 1000, qkv_bias: bool = False, drop_rate: float = 0., attn_drop_rate: float = 0., drop_path_rate: float = 0., pre_norm: bool = True, global_pool: bool = False, abs_pos: bool = True, qk_scale: Optional[float] = None, rpe: bool = True, ): super().__init__() embed_dim = nn.ValueChoice(list(search_embed_dim), label="embed_dim") fixed_embed_dim = nn.ModelParameterChoice( list(search_embed_dim), label="embed_dim") depth = nn.ValueChoice(list(search_depth), label="depth") self.patch_embed = nn.Conv2d( in_chans, cast(int, embed_dim), kernel_size=patch_size, stride=patch_size) self.patches_num = int((img_size // patch_size) ** 2) self.global_pool = global_pool self.cls_token = nn.Parameter(torch.zeros(1, 1, cast(int, fixed_embed_dim))) trunc_normal_(self.cls_token, std=.02) dpr = [ x.item() for x in torch.linspace( 0, drop_path_rate, max(search_depth))] # stochastic depth decay rule self.abs_pos = abs_pos if self.abs_pos: self.pos_embed = nn.Parameter(torch.zeros( 1, self.patches_num + 1, cast(int, fixed_embed_dim))) trunc_normal_(self.pos_embed, std=.02) self.blocks = nn.Repeat(lambda index: nn.LayerChoice([ TransformerEncoderLayer(embed_dim=embed_dim, fixed_embed_dim=fixed_embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop_rate=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[index], rpe_length=img_size // patch_size, qk_scale=qk_scale, rpe=rpe, pre_norm=pre_norm,) for mlp_ratio, num_heads in itertools.product(search_mlp_ratio, search_num_heads) ], label=f'layer{index}'), depth) self.pre_norm = pre_norm if self.pre_norm: self.norm = nn.LayerNorm(cast(int, embed_dim)) self.head = nn.Linear( cast(int, embed_dim), num_classes) if num_classes > 0 else nn.Identity()