Exemple #1
0
def compute_const_point_loss(model, images, logits, point_list=None):
    from kornia.geometry.transform import flips

    # consistency loss
    logits_flip = model(flips.Hflip()(images))
    loss = torch.max(torch.abs(flips.Hflip()(logits_flip)-logits))
    
    # point loss
    if point_list is not None:
        loss = loss + compute_point_level(images, logits, point_list)

    return loss 
Exemple #2
0
def compute_rot_point_loss(model, images, logits, point_list):
    from kornia.geometry.transform import flips
    
    # rotation loss
    rotations = np.random.choice([0, 90, 180, 270], images.shape[0], replace=True)
    images = flips.Hflip()(images)
    images_rotated = sst.batch_rotation(images, rotations)
    logits_rotated = model(images_rotated)
    logits_recovered = sst.batch_rotation(logits_rotated, 360 - rotations)
    logits_recovered = flips.Hflip()(logits_recovered)
    
    loss = torch.mean(torch.abs(logits_recovered-logits))
    
    # point loss
    loss += compute_point_level(images, logits, point_list)

    return loss
    def train_on_batch(self, batch):
        self.opt.zero_grad()

        images = batch["images"]
        images = images.cuda()

        logits = self.model_base(images)

        # compute loss
        loss_name = self.exp_dict['model']['loss']
        if loss_name == 'lcfcn_loss':
            points = batch['points'].cuda()
            loss = 0.

            for lg, pt in zip(logits, points):
                loss += lcfcn_loss.compute_loss((pt == 1).long(), lg.sigmoid())

        elif loss_name == 'const_lcfcn_loss':
            points = batch['points'].cuda()

            logits_flip = self.model_base(flips.Hflip()(images))
            loss = torch.mean(torch.abs(flips.Hflip()(logits_flip) - logits))

            for lg, pt in zip(logits, points):
                loss += lcfcn_loss.compute_loss((pt == 1).long(), lg.sigmoid())

        elif loss_name == 'point_loss':
            points = batch['points'].cuda()[:, None]
            ind = points != 255

            if ind.sum() == 0:
                loss = 0.
            else:
                loss = F.binary_cross_entropy_with_logits(
                    logits[ind], points[ind].float().cuda(), reduction='mean')

        if loss != 0:
            loss.backward()
        if self.exp_dict['model'].get('clip_grad'):
            ut.clip_gradient(self.opt, 0.5)

        self.opt.step()

        return {'train_loss': float(loss)}
Exemple #4
0
def compute_other_loss(self, loss_name, images, logits, points, point_list=None):
        if loss_name == 'toponet':
            if self.first_time:
                self.first_time = False
                self.vgg = nn.DataParallel(lanenet.VGG().cuda(1), list(range(1,4)))
                self.vgg.train()
            points = points[:,None]
            images_flip = flips.Hflip()(images)
            logits_flip = self.model_base(images_flip)
            loss = torch.mean(torch.abs(flips.Hflip()(logits_flip)-logits))
            
            logits_flip_vgg = self.vgg(F.sigmoid(logits_flip.cuda(1)))
            logits_vgg = self.vgg(F.sigmoid(logits.cuda(1))) 
            loss += self.exp_dict["model"]["loss_weight"] * torch.mean(torch.abs(flips.Hflip()(logits_flip_vgg)-logits_vgg)).cuda(0)

            ind = points!=255
            if ind.sum() != 0:
                loss += F.binary_cross_entropy_with_logits(logits[ind], 
                                        points[ind].float().cuda(), 
                                        reduction='mean')

                points_flip = flips.Hflip()(points)
                ind = points_flip!=255
                loss += F.binary_cross_entropy_with_logits(logits_flip[ind], 
                                        points_flip[ind].float().cuda(), 
                                        reduction='mean')

        elif loss_name == 'multiscale_cons_point_loss':
            logits, features = logits
            points = points[:,None]
            ind = points!=255
            # self.vis_on_batch(batch, savedir_image='tmp.png')

            # POINT LOSS
            # loss = ut.joint_loss(logits, points[:,None].float().cuda(), ignore_index=255)
            # print(points[ind].sum())
            logits_flip, features_flip = self.model_base(flips.Hflip()(images), return_features=True)

            loss = torch.mean(torch.abs(flips.Hflip()(logits_flip)-logits))
            for f, f_flip in zip(features, features_flip):
                loss += torch.mean(torch.abs(flips.Hflip()(f_flip)-f)) * self.exp_dict["model"]["loss_weight"]
            if ind.sum() != 0:
                loss += F.binary_cross_entropy_with_logits(logits[ind], 
                                        points[ind].float().cuda(), 
                                        reduction='mean')

                # logits_flip = self.model_base(flips.Hflip()(images))
                points_flip = flips.Hflip()(points)
                ind = points_flip!=255
                # if 1:
                #     pf = points_flip.clone()
                #     pf[pf==1] = 2
                #     pf[pf==0] = 1
                #     pf[pf==255] = 0
                #     lcfcn_loss.save_tmp('tmp.png', flips.Hflip()(images[[0]]), logits_flip[[0]], 3, pf[[0]])
                loss += F.binary_cross_entropy_with_logits(logits_flip[ind], 
                                        points_flip[ind].float().cuda(), 
                                        reduction='mean')

        elif loss_name == "affine_cons_point_loss":
            def rotate_img(self, img, rot):
                if rot == 0:  # 0 degrees rotation
                    return img
                elif rot == 90:  # 90 degrees rotation
                    return np.flipud(np.transpose(img, (1, 0, 2)))
                elif rot == 180:  # 90 degrees rotation
                    return np.fliplr(np.flipud(img))
                elif rot == 270:  # 270 degrees rotation / or -90
                    return np.transpose(np.flipud(img), (1, 0, 2))
                else:
                    raise ValueError('rotation should be 0, 90, 180, or 270 degrees')
            affine_params = self.exp_dict['model']["affine_params"]
            points = points[:,None].float().cuda()
            if np.random.randint(2) == 0:
                images_flip = flips.Hflip()(images)
                flipped = True
            else:
                images_flip = images
                flipped = False
            batch_size, C, height, width = logits.shape
            random_affine = RandomAffine(**affine_params, return_transform=True)
            images_aff, transform = random_affine(images_flip)
            logits_aff = self.model_base(images_aff)
            
            # hu.save_image('tmp1.png', images_aff[0])
            itransform = transform.inverse()
            logits_aff = kornia.geometry.transform.warp_affine(logits_aff, itransform[:,:2, :], dsize=(height, width))
            if flipped:
                logits_aff = flips.Hflip()(logits_aff)
            # hu.save_image('tmp2.png', logits_aff[0])

            
            # logits_flip = self.model_base(flips.Hflip()(images))

            loss = self.exp_dict['model']["loss_weight"] * torch.mean(torch.abs(logits_aff-logits))
            points_aff = kornia.geometry.transform.warp_affine(points, transform[:,:2, :], dsize=(height, width), flags="bilinear")
            # points_aff = points
            ind = points!=255
            if ind.sum() != 0:
                loss += F.binary_cross_entropy_with_logits(logits[ind], 
                                        points[ind], 
                                        reduction='mean')
                if flipped:
                    points_aff = flips.Hflip()(points_aff)
                # logits_flip = self.model_base(flips.Hflip()(images))
                ind = points_aff <= 1
                loss += F.binary_cross_entropy_with_logits(logits_aff[ind], 
                                        points_aff[ind].detach(), 
                                        reduction='mean')
        elif loss_name == "elastic_cons_point_loss":
            points = points[:,None].float().cuda()
            ind = points!=255
            # self.vis_on_batch(batch, savedir_image='tmp.png')

            # POINT LOSS
            # loss = ut.joint_loss(logits, points[:,None].float().cuda(), ignore_index=255)
            # print(points[ind].sum())
            B, C, H, W = images.shape
            # ELASTIC TRANSFORM
            def norm_grid(grid):
                grid -= grid.min()
                grid /= grid.max()
                grid = (grid - 0.5) * 2
                return grid 
            grid_x, grid_y = torch.meshgrid(torch.arange(H), torch.arange(W))
            grid_x = grid_x.float().cuda()
            grid_y = grid_y.float().cuda()
            sigma=self.exp_dict["model"]["sigma"]
            alpha=self.exp_dict["model"]["alpha"]
            indices = torch.stack([grid_y, grid_x], -1).view(1, H, W, 2).expand(B, H, W, 2).contiguous()
            indices = norm_grid(indices)
            dx = gaussian_filter((np.random.rand(H, W) * 2 - 1), sigma, mode="constant", cval=0) * alpha
            dy = gaussian_filter((np.random.rand(H, W) * 2 - 1), sigma, mode="constant", cval=0) * alpha
            dx = torch.from_numpy(dx).cuda().float()
            dy = torch.from_numpy(dy).cuda().float()
            dgrid_x = grid_x + dx
            dgrid_y = grid_y + dy
            dgrid_y = norm_grid(dgrid_y)
            dgrid_x = norm_grid(dgrid_x)
            dindices = torch.stack([dgrid_y, dgrid_x], -1).view(1, H, W, 2).expand(B, H, W, 2).contiguous()
            # dindices0 = dindices.permute(0, 3, 1, 2).contiguous().view(B*2, H, W)
            # indices0 = indices.permute(0, 3, 1, 2).contiguous().view(B*2, H, W)
            # iindices = torch.bmm(indices0, dindices0.pinverse()).view(B, 2, H, W).permute(0, 2, 3, 1)
            # indices_im = indices.permute(0, 3, 1, 2)
            # iindices = F.grid_sample(indices_im, dindices).permute(0, 2, 3, 1)
            images_aug = F.grid_sample(images, dindices)
            logits_aug = self.model_base(images_aug)
            aug_logits = F.grid_sample(logits, dindices)
            points_aug = F.grid_sample(points, dindices, mode='nearest')
            loss = self.exp_dict['model']["loss_weight"] * torch.mean(torch.abs(logits_aug-aug_logits))


            # logits_aff = self.model_base(images_aff)
            # inv_transform = transform.inverse()
            
            import pylab
            def save_im(image, name):
                _images_aff = image.data.cpu().numpy()
                _images_aff -= _images_aff.min()
                _images_aff /= _images_aff.max()
                _images_aff *= 255
                _images_aff = _images_aff.transpose((1,2,0))
                pylab.imsave(name, _images_aff.astype('uint8'))
            # save_im(images_aug[0], 'tmp1.png')
            ind = points!=255
            if ind.sum() != 0:
                loss += 2*F.binary_cross_entropy_with_logits(logits[ind], 
                                        points[ind], 
                                        reduction='mean')
                # if flipped:
                #     points_aff = flips.Hflip()(points_aff)
                # logits_flip = self.model_base(flips.Hflip()(images))
                ind = points_aug != 255
                loss += F.binary_cross_entropy_with_logits(logits_aug[ind], 
                                        points_aug[ind].detach(), 
                                        reduction='mean')

        elif loss_name == 'rot_point_loss':
            points = points[:,None]
            # grid = sst.get_grid(images.shape, normalized=True)
            # images = images[0, None, ...].repeat(8, 1, 1, 1)
            rotations = np.random.choice([0, 90, 180, 270], points.shape[0], replace=True)
            images = flips.Hflip()(images)
            images_rotated = sst.batch_rotation(images, rotations)
            logits_rotated = self.model_base(images_rotated)
            logits_recovered = sst.batch_rotation(logits_rotated, 360 - rotations)
            logits_recovered = flips.Hflip()(logits_recovered)
            
            loss = torch.mean(torch.abs(logits_recovered-logits))
            
            ind = points!=255
            if ind.sum() != 0:
                loss += F.binary_cross_entropy_with_logits(logits[ind], 
                                        points[ind].detach().float().cuda(), 
                                        reduction='mean')

                points_rotated = flips.Hflip()(points)
                points_rotated = sst.batch_rotation(points_rotated, rotations)
                ind = points_rotated!=255
                loss += F.binary_cross_entropy_with_logits(logits_rotated[ind], 
                                        points_rotated[ind].detach().float().cuda(), 
                                        reduction='mean')

        elif loss_name == 'lcfcn_loss':
            loss = 0.
  
            for lg, pt in zip(logits, points):
                loss += lcfcn_loss.compute_loss((pt==1).long(), lg.sigmoid())

                # loss += lcfcn_loss.compute_binary_lcfcn_loss(l[None], 
                #         p[None].long().cuda())

        elif loss_name == 'point_level':
            
            class_labels = torch.zeros(self.n_classes).cuda().long()
            class_labels[np.unique([p['cls'] for p in point_list[0]])] = 1
            n,c,h,w = logits.shape
            class_logits = logits.view(n,c,h*w)
            # REGION LOSS
            loss = F.multilabel_soft_margin_loss(class_logits.max(2)[0], 
                                                    class_labels[None], reduction='mean')

            # POINT LOSS
            points = torch.ones(h,w).cuda()*255
            for p in point_list[0]:
                if p['y'] >= h or p['x'] >= w: 
                    continue
                points[int(p['y']), int(p['x'])] = p['cls']
            probs = F.log_softmax(logits, dim=1)
            loss += F.nll_loss(probs, 
                            points[None].long(), reduction='mean', ignore_index=255)

            return loss

        elif loss_name == 'point_loss':
            points = points[:,None]
            ind = points!=255
            # self.vis_on_batch(batch, savedir_image='tmp.png')

            # POINT LOSS
            # loss = ut.joint_loss(logits, points[:,None].float().cuda(), ignore_index=255)
            # print(points[ind].sum())
            if ind.sum() == 0:
                loss = 0.
            else:
                loss = F.binary_cross_entropy_with_logits(logits[ind], 
                                        points[ind].float().cuda(), 
                                        reduction='mean')
                                        
            # print(points[ind].sum().item(), float(loss))
        elif loss_name == 'att_point_loss':
            points = points[:,None]
            ind = points!=255

            loss = 0.
            if ind.sum() != 0:
                loss = F.binary_cross_entropy_with_logits(logits[ind], 
                                        points[ind].float().cuda(), 
                                        reduction='mean')

                logits_flip = self.model_base(flips.Hflip()(images))
                points_flip = flips.Hflip()(points)
                ind = points_flip!=255
                loss += F.binary_cross_entropy_with_logits(logits_flip[ind], 
                                        points_flip[ind].float().cuda(), 
                                        reduction='mean')

        elif loss_name == 'cons_point_loss':
            points = points[:,None]
            
            logits_flip = self.model_base(flips.Hflip()(images))
            loss = torch.mean(torch.abs(flips.Hflip()(logits_flip)-logits))
            
            ind = points!=255
            if ind.sum() != 0:
                loss += F.binary_cross_entropy_with_logits(logits[ind], 
                                        points[ind].float().cuda(), 
                                        reduction='mean')

                points_flip = flips.Hflip()(points)
                ind = points_flip!=255
                loss += F.binary_cross_entropy_with_logits(logits_flip[ind], 
                                        points_flip[ind].float().cuda(), 
                                        reduction='mean')

        return loss
    def compute_point_loss(self, loss_name, images, logits, points):
        if loss_name == 'rot_point_loss':
            """ Flips the image and computes a random rotation of 
                {0, 90, 180, 270} degrees"""
            points = points[:, None]
            rotations = np.random.choice([0, 90, 180, 270],
                                         points.shape[0],
                                         replace=True)
            images = flips.Hflip()(images)
            images_rotated = sst.batch_rotation(images, rotations)
            logits_rotated = self.model_base(images_rotated)
            logits_recovered = sst.batch_rotation(logits_rotated,
                                                  360 - rotations)
            logits_recovered = flips.Hflip()(logits_recovered)

            loss = torch.mean(torch.abs(logits_recovered - logits))

            ind = points != 255
            if ind.sum() != 0:
                loss += F.binary_cross_entropy_with_logits(
                    logits[ind],
                    points[ind].detach().float().cuda(),
                    reduction='mean')

                points_rotated = flips.Hflip()(points)
                points_rotated = sst.batch_rotation(points_rotated, rotations)
                ind = points_rotated != 255
                loss += F.binary_cross_entropy_with_logits(
                    logits_rotated[ind],
                    points_rotated[ind].detach().float().cuda(),
                    reduction='mean')

        elif loss_name == 'cons_point_loss':
            """ CB point loss, see Laradji et al. 2020 """
            points = points[:, None]

            logits_flip = self.model_base(flips.Hflip()(images))
            loss = torch.mean(torch.abs(flips.Hflip()(logits_flip) - logits))

            ind = points != 255
            if ind.sum() != 0:
                loss += F.binary_cross_entropy_with_logits(
                    logits[ind], points[ind].float().cuda(), reduction='mean')

                points_flip = flips.Hflip()(points)
                ind = points_flip != 255
                loss += F.binary_cross_entropy_with_logits(
                    logits_flip[ind],
                    points_flip[ind].float().cuda(),
                    reduction='mean')

        elif loss_name == "elastic_cons_point_loss":
            """ Performs an elastic transformation to the images and logits and 
                computes the consistency between the transformed logits and the
                logits of the transformed images see: https://gist.github.com/chsasank/4d8f68caf01f041a6453e67fb30f8f5a """
            points = points[:, None].float().cuda()
            ind = points != 255

            B, C, H, W = images.shape

            # Sample normalized elastic grid
            def norm_grid(grid):
                grid -= grid.min()
                grid /= grid.max()
                grid = (grid - 0.5) * 2
                return grid

            grid_x, grid_y = torch.meshgrid(torch.arange(H), torch.arange(W))
            grid_x = grid_x.float().cuda()
            grid_y = grid_y.float().cuda()
            sigma = self.exp_dict["model"]["sigma"]
            alpha = self.exp_dict["model"]["alpha"]
            indices = torch.stack([grid_y, grid_x],
                                  -1).view(1, H, W, 2).expand(B, H, W,
                                                              2).contiguous()
            indices = norm_grid(indices)
            dx = gaussian_filter(
                (np.random.rand(H, W) * 2 - 1), sigma, mode="constant",
                cval=0) * alpha
            dy = gaussian_filter(
                (np.random.rand(H, W) * 2 - 1), sigma, mode="constant",
                cval=0) * alpha
            dx = torch.from_numpy(dx).cuda().float()
            dy = torch.from_numpy(dy).cuda().float()
            dgrid_x = grid_x + dx
            dgrid_y = grid_y + dy
            dgrid_y = norm_grid(dgrid_y)
            dgrid_x = norm_grid(dgrid_x)
            dindices = torch.stack([dgrid_y, dgrid_x],
                                   -1).view(1, H, W, 2).expand(B, H, W,
                                                               2).contiguous()
            # Use the grid to sample from the image and the logits
            images_aug = F.grid_sample(images, dindices)
            logits_aug = self.model_base(images_aug)
            aug_logits = F.grid_sample(logits, dindices)
            points_aug = F.grid_sample(points, dindices, mode='nearest')
            loss = self.exp_dict['model']["loss_weight"] * torch.mean(
                torch.abs(logits_aug - aug_logits))

            ind = points != 255
            if ind.sum() != 0:
                loss += 2 * F.binary_cross_entropy_with_logits(
                    logits[ind], points[ind], reduction='mean')
                ind = points_aug != 255
                loss += F.binary_cross_entropy_with_logits(
                    logits_aug[ind],
                    points_aug[ind].detach(),
                    reduction='mean')

        elif loss_name == 'lcfcn_loss':
            loss = 0.

            for lg, pt in zip(logits, points):
                loss += lcfcn_loss.compute_loss((pt == 1).long(), lg.sigmoid())

                # loss += lcfcn_loss.compute_binary_lcfcn_loss(l[None],
                #         p[None].long().cuda())

        elif loss_name == 'point_loss':
            points = points[:, None]
            ind = points != 255
            # self.vis_on_batch(batch, savedir_image='tmp.png')

            # POINT LOSS
            # loss = ut.joint_loss(logits, points[:,None].float().cuda(), ignore_index=255)
            # print(points[ind].sum())
            if ind.sum() == 0:
                loss = 0.
            else:
                loss = F.binary_cross_entropy_with_logits(
                    logits[ind], points[ind].float().cuda(), reduction='mean')

            # print(points[ind].sum().item(), float(loss))
        elif loss_name == 'att_point_loss':
            points = points[:, None]
            ind = points != 255

            loss = 0.
            if ind.sum() != 0:
                loss = F.binary_cross_entropy_with_logits(
                    logits[ind], points[ind].float().cuda(), reduction='mean')

                logits_flip = self.model_base(flips.Hflip()(images))
                points_flip = flips.Hflip()(points)
                ind = points_flip != 255
                loss += F.binary_cross_entropy_with_logits(
                    logits_flip[ind],
                    points_flip[ind].float().cuda(),
                    reduction='mean')

        return loss
Exemple #6
0
    def compute_point_loss(self, loss_name, images, logits, points):
        if loss_name == 'toponet':
            if self.first_time:
                self.first_time = False
                self.vgg = nn.DataParallel(lanenet.VGG().cuda(1), list(range(1,4)))
                self.vgg.train()
            points = points[:,None]
            images_flip = flips.Hflip()(images)
            logits_flip = self.model_base(images_flip)
            loss = torch.mean(torch.abs(flips.Hflip()(logits_flip)-logits))
            
            logits_flip_vgg = self.vgg(F.sigmoid(logits_flip.cuda(1)))
            logits_vgg = self.vgg(F.sigmoid(logits.cuda(1))) 
            loss += self.exp_dict["model"]["loss_weight"] * torch.mean(torch.abs(flips.Hflip()(logits_flip_vgg)-logits_vgg)).cuda(0)

            ind = points!=255
            if ind.sum() != 0:
                loss += F.binary_cross_entropy_with_logits(logits[ind], 
                                        points[ind].float().cuda(), 
                                        reduction='mean')

                points_flip = flips.Hflip()(points)
                ind = points_flip!=255
                loss += F.binary_cross_entropy_with_logits(logits_flip[ind], 
                                        points_flip[ind].float().cuda(), 
                                        reduction='mean')

        elif loss_name == 'multiscale_cons_point_loss':
            logits, features = logits
            points = points[:,None]
            ind = points!=255
            # self.vis_on_batch(batch, savedir_image='tmp.png')

            # POINT LOSS
            # loss = ut.joint_loss(logits, points[:,None].float().cuda(), ignore_index=255)
            # print(points[ind].sum())
            logits_flip, features_flip = self.model_base(flips.Hflip()(images), return_features=True)

            loss = torch.mean(torch.abs(flips.Hflip()(logits_flip)-logits))
            for f, f_flip in zip(features, features_flip):
                loss += torch.mean(torch.abs(flips.Hflip()(f_flip)-f)) * self.exp_dict["model"]["loss_weight"]
            if ind.sum() != 0:
                loss += F.binary_cross_entropy_with_logits(logits[ind], 
                                        points[ind].float().cuda(), 
                                        reduction='mean')

                # logits_flip = self.model_base(flips.Hflip()(images))
                points_flip = flips.Hflip()(points)
                ind = points_flip!=255
                # if 1:
                #     pf = points_flip.clone()
                #     pf[pf==1] = 2
                #     pf[pf==0] = 1
                #     pf[pf==255] = 0
                #     lcfcn_loss.save_tmp('tmp.png', flips.Hflip()(images[[0]]), logits_flip[[0]], 3, pf[[0]])
                loss += F.binary_cross_entropy_with_logits(logits_flip[ind], 
                                        points_flip[ind].float().cuda(), 
                                        reduction='mean')

        elif loss_name == "affine_cons_point_loss":
            points = points[:,None]
            if np.random.randint(2) == 0:
                images_flip = flips.Hflip()(images)
                flipped = True
            else:
                images_flip = images
                flipped = False
            batch_size, C, height, width = logits.shape
            random_affine = RandomAffine(degrees=2, translate=None, scale=(0.85, 1), shear=[-2, 2], return_transform=True)
            images_aff, transform = random_affine(images_flip)
            logits_aff = self.model_base(images_aff)
            
            # hu.save_image('tmp1.png', images_aff[0])
            itransform = transform.inverse()
            logits_aff = kornia.geometry.transform.warp_affine(logits_aff, itransform[:,:2, :], dsize=(height, width))
            if flipped:
                logits_aff = flips.Hflip()(logits_aff)
            # hu.save_image('tmp2.png', images_recovered[0])

            
            # logits_flip = self.model_base(flips.Hflip()(images))

            loss = torch.mean(torch.abs(logits_aff-logits))
            points_aff = kornia.geometry.transform.warp_affine(points.float(), itransform[:,:2, :], dsize=(height, width), flags="nearest").long()
            # points_aff = points
            ind = points!=255
            if ind.sum() != 0:
                loss += F.binary_cross_entropy_with_logits(logits[ind], 
                                        points[ind].float().cuda(), 
                                        reduction='mean')
                if flipped:
                    points_aff = flips.Hflip()(points_aff)
                # logits_flip =  self.model_base(flips.Hflip()(images))
                ind = points_aff!=255
                loss += F.binary_cross_entropy_with_logits(logits_aff[ind], 
                                        points_aff[ind].float().cuda(), 
                                        reduction='mean')
        elif loss_name == "elastic_cons_point_loss":
            points = points[:,None]
            ind = points!=255
            # self.vis_on_batch(batch, savedir_image='tmp.png')

            # POINT LOSS
            # loss = ut.joint_loss(logits, points[:,None].float().cuda(), ignore_index=255)
            # print(points[ind].sum())
            B, C, H, W = images.shape
            # ELASTIC TRANSFORM
            def norm_grid(grid):
                grid -= grid.min()
                grid /= grid.max()
                grid = (grid - 0.5) * 2
                return grid
            grid_x, grid_y = torch.meshgrid(torch.arange(H), torch.arange(W))
            grid_x = grid_x.float().cuda()
            grid_y = grid_y.float().cuda()
            sigma=4
            alpha=34
            indices = torch.stack([grid_y, grid_x], -1).view(1, H, W, 2).expand(B, H, W, 2).contiguous()
            indices = norm_grid(indices)
            dx = gaussian_filter((np.random.rand(H, W) * 2 - 1), sigma, mode="constant", cval=0) * alpha
            dy = gaussian_filter((np.random.rand(H, W) * 2 - 1), sigma, mode="constant", cval=0) * alpha
            dx = torch.from_numpy(dx).cuda().float()
            dy = torch.from_numpy(dy).cuda().float()
            dgrid_x = grid_x + dx
            dgrid_y = grid_y + dy
            dgrid_y = norm_grid(dgrid_y)
            dgrid_x = norm_grid(dgrid_x)
            dindices = torch.stack([dgrid_y, dgrid_x], -1).view(1, H, W, 2).expand(B, H, W, 2).contiguous()
            dindices0 = dindices.permute(0, 3, 1, 2).contiguous().view(B*2, H, W)
            indices0 = indices.permute(0, 3, 1, 2).contiguous().view(B*2, H, W)
            iindices = torch.bmm(indices0, dindices0.pinverse()).view(B, 2, H, W).permute(0, 2, 3, 1)
            # indices_im = indices.permute(0, 3, 1, 2)
            # iindices = F.grid_sample(indices_im, dindices).permute(0, 2, 3, 1)
            aug = F.grid_sample(images, dindices)
            iaug = F.grid_sample(aug,iindices)


            # logits_aff = self.model_base(images_aff)
            # inv_transform = transform.inverse()
            
            import pylab
            def save_im(image, name):
                _images_aff = image.data.cpu().numpy()
                _images_aff -= _images_aff.min()
                _images_aff /= _images_aff.max()
                _images_aff *= 255
                _images_aff = _images_aff.transpose((1,2,0))
                pylab.imsave(name, _images_aff.astype('uint8'))
            save_im(aug[0], 'tmp1.png')
            save_im(iaug[0], 'tmp2.png')
            pass


        elif loss_name == 'lcfcn_loss':
            loss = 0.
  
            for lg, pt in zip(logits, points):
                loss += lcfcn_loss.compute_loss((pt==1).long(), lg.sigmoid())

                # loss += lcfcn_loss.compute_binary_lcfcn_loss(l[None], 
                #         p[None].long().cuda())

        elif loss_name == 'point_loss':
            points = points[:,None]
            ind = points!=255
            # self.vis_on_batch(batch, savedir_image='tmp.png')

            # POINT LOSS
            # loss = ut.joint_loss(logits, points[:,None].float().cuda(), ignore_index=255)
            # print(points[ind].sum())
            if ind.sum() == 0:
                loss = 0.
            else:
                loss = F.binary_cross_entropy_with_logits(logits[ind], 
                                        points[ind].float().cuda(), 
                                        reduction='mean')
                                        
            # print(points[ind].sum().item(), float(loss))
        elif loss_name == 'att_point_loss':
            points = points[:,None]
            ind = points!=255

            loss = 0.
            if ind.sum() != 0:
                loss = F.binary_cross_entropy_with_logits(logits[ind], 
                                        points[ind].float().cuda(), 
                                        reduction='mean')

                logits_flip = self.model_base(flips.Hflip()(images))
                points_flip = flips.Hflip()(points)
                ind = points_flip!=255
                loss += F.binary_cross_entropy_with_logits(logits_flip[ind], 
                                        points_flip[ind].float().cuda(), 
                                        reduction='mean')

        elif loss_name == 'cons_point_loss':
            points = points[:,None]
            
            logits_flip = self.model_base(flips.Hflip()(images))
            loss = torch.mean(torch.abs(flips.Hflip()(logits_flip)-logits))
            
            ind = points!=255
            if ind.sum() != 0:
                loss += F.binary_cross_entropy_with_logits(logits[ind], 
                                        points[ind].float().cuda(), 
                                        reduction='mean')

                points_flip = flips.Hflip()(points)
                ind = points_flip!=255
                loss += F.binary_cross_entropy_with_logits(logits_flip[ind], 
                                        points_flip[ind].float().cuda(), 
                                        reduction='mean')

        return loss 
Exemple #7
0
        def train_on_batch(self, batch):
            # add to seen images
            for m in batch['meta']:
                self.train_hashes.add(m['hash'])

            self.opt.zero_grad()

            images = batch["images"].cuda()
            logits = self.model_base(images)

            index = batch['meta'][0]['index']
            n, c, h, w = images.shape
            bbox_yxyx = get_rect_bbox(
                h, w, n_regions=self.n_regions)[self.label_map[index] == 1]

            assert (len(bbox_yxyx) > 0)
            loss_name = self.exp_dict['model']['loss']

            if loss_name == 'joint_cross_entropy':
                if len(bbox_yxyx) == self.n_regions:
                    masks = batch['masks'].cuda()
                    roi_mask = torch.ones((h, w), dtype=bool)
                    loss = ut.joint_loss_flat(logits, masks.float(), roi_mask)
                else:
                    masks = batch['masks'].cuda()
                    roi_mask = torch.zeros((h, w), dtype=bool)
                    for y1, x1, y2, x2 in bbox_yxyx:
                        roi_mask[y1:y2, x1:x2] = 1
                    loss = ut.joint_loss_flat(logits, masks.float(), roi_mask)

            elif loss_name == 'cross_entropy':
                if len(bbox_yxyx) == self.n_regions:
                    masks = batch['masks'].cuda()
                    roi_mask = torch.ones((h, w), dtype=bool)
                    L = F.binary_cross_entropy_with_logits(
                        logits, masks.float(), reduction='none').squeeze()
                    loss = L.mean()
                else:
                    masks = batch['masks'].cuda()
                    roi_mask = torch.zeros((h, w), dtype=bool)
                    for y1, x1, y2, x2 in bbox_yxyx:
                        roi_mask[y1:y2, x1:x2] = 1

                    L = F.binary_cross_entropy_with_logits(
                        logits, masks.float(), reduction='none').squeeze()
                    L = L[roi_mask]
                    loss = L.mean()

            elif loss_name in ['image_level', 'image_level3']:
                masks = batch['masks'].cuda().squeeze()
                logits = logits.squeeze()

                loss_fg = 0.
                n_fg = 0
                loss_bg = 0.
                n_bg = 0
                loss = 0.
                assert (logits.ndim == 2)
                assert (masks.ndim == 2)
                for y1, x1, y2, x2 in bbox_yxyx:
                    u_list = masks[y1:y2, x1:x2].unique()
                    l_box = logits[y1:y2, x1:x2]
                    # lcfcn_loss.get_random_points(masks[y1:y2, x1:x2], n_points=1)
                    if 1 in u_list:
                        n_fg += 1
                        # foreground
                        loss_fg += F.binary_cross_entropy_with_logits(
                            l_box.max()[None],
                            torch.ones(1, device=l_box.device),
                            reduction='mean')
                        if 0 in u_list:
                            loss_fg += F.binary_cross_entropy_with_logits(
                                l_box.min()[None],
                                torch.zeros(1, device=l_box.device),
                                reduction='mean')
                    elif 0 in u_list:
                        # only background
                        n_bg += 1
                        loss_bg += F.binary_cross_entropy_with_logits(
                            l_box.max()[None],
                            torch.zeros(1, device=l_box.device),
                            reduction='mean')
                if loss_name in 'image_level3':
                    logits_flip = self.model_base(
                        flips.Hflip()(images)).squeeze()
                    loss += torch.mean(
                        torch.abs(flips.Hflip()(logits_flip) - logits))

                loss += loss_fg / max(n_fg, 1) + loss_bg / max(n_bg, 1)

            elif loss_name in [
                    'const_point_level', 'point_level', 'point_level_2'
            ]:
                masks = batch['masks'].cuda().squeeze()
                logits = logits.squeeze()
                assert (logits.ndim == 2)
                assert (masks.ndim == 2)
                points = torch.ones(masks.shape) * 255
                for i, (y1, x1, y2, x2) in enumerate(bbox_yxyx):
                    mask_box = masks[y1:y2, x1:x2]
                    u_list = mask_box.unique()
                    if 1 in u_list:
                        pts = lcfcn_loss.get_random_points(
                            mask_box.cpu().numpy() == 1,
                            n_points=1,
                            seed=y1 + x1 + y2 + x2)
                        yi, xi = np.where(pts)
                        points[y1:y2, x1:x2][yi, xi] = 1
                        if loss_name == 'point_level_2':
                            if 0 in u_list:
                                pts = lcfcn_loss.get_random_points(
                                    mask_box.cpu().numpy() == 0,
                                    n_points=1,
                                    seed=y1 + x1 + y2 + x2)
                                yi, xi = np.where(pts)
                                points[y1:y2, x1:x2][yi, xi] = 0
                    elif 0 in u_list:
                        pts = lcfcn_loss.get_random_points(
                            mask_box.cpu().numpy() == 0,
                            n_points=1,
                            seed=y1 + x1 + y2 + x2)
                        yi, xi = np.where(pts)
                        points[y1:y2, x1:x2][yi, xi] = 0
                loss = 0.

                # vis
                # if 1:
                #     original = hu.denormalize(batch['images'], mode='rgb')[0]
                #     pts = points.clone()
                #     pts[pts == 1] = 2
                #     pts[pts == 0] = 1
                #     pts[pts == 255] = 0
                #     img_bbox = bbox_yxyx_on_image(bbox_yxyx, original)
                #     hu.save_image('tmp_mask.png', img_bbox,
                #                         mask=masks.cpu().numpy(), return_image=False)
                #     hu.save_image('tmp.png', img_bbox,
                #                         points=pts.long().numpy(), radius=2, return_image=False)

                # foreground loss
                ind = ((points != 255) & (points != 0))
                if ind.sum() > 0:
                    loss += F.binary_cross_entropy_with_logits(
                        logits[ind],
                        points[ind].float().cuda(),
                        reduction='mean')
                # background loss
                ind = ((points != 255) & (points != 1))
                if ind.sum() > 0:
                    loss += F.binary_cross_entropy_with_logits(
                        logits[ind],
                        points[ind].float().cuda(),
                        reduction='mean')

                if loss_name == 'const_point_level':
                    logits_flip = self.model_base(
                        flips.Hflip()(images)).squeeze()
                    loss += torch.mean(
                        torch.abs(flips.Hflip()(logits_flip) - logits))

            elif loss_name in ['image_level2']:
                masks = batch['masks'].cuda().squeeze()
                logits = logits.squeeze()

                loss_fg = 0.
                n_fg = 0
                loss_bg = 0.
                n_bg = 0
                assert (logits.ndim == 2)
                assert (masks.ndim == 2)
                for y1, x1, y2, x2 in bbox_yxyx:
                    u_list = masks[y1:y2, x1:x2].unique()
                    l_box = logits[y1:y2, x1:x2]
                    # lcfcn_loss.get_random_points(masks[y1:y2, x1:x2], n_points=1)
                    if 1 in u_list:
                        n_fg += 1
                        # foreground
                        loss_fg += F.binary_cross_entropy_with_logits(
                            l_box.max()[None],
                            torch.ones(1, device=l_box.device),
                            reduction='mean')
                    elif 0 in u_list:
                        # only background
                        n_bg += 1
                        loss_bg += F.binary_cross_entropy_with_logits(
                            l_box.max()[None],
                            torch.zeros(1, device=l_box.device),
                            reduction='mean')
                loss = loss_fg / max(n_fg, 1) + loss_bg / max(n_bg, 1)
            if loss != 0:
                loss.backward()
                if self.exp_dict['model'].get('clip_grad'):
                    ut.clip_gradient(self.opt, 0.5)
                self.opt.step()

            return {'train_loss': float(loss)}
    def train_on_batch(self, batch):
        # add to seen images
        # for m in batch['meta']:
        #     self.train_hashes.add(m['hash'])

        self.opt.zero_grad()

        images = batch["images"].cuda()
       

        
        # compute loss
        loss_name = self.exp_dict['model']['loss']
        if loss_name in 'cross_entropy':
            logits = self.model_base(images)
            # full supervision
            loss = losses.compute_cross_entropy(images, logits, masks=batch["masks"].cuda())
        
        elif loss_name in 'point_level':
            logits = self.model_base(images)
            # point supervision
            loss = losses.compute_point_level(images, logits, point_list=batch['point_list'])
        
        elif loss_name in ['lcfcn_ndvi']:
            images = torch.cat([images, batch["ndvi"][0].cuda()[None,None]], dim=1)
            logits = self.model_base(images)
            loss = lcfcn_loss.compute_loss(points=batch['points'], probs=logits.sigmoid())
    
        elif loss_name in ['lcfcn', 'lcfcn_nopretrain']:
            # implementation needed
            logits = self.model_base(images)
            loss = lcfcn_loss.compute_loss(points=batch['points'], probs=logits.sigmoid())

        elif loss_name in 'prm':
            counts = batch['points'].sum()

            logits = self.model_base(images)
            peak_map = get_peak_map(logits,
                                win_size=3,
                                counts=counts)
            act_avg = (logits * peak_map).sum((2,3)) / peak_map.sum((2,3))
            loss = F.mse_loss(act_avg.squeeze(), counts.float().squeeze().cuda())

        elif loss_name in 'cam':
            # implementation needed
            counts = batch['points'].sum()
            logits = self.model_base(images).mean()
            loss = F.binary_cross_entropy_with_logits(logits, (counts>0).float().squeeze().cuda(), reduction='mean')


        elif loss_name in 'prm_points':
            # implementation needed
            counts = batch['points'].sum()
            logits = self.model_base(images)
            peak_map = batch['points'].cuda()[None]
            logits_avg = (logits * peak_map).mean()
            loss = F.binary_cross_entropy_with_logits(logits_avg, (counts>0).float().squeeze().cuda(), reduction='mean')

        elif loss_name in 'density':
            if batch['points'].sum() == 0:
                density = 0
            else:
                import kornia
                sigma=1
                kernel_size = (3, 3)
                sigma_list = (sigma, sigma)
                gfilter = kornia.filters.get_gaussian_kernel2d(kernel_size, sigma_list).cuda()
                density = kornia.filters.filter2D(batch['points'][None].float().cuda(), kernel=gfilter[None], border_type='reflect')

            logits = self.model_base(images)
            diff = (logits - density)**2
            loss = torch.sqrt(diff.mean())

        elif loss_name == 'lcfcn_consistency':
            # implementation needed
            logits = self.model_base(images)
            loss = lcfcn_loss.compute_loss(points=batch['points'], probs=logits.sigmoid())
            
            logits_flip = self.model_base(flips.Hflip()(images))
            loss_const = torch.mean(torch.abs(flips.Hflip()(logits_flip)-logits))
          
            loss += loss_const

        elif loss_name in 'lcfcn_rot_loss':
            # implementation needed
            logits = self.model_base(images)
            loss = lcfcn_loss.compute_loss(points=batch['points'], probs=logits.sigmoid())
            
            rotations = np.random.choice([0, 90, 180, 270], images.shape[0], replace=True)
            images = flips.Hflip()(images)
            images_rotated = sst.batch_rotation(images, rotations)
            logits_rotated = self.model_base(images_rotated)
            logits_recovered = sst.batch_rotation(logits_rotated, 360 - rotations)
            logits_recovered = flips.Hflip()(logits_recovered)
            
            loss += torch.mean(torch.abs(logits_recovered-logits))
            
        if loss != 0:
            loss.backward()
            if self.exp_dict['model'].get('clip_grad'):
                ut.clip_gradient(self.opt, 0.5)
            try:
                self.opt.step()
            except:
                self.opt.step(loss=loss)

        return {'train_loss': float(loss)}