Beispiel #1
0
    def __init__(self, model_config: AttrDict, model_name: str):
        super().__init__()

        assert model_config.INPUT_TYPE in ["rgb",
                                           "bgr"], "Input type not supported"
        trunk_config = copy.deepcopy(
            model_config.TRUNK.TRUNK_PARAMS.VISION_TRANSFORMERS)

        logging.info("Building model: Vision Transformer from yaml config")
        # Hacky workaround
        trunk_config = AttrDict(
            {k.lower(): v
             for k, v in trunk_config.items()})

        img_size = trunk_config.image_size
        patch_size = trunk_config.patch_size
        in_chans = 3
        embed_dim = trunk_config.hidden_dim
        depth = trunk_config.num_layers
        num_heads = trunk_config.num_heads
        mlp_ratio = 4.0
        qkv_bias = trunk_config.qkv_bias
        qk_scale = trunk_config.qk_scale
        drop_rate = trunk_config.dropout_rate
        attn_drop_rate = trunk_config.attention_dropout_rate
        drop_path_rate = trunk_config.drop_path_rate
        hybrid_backbone_string = None
        # TODO Implement hybrid backbones
        if "HYBRID" in trunk_config.keys():
            hybrid_backbone_string = trunk_config.HYBRID
        norm_layer = nn.LayerNorm

        self.num_features = (
            self.embed_dim
        ) = embed_dim  # num_features for consistency with other models

        # TODO : Enable Hybrid Backbones
        if hybrid_backbone_string:
            self.patch_embed = globals()[hybrid_backbone_string](
                out_dim=embed_dim, img_size=img_size)
        # if hybrid_backbone is not None:
        #     self.patch_embed = HybridEmbed(
        #         hybrid_backbone,
        #         img_size=img_size,
        #         in_chans=in_chans,
        #         embed_dim=embed_dim,
        #     )
        else:
            self.patch_embed = PatchEmbed(
                img_size=img_size,
                patch_size=patch_size,
                in_chans=in_chans,
                embed_dim=embed_dim,
            )
        num_patches = self.patch_embed.num_patches

        self.class_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embedding = nn.Parameter(
            torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_rate)

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)
               ]  # stochastic depth decay rule
        self.blocks = nn.ModuleList([
            Block(
                dim=embed_dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                qk_scale=qk_scale,
                drop=drop_rate,
                attn_drop=attn_drop_rate,
                drop_path=dpr[i],
                norm_layer=norm_layer,
            ) for i in range(depth)
        ])
        self.norm = norm_layer(embed_dim)

        # NOTE as per official impl, we could have a pre-logits
        # representation dense layer + tanh here
        # self.repr = nn.Linear(embed_dim, representation_size)
        # self.repr_act = nn.Tanh()

        trunc_normal_(self.pos_embedding, std=0.02)
        trunc_normal_(self.class_token, std=0.02)
        self.apply(self._init_weights)
Beispiel #2
0
    def __init__(self, model_config, model_name):
        super().__init__()
        trunk_config = copy.deepcopy(model_config.TRUNK.TRUNK_PARAMS.CONVIT)
        trunk_config.update(model_config.TRUNK.TRUNK_PARAMS.VISION_TRANSFORMERS)

        logging.info("Building model: ConViT from yaml config")
        # Hacky workaround
        trunk_config = AttrDict({k.lower(): v for k, v in trunk_config.items()})

        image_size = trunk_config.image_size
        patch_size = trunk_config.patch_size
        classifier = trunk_config.classifier
        assert image_size % patch_size == 0, "Input shape indivisible by patch size"
        assert classifier in ["token", "gap"], "Unexpected classifier mode"
        n_gpsa_layers = trunk_config.n_gpsa_layers
        class_token_in_local_layers = trunk_config.class_token_in_local_layers
        mlp_dim = trunk_config.mlp_dim
        embed_dim = trunk_config.hidden_dim
        locality_dim = trunk_config.locality_dim
        attention_dropout_rate = trunk_config.attention_dropout_rate
        dropout_rate = trunk_config.dropout_rate
        drop_path_rate = trunk_config.drop_path_rate
        num_layers = trunk_config.num_layers
        locality_strength = trunk_config.locality_strength
        num_heads = trunk_config.num_heads
        qkv_bias = trunk_config.qkv_bias
        qk_scale = trunk_config.qk_scale
        use_local_init = trunk_config.use_local_init

        hybrid_backbone = None
        if "hybrid" in trunk_config.keys():
            hybrid_backbone = trunk_config.hybrid

        in_chans = 3
        # TODO: Make this configurable
        norm_layer = nn.LayerNorm

        self.classifier = classifier
        self.n_gpsa_layers = n_gpsa_layers
        self.class_token_in_local_layers = class_token_in_local_layers
        # For consistency with other models
        self.num_features = self.embed_dim = self.hidden_dim = embed_dim
        self.locality_dim = locality_dim

        # Hybrid backbones not tested
        if hybrid_backbone is not None:
            self.patch_embed = HybridEmbed(
                hybrid_backbone,
                img_size=image_size,
                in_chans=in_chans,
                embed_dim=embed_dim,
            )
        else:
            self.patch_embed = PatchEmbed(
                img_size=image_size,
                patch_size=patch_size,
                in_chans=in_chans,
                embed_dim=embed_dim,
            )

        seq_length = (image_size // patch_size) ** 2
        self.seq_length = seq_length

        self.class_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embedding = nn.Parameter(torch.zeros(1, seq_length, embed_dim))
        self.pos_drop = nn.Dropout(p=dropout_rate)

        if class_token_in_local_layers:
            seq_length += 1

        # stochastic depth decay rule
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)]

        layers = []
        for i in range(num_layers):
            if i < self.n_gpsa_layers:
                if locality_strength > 0:
                    layer_locality_strength = locality_strength
                else:
                    layer_locality_strength = 1 / (i + 1)
                layers.append(
                    AttentionBlock(
                        attention_module=GPSA,
                        embed_dim=embed_dim,
                        num_heads=num_heads,
                        mlp_dim=mlp_dim,
                        qkv_bias=qkv_bias,
                        qk_scale=qk_scale,
                        dropout_rate=dropout_rate,
                        attention_dropout_rate=attention_dropout_rate,
                        drop_path_rate=dpr[i],
                        norm_layer=norm_layer,
                        locality_strength=layer_locality_strength,
                        locality_dim=self.locality_dim,
                        use_local_init=use_local_init,
                    )
                )
            else:
                layers.append(
                    AttentionBlock(
                        attention_module=SelfAttention,
                        embed_dim=embed_dim,
                        num_heads=num_heads,
                        mlp_dim=mlp_dim,
                        qkv_bias=qkv_bias,
                        qk_scale=qk_scale,
                        dropout_rate=dropout_rate,
                        attention_dropout_rate=attention_dropout_rate,
                        drop_path_rate=dpr[i],
                        norm_layer=norm_layer,
                    )
                )
        self.blocks = nn.ModuleList(layers)
        self.norm = norm_layer(embed_dim)

        trunc_normal_(self.pos_embedding, std=0.02)
        trunc_normal_(self.class_token, std=0.02)
        self.apply(self._init_weights)