Ejemplo n.º 1
0
def get_random_rotation_LAFs(patches, angle_mag=math.pi):
    rot_LAFs = Variable(
        torch.FloatTensor([[0.5, 0, 0.5],
                           [0, 0.5,
                            0.5]]).unsqueeze(0).repeat(patches.size(0), 1, 1))
    phi = (Variable(2.0 * torch.rand(patches.size(0)) - 1.0)).view(-1, 1, 1)
    if patches.is_cuda:
        rot_LAFs = rot_LAFs.cuda()
        phi = phi.cuda()
    rotmat = get_rotation_matrix(angle_mag * phi)
    inv_rotmat = get_rotation_matrix(-angle_mag * phi)
    rot_LAFs[:, 0:2, 0:2] = torch.bmm(rotmat, rot_LAFs[:, 0:2, 0:2])
    return rot_LAFs, inv_rotmat
Ejemplo n.º 2
0
 def forward(self, input, return_A_matrix=False):
     x = self.features(self.input_norm(input)).view(-1, 5)
     rot = get_rotation_matrix(torch.atan2(x[:, 3], x[:, 4] + 1e-8))
     if input.is_cuda:
         return torch.bmm(
             rot,
             torch.cat([
                 torch.cat([
                     x[:, 0:1].view(-1, 1, 1),
                     torch.zeros(x.size(0), 1, 1).cuda()
                 ],
                           dim=2), x[:, 1:3].view(-1, 1, 2).contiguous()
             ],
                       dim=1))
     else:
         return torch.bmm(
             rot,
             torch.cat([
                 torch.cat([
                     x[:, 0:1].view(-1, 1, 1),
                     torch.zeros(x.size(0), 1, 1)
                 ],
                           dim=2), x[:, 1:3].view(-1, 1, 2).contiguous()
             ],
                       dim=1))
Ejemplo n.º 3
0
 def forward(self, input, return_A_matrix=False):
     x = self.features(self.input_norm(input)).view(-1, 3)
     angle = torch.atan2(x[:, 1], x[:, 2] + 1e-8)
     rot = get_rotation_matrix(angle)
     tilt = torch.exp(1.8 * F.tanh(x[:, 0]))
     tilt_matrix = torch.eye(2).unsqueeze(0).repeat(input.size(0), 1, 1)
     if x.is_cuda:
         tilt_matrix = tilt_matrix.cuda()
     tilt_matrix[:, 0, 0] = torch.sqrt(tilt)
     tilt_matrix[:, 1, 1] = 1.0 / torch.sqrt(tilt)
     return rectifyAffineTransformationUpIsUp(torch.bmm(
         rot, tilt_matrix)).contiguous()
Ejemplo n.º 4
0
 def forward(self, input, return_A_matrix=False):
     x = self.features(self.input_norm(input)).view(-1, 5)
     angle = torch.atan2(x[:, 3], x[:, 4] + 1e-8)
     rot = get_rotation_matrix(angle)
     return torch.bmm(
         rot,
         torch.cat([
             torch.cat([
                 x[:, 0:1].view(-1, 1, 1), x[:, 1:2].view(x.size(0), 1,
                                                          1).contiguous()
             ],
                       dim=2), x[:, 1:3].view(-1, 1, 2).contiguous()
         ],
                   dim=1))
Ejemplo n.º 5
0
 def forward(self, input, return_rot_matrix = True):
     xy = self.features(self.input_norm(input)).view(-1,2) 
     angle = torch.atan2(xy[:,0] + 1e-8, xy[:,1]+1e-8);
     if return_rot_matrix:
         return get_rotation_matrix(angle)
     return angle
Ejemplo n.º 6
0
 def forward(self, input, return_A_matrix = False):
     xy = self.features(self.input_norm(input)).view(-1,5)
     a1 = torch.cat([1.0 + xy[:,0].contiguous().view(-1,1,1), xy[:,1].contiguous().view(-1,1,1)], dim = 2).contiguous()
     a2 = torch.cat([xy[:,1].contiguous().view(-1,1,1), 1.0 + xy[:,2].contiguous().view(-1,1,1)], dim = 2).contiguous()
     return torch.bmm(get_rotation_matrix(torch.atan2(xy[:,3] + 1e-8, xy[:,4]+1e-8)), torch.cat([a1,a2], dim = 1).contiguous())
Ejemplo n.º 7
0
 def forward(self, input, return_rot_matrix=False):
     xy = self.features(self.input_norm(input))
     angle = torch.atan2(xy[:, 0] + 1e-8, xy[:, 1] + 1e-8)
     if return_rot_matrix:
         return get_rotation_matrix(-angle)
     return angle