Exemplo n.º 1
0
 def __init__(self,
              backbone,
              img_size=224,
              feature_size=None,
              in_chans=3,
              embed_dim=768):
     super(HybridEmbed, self).__init__()
     assert isinstance(backbone, nn.Module)
     img_size = _pair(img_size)
     self.img_size = img_size
     self.backbone = backbone
     if feature_size is None:
         with jt.no_grad():
             # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature
             # map for all networks, the feature metadata has reliable channel and stride info, but using
             # stride to calc feature dim requires info about padding of each stage that isn't captured.
             training = backbone.is_training()
             if training:
                 backbone.eval()
             o = self.backbone(
                 jt.zeros((1, in_chans, img_size[0], img_size[1])))[-1]
             feature_size = o.shape[-2:]
             feature_dim = o.shape[1]
             backbone.train()
     else:
         feature_size = _pair(feature_size)
         feature_dim = self.backbone.feature_info.channels()[-1]
     self.num_patches = feature_size[0] * feature_size[1]
     self.proj = nn.Linear(feature_dim, embed_dim)
Exemplo n.º 2
0
def conv2d(x, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
    padding = _pair(padding)
    stride = _pair(stride)
    dilation = _pair(dilation)
    out_channels = weight.shape[0]

    if groups == 1:
        N, C, H, W = x.shape
        Kh, Kw = weight.shape[-2:]
        oh = (H + padding[0] * 2 - Kh * dilation[0] + dilation[0] -
              1) // stride[0] + 1
        ow = (W + padding[1] * 2 - Kw * dilation[1] + dilation[1] -
              1) // stride[1] + 1
        xx = x.reindex(
            [N, out_channels, C, oh, ow, Kh, Kw],
            [
                'i0',  # Nid
                'i2',  # Cid
                f'i3*{stride[0]}-{padding[0]}+i5*{dilation[0]}',  # Hid+Khid
                f'i4*{stride[1]}-{padding[1]}+i6*{dilation[1]}',  # Wid+KWid
            ])
        ww = weight.broadcast(xx.shape, [0, 3, 4])
        yy = xx * ww
        y = yy.sum([2, 5, 6])  # Kc, Kh, Kw
        if bias is not None:
            b = bias.broadcast(y.shape, [0, 2, 3])
            y = y + b
        return y
    else:
        N, C, H, W = x.shape
        Kh, Kw = weight.shape[-2:]
        G = groups
        CpG = C // G  # channels per group
        oc = out_channels
        oh = (H + padding[0] * 2 - Kh * dilation[0] + dilation[0] -
              1) // stride[0] + 1
        ow = (W + padding[1] * 2 - Kw * dilation[1] + dilation[1] -
              1) // stride[1] + 1
        xx = x.reindex(
            [N, G, oc // G, CpG, oh, ow, Kh, Kw],
            [
                'i0',  # Nid
                f'i1*{CpG}+i3',  # Gid
                f'i4*{stride[0]}-{padding[0]}+i6*{dilation[0]}',  # Hid+Khid
                f'i5*{stride[1]}-{padding[1]}+i7*{dilation[1]}',  # Wid+KWid
            ])
        xx.compile_options = {"G": G}
        # w: [oc, CpG, Kh, Kw]
        ww = weight.reindex([N, G, oc // G, CpG, oh, ow, Kh, Kw],
                            [f'i1*{oc//G}+i2', 'i3', 'i6', 'i7'])
        yy = xx * ww
        y = yy.reindex_reduce('add', [N, oc, oh, ow],
                              ['i0', f'i1*{oc//G}+i2', 'i4', 'i5'])
        if bias is not None:
            b = bias.broadcast(y.shape, [0, 2, 3])
            y = y + b
        return y
Exemplo n.º 3
0
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super(PatchEmbed, self).__init__()
        img_size = _pair(img_size)
        patch_size = _pair(patch_size)
        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] //
                                                        patch_size[0])
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches

        self.proj = nn.Conv(in_chans,
                            embed_dim,
                            kernel_size=patch_size,
                            stride=patch_size)
Exemplo n.º 4
0
def roi_pool(input,rois,output_size,spatial_scale):
    output_size = _pair(output_size)
    spatial_scale = jt.array([spatial_scale])
    output_shapes = [(rois.shape[0], input.shape[1], output_size[0], output_size[1])]*2
    inputs = [input,rois,spatial_scale]
    output_types = [input.dtype,'int32']
    output,arg_output = jt.code(output_shapes,output_types,inputs,cuda_header=CUDA_HEADER,cuda_src=CUDA_SRC,cuda_grad_src=CUDA_GRAD_SRC)
    return output
Exemplo n.º 5
0
def roi_align(input, rois, output_size, spatial_scale, sampling_ratio):
    output_size = _pair(output_size)
    options = jt.array([spatial_scale, sampling_ratio])
    output_shapes = (rois.shape[0], input.shape[1], output_size[0],
                     output_size[1])
    inputs = [input, rois, options]
    output_types = input.dtype
    if rois.shape[0] == 0:
        return jt.zeros(output_shapes, input.dtype)
    output = jt.code(output_shapes,
                     output_types,
                     inputs,
                     cuda_header=CUDA_HEADER,
                     cuda_src=CUDA_SRC,
                     cuda_grad_src=CUDA_GRAD_SRC)
    return output