Пример #1
0
 def init_weights(self):
     if hasattr(self.layers, "pre_logits"):
         lecun_normal_init(self.layers.pre_logits.weight,
                           fan_in=self.layers.pre_logits.in_features)
         nn.init.zeros_(self.layers.pre_logits.bias)
     trunc_normal_(self.layers.head.weight, std=0.02)
     nn.init.zeros_(self.layers.head.bias)
Пример #2
0
 def _init_weights(self, m):
     if isinstance(m, nn.Linear):
         trunc_normal_(m.weight, std=0.02)
         if isinstance(m, nn.Linear) and m.bias is not None:
             nn.init.constant_(m.bias, 0)
     elif isinstance(m, nn.LayerNorm):
         nn.init.constant_(m.bias, 0)
         nn.init.constant_(m.weight, 1.0)
Пример #3
0
    def __init__(
        self,
        dim,
        window_size,
        num_heads,
        qkv_bias=True,
        qk_scale=None,
        attn_drop=0.0,
        proj_drop=0.0,
    ):

        super().__init__()
        self.dim = dim
        self.window_size = window_size  # Wh, Ww
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        # define a parameter table of relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
        )  # 2*Wh-1 * 2*Ww-1, nH

        # get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        relative_coords = (
            coords_flatten[:, :, None] - coords_flatten[:, None, :]
        )  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(
            1, 2, 0
        ).contiguous()  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
        self.register_buffer("relative_position_index", relative_position_index)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        trunc_normal_(self.relative_position_bias_table, std=0.02)
        self.softmax = nn.Softmax(dim=-1)
Пример #4
0
    def __init__(self, model_config: AttrDict, model_name: str):
        super().__init__()

        assert model_config.INPUT_TYPE in ["rgb", "bgr"], "Input type not supported"
        trunk_config = model_config.TRUNK.XCIT

        logging.info("Building model: XCiT from yaml config")
        self.model_config = model_config
        self.img_size = trunk_config.IMAGE_SIZE
        self.patch_size = trunk_config.PATCH_SIZE
        self.embed_dim = trunk_config.HIDDEN_DIM
        self.depth = trunk_config.NUM_LAYERS
        self.num_heads = trunk_config.NUM_HEADS
        self.mlp_ratio = trunk_config.MLP_RATIO
        self.qkv_bias = trunk_config.QKV_BIAS
        self.qk_scale = trunk_config.QK_SCALE
        self.drop_rate = trunk_config.DROPOUT_RATE
        self.attn_drop_rate = trunk_config.ATTENTION_DROPOUT_RATE
        self.drop_path_rate = trunk_config.DROP_PATH_RATE
        self.eta = trunk_config.ETA
        self.tokens_norm = trunk_config.TOKENS_NORM

        # num_features for consistency with other models
        self.num_features = self.embed_dim

        self.patch_embed = ConvPatchEmbed(
            img_size=self.img_size, embed_dim=self.embed_dim, patch_size=self.patch_size
        )
        self.num_patches = self.patch_embed.num_patches

        norm_layer = partial(nn.LayerNorm, eps=1e-6)

        self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
        self.pos_drop = nn.Dropout(p=self.drop_rate)

        self.blocks = self._create_xca_blocks(norm_layer)

        self.cls_attn_blocks = self._create_cls_attn_blocks(norm_layer)
        self.norm = norm_layer(self.embed_dim)

        self.pos_embeder = PositionalEncodingFourier(dim=self.embed_dim)
        self.use_pos = True

        # Initialize weights
        trunc_normal_(self.cls_token, std=0.02)
        self.apply(self._init_weights)
Пример #5
0
    def __init__(self, model_config: AttrDict, model_name: str):
        super().__init__()

        assert model_config.INPUT_TYPE in ["rgb", "bgr"], "Input type not supported"
        trunk_config = copy.deepcopy(model_config.TRUNK.VISION_TRANSFORMERS)

        logging.info("Building model: Vision Transformer from yaml config")
        # Hacky workaround
        trunk_config = AttrDict({k.lower(): v for k, v in trunk_config.items()})

        img_size = trunk_config.image_size
        patch_size = trunk_config.patch_size
        in_chans = 3
        embed_dim = trunk_config.hidden_dim
        depth = trunk_config.num_layers
        num_heads = trunk_config.num_heads
        mlp_ratio = 4.0
        qkv_bias = trunk_config.qkv_bias
        qk_scale = trunk_config.qk_scale
        drop_rate = trunk_config.dropout_rate
        attn_drop_rate = trunk_config.attention_dropout_rate
        drop_path_rate = trunk_config.drop_path_rate
        hybrid_backbone_string = None
        # TODO Implement hybrid backbones
        if "HYBRID" in trunk_config.keys():
            hybrid_backbone_string = trunk_config.HYBRID
        norm_layer = partial(nn.LayerNorm, eps=1e-6)

        self.num_features = (
            self.embed_dim
        ) = embed_dim  # num_features for consistency with other models

        # TODO : Enable Hybrid Backbones
        if hybrid_backbone_string:
            self.patch_embed = globals()[hybrid_backbone_string](
                out_dim=embed_dim, img_size=img_size
            )
        # if hybrid_backbone is not None:
        #     self.patch_embed = HybridEmbed(
        #         hybrid_backbone,
        #         img_size=img_size,
        #         in_chans=in_chans,
        #         embed_dim=embed_dim,
        #     )
        else:
            self.patch_embed = PatchEmbed(
                img_size=img_size,
                patch_size=patch_size,
                in_chans=in_chans,
                embed_dim=embed_dim,
            )
        num_patches = self.patch_embed.num_patches

        self.class_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embedding = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_rate)

        dpr = [
            x.item() for x in torch.linspace(0, drop_path_rate, depth)
        ]  # stochastic depth decay rule
        self.blocks = nn.ModuleList(
            [
                Block(
                    dim=embed_dim,
                    num_heads=num_heads,
                    mlp_ratio=mlp_ratio,
                    qkv_bias=qkv_bias,
                    qk_scale=qk_scale,
                    drop=drop_rate,
                    attn_drop=attn_drop_rate,
                    drop_path=dpr[i],
                    norm_layer=norm_layer,
                )
                for i in range(depth)
            ]
        )
        self.norm = norm_layer(embed_dim)

        # NOTE as per official impl, we could have a pre-logits
        # representation dense layer + tanh here
        # self.repr = nn.Linear(embed_dim, representation_size)
        # self.repr_act = nn.Tanh()

        trunc_normal_(self.pos_embedding, std=0.02)
        trunc_normal_(self.class_token, std=0.02)
        self.apply(self._init_weights)
Пример #6
0
 def _init_weights(self, m):
     if isinstance(m, (nn.Conv2d, nn.Linear)):
         trunc_normal_(m.weight, std=0.02)
         nn.init.constant_(m.bias, 0)
Пример #7
0
    def __init__(self, model_config: AttrDict, model_name: str):
        super().__init__()

        assert model_config.INPUT_TYPE in ["rgb",
                                           "bgr"], "Input type not supported"
        trunk_config = copy.deepcopy(model_config.TRUNK.XCIT)

        logging.info("Building model: XCiT from yaml config")
        # Hacky workaround
        trunk_config = AttrDict(
            {k.lower(): v
             for k, v in trunk_config.items()})
        img_size = trunk_config.image_size
        patch_size = trunk_config.patch_size
        embed_dim = trunk_config.hidden_dim
        depth = trunk_config.num_layers
        num_heads = trunk_config.num_heads
        mlp_ratio = trunk_config.mlp_ratio
        qkv_bias = trunk_config.qkv_bias
        qk_scale = trunk_config.qk_scale
        drop_rate = trunk_config.dropout_rate
        attn_drop_rate = trunk_config.attention_dropout_rate
        drop_path_rate = trunk_config.drop_path_rate
        eta = trunk_config.eta
        tokens_norm = trunk_config.tokens_norm
        norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.num_features = (
            self.embed_dim
        ) = embed_dim  # num_features for consistency with other models
        self.patch_embed = ConvPatchEmbed(img_size=img_size,
                                          embed_dim=embed_dim,
                                          patch_size=patch_size)
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_rate)

        dpr = [drop_path_rate for i in range(depth)]
        self.blocks = nn.ModuleList([
            XCABlock(
                dim=embed_dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                qk_scale=qk_scale,
                drop=drop_rate,
                attn_drop=attn_drop_rate,
                drop_path=dpr[i],
                norm_layer=norm_layer,
                num_tokens=num_patches,
                eta=eta,
            ) for i in range(depth)
        ])

        cls_attn_layers = 2
        self.cls_attn_blocks = nn.ModuleList([
            ClassAttentionBlock(
                dim=embed_dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                qk_scale=qk_scale,
                drop=drop_rate,
                attn_drop=attn_drop_rate,
                norm_layer=norm_layer,
                eta=eta,
                tokens_norm=tokens_norm,
            ) for i in range(cls_attn_layers)
        ])
        self.norm = norm_layer(embed_dim)

        self.pos_embeder = PositionalEncodingFourier(dim=embed_dim)
        self.use_pos = True

        # Classifier head
        trunc_normal_(self.cls_token, std=0.02)
        self.apply(self._init_weights)
Пример #8
0
    def __init__(
        self,
        img_size=224,
        patch_size=4,
        in_chans=3,
        embed_dim=96,
        depths: Optional[List[int]] = None,
        num_heads: Optional[List[int]] = None,
        window_size=7,
        mlp_ratio=4.0,
        qkv_bias=True,
        qk_scale=None,
        drop_rate=0.0,
        attn_drop_rate=0.0,
        drop_path_rate=0.1,
        norm_layer=nn.LayerNorm,
        ape=False,
        patch_norm=True,
        use_checkpoint=False,
        **kwargs,
    ):
        super().__init__()
        depths = depths or [2, 2, 6, 2]
        num_heads = num_heads or [3, 6, 12, 24]

        self.num_layers = len(depths)
        self.embed_dim = embed_dim
        self.ape = ape
        self.patch_norm = patch_norm
        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
        self.mlp_ratio = mlp_ratio

        # split image into non-overlapping patches
        self.patch_embed = PatchEmbed(
            img_size=img_size,
            patch_size=patch_size,
            in_chans=in_chans,
            embed_dim=embed_dim,
            norm_layer=norm_layer if self.patch_norm else None,
        )
        num_patches = self.patch_embed.num_patches
        patches_resolution = self.patch_embed.patches_resolution
        self.patches_resolution = patches_resolution

        # absolute position embedding
        if self.ape:
            self.absolute_pos_embed = nn.Parameter(
                torch.zeros(1, num_patches, embed_dim)
            )
            trunc_normal_(self.absolute_pos_embed, std=0.02)

        self.pos_drop = nn.Dropout(p=drop_rate)

        # stochastic depth
        dpr = [
            x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
        ]  # stochastic depth decay rule

        # build layers
        self.layers = nn.ModuleList()
        for i_layer in range(self.num_layers):
            layer = BasicLayer(
                dim=int(embed_dim * 2 ** i_layer),
                input_resolution=(
                    patches_resolution[0] // (2 ** i_layer),
                    patches_resolution[1] // (2 ** i_layer),
                ),
                depth=depths[i_layer],
                num_heads=num_heads[i_layer],
                window_size=window_size,
                mlp_ratio=self.mlp_ratio,
                qkv_bias=qkv_bias,
                qk_scale=qk_scale,
                drop=drop_rate,
                attn_drop=attn_drop_rate,
                drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
                norm_layer=norm_layer,
                downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
                use_checkpoint=use_checkpoint,
            )
            self.layers.append(layer)

        self.norm = norm_layer(self.num_features)
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.apply(self._init_weights)
Пример #9
0
    def __init__(self, model_config, model_name):
        super().__init__()
        trunk_config = copy.deepcopy(model_config.TRUNK.TRUNK_PARAMS.CONVIT)
        trunk_config.update(model_config.TRUNK.TRUNK_PARAMS.VISION_TRANSFORMERS)

        logging.info("Building model: ConViT from yaml config")
        # Hacky workaround
        trunk_config = AttrDict({k.lower(): v for k, v in trunk_config.items()})

        image_size = trunk_config.image_size
        patch_size = trunk_config.patch_size
        classifier = trunk_config.classifier
        assert image_size % patch_size == 0, "Input shape indivisible by patch size"
        assert classifier in ["token", "gap"], "Unexpected classifier mode"
        n_gpsa_layers = trunk_config.n_gpsa_layers
        class_token_in_local_layers = trunk_config.class_token_in_local_layers
        mlp_dim = trunk_config.mlp_dim
        embed_dim = trunk_config.hidden_dim
        locality_dim = trunk_config.locality_dim
        attention_dropout_rate = trunk_config.attention_dropout_rate
        dropout_rate = trunk_config.dropout_rate
        drop_path_rate = trunk_config.drop_path_rate
        num_layers = trunk_config.num_layers
        locality_strength = trunk_config.locality_strength
        num_heads = trunk_config.num_heads
        qkv_bias = trunk_config.qkv_bias
        qk_scale = trunk_config.qk_scale
        use_local_init = trunk_config.use_local_init

        hybrid_backbone = None
        if "hybrid" in trunk_config.keys():
            hybrid_backbone = trunk_config.hybrid

        in_chans = 3
        # TODO: Make this configurable
        norm_layer = nn.LayerNorm

        self.classifier = classifier
        self.n_gpsa_layers = n_gpsa_layers
        self.class_token_in_local_layers = class_token_in_local_layers
        # For consistency with other models
        self.num_features = self.embed_dim = self.hidden_dim = embed_dim
        self.locality_dim = locality_dim

        # Hybrid backbones not tested
        if hybrid_backbone is not None:
            self.patch_embed = HybridEmbed(
                hybrid_backbone,
                img_size=image_size,
                in_chans=in_chans,
                embed_dim=embed_dim,
            )
        else:
            self.patch_embed = PatchEmbed(
                img_size=image_size,
                patch_size=patch_size,
                in_chans=in_chans,
                embed_dim=embed_dim,
            )

        seq_length = (image_size // patch_size) ** 2
        self.seq_length = seq_length

        self.class_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embedding = nn.Parameter(torch.zeros(1, seq_length, embed_dim))
        self.pos_drop = nn.Dropout(p=dropout_rate)

        if class_token_in_local_layers:
            seq_length += 1

        # stochastic depth decay rule
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)]

        layers = []
        for i in range(num_layers):
            if i < self.n_gpsa_layers:
                if locality_strength > 0:
                    layer_locality_strength = locality_strength
                else:
                    layer_locality_strength = 1 / (i + 1)
                layers.append(
                    AttentionBlock(
                        attention_module=GPSA,
                        embed_dim=embed_dim,
                        num_heads=num_heads,
                        mlp_dim=mlp_dim,
                        qkv_bias=qkv_bias,
                        qk_scale=qk_scale,
                        dropout_rate=dropout_rate,
                        attention_dropout_rate=attention_dropout_rate,
                        drop_path_rate=dpr[i],
                        norm_layer=norm_layer,
                        locality_strength=layer_locality_strength,
                        locality_dim=self.locality_dim,
                        use_local_init=use_local_init,
                    )
                )
            else:
                layers.append(
                    AttentionBlock(
                        attention_module=SelfAttention,
                        embed_dim=embed_dim,
                        num_heads=num_heads,
                        mlp_dim=mlp_dim,
                        qkv_bias=qkv_bias,
                        qk_scale=qk_scale,
                        dropout_rate=dropout_rate,
                        attention_dropout_rate=attention_dropout_rate,
                        drop_path_rate=dpr[i],
                        norm_layer=norm_layer,
                    )
                )
        self.blocks = nn.ModuleList(layers)
        self.norm = norm_layer(embed_dim)

        trunc_normal_(self.pos_embedding, std=0.02)
        trunc_normal_(self.class_token, std=0.02)
        self.apply(self._init_weights)
Пример #10
0
    def __init__(self, model_config: AttrDict, model_name: str):
        super().__init__()

        assert model_config.INPUT_TYPE in ["rgb",
                                           "bgr"], "Input type not supported"
        logging.info("Building model: Vision Transformer from yaml config")

        self.model_config = model_config
        self.trunk_config = model_config.TRUNK.VISION_TRANSFORMERS
        self.img_size = self.trunk_config.IMAGE_SIZE
        self.patch_size = self.trunk_config.PATCH_SIZE
        self.in_chans = 3
        self.embed_dim = self.trunk_config.HIDDEN_DIM
        self.depth = self.trunk_config.NUM_LAYERS
        self.num_heads = self.trunk_config.NUM_HEADS
        self.mlp_ratio = 4.0
        self.qkv_bias = self.trunk_config.QKV_BIAS
        self.qk_scale = self.trunk_config.QK_SCALE
        self.drop_rate = self.trunk_config.DROPOUT_RATE
        self.attn_drop_rate = self.trunk_config.ATTENTION_DROPOUT_RATE
        self.drop_path_rate = self.trunk_config.DROP_PATH_RATE

        # TODO Implement hybrid backbones
        hybrid_backbone_string = None
        if "HYBRID" in self.trunk_config.keys():
            hybrid_backbone_string = self.trunk_config.HYBRID

        norm_layer = partial(nn.LayerNorm, eps=1e-6)

        # num_features for consistency with other models
        self.num_features = self.embed_dim

        # TODO : Enable Hybrid Backbones
        if hybrid_backbone_string:
            self.patch_embed = globals()[hybrid_backbone_string](
                out_dim=self.embed_dim, img_size=self.img_size)
        # if hybrid_backbone is not None:
        #     self.patch_embed = HybridEmbed(
        #         hybrid_backbone,
        #         img_size=img_size,
        #         in_chans=in_chans,
        #         embed_dim=embed_dim,
        #     )
        else:
            self.patch_embed = PatchEmbed(
                img_size=self.img_size,
                patch_size=self.patch_size,
                in_chans=self.in_chans,
                embed_dim=self.embed_dim,
            )
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
        self.pos_embed = nn.Parameter(
            torch.zeros(1, num_patches + 1, self.embed_dim))
        self.pos_drop = nn.Dropout(p=self.drop_rate)

        self.blocks = self._build_blocks(norm_layer)
        self.norm = norm_layer(self.embed_dim)

        # NOTE as per official impl, we could have a pre-logits
        # representation dense layer + tanh here
        # self.repr = nn.Linear(embed_dim, representation_size)
        # self.repr_act = nn.Tanh()

        trunc_normal_(self.pos_embed, std=0.02)
        trunc_normal_(self.cls_token, std=0.02)
        self.apply(self._init_weights)