def res_stage_nonlocal( model, block_fn, blob_in, dim_in, dim_out, stride, num_blocks, prefix, dim_inner=None, group=None, use_temp_convs=None, temp_strides=None, batch_size=None, pool_stride=None, spatial_dim=None, nonlocal_name=None, nonlocal_mod=1000, ): # prefix is something like: res2, res3, etc. # each res layer has num_blocks stacked if use_temp_convs is None: use_temp_convs = np.zeros(num_blocks).astype(int) if temp_strides is None: temp_strides = np.ones(num_blocks).astype(int) if len(use_temp_convs) < num_blocks: for _ in range(num_blocks - len(use_temp_convs)): use_temp_convs.append(0) temp_strides.append(1) for idx in range(num_blocks): block_prefix = "{}_{}".format(prefix, idx) block_stride = 2 if (idx == 0 and stride == 2) else 1 blob_in = _generic_residual_block_3d(model, blob_in, dim_in, dim_out, block_stride, block_prefix, dim_inner, group, use_temp_convs[idx], temp_strides[idx]) dim_in = dim_out if idx % nonlocal_mod == nonlocal_mod - 1: blob_in = nonlocal_helper.add_nonlocal( model, blob_in, dim_in, dim_in, batch_size, pool_stride, spatial_dim, spatial_dim, nonlocal_name + '_{}'.format(idx), int(dim_in / 2)) return blob_in, dim_in
def res_stage_nonlocal(model, block_fn, blob_in, dim_in, dim_out, stride, num_blocks, prefix, dim_inner=None, group=None, use_temp_convs=None, temp_strides=None, batch_size=None, nonlocal_name=None, nonlocal_mod=1000): """ ResNet stage with optionally non-local blocks. Prefix is something like: res2, res3, etc. """ if use_temp_convs is None: use_temp_convs = np.zeros(num_blocks).astype(int) if temp_strides is None: temp_strides = np.ones(num_blocks).astype(int) if len(use_temp_convs) < num_blocks: for _ in range(num_blocks - len(use_temp_convs)): use_temp_convs.append(0) temp_strides.append(1) for idx in range(num_blocks): block_prefix = "{}_{}".format(prefix, idx) block_stride = 2 if (idx == 0 and stride == 2) else 1 blob_in = _generic_residual_block_3d(model, blob_in, dim_in, dim_out, block_stride, block_prefix, dim_inner, group, use_temp_convs[idx], temp_strides[idx]) dim_in = dim_out if idx % nonlocal_mod == nonlocal_mod - 1: if cfg.AVABOX.CONCAT_GLOBAL_MID_NL: B = batch_size // 2 T = 4 if prefix == 'res3': if cfg.TEST_MODE: H, W = 32, 32 else: H, W = 28, 28 elif prefix == 'res4': if cfg.TEST_MODE: H, W = 16, 16 else: H, W = 14, 14 else: assert False # (B*2, C, T, H, W)->(B, 2, C, T, H, W) blob_in, _ = model.Reshape( blob_in, [blob_in + '_re', blob_in + '_re_shape'], shape=(B, 2, dim_in, T, H, W)) # ->(B, C, 2, T, H, W) blob_in = model.Transpose(blob_in, blob_in + '_tr', axes=(0, 2, 1, 3, 4, 5)) # ->(B, C, 2*T, H, W) blob_in, _ = model.Reshape( blob_in, [blob_in + '_re', blob_in + '_re_shape'], shape=(B, dim_in, 2 * T, H, W)) batch_size = B blob_in = nonlocal_helper.add_nonlocal( model, blob_in, dim_in, dim_in, batch_size, nonlocal_name + '_{}'.format(idx), int(dim_in / 2)) if cfg.AVABOX.CONCAT_GLOBAL_MID_NL: # (B, C, 2*T, H, W) -> (B, C, 2, T, H, W) blob_in, _ = model.Reshape( blob_in, [blob_in + '_re', blob_in + '_re_shape'], shape=(B, dim_in, 2, T, H, W)) # ->(B, 2, C, T, H, W) blob_in = model.Transpose(blob_in, blob_in + '_tr', axes=(0, 2, 1, 3, 4, 5)) # ->(B*2, C, T, H, W) blob_in, _ = model.Reshape( blob_in, [blob_in + '_re', blob_in + '_re_shape'], shape=(B * 2, dim_in, T, H, W)) batch_size = B * 2 return blob_in, dim_in