Example #1
0
    def __init__(self, model_config: AttrDict, model_name: str):
        super().__init__()
        self.model_config = model_config

        assert model_config.INPUT_TYPE in ["rgb",
                                           "bgr"], "Input type not supported"
        trunk_config = copy.deepcopy(model_config.TRUNK.VISION_TRANSFORMERS)

        logging.info("Building model: Vision Transformer from yaml config")
        trunk_config = AttrDict(
            {k.lower(): v
             for k, v in trunk_config.items()})

        self.model = ClassyVisionTransformer(
            image_size=trunk_config.image_size,
            patch_size=trunk_config.patch_size,
            num_layers=trunk_config.num_layers,
            num_heads=trunk_config.num_heads,
            hidden_dim=trunk_config.hidden_dim,
            mlp_dim=trunk_config.mlp_dim,
            dropout_rate=trunk_config.dropout_rate,
            attention_dropout_rate=trunk_config.attention_dropout_rate,
            classifier=trunk_config.classifier,
        )
Example #2
0
    def __init__(self, model_config: AttrDict, model_name: str):
        super().__init__()

        assert model_config.INPUT_TYPE in ["rgb", "bgr"], "Input type not supported"
        trunk_config = copy.deepcopy(model_config.TRUNK.VISION_TRANSFORMERS)

        logging.info("Building model: Vision Transformer from yaml config")
        # Hacky workaround
        trunk_config = AttrDict({k.lower(): v for k, v in trunk_config.items()})

        img_size = trunk_config.image_size
        patch_size = trunk_config.patch_size
        in_chans = 3
        embed_dim = trunk_config.hidden_dim
        depth = trunk_config.num_layers
        num_heads = trunk_config.num_heads
        mlp_ratio = 4.0
        qkv_bias = trunk_config.qkv_bias
        qk_scale = trunk_config.qk_scale
        drop_rate = trunk_config.dropout_rate
        attn_drop_rate = trunk_config.attention_dropout_rate
        drop_path_rate = trunk_config.drop_path_rate
        hybrid_backbone_string = None
        # TODO Implement hybrid backbones
        if "HYBRID" in trunk_config.keys():
            hybrid_backbone_string = trunk_config.HYBRID
        norm_layer = partial(nn.LayerNorm, eps=1e-6)

        self.num_features = (
            self.embed_dim
        ) = embed_dim  # num_features for consistency with other models

        # TODO : Enable Hybrid Backbones
        if hybrid_backbone_string:
            self.patch_embed = globals()[hybrid_backbone_string](
                out_dim=embed_dim, img_size=img_size
            )
        # if hybrid_backbone is not None:
        #     self.patch_embed = HybridEmbed(
        #         hybrid_backbone,
        #         img_size=img_size,
        #         in_chans=in_chans,
        #         embed_dim=embed_dim,
        #     )
        else:
            self.patch_embed = PatchEmbed(
                img_size=img_size,
                patch_size=patch_size,
                in_chans=in_chans,
                embed_dim=embed_dim,
            )
        num_patches = self.patch_embed.num_patches

        self.class_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embedding = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_rate)

        dpr = [
            x.item() for x in torch.linspace(0, drop_path_rate, depth)
        ]  # stochastic depth decay rule
        self.blocks = nn.ModuleList(
            [
                Block(
                    dim=embed_dim,
                    num_heads=num_heads,
                    mlp_ratio=mlp_ratio,
                    qkv_bias=qkv_bias,
                    qk_scale=qk_scale,
                    drop=drop_rate,
                    attn_drop=attn_drop_rate,
                    drop_path=dpr[i],
                    norm_layer=norm_layer,
                )
                for i in range(depth)
            ]
        )
        self.norm = norm_layer(embed_dim)

        # NOTE as per official impl, we could have a pre-logits
        # representation dense layer + tanh here
        # self.repr = nn.Linear(embed_dim, representation_size)
        # self.repr_act = nn.Tanh()

        trunc_normal_(self.pos_embedding, std=0.02)
        trunc_normal_(self.class_token, std=0.02)
        self.apply(self._init_weights)
Example #3
0
    def __init__(self, model_config: AttrDict, model_name: str):
        super().__init__()

        assert model_config.INPUT_TYPE in ["rgb",
                                           "bgr"], "Input type not supported"
        trunk_config = copy.deepcopy(model_config.TRUNK.XCIT)

        logging.info("Building model: XCiT from yaml config")
        # Hacky workaround
        trunk_config = AttrDict(
            {k.lower(): v
             for k, v in trunk_config.items()})
        img_size = trunk_config.image_size
        patch_size = trunk_config.patch_size
        embed_dim = trunk_config.hidden_dim
        depth = trunk_config.num_layers
        num_heads = trunk_config.num_heads
        mlp_ratio = trunk_config.mlp_ratio
        qkv_bias = trunk_config.qkv_bias
        qk_scale = trunk_config.qk_scale
        drop_rate = trunk_config.dropout_rate
        attn_drop_rate = trunk_config.attention_dropout_rate
        drop_path_rate = trunk_config.drop_path_rate
        eta = trunk_config.eta
        tokens_norm = trunk_config.tokens_norm
        norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.num_features = (
            self.embed_dim
        ) = embed_dim  # num_features for consistency with other models
        self.patch_embed = ConvPatchEmbed(img_size=img_size,
                                          embed_dim=embed_dim,
                                          patch_size=patch_size)
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_rate)

        dpr = [drop_path_rate for i in range(depth)]
        self.blocks = nn.ModuleList([
            XCABlock(
                dim=embed_dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                qk_scale=qk_scale,
                drop=drop_rate,
                attn_drop=attn_drop_rate,
                drop_path=dpr[i],
                norm_layer=norm_layer,
                num_tokens=num_patches,
                eta=eta,
            ) for i in range(depth)
        ])

        cls_attn_layers = 2
        self.cls_attn_blocks = nn.ModuleList([
            ClassAttentionBlock(
                dim=embed_dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                qk_scale=qk_scale,
                drop=drop_rate,
                attn_drop=attn_drop_rate,
                norm_layer=norm_layer,
                eta=eta,
                tokens_norm=tokens_norm,
            ) for i in range(cls_attn_layers)
        ])
        self.norm = norm_layer(embed_dim)

        self.pos_embeder = PositionalEncodingFourier(dim=embed_dim)
        self.use_pos = True

        # Classifier head
        trunc_normal_(self.cls_token, std=0.02)
        self.apply(self._init_weights)