Beispiel #1
0
 def forward(self, input, offset, mask):
     assert 2 * self.deformable_groups * self.kernel_size[0] * self.kernel_size[1] == \
         offset.shape[1]
     assert self.deformable_groups * self.kernel_size[0] * self.kernel_size[1] == \
         mask.shape[1]
     return ModulatedDeformConvFunction.apply(input, offset, mask,
                                              self.weight, self.bias,
                                              self.stride, self.padding,
                                              self.dilation, self.groups,
                                              self.deformable_groups,
                                              self.im2col_step)
Beispiel #2
0
 def forward(self, input):
     out = self.conv_offset_mask(input)
     o1, o2, mask = torch.chunk(out, 3, dim=1)
     offset = torch.cat((o1, o2), dim=1)
     mask = torch.sigmoid(mask)
     return ModulatedDeformConvFunction.apply(input, offset, mask,
                                              self.weight, self.bias,
                                              self.stride, self.padding,
                                              self.dilation, self.groups,
                                              self.deformable_groups,
                                              self.im2col_step)
 def forward(self, input, t):
     out = self.conv_offset_mask(input)
     o1, o2 = torch.chunk(out, 2, dim=1)
     offset = torch.cat((o1, o2), dim=1)
     mask_val = torch.sigmoid(self.weight_t[0] * (t**2) +
                              self.weight_t[1] * t + self.weight_t[2])
     mask = torch.ones(o1.shape, device=input.device) * mask_val
     # mask = torch.full(o1.shape, mask_val).to(device=input.device)
     return ModulatedDeformConvFunction.apply(input, offset, mask,
                                              self.weight, self.bias,
                                              self.stride, self.padding,
                                              self.dilation, self.groups,
                                              self.deformable_groups,
                                              self.im2col_step)
Beispiel #4
0
    def forward(self, input_list):
        input, input_LO, output_LO = input_list
        strideh = stridew = self.stride[0]
        padh = padw = self.padding[0]
        HK = self.kernel_size[0]
        WK = self.kernel_size[1]
        batchsize, label_nc, HI, WI = input_LO.size()
        out = self.conv_offset(input)
        _, _, HO, WO = out.size()
        o1, o2 = torch.chunk(out, 2, dim=1)
        offset = torch.cat((o1, o2), dim=1)

        sample_location_x_0, sample_location_y_0 = get_grid(batchsize, HO, WO, gpu_id=input.get_device(), dtype=input.dtype)
        sample_location_x_0 = sample_location_x_0 * stridew - padw
        sample_location_y_0 = sample_location_y_0 * strideh - padh

        # here we enable layout-constrained sampling
        for hk in range(HK):
            for wk in range(WK):
                sample_location_x_i = sample_location_x_0 + (wk + offset[:, 2*(hk*WK+wk)+1]) / ((WI - 1.0) / 2.0)
                sample_location_y_i = sample_location_y_0 + (hk + offset[:, 2*(hk*WK+wk)]) / ((HI - 1.0) / 2.0)
                sample_location_i = torch.cat([sample_location_x_i, sample_location_y_i], 1)
                if hk==0 and wk==0:
                    sample_location = sample_location_i
                else:
                    sample_location = torch.cat([sample_location, sample_location_i], 1)

        sample_location = sample_location.permute(0, 2, 3, 1).contiguous().view(-1, HO, WO, 2)
        input_LO = input_LO.repeat(1, WK*HK, 1, 1).view(-1, label_nc, HI, WI)

        sample_LO = torch.nn.functional.grid_sample(input_LO, sample_location, mode='bilinear', padding_mode='border')
        sample_LO = sample_LO * output_LO.repeat(1, WK*HK, 1, 1).view(-1, label_nc, HO, WO)

        mask = torch.sum(sample_LO, dim=1, keepdim=True).view(batchsize, WK*HK, HO, WO)

        output = ModulatedDeformConvFunction.apply(input, offset, mask, 
                                                self.weight, 
                                                self.bias, 
                                                self.stride, 
                                                self.padding, 
                                                self.dilation, 
                                                self.groups,
                                                self.deformable_groups,
                                                self.im2col_step)
        return output