def __init__(self, cfg): """ The `__init__` method of any subclass should also contain these arguments. Args: cfg (CfgNode): model building configs, details are in the comments of the config file. """ super(X3D, self).__init__() self.norm_module = get_norm(cfg) self.enable_detection = cfg.DETECTION.ENABLE self.num_pathways = 1 exp_stage = 2.0 self.dim_c1 = cfg.X3D.DIM_C1 self.dim_res2 = (round_width(self.dim_c1, exp_stage, divisor=8) if cfg.X3D.SCALE_RES2 else self.dim_c1) self.dim_res3 = round_width(self.dim_res2, exp_stage, divisor=8) self.dim_res4 = round_width(self.dim_res3, exp_stage, divisor=8) self.dim_res5 = round_width(self.dim_res4, exp_stage, divisor=8) self.block_basis = [ # blocks, c, stride [1, self.dim_res2, 2], [2, self.dim_res3, 2], [5, self.dim_res4, 2], [3, self.dim_res5, 2], ] self._construct_network(cfg) init_helper.init_weights(self, cfg.MODEL.FC_INIT_STD, cfg.RESNET.ZERO_INIT_FINAL_BN)
def __init__(self, cfg): super().__init__() # Get parameters. assert cfg.DATA.TRAIN_CROP_SIZE == cfg.DATA.TEST_CROP_SIZE self.cfg = cfg pool_first = cfg.MVIT.POOL_FIRST # Prepare input. spatial_size = cfg.DATA.TRAIN_CROP_SIZE temporal_size = cfg.DATA.NUM_FRAMES in_chans = cfg.DATA.INPUT_CHANNEL_NUM[0] use_2d_patch = cfg.MVIT.PATCH_2D self.patch_stride = cfg.MVIT.PATCH_STRIDE if use_2d_patch: self.patch_stride = [1] + self.patch_stride # Prepare output. num_classes = cfg.MODEL.NUM_CLASSES embed_dim = cfg.MVIT.EMBED_DIM # Prepare backbone num_heads = cfg.MVIT.NUM_HEADS mlp_ratio = cfg.MVIT.MLP_RATIO qkv_bias = cfg.MVIT.QKV_BIAS self.drop_rate = cfg.MVIT.DROPOUT_RATE depth = cfg.MVIT.DEPTH drop_path_rate = cfg.MVIT.DROPPATH_RATE mode = cfg.MVIT.MODE self.cls_embed_on = cfg.MVIT.CLS_EMBED_ON self.use_abs_pos = cfg.MVIT.USE_ABS_POS assert self.use_abs_pos self.sep_pos_embed = cfg.MVIT.SEP_POS_EMBED if cfg.MVIT.NORM == "layernorm": norm_layer = partial(nn.LayerNorm, eps=1e-6) else: raise NotImplementedError("Only supports layernorm.") self.num_classes = num_classes self.patch_embed = stem_helper.PatchEmbed( dim_in=in_chans, dim_out=embed_dim, kernel=cfg.MVIT.PATCH_KERNEL, stride=cfg.MVIT.PATCH_STRIDE, padding=cfg.MVIT.PATCH_PADDING, conv_2d=use_2d_patch, ) self.input_dims = [temporal_size, spatial_size, spatial_size] assert self.input_dims[1] == self.input_dims[2] self.patch_dims = [ self.input_dims[i] // self.patch_stride[i] for i in range(len(self.input_dims)) ] num_patches = math.prod(self.patch_dims) dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth) ] # stochastic depth decay rule if self.cls_embed_on: self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) pos_embed_dim = num_patches + 1 else: pos_embed_dim = num_patches if self.sep_pos_embed: self.pos_embed_spatial = nn.Parameter( torch.zeros(1, self.patch_dims[1] * self.patch_dims[2], embed_dim)) self.pos_embed_temporal = nn.Parameter( torch.zeros(1, self.patch_dims[0], embed_dim)) if self.cls_embed_on: self.pos_embed_class = nn.Parameter( torch.zeros(1, 1, embed_dim)) else: self.pos_embed = nn.Parameter( torch.zeros(1, pos_embed_dim, embed_dim)) if self.drop_rate > 0.0: self.pos_drop = nn.Dropout(p=self.drop_rate) dim_mul, head_mul = torch.ones(depth + 1), torch.ones(depth + 1) for i in range(len(cfg.MVIT.DIM_MUL)): dim_mul[cfg.MVIT.DIM_MUL[i][0]] = cfg.MVIT.DIM_MUL[i][1] for i in range(len(cfg.MVIT.HEAD_MUL)): head_mul[cfg.MVIT.HEAD_MUL[i][0]] = cfg.MVIT.HEAD_MUL[i][1] pool_q = [[] for i in range(cfg.MVIT.DEPTH)] pool_kv = [[] for i in range(cfg.MVIT.DEPTH)] stride_q = [[] for i in range(cfg.MVIT.DEPTH)] stride_kv = [[] for i in range(cfg.MVIT.DEPTH)] for i in range(len(cfg.MVIT.POOL_Q_STRIDE)): stride_q[cfg.MVIT.POOL_Q_STRIDE[i] [0]] = cfg.MVIT.POOL_Q_STRIDE[i][1:] if cfg.MVIT.POOL_KVQ_KERNEL is not None: pool_q[cfg.MVIT.POOL_Q_STRIDE[i][0]] = cfg.MVIT.POOL_KVQ_KERNEL else: pool_q[cfg.MVIT.POOL_Q_STRIDE[i][0]] = [ s + 1 if s > 1 else s for s in cfg.MVIT.POOL_Q_STRIDE[i][1:] ] # If POOL_KV_STRIDE_ADAPTIVE is not None, initialize POOL_KV_STRIDE. if cfg.MVIT.POOL_KV_STRIDE_ADAPTIVE is not None: _stride_kv = cfg.MVIT.POOL_KV_STRIDE_ADAPTIVE cfg.MVIT.POOL_KV_STRIDE = [] for i in range(cfg.MVIT.DEPTH): if len(stride_q[i]) > 0: _stride_kv = [ max(_stride_kv[d] // stride_q[i][d], 1) for d in range(len(_stride_kv)) ] cfg.MVIT.POOL_KV_STRIDE.append([i] + _stride_kv) for i in range(len(cfg.MVIT.POOL_KV_STRIDE)): stride_kv[cfg.MVIT.POOL_KV_STRIDE[i] [0]] = cfg.MVIT.POOL_KV_STRIDE[i][1:] if cfg.MVIT.POOL_KVQ_KERNEL is not None: pool_kv[cfg.MVIT.POOL_KV_STRIDE[i] [0]] = cfg.MVIT.POOL_KVQ_KERNEL else: pool_kv[cfg.MVIT.POOL_KV_STRIDE[i][0]] = [ s + 1 if s > 1 else s for s in cfg.MVIT.POOL_KV_STRIDE[i][1:] ] self.norm_stem = norm_layer(embed_dim) if cfg.MVIT.NORM_STEM else None if self.cfg.MVIT.REV.ENABLE: handle_fuse = lambda mode: norm_layer( 2 * embed_dim ) if "concat" in cfg.MVIT.REV.RESPATH_FUSE else norm_layer( embed_dim) self.encoder = RevMViT(cfg) # check on this for rev VT -> this should be fine embed_dim = round_width(embed_dim, dim_mul.prod(), divisor=num_heads) self.fuse = ResPathFusion(cfg.MVIT.REV.RESPATH_FUSE, dim=2 * embed_dim) if self.cfg.MVIT.REV.EXTRA_LN: self.norm1 = norm_layer(2 * embed_dim) self.norm2 = handle_fuse(cfg.MVIT.REV.RESPATH_FUSE) else: if cfg.MVIT.REV.LN_BEFORE_FUSE: if self.cfg.MVIT.REV.HALF_C: self.norm = norm_layer(embed_dim) else: if "concat" in self.cfg.MVIT.REV.RESPATH_FUSE: self.norm = norm_layer(2 * embed_dim) else: self.norm = norm_layer(embed_dim) else: self.norm = handle_fuse(cfg.MVIT.REV.RESPATH_FUSE) # embed_dim = embed_dim else: if cfg.MODEL.ACT_CHECKPOINT: validate_checkpoint_wrapper_import(checkpoint_wrapper) input_size = self.patch_dims self.blocks = nn.ModuleList() for i in range(depth): num_heads = round_width(num_heads, head_mul[i]) embed_dim = round_width(embed_dim, dim_mul[i], divisor=num_heads) dim_out = round_width( embed_dim, dim_mul[i + 1], divisor=round_width(num_heads, head_mul[i + 1]), ) attention_block = MultiScaleBlock( dim=embed_dim, dim_out=dim_out, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop_rate=self.drop_rate, drop_path=dpr[i], norm_layer=norm_layer, kernel_q=pool_q[i] if len(pool_q) > i else [], kernel_kv=pool_kv[i] if len(pool_kv) > i else [], stride_q=stride_q[i] if len(stride_q) > i else [], stride_kv=stride_kv[i] if len(stride_kv) > i else [], mode=mode, has_cls_embed=self.cls_embed_on, pool_first=pool_first) if cfg.MODEL.ACT_CHECKPOINT: attention_block = checkpoint_wrapper(attention_block) self.blocks.append(attention_block) embed_dim = dim_out self.norm = norm_layer(embed_dim) self.head = head_helper.TransformerBasicHead( 2 * embed_dim if ("concat" in cfg.MVIT.REV.RESPATH_FUSE and cfg.MVIT.REV.ENABLE and not cfg.MVIT.REV.HALF_C) else embed_dim, num_classes, dropout_rate=cfg.MODEL.DROPOUT_RATE, act_func=cfg.MODEL.HEAD_ACT, ) if self.sep_pos_embed: trunc_normal_(self.pos_embed_spatial, std=0.02) trunc_normal_(self.pos_embed_temporal, std=0.02) if self.cls_embed_on: trunc_normal_(self.pos_embed_class, std=0.02) else: trunc_normal_(self.pos_embed, std=0.02) if self.cls_embed_on: trunc_normal_(self.cls_token, std=0.02) self.apply(self._init_weights)
def _construct_network(self, cfg): """ Builds a single pathway X3D model. Args: cfg (CfgNode): model building configs, details are in the comments of the config file. """ assert cfg.MODEL.ARCH in _POOL1.keys() assert cfg.RESNET.DEPTH in _MODEL_STAGE_DEPTH.keys() (d2, d3, d4, d5) = _MODEL_STAGE_DEPTH[cfg.RESNET.DEPTH] num_groups = cfg.RESNET.NUM_GROUPS width_per_group = cfg.RESNET.WIDTH_PER_GROUP dim_inner = num_groups * width_per_group w_mul = cfg.X3D.WIDTH_FACTOR d_mul = cfg.X3D.DEPTH_FACTOR dim_res1 = round_width(self.dim_c1, w_mul) temp_kernel = _TEMPORAL_KERNEL_BASIS[cfg.MODEL.ARCH] self.s1 = stem_helper.VideoModelStem( dim_in=cfg.DATA.INPUT_CHANNEL_NUM, dim_out=[dim_res1], kernel=[temp_kernel[0][0] + [3, 3]], stride=[[1, 2, 2]], padding=[[temp_kernel[0][0][0] // 2, 1, 1]], norm_module=self.norm_module, stem_func_name="x3d_stem", ) # blob_in = s1 dim_in = dim_res1 for stage, block in enumerate(self.block_basis): dim_out = round_width(block[1], w_mul) dim_inner = int(cfg.X3D.BOTTLENECK_FACTOR * dim_out) n_rep = self._round_repeats(block[0], d_mul) prefix = "s{}".format(stage + 2) # start w res2 to follow convention s = resnet_helper.ResStage( dim_in=[dim_in], dim_out=[dim_out], dim_inner=[dim_inner], temp_kernel_sizes=temp_kernel[1], stride=[block[2]], num_blocks=[n_rep], num_groups=[dim_inner] if cfg.X3D.CHANNELWISE_3x3x3 else [num_groups], num_block_temp_kernel=[n_rep], nonlocal_inds=cfg.NONLOCAL.LOCATION[0], nonlocal_group=cfg.NONLOCAL.GROUP[0], nonlocal_pool=cfg.NONLOCAL.POOL[0], instantiation=cfg.NONLOCAL.INSTANTIATION, trans_func_name=cfg.RESNET.TRANS_FUNC, stride_1x1=cfg.RESNET.STRIDE_1X1, norm_module=self.norm_module, dilation=cfg.RESNET.SPATIAL_DILATIONS[stage], drop_connect_rate=cfg.MODEL.DROPCONNECT_RATE * (stage + 2) / (len(self.block_basis) + 1), ) dim_in = dim_out self.add_module(prefix, s) if self.enable_detection: NotImplementedError else: spat_sz = int(math.ceil(cfg.DATA.TRAIN_CROP_SIZE / 32.0)) self.head = head_helper.X3DHead( dim_in=dim_out, dim_inner=dim_inner, dim_out=cfg.X3D.DIM_C5, num_classes=cfg.MODEL.NUM_CLASSES, pool_size=[cfg.DATA.NUM_FRAMES, spat_sz, spat_sz], dropout_rate=cfg.MODEL.DROPOUT_RATE, act_func=cfg.MODEL.HEAD_ACT, bn_lin5_on=cfg.X3D.BN_LIN5, )
def __init__(self, cfg): super().__init__() # Get parameters. assert cfg.DATA.TRAIN_CROP_SIZE == cfg.DATA.TEST_CROP_SIZE self.cfg = cfg pool_first = cfg.MVIT.POOL_FIRST # Prepare input. spatial_size = cfg.DATA.TRAIN_CROP_SIZE temporal_size = cfg.DATA.NUM_FRAMES in_chans = cfg.DATA.INPUT_CHANNEL_NUM[0] use_2d_patch = cfg.MVIT.PATCH_2D self.patch_stride = cfg.MVIT.PATCH_STRIDE if use_2d_patch: self.patch_stride = [1] + self.patch_stride # Prepare output. num_classes = cfg.MODEL.NUM_CLASSES embed_dim = cfg.MVIT.EMBED_DIM # Prepare backbone num_heads = cfg.MVIT.NUM_HEADS mlp_ratio = cfg.MVIT.MLP_RATIO qkv_bias = cfg.MVIT.QKV_BIAS self.drop_rate = cfg.MVIT.DROPOUT_RATE depth = cfg.MVIT.DEPTH drop_path_rate = cfg.MVIT.DROPPATH_RATE mode = cfg.MVIT.MODE self.cls_embed_on = cfg.MVIT.CLS_EMBED_ON # Params for positional embedding self.use_abs_pos = cfg.MVIT.USE_ABS_POS self.sep_pos_embed = cfg.MVIT.SEP_POS_EMBED self.rel_pos_spatial = cfg.MVIT.REL_POS_SPATIAL self.rel_pos_temporal = cfg.MVIT.REL_POS_TEMPORAL if cfg.MVIT.NORM == "layernorm": norm_layer = partial(nn.LayerNorm, eps=1e-6) else: raise NotImplementedError("Only supports layernorm.") self.num_classes = num_classes patch_embed = stem_helper.PatchEmbed( dim_in=in_chans, dim_out=embed_dim, kernel=cfg.MVIT.PATCH_KERNEL, stride=cfg.MVIT.PATCH_STRIDE, padding=cfg.MVIT.PATCH_PADDING, conv_2d=use_2d_patch, ) if cfg.MODEL.ACT_CHECKPOINT: patch_embed = checkpoint_wrapper(patch_embed) self.patch_embed = patch_embed self.input_dims = [temporal_size, spatial_size, spatial_size] assert self.input_dims[1] == self.input_dims[2] self.patch_dims = [ self.input_dims[i] // self.patch_stride[i] for i in range(len(self.input_dims)) ] num_patches = math.prod(self.patch_dims) dpr = [ x.item() for x in torch.linspace(0, drop_path_rate, depth) ] # stochastic depth decay rule if self.cls_embed_on: self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) pos_embed_dim = num_patches + 1 else: pos_embed_dim = num_patches if self.use_abs_pos: if self.sep_pos_embed: self.pos_embed_spatial = nn.Parameter( torch.zeros( 1, self.patch_dims[1] * self.patch_dims[2], embed_dim ) ) self.pos_embed_temporal = nn.Parameter( torch.zeros(1, self.patch_dims[0], embed_dim) ) if self.cls_embed_on: self.pos_embed_class = nn.Parameter( torch.zeros(1, 1, embed_dim) ) else: self.pos_embed = nn.Parameter( torch.zeros(1, pos_embed_dim, embed_dim) ) if self.drop_rate > 0.0: self.pos_drop = nn.Dropout(p=self.drop_rate) dim_mul, head_mul = torch.ones(depth + 1), torch.ones(depth + 1) for i in range(len(cfg.MVIT.DIM_MUL)): dim_mul[cfg.MVIT.DIM_MUL[i][0]] = cfg.MVIT.DIM_MUL[i][1] for i in range(len(cfg.MVIT.HEAD_MUL)): head_mul[cfg.MVIT.HEAD_MUL[i][0]] = cfg.MVIT.HEAD_MUL[i][1] pool_q = [[] for i in range(cfg.MVIT.DEPTH)] pool_kv = [[] for i in range(cfg.MVIT.DEPTH)] stride_q = [[] for i in range(cfg.MVIT.DEPTH)] stride_kv = [[] for i in range(cfg.MVIT.DEPTH)] for i in range(len(cfg.MVIT.POOL_Q_STRIDE)): stride_q[cfg.MVIT.POOL_Q_STRIDE[i][0]] = cfg.MVIT.POOL_Q_STRIDE[i][ 1: ] if cfg.MVIT.POOL_KVQ_KERNEL is not None: pool_q[cfg.MVIT.POOL_Q_STRIDE[i][0]] = cfg.MVIT.POOL_KVQ_KERNEL else: pool_q[cfg.MVIT.POOL_Q_STRIDE[i][0]] = [ s + 1 if s > 1 else s for s in cfg.MVIT.POOL_Q_STRIDE[i][1:] ] # If POOL_KV_STRIDE_ADAPTIVE is not None, initialize POOL_KV_STRIDE. if cfg.MVIT.POOL_KV_STRIDE_ADAPTIVE is not None: _stride_kv = cfg.MVIT.POOL_KV_STRIDE_ADAPTIVE cfg.MVIT.POOL_KV_STRIDE = [] for i in range(cfg.MVIT.DEPTH): if len(stride_q[i]) > 0: _stride_kv = [ max(_stride_kv[d] // stride_q[i][d], 1) for d in range(len(_stride_kv)) ] cfg.MVIT.POOL_KV_STRIDE.append([i] + _stride_kv) for i in range(len(cfg.MVIT.POOL_KV_STRIDE)): stride_kv[cfg.MVIT.POOL_KV_STRIDE[i][0]] = cfg.MVIT.POOL_KV_STRIDE[ i ][1:] if cfg.MVIT.POOL_KVQ_KERNEL is not None: pool_kv[ cfg.MVIT.POOL_KV_STRIDE[i][0] ] = cfg.MVIT.POOL_KVQ_KERNEL else: pool_kv[cfg.MVIT.POOL_KV_STRIDE[i][0]] = [ s + 1 if s > 1 else s for s in cfg.MVIT.POOL_KV_STRIDE[i][1:] ] self.norm_stem = norm_layer(embed_dim) if cfg.MVIT.NORM_STEM else None input_size = self.patch_dims self.blocks = nn.ModuleList() if cfg.MODEL.ACT_CHECKPOINT: validate_checkpoint_wrapper_import(checkpoint_wrapper) for i in range(depth): num_heads = round_width(num_heads, head_mul[i]) if cfg.MVIT.DIM_MUL_IN_ATT: dim_out = round_width( embed_dim, dim_mul[i], divisor=round_width(num_heads, head_mul[i]), ) else: dim_out = round_width( embed_dim, dim_mul[i + 1], divisor=round_width(num_heads, head_mul[i + 1]), ) attention_block = MultiScaleBlock( dim=embed_dim, dim_out=dim_out, num_heads=num_heads, input_size=input_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop_rate=self.drop_rate, drop_path=dpr[i], norm_layer=norm_layer, kernel_q=pool_q[i] if len(pool_q) > i else [], kernel_kv=pool_kv[i] if len(pool_kv) > i else [], stride_q=stride_q[i] if len(stride_q) > i else [], stride_kv=stride_kv[i] if len(stride_kv) > i else [], mode=mode, has_cls_embed=self.cls_embed_on, pool_first=pool_first, rel_pos_spatial=self.rel_pos_spatial, rel_pos_temporal=self.rel_pos_temporal, rel_pos_zero_init=cfg.MVIT.REL_POS_ZERO_INIT, residual_pooling=cfg.MVIT.RESIDUAL_POOLING, dim_mul_in_att=cfg.MVIT.DIM_MUL_IN_ATT, separate_qkv=cfg.MVIT.SEPARATE_QKV, ) if cfg.MODEL.ACT_CHECKPOINT: attention_block = checkpoint_wrapper(attention_block) self.blocks.append(attention_block) if len(stride_q[i]) > 0: input_size = [ size // stride for size, stride in zip(input_size, stride_q[i]) ] embed_dim = dim_out self.norm = norm_layer(embed_dim) self.head = head_helper.TransformerBasicHead( embed_dim, num_classes, dropout_rate=cfg.MODEL.DROPOUT_RATE, act_func=cfg.MODEL.HEAD_ACT, cfg=cfg, ) if self.use_abs_pos: if self.sep_pos_embed: trunc_normal_(self.pos_embed_spatial, std=0.02) trunc_normal_(self.pos_embed_temporal, std=0.02) if self.cls_embed_on: trunc_normal_(self.pos_embed_class, std=0.02) else: trunc_normal_(self.pos_embed, std=0.02) if self.cls_embed_on: trunc_normal_(self.cls_token, std=0.02) self.apply(self._init_weights)
def __init__(self, cfg): super().__init__() # Get parameters. assert cfg.DATA.TRAIN_CROP_SIZE == cfg.DATA.TEST_CROP_SIZE self.cfg = cfg # Prepare input. spatial_size = cfg.DATA.TRAIN_CROP_SIZE temporal_size = cfg.DATA.NUM_FRAMES in_chans = cfg.DATA.INPUT_CHANNEL_NUM[0] use_2d_patch = cfg.MVIT.PATCH_2D self.patch_stride = cfg.MVIT.PATCH_STRIDE if use_2d_patch: self.patch_stride = [1] + self.patch_stride # Prepare output. num_classes = cfg.MODEL.NUM_CLASSES embed_dim = cfg.MVIT.EMBED_DIM # Prepare backbone num_heads = cfg.MVIT.NUM_HEADS mlp_ratio = cfg.MVIT.MLP_RATIO qkv_bias = cfg.MVIT.QKV_BIAS self.drop_rate = cfg.MVIT.DROPOUT_RATE depth = cfg.MVIT.DEPTH drop_path_rate = cfg.MVIT.DROPPATH_RATE mode = cfg.MVIT.MODE self.cls_embed_on = cfg.MVIT.CLS_EMBED_ON self.sep_pos_embed = cfg.MVIT.SEP_POS_EMBED if cfg.MVIT.NORM == "layernorm": norm_layer = partial(nn.LayerNorm, eps=1e-6) else: raise NotImplementedError("Only supports layernorm.") self.num_classes = num_classes self.patch_embed = stem_helper.PatchEmbed( dim_in=in_chans, dim_out=embed_dim, kernel=cfg.MVIT.PATCH_KERNEL, stride=cfg.MVIT.PATCH_STRIDE, padding=cfg.MVIT.PATCH_PADDING, conv_2d=use_2d_patch, ) self.input_dims = [temporal_size, spatial_size, spatial_size] assert self.input_dims[1] == self.input_dims[2] self.patch_dims = [ self.input_dims[i] // self.patch_stride[i] for i in range(len(self.input_dims)) ] num_patches = math.prod(self.patch_dims) dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth) ] # stochastic depth decay rule if self.cls_embed_on: self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) pos_embed_dim = num_patches + 1 else: pos_embed_dim = num_patches if self.sep_pos_embed: self.pos_embed_spatial = nn.Parameter( torch.zeros(1, self.patch_dims[1] * self.patch_dims[2], embed_dim)) self.pos_embed_temporal = nn.Parameter( torch.zeros(1, self.patch_dims[0], embed_dim)) self.pos_embed_class = nn.Parameter(torch.zeros(1, 1, embed_dim)) else: self.pos_embed = nn.Parameter( torch.zeros(1, pos_embed_dim, embed_dim)) if self.drop_rate > 0.0: self.pos_drop = nn.Dropout(p=self.drop_rate) pool_q = [[] for i in range(cfg.MVIT.DEPTH)] pool_kv = [[] for i in range(cfg.MVIT.DEPTH)] stride_q = [[] for i in range(cfg.MVIT.DEPTH)] stride_kv = [[] for i in range(cfg.MVIT.DEPTH)] for i in range(len(cfg.MVIT.POOL_Q_STRIDE)): stride_q[cfg.MVIT.POOL_Q_STRIDE[i] [0]] = cfg.MVIT.POOL_Q_STRIDE[i][1:] if cfg.MVIT.POOL_KVQ_KERNEL is not None: pool_q[cfg.MVIT.POOL_Q_STRIDE[i][0]] = cfg.MVIT.POOL_KVQ_KERNEL else: pool_q[cfg.MVIT.POOL_Q_STRIDE[i][0]] = [ s + 1 if s > 1 else s for s in cfg.MVIT.POOL_Q_STRIDE[i][1:] ] for i in range(len(cfg.MVIT.POOL_KV_STRIDE)): stride_kv[cfg.MVIT.POOL_KV_STRIDE[i] [0]] = cfg.MVIT.POOL_KV_STRIDE[i][1:] if cfg.MVIT.POOL_KVQ_KERNEL is not None: pool_kv[cfg.MVIT.POOL_KV_STRIDE[i] [0]] = cfg.MVIT.POOL_KVQ_KERNEL else: pool_kv[cfg.MVIT.POOL_KV_STRIDE[i][0]] = [ s + 1 if s > 1 else s for s in cfg.MVIT.POOL_KV_STRIDE[i][1:] ] dim_mul, head_mul = torch.ones(depth + 1), torch.ones(depth + 1) for i in range(len(cfg.MVIT.DIM_MUL)): dim_mul[cfg.MVIT.DIM_MUL[i][0]] = cfg.MVIT.DIM_MUL[i][1] for i in range(len(cfg.MVIT.HEAD_MUL)): head_mul[cfg.MVIT.HEAD_MUL[i][0]] = cfg.MVIT.HEAD_MUL[i][1] self.norm_stem = norm_layer(embed_dim) if cfg.MVIT.NORM_STEM else None self.blocks = nn.ModuleList() for i in range(depth): num_heads = round_width(num_heads, head_mul[i]) embed_dim = round_width(embed_dim, dim_mul[i], divisor=num_heads) dim_out = round_width( embed_dim, dim_mul[i + 1], divisor=round_width(num_heads, head_mul[i + 1]), ) self.blocks.append( MultiScaleBlock( dim=embed_dim, dim_out=dim_out, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop_rate=self.drop_rate, drop_path=dpr[i], norm_layer=norm_layer, kernel_q=pool_q[i] if len(pool_q) > i else [], kernel_kv=pool_kv[i] if len(pool_kv) > i else [], stride_q=stride_q[i] if len(stride_q) > i else [], stride_kv=stride_kv[i] if len(stride_kv) > i else [], mode=mode, has_cls_embed=self.cls_embed_on, )) embed_dim = dim_out self.norm = norm_layer(embed_dim) self.head = head_helper.TransformerBasicHead( embed_dim, num_classes, dropout_rate=cfg.MODEL.DROPOUT_RATE, act_func=cfg.MODEL.HEAD_ACT, ) if self.sep_pos_embed: trunc_normal_(self.pos_embed_spatial, std=0.02) trunc_normal_(self.pos_embed_temporal, std=0.02) trunc_normal_(self.pos_embed_class, std=0.02) else: trunc_normal_(self.pos_embed, std=0.02) if self.cls_embed_on: trunc_normal_(self.cls_token, std=0.02) self.apply(self._init_weights)