Ejemplo n.º 1
0
def be_sample(segm: Tensor) -> Tuple[Tensor, Mask]:

    dist = Bernoulli(segm)
    mask_sample = dist.sample()
    L: Tensor = (segm * mask_sample + (1 - segm) * (1 - mask_sample)).log().sum() / segm.numel()

    return L,  Mask(mask_sample)
Ejemplo n.º 2
0
    def get_random_segment(masks: Mask) -> Mask:

        nc = masks.tensor.size(1)
        batch_size = masks.tensor.size(0)
        device = masks.tensor.device
        mm: Tensor = torch.zeros((batch_size, nc, masks.tensor.size(2), masks.tensor.size(3)), device=device, dtype=torch.float32)

        for i in range(batch_size):

            index = Transformer.generate_segment_index(masks.tensor[i])
            mm[i, index, :, :] = 1

        new_mask = masks.tensor[mm == 1].view(batch_size, 1, masks.tensor.size(2), masks.tensor.size(3))

        return Mask(new_mask)
Ejemplo n.º 3
0
    def generator_loss(self, images: Tensor, segments: Mask) -> Loss:

        return self.front_model.generator_loss(images, segments) + \
            self.bk_model.generator_loss(images, Mask(1 - segments.tensor))
Ejemplo n.º 4
0
    def test(self, images: Tensor, segments: Mask) -> Tuple[Tensor, Tensor]:

        return self.front_model.test(images, segments), \
               self.bk_model.test(images, Mask(1 - segments.tensor))
Ejemplo n.º 5
0
    def train(self, images: Tensor, segments: Mask):

        self.front_model.train(images, segments)
        self.bk_model.train(images, Mask(1 - segments.tensor))
Ejemplo n.º 6
0
)

opt = torch.optim.Adam(segm_net.parameters(), lr=0.0002, betas=(0.5, 0.999))


classifier = MaskClassifier()


print("Starting Training Loop...")
# For each epoch
for epoch in range(5):
    # For each batch in the dataloader
    for i, (imgs, labels) in enumerate(trainloader, 0):

        imgs = imgs.to(ParallelConfig.MAIN_DEVICE)
        segm: Tensor = segm_net(imgs)

        cl_loss = classifier.loss(imgs, segm)
        cl_loss += Loss(segm.mean() / 2)
        cl_loss += NeighbourDiffLoss(5)(Mask(segm)) * 3


        opt.zero_grad()
        cl_loss.minimize()
        opt.step()
        print(cl_loss.item())

        if i % 20 == 0:
            show_images(imgs.cpu(), 2, 2)
            show_segmentation(segm.cpu())