Пример #1
0
def res_stage_nonlocal(
    model,
    block_fn,
    blob_in,
    dim_in,
    dim_out,
    stride,
    num_blocks,
    prefix,
    dim_inner=None,
    group=None,
    use_temp_convs=None,
    temp_strides=None,
    batch_size=None,
    pool_stride=None,
    spatial_dim=None,
    nonlocal_name=None,
    nonlocal_mod=1000,
):
    # prefix is something like: res2, res3, etc.
    # each res layer has num_blocks stacked

    if use_temp_convs is None:
        use_temp_convs = np.zeros(num_blocks).astype(int)
    if temp_strides is None:
        temp_strides = np.ones(num_blocks).astype(int)

    if len(use_temp_convs) < num_blocks:
        for _ in range(num_blocks - len(use_temp_convs)):
            use_temp_convs.append(0)
            temp_strides.append(1)

    for idx in range(num_blocks):
        block_prefix = "{}_{}".format(prefix, idx)
        block_stride = 2 if (idx == 0 and stride == 2) else 1
        blob_in = _generic_residual_block_3d(model, blob_in, dim_in, dim_out,
                                             block_stride, block_prefix,
                                             dim_inner, group,
                                             use_temp_convs[idx],
                                             temp_strides[idx])
        dim_in = dim_out

        if idx % nonlocal_mod == nonlocal_mod - 1:
            blob_in = nonlocal_helper.add_nonlocal(
                model, blob_in, dim_in, dim_in, batch_size, pool_stride,
                spatial_dim, spatial_dim, nonlocal_name + '_{}'.format(idx),
                int(dim_in / 2))

    return blob_in, dim_in
Пример #2
0
def res_stage_nonlocal(model,
                       block_fn,
                       blob_in,
                       dim_in,
                       dim_out,
                       stride,
                       num_blocks,
                       prefix,
                       dim_inner=None,
                       group=None,
                       use_temp_convs=None,
                       temp_strides=None,
                       batch_size=None,
                       nonlocal_name=None,
                       nonlocal_mod=1000):
    """
    ResNet stage with optionally non-local blocks.
    Prefix is something like: res2, res3, etc.
    """

    if use_temp_convs is None:
        use_temp_convs = np.zeros(num_blocks).astype(int)
    if temp_strides is None:
        temp_strides = np.ones(num_blocks).astype(int)

    if len(use_temp_convs) < num_blocks:
        for _ in range(num_blocks - len(use_temp_convs)):
            use_temp_convs.append(0)
            temp_strides.append(1)

    for idx in range(num_blocks):
        block_prefix = "{}_{}".format(prefix, idx)
        block_stride = 2 if (idx == 0 and stride == 2) else 1
        blob_in = _generic_residual_block_3d(model, blob_in, dim_in, dim_out,
                                             block_stride, block_prefix,
                                             dim_inner, group,
                                             use_temp_convs[idx],
                                             temp_strides[idx])
        dim_in = dim_out

        if idx % nonlocal_mod == nonlocal_mod - 1:

            if cfg.AVABOX.CONCAT_GLOBAL_MID_NL:
                B = batch_size // 2
                T = 4
                if prefix == 'res3':
                    if cfg.TEST_MODE:
                        H, W = 32, 32
                    else:
                        H, W = 28, 28
                elif prefix == 'res4':
                    if cfg.TEST_MODE:
                        H, W = 16, 16
                    else:
                        H, W = 14, 14
                else:
                    assert False

                # (B*2, C, T, H, W)->(B, 2, C, T, H, W)
                blob_in, _ = model.Reshape(
                    blob_in, [blob_in + '_re', blob_in + '_re_shape'],
                    shape=(B, 2, dim_in, T, H, W))

                # ->(B, C, 2, T, H, W)
                blob_in = model.Transpose(blob_in,
                                          blob_in + '_tr',
                                          axes=(0, 2, 1, 3, 4, 5))

                # ->(B, C, 2*T, H, W)
                blob_in, _ = model.Reshape(
                    blob_in, [blob_in + '_re', blob_in + '_re_shape'],
                    shape=(B, dim_in, 2 * T, H, W))

                batch_size = B

            blob_in = nonlocal_helper.add_nonlocal(
                model, blob_in, dim_in, dim_in, batch_size,
                nonlocal_name + '_{}'.format(idx), int(dim_in / 2))

            if cfg.AVABOX.CONCAT_GLOBAL_MID_NL:
                # (B, C, 2*T, H, W) -> (B, C, 2, T, H, W)
                blob_in, _ = model.Reshape(
                    blob_in, [blob_in + '_re', blob_in + '_re_shape'],
                    shape=(B, dim_in, 2, T, H, W))

                # ->(B, 2, C, T, H, W)
                blob_in = model.Transpose(blob_in,
                                          blob_in + '_tr',
                                          axes=(0, 2, 1, 3, 4, 5))

                # ->(B*2, C, T, H, W)
                blob_in, _ = model.Reshape(
                    blob_in, [blob_in + '_re', blob_in + '_re_shape'],
                    shape=(B * 2, dim_in, T, H, W))

                batch_size = B * 2

    return blob_in, dim_in