Exemplo n.º 1
0
    def forward(self, x):
        offset_3 = self.conv_offset_3(x)
        offset_5 = self.conv_offset_5(x)

        splited = list()
        splited.append(
            deform_conv2d(x, offset_3, self.weight, self.stride, self.padding,
                          self.dilation, self.groups, self.deform_groups))
        for i in range(1, self.M):
            self.padding = tuple(2 + p for p in self.padding)
            self.dilation = tuple(2 + d for d in self.dilation)
            weight = self.weight + self.weight_diff
            splited.append(
                deform_conv2d(x, offset_5, weight, self.stride, self.padding,
                              self.dilation, self.groups, self.deform_groups))
            self.padding = (1, 1)
            self.dilation = (1, 1)

        feats = sum(splited)
        att_c = self.att_c(feats.contiguous())
        att_c = att_c.reshape(x.size(0), self.M, x.size(1))
        att_c = att_c.softmax(dim=1)
        att_c = att_c.reshape(x.size(0), -1, 1, 1)
        att_c = torch.split(att_c, x.size(1), dim=1)

        att_c = sum([w * s for w, s in zip(att_c, splited)])

        att_s = self.att_s(torch.max(feats, dim=1, keepdim=True)[0])
        att_s = att_s.softmax(dim=1)
        att_s = torch.split(att_s, 1, dim=1)

        att_s = sum([w * s for w, s in zip(att_s, splited)])

        #return (att_c + att_s) / 2
        return torch.where(att_c > att_s, att_c, att_s)
Exemplo n.º 2
0
    def forward(self, x):
        # pre-context
        avg_x = F.adaptive_avg_pool2d(x, output_size=1)
        avg_x = self.pre_context(avg_x)
        avg_x = avg_x.expand_as(x)
        x = x + avg_x
        # switch
        avg_x = F.pad(x, pad=(2, 2, 2, 2), mode='reflect')
        avg_x = F.avg_pool2d(avg_x, kernel_size=5, stride=1, padding=0)
        switch = self.switch(avg_x)
        # sac
        weight = self._get_weight(self.weight)
        zero_bias = torch.zeros(self.out_channels,
                                device=weight.device,
                                dtype=weight.dtype)

        if self.use_deform:
            offset = self.offset_s(avg_x)
            out_s = deform_conv2d(x, offset, weight, self.stride, self.padding,
                                  self.dilation, self.groups, 1)
        else:
            if (TORCH_VERSION == 'parrots'
                    or digit_version(TORCH_VERSION) < digit_version('1.5.0')):
                out_s = super().conv2d_forward(x, weight)
            elif digit_version(TORCH_VERSION) >= digit_version('1.8.0'):
                # bias is a required argument of _conv_forward in torch 1.8.0
                out_s = super()._conv_forward(x, weight, zero_bias)
            else:
                out_s = super()._conv_forward(x, weight)
        ori_p = self.padding
        ori_d = self.dilation
        self.padding = tuple(3 * p for p in self.padding)
        self.dilation = tuple(3 * d for d in self.dilation)
        weight = weight + self.weight_diff
        if self.use_deform:
            offset = self.offset_l(avg_x)
            out_l = deform_conv2d(x, offset, weight, self.stride, self.padding,
                                  self.dilation, self.groups, 1)
        else:
            if (TORCH_VERSION == 'parrots'
                    or digit_version(TORCH_VERSION) < digit_version('1.5.0')):
                out_l = super().conv2d_forward(x, weight)
            elif digit_version(TORCH_VERSION) >= digit_version('1.8.0'):
                # bias is a required argument of _conv_forward in torch 1.8.0
                out_l = super()._conv_forward(x, weight, zero_bias)
            else:
                out_l = super()._conv_forward(x, weight)

        out = switch * out_s + (1 - switch) * out_l
        self.padding = ori_p
        self.dilation = ori_d
        # post-context
        avg_x = F.adaptive_avg_pool2d(out, output_size=1)
        avg_x = self.post_context(avg_x)
        avg_x = avg_x.expand_as(out)
        out = out + avg_x
        return out
Exemplo n.º 3
0
    def forward(self, i, x):
        if i < self.start_level or not self.part_deform:
            return torch.nn.functional.conv2d(
                x,
                self.weight,
                bias=self.bias,
                stride=self.stride,
                padding=self.padding,
                dilation=self.dilation,
                groups=self.groups)

        offset = self.conv_offset(x)

        # padding is needed to avoid error `input image is smaller than kernel`
        input_pad = (x.size(2) < self.kernel_size[0]) or (x.size(3) <
                                                          self.kernel_size[1])
        if input_pad:
            pad_h = max(self.kernel_size[0] - x.size(2), 0)
            pad_w = max(self.kernel_size[1] - x.size(3), 0)
            x = F.pad(x, (0, pad_w, 0, pad_h), 'constant', 0).contiguous()
            offset = F.pad(offset, (0, pad_w, 0, pad_h), 'constant', 0)
            offset = offset.contiguous()

        out = deform_conv2d(x, offset, self.weight, self.stride, self.padding,
                            self.dilation, self.groups, self.deform_groups)
        if input_pad:
            out = out[:, :, :out.size(2) - pad_h, :out.size(3) -
                      pad_w].contiguous()
        return out + self.bias.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
Exemplo n.º 4
0
 def forward(self, x):
     # pre-context
     avg_x = F.adaptive_avg_pool2d(x, output_size=1)
     avg_x = self.pre_context(avg_x)
     avg_x = avg_x.expand_as(x)
     x = x + avg_x
     # switch
     avg_x = F.pad(x, pad=(2, 2, 2, 2), mode='reflect')
     avg_x = F.avg_pool2d(avg_x, kernel_size=5, stride=1, padding=0)
     switch = self.switch(avg_x)
     # sac
     weight = self._get_weight(self.weight)
     if self.use_deform:
         offset = self.offset_s(avg_x)
         out_s = deform_conv2d(x, offset, weight, self.stride, self.padding,
                               self.dilation, self.groups, 1)
     else:
         if (LooseVersion(TORCH_VERSION) < LooseVersion('1.5.0')
                 or TORCH_VERSION == 'parrots'):
             out_s = super().conv2d_forward(x, weight)
         else:
             out_s = super()._conv_forward(x, weight)
     ori_p = self.padding
     ori_d = self.dilation
     self.padding = tuple(3 * p for p in self.padding)
     self.dilation = tuple(3 * d for d in self.dilation)
     weight = weight + self.weight_diff
     if self.use_deform:
         offset = self.offset_l(avg_x)
         out_l = deform_conv2d(x, offset, weight, self.stride, self.padding,
                               self.dilation, self.groups, 1)
     else:
         if (LooseVersion(TORCH_VERSION) < LooseVersion('1.5.0')
                 or TORCH_VERSION == 'parrots'):
             out_l = super().conv2d_forward(x, weight)
         else:
             out_l = super()._conv_forward(x, weight)
     out = switch * out_s + (1 - switch) * out_l
     self.padding = ori_p
     self.dilation = ori_d
     # post-context
     avg_x = F.adaptive_avg_pool2d(out, output_size=1)
     avg_x = self.post_context(avg_x)
     avg_x = avg_x.expand_as(out)
     out = out + avg_x
     return out
Exemplo n.º 5
0
    def forward(self, i, x):
        if i < self.start_level or not self.part_deform:
            return torch.nn.functional.conv2d(x,
                                              self.weight,
                                              bias=self.bias,
                                              stride=self.stride,
                                              padding=self.padding,
                                              dilation=self.dilation,
                                              groups=self.groups)

        offset = self.conv_offset(x)
        return deform_conv2d(x, offset, self.weight, self.stride, self.padding,
                             self.dilation, self.groups,
                             self.deform_groups) + self.bias.unsqueeze(
                                 0).unsqueeze(-1).unsqueeze(-1)