def __init__(self, model_config: AttrDict, model_name: str): super().__init__() assert model_config.INPUT_TYPE in ["rgb", "bgr"], "Input type not supported" trunk_config = copy.deepcopy( model_config.TRUNK.TRUNK_PARAMS.VISION_TRANSFORMERS) logging.info("Building model: Vision Transformer from yaml config") # Hacky workaround trunk_config = AttrDict( {k.lower(): v for k, v in trunk_config.items()}) img_size = trunk_config.image_size patch_size = trunk_config.patch_size in_chans = 3 embed_dim = trunk_config.hidden_dim depth = trunk_config.num_layers num_heads = trunk_config.num_heads mlp_ratio = 4.0 qkv_bias = trunk_config.qkv_bias qk_scale = trunk_config.qk_scale drop_rate = trunk_config.dropout_rate attn_drop_rate = trunk_config.attention_dropout_rate drop_path_rate = trunk_config.drop_path_rate hybrid_backbone_string = None # TODO Implement hybrid backbones if "HYBRID" in trunk_config.keys(): hybrid_backbone_string = trunk_config.HYBRID norm_layer = nn.LayerNorm self.num_features = ( self.embed_dim ) = embed_dim # num_features for consistency with other models # TODO : Enable Hybrid Backbones if hybrid_backbone_string: self.patch_embed = globals()[hybrid_backbone_string]( out_dim=embed_dim, img_size=img_size) # if hybrid_backbone is not None: # self.patch_embed = HybridEmbed( # hybrid_backbone, # img_size=img_size, # in_chans=in_chans, # embed_dim=embed_dim, # ) else: self.patch_embed = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, ) num_patches = self.patch_embed.num_patches self.class_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.pos_embedding = nn.Parameter( torch.zeros(1, num_patches + 1, embed_dim)) self.pos_drop = nn.Dropout(p=drop_rate) dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth) ] # stochastic depth decay rule self.blocks = nn.ModuleList([ Block( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, ) for i in range(depth) ]) self.norm = norm_layer(embed_dim) # NOTE as per official impl, we could have a pre-logits # representation dense layer + tanh here # self.repr = nn.Linear(embed_dim, representation_size) # self.repr_act = nn.Tanh() trunc_normal_(self.pos_embedding, std=0.02) trunc_normal_(self.class_token, std=0.02) self.apply(self._init_weights)
def __init__(self, model_config, model_name): super().__init__() trunk_config = copy.deepcopy(model_config.TRUNK.TRUNK_PARAMS.CONVIT) trunk_config.update(model_config.TRUNK.TRUNK_PARAMS.VISION_TRANSFORMERS) logging.info("Building model: ConViT from yaml config") # Hacky workaround trunk_config = AttrDict({k.lower(): v for k, v in trunk_config.items()}) image_size = trunk_config.image_size patch_size = trunk_config.patch_size classifier = trunk_config.classifier assert image_size % patch_size == 0, "Input shape indivisible by patch size" assert classifier in ["token", "gap"], "Unexpected classifier mode" n_gpsa_layers = trunk_config.n_gpsa_layers class_token_in_local_layers = trunk_config.class_token_in_local_layers mlp_dim = trunk_config.mlp_dim embed_dim = trunk_config.hidden_dim locality_dim = trunk_config.locality_dim attention_dropout_rate = trunk_config.attention_dropout_rate dropout_rate = trunk_config.dropout_rate drop_path_rate = trunk_config.drop_path_rate num_layers = trunk_config.num_layers locality_strength = trunk_config.locality_strength num_heads = trunk_config.num_heads qkv_bias = trunk_config.qkv_bias qk_scale = trunk_config.qk_scale use_local_init = trunk_config.use_local_init hybrid_backbone = None if "hybrid" in trunk_config.keys(): hybrid_backbone = trunk_config.hybrid in_chans = 3 # TODO: Make this configurable norm_layer = nn.LayerNorm self.classifier = classifier self.n_gpsa_layers = n_gpsa_layers self.class_token_in_local_layers = class_token_in_local_layers # For consistency with other models self.num_features = self.embed_dim = self.hidden_dim = embed_dim self.locality_dim = locality_dim # Hybrid backbones not tested if hybrid_backbone is not None: self.patch_embed = HybridEmbed( hybrid_backbone, img_size=image_size, in_chans=in_chans, embed_dim=embed_dim, ) else: self.patch_embed = PatchEmbed( img_size=image_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, ) seq_length = (image_size // patch_size) ** 2 self.seq_length = seq_length self.class_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.pos_embedding = nn.Parameter(torch.zeros(1, seq_length, embed_dim)) self.pos_drop = nn.Dropout(p=dropout_rate) if class_token_in_local_layers: seq_length += 1 # stochastic depth decay rule dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)] layers = [] for i in range(num_layers): if i < self.n_gpsa_layers: if locality_strength > 0: layer_locality_strength = locality_strength else: layer_locality_strength = 1 / (i + 1) layers.append( AttentionBlock( attention_module=GPSA, embed_dim=embed_dim, num_heads=num_heads, mlp_dim=mlp_dim, qkv_bias=qkv_bias, qk_scale=qk_scale, dropout_rate=dropout_rate, attention_dropout_rate=attention_dropout_rate, drop_path_rate=dpr[i], norm_layer=norm_layer, locality_strength=layer_locality_strength, locality_dim=self.locality_dim, use_local_init=use_local_init, ) ) else: layers.append( AttentionBlock( attention_module=SelfAttention, embed_dim=embed_dim, num_heads=num_heads, mlp_dim=mlp_dim, qkv_bias=qkv_bias, qk_scale=qk_scale, dropout_rate=dropout_rate, attention_dropout_rate=attention_dropout_rate, drop_path_rate=dpr[i], norm_layer=norm_layer, ) ) self.blocks = nn.ModuleList(layers) self.norm = norm_layer(embed_dim) trunc_normal_(self.pos_embedding, std=0.02) trunc_normal_(self.class_token, std=0.02) self.apply(self._init_weights)