def trunc_normal(tensor, mean=0., std=1., a=-2., b=2.): # type: (Tensor, float, float, float, float) -> Tensor r"""Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn from the normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` with values outside :math:`[a, b]` redrawn until they are within the bounds. The method used for generating the random values works best when :math:`a \leq \text{mean} \leq b`. Args: tensor: an n-dimensional `torch.Tensor` mean: the mean of the normal distribution std: the standard deviation of the normal distribution a: the minimum cutoff value b: the maximum cutoff value Examples: >>> w = torch.empty(3, 5) >>> nn.init.trunc_normal_(w) """ if isinstance(tensor, nn.Module): for name, weight in tensor.named_parameters(): if weight.requires_grad == True and 'bias' not in name: init.trunc_normal_(weight, mean=0., std=1., a=-2., b=2) elif isinstance(tensor, nn.Parameter): if tensor.requires_grad: init.trunc_normal_(tensor, mean=0., std=1., a=-2., b=2)
def init_weight(m): if isinstance(m, Linear): init.trunc_normal_(m.weight, std=.02) if isinstance(m, Linear) and m.bias is not None: init.constant_(m.bias, 0) elif isinstance(m, LayerNorm): init.constant_(m.bias, 0) init.constant_(m.weight, 1.0)
def init_weights_vit(model): """Performs ViT weight init.""" for k, m in model.named_modules(): if isinstance(m, torch.nn.Conv2d): if "patchify" in k: # ViT patchify stem init fan_in = m.in_channels * m.kernel_size[0] * m.kernel_size[1] init.trunc_normal_(m.weight, std=math.sqrt(1.0 / fan_in)) init.zeros_(m.bias) elif "cstem_last" in k: # The last 1x1 conv of the conv stem init.normal_(m.weight, mean=0.0, std=math.sqrt(2.0 / m.out_channels)) init.zeros_(m.bias) elif "cstem" in k: # Use default pytorch init for other conv layers in the C-stem pass else: raise NotImplementedError if isinstance(m, torch.nn.Linear): if "self_attention" in k: # Use default pytorch init for multi-head attention module pass elif "mlp_block" in k: # MLP block init init.xavier_uniform_(m.weight) init.normal_(m.bias, std=1e-6) elif "head_fc" in k: # Head (classifier) init init.zeros_(m.weight) init.zeros_(m.bias) else: raise NotImplementedError if isinstance(m, torch.nn.BatchNorm2d) or isinstance( m, torch.nn.LayerNorm): # Use default pytorch init for norm layers pass # Pos-embedding init init.normal_(model.pos_embedding, mean=0.0, std=0.02)
def __init__(self, seq_pool=True, embedding_dim=768, num_layers=12, num_heads=12, mlp_ratio=4.0, num_classes=1000, dropout=0.1, attention_dropout=0.1, stochastic_depth=0.1, positional_embedding='sine', seq_len=None, *args, **kwargs): super().__init__() positional_embedding = positional_embedding if \ positional_embedding in ['sine', 'learnable', 'none'] else 'sine' dim_feedforward = int(embedding_dim * mlp_ratio) self.embedding_dim = embedding_dim self.seq_len = seq_len self.seq_pool = seq_pool self.num_tokens = 0 assert seq_len is not None or positional_embedding == 'none', \ f"Positional embedding is set to {positional_embedding} and" \ f" the sequence length was not specified." if not seq_pool: seq_len += 1 self.class_emb = Parameter(torch.zeros(1, 1, self.embedding_dim), requires_grad=True) self.num_tokens = 1 else: self.attention_pool = Linear(self.embedding_dim, 1) if positional_embedding != 'none': if positional_embedding == 'learnable': seq_len += 1 # padding idx self.positional_emb = Parameter(torch.zeros( 1, seq_len, embedding_dim), requires_grad=True) init.trunc_normal_(self.positional_emb, std=0.2) else: self.positional_emb = Parameter(self.sinusoidal_embedding( seq_len, embedding_dim, padding_idx=True), requires_grad=False) else: self.positional_emb = None self.dropout = Dropout(p=dropout) dpr = [ x.item() for x in torch.linspace(0, stochastic_depth, num_layers) ] self.blocks = ModuleList([ MaskedTransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads, dim_feedforward=dim_feedforward, dropout=dropout, attention_dropout=attention_dropout, drop_path_rate=dpr[i]) for i in range(num_layers) ]) self.norm = LayerNorm(embedding_dim) self.fc = Linear(embedding_dim, num_classes) self.apply(self.init_weight)
def __init__(self, cfg): super().__init__() # Get parameters. assert cfg.DATA.TRAIN_CROP_SIZE == cfg.DATA.TEST_CROP_SIZE self.cfg = cfg pool_first = cfg.MVIT.POOL_FIRST # Prepare input. spatial_size = cfg.DATA.TRAIN_CROP_SIZE temporal_size = cfg.DATA.NUM_FRAMES in_chans = cfg.DATA.INPUT_CHANNEL_NUM[0] use_2d_patch = cfg.MVIT.PATCH_2D self.patch_stride = cfg.MVIT.PATCH_STRIDE if use_2d_patch: self.patch_stride = [1] + self.patch_stride # Prepare output. num_classes = cfg.MODEL.NUM_CLASSES embed_dim = cfg.MVIT.EMBED_DIM # Prepare backbone num_heads = cfg.MVIT.NUM_HEADS mlp_ratio = cfg.MVIT.MLP_RATIO qkv_bias = cfg.MVIT.QKV_BIAS self.drop_rate = cfg.MVIT.DROPOUT_RATE depth = cfg.MVIT.DEPTH drop_path_rate = cfg.MVIT.DROPPATH_RATE mode = cfg.MVIT.MODE self.cls_embed_on = cfg.MVIT.CLS_EMBED_ON self.use_abs_pos = cfg.MVIT.USE_ABS_POS assert self.use_abs_pos self.sep_pos_embed = cfg.MVIT.SEP_POS_EMBED if cfg.MVIT.NORM == "layernorm": norm_layer = partial(nn.LayerNorm, eps=1e-6) else: raise NotImplementedError("Only supports layernorm.") self.num_classes = num_classes self.patch_embed = stem_helper.PatchEmbed( dim_in=in_chans, dim_out=embed_dim, kernel=cfg.MVIT.PATCH_KERNEL, stride=cfg.MVIT.PATCH_STRIDE, padding=cfg.MVIT.PATCH_PADDING, conv_2d=use_2d_patch, ) self.input_dims = [temporal_size, spatial_size, spatial_size] assert self.input_dims[1] == self.input_dims[2] self.patch_dims = [ self.input_dims[i] // self.patch_stride[i] for i in range(len(self.input_dims)) ] num_patches = math.prod(self.patch_dims) dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth) ] # stochastic depth decay rule if self.cls_embed_on: self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) pos_embed_dim = num_patches + 1 else: pos_embed_dim = num_patches if self.sep_pos_embed: self.pos_embed_spatial = nn.Parameter( torch.zeros(1, self.patch_dims[1] * self.patch_dims[2], embed_dim)) self.pos_embed_temporal = nn.Parameter( torch.zeros(1, self.patch_dims[0], embed_dim)) if self.cls_embed_on: self.pos_embed_class = nn.Parameter( torch.zeros(1, 1, embed_dim)) else: self.pos_embed = nn.Parameter( torch.zeros(1, pos_embed_dim, embed_dim)) if self.drop_rate > 0.0: self.pos_drop = nn.Dropout(p=self.drop_rate) dim_mul, head_mul = torch.ones(depth + 1), torch.ones(depth + 1) for i in range(len(cfg.MVIT.DIM_MUL)): dim_mul[cfg.MVIT.DIM_MUL[i][0]] = cfg.MVIT.DIM_MUL[i][1] for i in range(len(cfg.MVIT.HEAD_MUL)): head_mul[cfg.MVIT.HEAD_MUL[i][0]] = cfg.MVIT.HEAD_MUL[i][1] pool_q = [[] for i in range(cfg.MVIT.DEPTH)] pool_kv = [[] for i in range(cfg.MVIT.DEPTH)] stride_q = [[] for i in range(cfg.MVIT.DEPTH)] stride_kv = [[] for i in range(cfg.MVIT.DEPTH)] for i in range(len(cfg.MVIT.POOL_Q_STRIDE)): stride_q[cfg.MVIT.POOL_Q_STRIDE[i] [0]] = cfg.MVIT.POOL_Q_STRIDE[i][1:] if cfg.MVIT.POOL_KVQ_KERNEL is not None: pool_q[cfg.MVIT.POOL_Q_STRIDE[i][0]] = cfg.MVIT.POOL_KVQ_KERNEL else: pool_q[cfg.MVIT.POOL_Q_STRIDE[i][0]] = [ s + 1 if s > 1 else s for s in cfg.MVIT.POOL_Q_STRIDE[i][1:] ] # If POOL_KV_STRIDE_ADAPTIVE is not None, initialize POOL_KV_STRIDE. if cfg.MVIT.POOL_KV_STRIDE_ADAPTIVE is not None: _stride_kv = cfg.MVIT.POOL_KV_STRIDE_ADAPTIVE cfg.MVIT.POOL_KV_STRIDE = [] for i in range(cfg.MVIT.DEPTH): if len(stride_q[i]) > 0: _stride_kv = [ max(_stride_kv[d] // stride_q[i][d], 1) for d in range(len(_stride_kv)) ] cfg.MVIT.POOL_KV_STRIDE.append([i] + _stride_kv) for i in range(len(cfg.MVIT.POOL_KV_STRIDE)): stride_kv[cfg.MVIT.POOL_KV_STRIDE[i] [0]] = cfg.MVIT.POOL_KV_STRIDE[i][1:] if cfg.MVIT.POOL_KVQ_KERNEL is not None: pool_kv[cfg.MVIT.POOL_KV_STRIDE[i] [0]] = cfg.MVIT.POOL_KVQ_KERNEL else: pool_kv[cfg.MVIT.POOL_KV_STRIDE[i][0]] = [ s + 1 if s > 1 else s for s in cfg.MVIT.POOL_KV_STRIDE[i][1:] ] self.norm_stem = norm_layer(embed_dim) if cfg.MVIT.NORM_STEM else None if self.cfg.MVIT.REV.ENABLE: handle_fuse = lambda mode: norm_layer( 2 * embed_dim ) if "concat" in cfg.MVIT.REV.RESPATH_FUSE else norm_layer( embed_dim) self.encoder = RevMViT(cfg) # check on this for rev VT -> this should be fine embed_dim = round_width(embed_dim, dim_mul.prod(), divisor=num_heads) self.fuse = ResPathFusion(cfg.MVIT.REV.RESPATH_FUSE, dim=2 * embed_dim) if self.cfg.MVIT.REV.EXTRA_LN: self.norm1 = norm_layer(2 * embed_dim) self.norm2 = handle_fuse(cfg.MVIT.REV.RESPATH_FUSE) else: if cfg.MVIT.REV.LN_BEFORE_FUSE: if self.cfg.MVIT.REV.HALF_C: self.norm = norm_layer(embed_dim) else: if "concat" in self.cfg.MVIT.REV.RESPATH_FUSE: self.norm = norm_layer(2 * embed_dim) else: self.norm = norm_layer(embed_dim) else: self.norm = handle_fuse(cfg.MVIT.REV.RESPATH_FUSE) # embed_dim = embed_dim else: if cfg.MODEL.ACT_CHECKPOINT: validate_checkpoint_wrapper_import(checkpoint_wrapper) input_size = self.patch_dims self.blocks = nn.ModuleList() for i in range(depth): num_heads = round_width(num_heads, head_mul[i]) embed_dim = round_width(embed_dim, dim_mul[i], divisor=num_heads) dim_out = round_width( embed_dim, dim_mul[i + 1], divisor=round_width(num_heads, head_mul[i + 1]), ) attention_block = MultiScaleBlock( dim=embed_dim, dim_out=dim_out, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop_rate=self.drop_rate, drop_path=dpr[i], norm_layer=norm_layer, kernel_q=pool_q[i] if len(pool_q) > i else [], kernel_kv=pool_kv[i] if len(pool_kv) > i else [], stride_q=stride_q[i] if len(stride_q) > i else [], stride_kv=stride_kv[i] if len(stride_kv) > i else [], mode=mode, has_cls_embed=self.cls_embed_on, pool_first=pool_first) if cfg.MODEL.ACT_CHECKPOINT: attention_block = checkpoint_wrapper(attention_block) self.blocks.append(attention_block) embed_dim = dim_out self.norm = norm_layer(embed_dim) self.head = head_helper.TransformerBasicHead( 2 * embed_dim if ("concat" in cfg.MVIT.REV.RESPATH_FUSE and cfg.MVIT.REV.ENABLE and not cfg.MVIT.REV.HALF_C) else embed_dim, num_classes, dropout_rate=cfg.MODEL.DROPOUT_RATE, act_func=cfg.MODEL.HEAD_ACT, ) if self.sep_pos_embed: trunc_normal_(self.pos_embed_spatial, std=0.02) trunc_normal_(self.pos_embed_temporal, std=0.02) if self.cls_embed_on: trunc_normal_(self.pos_embed_class, std=0.02) else: trunc_normal_(self.pos_embed, std=0.02) if self.cls_embed_on: trunc_normal_(self.cls_token, std=0.02) self.apply(self._init_weights)
def __init__(self, cfg): super().__init__() # Get parameters. assert cfg.DATA.TRAIN_CROP_SIZE == cfg.DATA.TEST_CROP_SIZE self.cfg = cfg pool_first = cfg.MVIT.POOL_FIRST # Prepare input. spatial_size = cfg.DATA.TRAIN_CROP_SIZE temporal_size = cfg.DATA.NUM_FRAMES in_chans = cfg.DATA.INPUT_CHANNEL_NUM[0] use_2d_patch = cfg.MVIT.PATCH_2D self.patch_stride = cfg.MVIT.PATCH_STRIDE if use_2d_patch: self.patch_stride = [1] + self.patch_stride # Prepare output. num_classes = cfg.MODEL.NUM_CLASSES embed_dim = cfg.MVIT.EMBED_DIM # Prepare backbone num_heads = cfg.MVIT.NUM_HEADS mlp_ratio = cfg.MVIT.MLP_RATIO qkv_bias = cfg.MVIT.QKV_BIAS self.drop_rate = cfg.MVIT.DROPOUT_RATE depth = cfg.MVIT.DEPTH drop_path_rate = cfg.MVIT.DROPPATH_RATE mode = cfg.MVIT.MODE self.cls_embed_on = cfg.MVIT.CLS_EMBED_ON # Params for positional embedding self.use_abs_pos = cfg.MVIT.USE_ABS_POS self.sep_pos_embed = cfg.MVIT.SEP_POS_EMBED self.rel_pos_spatial = cfg.MVIT.REL_POS_SPATIAL self.rel_pos_temporal = cfg.MVIT.REL_POS_TEMPORAL if cfg.MVIT.NORM == "layernorm": norm_layer = partial(nn.LayerNorm, eps=1e-6) else: raise NotImplementedError("Only supports layernorm.") self.num_classes = num_classes patch_embed = stem_helper.PatchEmbed( dim_in=in_chans, dim_out=embed_dim, kernel=cfg.MVIT.PATCH_KERNEL, stride=cfg.MVIT.PATCH_STRIDE, padding=cfg.MVIT.PATCH_PADDING, conv_2d=use_2d_patch, ) if cfg.MODEL.ACT_CHECKPOINT: patch_embed = checkpoint_wrapper(patch_embed) self.patch_embed = patch_embed self.input_dims = [temporal_size, spatial_size, spatial_size] assert self.input_dims[1] == self.input_dims[2] self.patch_dims = [ self.input_dims[i] // self.patch_stride[i] for i in range(len(self.input_dims)) ] num_patches = math.prod(self.patch_dims) dpr = [ x.item() for x in torch.linspace(0, drop_path_rate, depth) ] # stochastic depth decay rule if self.cls_embed_on: self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) pos_embed_dim = num_patches + 1 else: pos_embed_dim = num_patches if self.use_abs_pos: if self.sep_pos_embed: self.pos_embed_spatial = nn.Parameter( torch.zeros( 1, self.patch_dims[1] * self.patch_dims[2], embed_dim ) ) self.pos_embed_temporal = nn.Parameter( torch.zeros(1, self.patch_dims[0], embed_dim) ) if self.cls_embed_on: self.pos_embed_class = nn.Parameter( torch.zeros(1, 1, embed_dim) ) else: self.pos_embed = nn.Parameter( torch.zeros(1, pos_embed_dim, embed_dim) ) if self.drop_rate > 0.0: self.pos_drop = nn.Dropout(p=self.drop_rate) dim_mul, head_mul = torch.ones(depth + 1), torch.ones(depth + 1) for i in range(len(cfg.MVIT.DIM_MUL)): dim_mul[cfg.MVIT.DIM_MUL[i][0]] = cfg.MVIT.DIM_MUL[i][1] for i in range(len(cfg.MVIT.HEAD_MUL)): head_mul[cfg.MVIT.HEAD_MUL[i][0]] = cfg.MVIT.HEAD_MUL[i][1] pool_q = [[] for i in range(cfg.MVIT.DEPTH)] pool_kv = [[] for i in range(cfg.MVIT.DEPTH)] stride_q = [[] for i in range(cfg.MVIT.DEPTH)] stride_kv = [[] for i in range(cfg.MVIT.DEPTH)] for i in range(len(cfg.MVIT.POOL_Q_STRIDE)): stride_q[cfg.MVIT.POOL_Q_STRIDE[i][0]] = cfg.MVIT.POOL_Q_STRIDE[i][ 1: ] if cfg.MVIT.POOL_KVQ_KERNEL is not None: pool_q[cfg.MVIT.POOL_Q_STRIDE[i][0]] = cfg.MVIT.POOL_KVQ_KERNEL else: pool_q[cfg.MVIT.POOL_Q_STRIDE[i][0]] = [ s + 1 if s > 1 else s for s in cfg.MVIT.POOL_Q_STRIDE[i][1:] ] # If POOL_KV_STRIDE_ADAPTIVE is not None, initialize POOL_KV_STRIDE. if cfg.MVIT.POOL_KV_STRIDE_ADAPTIVE is not None: _stride_kv = cfg.MVIT.POOL_KV_STRIDE_ADAPTIVE cfg.MVIT.POOL_KV_STRIDE = [] for i in range(cfg.MVIT.DEPTH): if len(stride_q[i]) > 0: _stride_kv = [ max(_stride_kv[d] // stride_q[i][d], 1) for d in range(len(_stride_kv)) ] cfg.MVIT.POOL_KV_STRIDE.append([i] + _stride_kv) for i in range(len(cfg.MVIT.POOL_KV_STRIDE)): stride_kv[cfg.MVIT.POOL_KV_STRIDE[i][0]] = cfg.MVIT.POOL_KV_STRIDE[ i ][1:] if cfg.MVIT.POOL_KVQ_KERNEL is not None: pool_kv[ cfg.MVIT.POOL_KV_STRIDE[i][0] ] = cfg.MVIT.POOL_KVQ_KERNEL else: pool_kv[cfg.MVIT.POOL_KV_STRIDE[i][0]] = [ s + 1 if s > 1 else s for s in cfg.MVIT.POOL_KV_STRIDE[i][1:] ] self.norm_stem = norm_layer(embed_dim) if cfg.MVIT.NORM_STEM else None input_size = self.patch_dims self.blocks = nn.ModuleList() if cfg.MODEL.ACT_CHECKPOINT: validate_checkpoint_wrapper_import(checkpoint_wrapper) for i in range(depth): num_heads = round_width(num_heads, head_mul[i]) if cfg.MVIT.DIM_MUL_IN_ATT: dim_out = round_width( embed_dim, dim_mul[i], divisor=round_width(num_heads, head_mul[i]), ) else: dim_out = round_width( embed_dim, dim_mul[i + 1], divisor=round_width(num_heads, head_mul[i + 1]), ) attention_block = MultiScaleBlock( dim=embed_dim, dim_out=dim_out, num_heads=num_heads, input_size=input_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop_rate=self.drop_rate, drop_path=dpr[i], norm_layer=norm_layer, kernel_q=pool_q[i] if len(pool_q) > i else [], kernel_kv=pool_kv[i] if len(pool_kv) > i else [], stride_q=stride_q[i] if len(stride_q) > i else [], stride_kv=stride_kv[i] if len(stride_kv) > i else [], mode=mode, has_cls_embed=self.cls_embed_on, pool_first=pool_first, rel_pos_spatial=self.rel_pos_spatial, rel_pos_temporal=self.rel_pos_temporal, rel_pos_zero_init=cfg.MVIT.REL_POS_ZERO_INIT, residual_pooling=cfg.MVIT.RESIDUAL_POOLING, dim_mul_in_att=cfg.MVIT.DIM_MUL_IN_ATT, separate_qkv=cfg.MVIT.SEPARATE_QKV, ) if cfg.MODEL.ACT_CHECKPOINT: attention_block = checkpoint_wrapper(attention_block) self.blocks.append(attention_block) if len(stride_q[i]) > 0: input_size = [ size // stride for size, stride in zip(input_size, stride_q[i]) ] embed_dim = dim_out self.norm = norm_layer(embed_dim) self.head = head_helper.TransformerBasicHead( embed_dim, num_classes, dropout_rate=cfg.MODEL.DROPOUT_RATE, act_func=cfg.MODEL.HEAD_ACT, cfg=cfg, ) if self.use_abs_pos: if self.sep_pos_embed: trunc_normal_(self.pos_embed_spatial, std=0.02) trunc_normal_(self.pos_embed_temporal, std=0.02) if self.cls_embed_on: trunc_normal_(self.pos_embed_class, std=0.02) else: trunc_normal_(self.pos_embed, std=0.02) if self.cls_embed_on: trunc_normal_(self.cls_token, std=0.02) self.apply(self._init_weights)
def __init__( self, dim, dim_out, input_size, num_heads=8, qkv_bias=False, drop_rate=0.0, kernel_q=(1, 1, 1), kernel_kv=(1, 1, 1), stride_q=(1, 1, 1), stride_kv=(1, 1, 1), norm_layer=nn.LayerNorm, has_cls_embed=True, # Options include `conv`, `avg`, and `max`. mode="conv", # If True, perform pool before projection. pool_first=False, rel_pos_spatial=False, rel_pos_temporal=False, rel_pos_zero_init=False, residual_pooling=False, separate_qkv=False, ): super().__init__() self.pool_first = pool_first self.separate_qkv = separate_qkv self.drop_rate = drop_rate self.num_heads = num_heads self.dim_out = dim_out head_dim = dim_out // num_heads self.scale = head_dim**-0.5 self.has_cls_embed = has_cls_embed padding_q = [int(q // 2) for q in kernel_q] padding_kv = [int(kv // 2) for kv in kernel_kv] if pool_first or separate_qkv: self.q = nn.Linear(dim, dim_out, bias=qkv_bias) self.k = nn.Linear(dim, dim_out, bias=qkv_bias) self.v = nn.Linear(dim, dim_out, bias=qkv_bias) else: self.qkv = nn.Linear(dim, dim_out * 3, bias=qkv_bias) self.proj = nn.Linear(dim_out, dim_out) if drop_rate > 0.0: self.proj_drop = nn.Dropout(drop_rate) # Skip pooling with kernel and stride size of (1, 1, 1). if numpy.prod(kernel_q) == 1 and numpy.prod(stride_q) == 1: kernel_q = () if numpy.prod(kernel_kv) == 1 and numpy.prod(stride_kv) == 1: kernel_kv = () self.mode = mode if mode in ("avg", "max"): pool_op = nn.MaxPool3d if mode == "max" else nn.AvgPool3d self.pool_q = (pool_op( kernel_q, stride_q, padding_q, ceil_mode=False) if len(kernel_q) > 0 else None) self.pool_k = (pool_op( kernel_kv, stride_kv, padding_kv, ceil_mode=False) if len(kernel_kv) > 0 else None) self.pool_v = (pool_op( kernel_kv, stride_kv, padding_kv, ceil_mode=False) if len(kernel_kv) > 0 else None) elif mode == "conv" or mode == "conv_unshared": if pool_first: dim_conv = dim // num_heads if mode == "conv" else dim else: dim_conv = dim_out // num_heads if mode == "conv" else dim_out self.pool_q = (nn.Conv3d( dim_conv, dim_conv, kernel_q, stride=stride_q, padding=padding_q, groups=dim_conv, bias=False, ) if len(kernel_q) > 0 else None) self.norm_q = norm_layer(dim_conv) if len(kernel_q) > 0 else None self.pool_k = (nn.Conv3d( dim_conv, dim_conv, kernel_kv, stride=stride_kv, padding=padding_kv, groups=dim_conv, bias=False, ) if len(kernel_kv) > 0 else None) self.norm_k = norm_layer(dim_conv) if len(kernel_kv) > 0 else None self.pool_v = (nn.Conv3d( dim_conv, dim_conv, kernel_kv, stride=stride_kv, padding=padding_kv, groups=dim_conv, bias=False, ) if len(kernel_kv) > 0 else None) self.norm_v = norm_layer(dim_conv) if len(kernel_kv) > 0 else None else: raise NotImplementedError(f"Unsupported model {mode}") self.rel_pos_spatial = rel_pos_spatial self.rel_pos_temporal = rel_pos_temporal if self.rel_pos_spatial: assert input_size[1] == input_size[2] size = input_size[1] q_size = size // stride_q[1] if len(stride_q) > 0 else size kv_size = size // stride_kv[1] if len(stride_kv) > 0 else size rel_sp_dim = 2 * max(q_size, kv_size) - 1 self.rel_pos_h = nn.Parameter(torch.zeros(rel_sp_dim, head_dim)) self.rel_pos_w = nn.Parameter(torch.zeros(rel_sp_dim, head_dim)) if not rel_pos_zero_init: trunc_normal_(self.rel_pos_h, std=0.02) trunc_normal_(self.rel_pos_w, std=0.02) if self.rel_pos_temporal: self.rel_pos_t = nn.Parameter( torch.zeros(2 * input_size[0] - 1, head_dim)) if not rel_pos_zero_init: trunc_normal_(self.rel_pos_t, std=0.02) self.residual_pooling = residual_pooling
def _init_weights(m: nn.Module) -> None: if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=0.02) if m.bias is not None: nn.init.constant_(m.bias, 0)
def __init__(self, cfg): super().__init__() # Get parameters. assert cfg.DATA.TRAIN_CROP_SIZE == cfg.DATA.TEST_CROP_SIZE self.cfg = cfg # Prepare input. spatial_size = cfg.DATA.TRAIN_CROP_SIZE temporal_size = cfg.DATA.NUM_FRAMES in_chans = cfg.DATA.INPUT_CHANNEL_NUM[0] use_2d_patch = cfg.MVIT.PATCH_2D self.patch_stride = cfg.MVIT.PATCH_STRIDE if use_2d_patch: self.patch_stride = [1] + self.patch_stride # Prepare output. num_classes = cfg.MODEL.NUM_CLASSES embed_dim = cfg.MVIT.EMBED_DIM # Prepare backbone num_heads = cfg.MVIT.NUM_HEADS mlp_ratio = cfg.MVIT.MLP_RATIO qkv_bias = cfg.MVIT.QKV_BIAS self.drop_rate = cfg.MVIT.DROPOUT_RATE depth = cfg.MVIT.DEPTH drop_path_rate = cfg.MVIT.DROPPATH_RATE mode = cfg.MVIT.MODE self.cls_embed_on = cfg.MVIT.CLS_EMBED_ON self.sep_pos_embed = cfg.MVIT.SEP_POS_EMBED if cfg.MVIT.NORM == "layernorm": norm_layer = partial(nn.LayerNorm, eps=1e-6) else: raise NotImplementedError("Only supports layernorm.") self.num_classes = num_classes self.patch_embed = stem_helper.PatchEmbed( dim_in=in_chans, dim_out=embed_dim, kernel=cfg.MVIT.PATCH_KERNEL, stride=cfg.MVIT.PATCH_STRIDE, padding=cfg.MVIT.PATCH_PADDING, conv_2d=use_2d_patch, ) self.input_dims = [temporal_size, spatial_size, spatial_size] assert self.input_dims[1] == self.input_dims[2] self.patch_dims = [ self.input_dims[i] // self.patch_stride[i] for i in range(len(self.input_dims)) ] num_patches = math.prod(self.patch_dims) dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth) ] # stochastic depth decay rule if self.cls_embed_on: self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) pos_embed_dim = num_patches + 1 else: pos_embed_dim = num_patches if self.sep_pos_embed: self.pos_embed_spatial = nn.Parameter( torch.zeros(1, self.patch_dims[1] * self.patch_dims[2], embed_dim)) self.pos_embed_temporal = nn.Parameter( torch.zeros(1, self.patch_dims[0], embed_dim)) self.pos_embed_class = nn.Parameter(torch.zeros(1, 1, embed_dim)) else: self.pos_embed = nn.Parameter( torch.zeros(1, pos_embed_dim, embed_dim)) if self.drop_rate > 0.0: self.pos_drop = nn.Dropout(p=self.drop_rate) pool_q = [[] for i in range(cfg.MVIT.DEPTH)] pool_kv = [[] for i in range(cfg.MVIT.DEPTH)] stride_q = [[] for i in range(cfg.MVIT.DEPTH)] stride_kv = [[] for i in range(cfg.MVIT.DEPTH)] for i in range(len(cfg.MVIT.POOL_Q_STRIDE)): stride_q[cfg.MVIT.POOL_Q_STRIDE[i] [0]] = cfg.MVIT.POOL_Q_STRIDE[i][1:] if cfg.MVIT.POOL_KVQ_KERNEL is not None: pool_q[cfg.MVIT.POOL_Q_STRIDE[i][0]] = cfg.MVIT.POOL_KVQ_KERNEL else: pool_q[cfg.MVIT.POOL_Q_STRIDE[i][0]] = [ s + 1 if s > 1 else s for s in cfg.MVIT.POOL_Q_STRIDE[i][1:] ] for i in range(len(cfg.MVIT.POOL_KV_STRIDE)): stride_kv[cfg.MVIT.POOL_KV_STRIDE[i] [0]] = cfg.MVIT.POOL_KV_STRIDE[i][1:] if cfg.MVIT.POOL_KVQ_KERNEL is not None: pool_kv[cfg.MVIT.POOL_KV_STRIDE[i] [0]] = cfg.MVIT.POOL_KVQ_KERNEL else: pool_kv[cfg.MVIT.POOL_KV_STRIDE[i][0]] = [ s + 1 if s > 1 else s for s in cfg.MVIT.POOL_KV_STRIDE[i][1:] ] dim_mul, head_mul = torch.ones(depth + 1), torch.ones(depth + 1) for i in range(len(cfg.MVIT.DIM_MUL)): dim_mul[cfg.MVIT.DIM_MUL[i][0]] = cfg.MVIT.DIM_MUL[i][1] for i in range(len(cfg.MVIT.HEAD_MUL)): head_mul[cfg.MVIT.HEAD_MUL[i][0]] = cfg.MVIT.HEAD_MUL[i][1] self.norm_stem = norm_layer(embed_dim) if cfg.MVIT.NORM_STEM else None self.blocks = nn.ModuleList() for i in range(depth): num_heads = round_width(num_heads, head_mul[i]) embed_dim = round_width(embed_dim, dim_mul[i], divisor=num_heads) dim_out = round_width( embed_dim, dim_mul[i + 1], divisor=round_width(num_heads, head_mul[i + 1]), ) self.blocks.append( MultiScaleBlock( dim=embed_dim, dim_out=dim_out, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop_rate=self.drop_rate, drop_path=dpr[i], norm_layer=norm_layer, kernel_q=pool_q[i] if len(pool_q) > i else [], kernel_kv=pool_kv[i] if len(pool_kv) > i else [], stride_q=stride_q[i] if len(stride_q) > i else [], stride_kv=stride_kv[i] if len(stride_kv) > i else [], mode=mode, has_cls_embed=self.cls_embed_on, )) embed_dim = dim_out self.norm = norm_layer(embed_dim) self.head = head_helper.TransformerBasicHead( embed_dim, num_classes, dropout_rate=cfg.MODEL.DROPOUT_RATE, act_func=cfg.MODEL.HEAD_ACT, ) if self.sep_pos_embed: trunc_normal_(self.pos_embed_spatial, std=0.02) trunc_normal_(self.pos_embed_temporal, std=0.02) trunc_normal_(self.pos_embed_class, std=0.02) else: trunc_normal_(self.pos_embed, std=0.02) if self.cls_embed_on: trunc_normal_(self.cls_token, std=0.02) self.apply(self._init_weights)