示例#1
0
    def __init__(self,
                 image_size: int,
                 mask_channels_count: int,
                 image_channels_count: int = 3,
                 generator_size: int = 32,
                 discriminator_size: int = 32):

        super().__init__()
        # netG = UNetGenerator(noise, image_size, mask_channels_count, image_channels_count, generator_size) \
        netG = ResnetGenerator(mask_channels_count, image_channels_count, generator_size, n_blocks=9)\
            .to(ParallelConfig.MAIN_DEVICE)
        # netD = Discriminator(discriminator_size, image_channels_count + mask_channels_count, image_size) \
        netD = ResDiscriminator(image_channels_count + mask_channels_count, discriminator_size, img_f=256, layers=4) \
            .to(ParallelConfig.MAIN_DEVICE)

        if torch.cuda.device_count() > 1:
            netD = nn.DataParallel(netD, ParallelConfig.GPU_IDS)
            netG = nn.DataParallel(netG, ParallelConfig.GPU_IDS)

        self.gan_model = ConditionalGANModel(
            netG,
            HingeLoss(netD) + GANLossObject(
                lambda x, y: Loss.ZERO(),
                lambda dgz, real, fake: Loss(nn.L1Loss()(fake[0], real[0])) * 0.1,
                netD
            )
        )
示例#2
0
    def train(self, image: Tensor, sp: Tensor) -> Loss:
        segm = self.segmentation(image)
        loss: Loss = Loss.ZERO()
        for pen in self.penalties:
            loss = loss + pen.forward(image, sp, segm)
        loss.minimize_step(self.opt)

        return loss
示例#3
0
    def forward(self, image: Tensor, sp: Tensor, segm: Tensor) -> Loss:
        nc = segm.shape[1]
        sp = torch.cat([sp] * nc, dim=1).detach()

        sp_argmax = self.pooling.forward(segm.detach(),
                                         sp).detach().max(dim=1)[1]

        return Loss(self.loss(segm.sigmoid(), sp_argmax)) * self.weight
示例#4
0
    def forward(self, x: Tensor, y: Tensor) -> Loss:

        x_vgg, y_vgg = self.vgg(x).view(x.size(0),
                                        -1), self.vgg(y).view(x.size(0), -1)
        x_vgg = x_vgg - x_vgg.mean(dim=0, keepdim=True).detach()
        y_vgg = y_vgg - y_vgg.mean(dim=0, keepdim=True).detach()

        return Loss((x_vgg * y_vgg).mean())
示例#5
0
    def loss_backward(self, condition: Dict[str, Tensor]) -> Loss:
        condition_pred: Dict[str, Tensor] = self.g_forward(
            self.g_backward(condition))

        loss = Loss.ZERO()
        for name in condition.keys():
            loss += self.loss_2[name](condition_pred[name], condition[name])

        return loss
示例#6
0
    def loss_forward(self, condition: Dict[str, T1]) -> Loss:

        t2 = self.g_forward(condition)
        condition_pred: Dict[str, T1] = self.g_backward(t2)

        loss = Loss.ZERO()
        for name in condition.keys():
            loss += self.loss_1[name](condition_pred[name], condition[name])

        return loss
示例#7
0
    def generator_loss(self, dgz: Tensor, real: Tensor, fake: Tensor) -> Loss:
        batch_size = dgz.size(0)
        nc = dgz.size(1)

        real_labels = torch.full((
            batch_size,
            nc,
        ), 1, device=dgz.device)
        errG = self.__criterion(
            dgz.view(batch_size, nc).sigmoid(), real_labels)
        return Loss(errG)
示例#8
0
    def forward(self, images: Tensor, sp: Tensor,
                segmentation: Tensor) -> Loss:

        fich = self.vgg(images).detach()
        fich = nn.BatchNorm2d(fich.shape[1]).to(
            fich.device).forward(fich).detach()

        down_segm = self.down_sample_to(segmentation, fich)
        # print(down_segm.shape)

        norm = 2 * ParallelConfig.GPU_IDS.__len__()

        return Loss(self.modularity_vgg(fich,
                                        down_segm).sum()) * self.weight / norm
示例#9
0
    def __call__(self, Dx: Tensor, x: List[Tensor]) -> Loss:

        gradients = torch.autograd.grad(outputs=Dx,
                                        inputs=x,
                                        grad_outputs=torch.ones(
                                            Dx.size(), device=Dx.device),
                                        create_graph=True,
                                        retain_graph=True,
                                        only_inputs=True)[0]

        gradients: Tensor = gradients.view(gradients.size(0), -1)
        gradient_penalty_value = (gradients.norm(2, dim=1) - 1).mean()
        res = self.weight * gradient_penalty_value
        self.weight += self.lr * gradient_penalty_value.item()
        self.weight = max(0, self.weight)
        return Loss(res)
示例#10
0
    def discriminator_loss(self, dx: Tensor, dy: Tensor) -> Loss:

        batch_size = dx.size(0)
        nc = dx.size(1)

        real_labels = torch.full((
            batch_size,
            nc,
        ), 1, device=dx.device)
        err_real = self.__criterion(
            dx.view(batch_size, nc).sigmoid(), real_labels)

        fake_labels = torch.full((
            batch_size,
            nc,
        ), 0, device=dx.device)
        err_fake = self.__criterion(
            dy.view(batch_size, nc).sigmoid(), fake_labels)

        return Loss(-(err_fake + err_real))
示例#11
0
 def _discriminator_loss(self, d_real: Tensor, d_fake: Tensor) -> Loss:
     discriminator_loss = (1 - d_real).relu().mean() + (
         1 + d_fake).relu().mean()
     return Loss(-discriminator_loss)
示例#12
0
    def __call__(self, dx: Tensor, x: List[Tensor]) -> Loss:

        penalty_value = dx.abs().pow(1.5).mean()
        return Loss(self.weight * penalty_value)
示例#13
0
 def discriminator_loss(self, d_real: Tensor, d_fake: Tensor) -> Loss:
     return Loss.ZERO()
    def __call__(self, segm: Mask) -> Loss:

        return Loss(self.mapper(segm.tensor, self.diff).mean())
示例#15
0
    def forward(self, x: Tensor, y: Tensor) -> Loss:

        x_vgg, y_vgg = self.vgg(x), self.vgg(y)

        return Loss(self.criterion(x_vgg, y_vgg.detach()) * self.weight)
示例#16
0
    def _discriminator_loss(self, d_real: Tensor, d_fake: Tensor) -> Loss:
        discriminator_loss = (d_real).mean() - d_fake.mean()

        return Loss(discriminator_loss)
示例#17
0
 def _generator_loss(self, dgz: Tensor, real: Tensor, fake: Tensor) -> Loss:
     return Loss(-dgz.mean())
示例#18
0
    def _compute(self, delta: Tensor) -> Loss:

        gradient_penalty_value = delta.relu().norm(2, dim=1).pow(2).mean()

        return Loss(self.weight * gradient_penalty_value)
示例#19
0
 def _generator_loss(self, dgz: Tensor, real: List[Tensor],
                     fake: List[Tensor]) -> Loss:
     return Loss(-dgz.mean())
    def forward(self, segm: Tensor) -> Loss:

        return Loss((-segm * (segm + 1e-8).log()).sum(dim=1).mean())
示例#21
0
    def _compute(self, gradients: Tensor) -> Loss:

        gradients: Tensor = gradients.view((gradients.size(0), -1))
        gradient_penalty_value = ((gradients.norm(2, dim=1) - 1)**2).mean()
        return Loss(self.weight * gradient_penalty_value)