Exemple #1
0
        def loss(image: Tensor, mask: ProbabilityMeasure):

            with torch.no_grad():
                A, T = LinearTransformOT.forward(mask, barycenter)

            t_loss = Samples_Loss(scaling=0.8,
                                  border=0.0001)(mask, mask.detach() + T)
            a_loss = Samples_Loss(scaling=0.8, border=0.0001)(
                mask.centered(), mask.centered().multiply(A).detach())
            w_loss = Samples_Loss(scaling=0.85, border=0.00001)(
                mask.centered().multiply(A), barycenter.centered().detach())

            # print(time.time() - t1)

            return a_loss * ca + w_loss * cw + t_loss * ct
def content_to_measure(content):
    batch_size = content.shape[0]
    pred_measures: ProbabilityMeasure = ProbabilityMeasure(
            torch.ones(batch_size, 70, device=content.device) / 70,
            content.reshape(batch_size, 70, 2)
        )
    return pred_measures
    def forward(self, x: Tensor) -> ProbabilityMeasure:
        conv = self.main(x)

        coord = self.coord(
            conv.view(conv.shape[0], -1)
        ).view(x.shape[0], self.measure_size, 2)
        coord = coord * 255 / 256

        # prob = self.prob(
        #     conv.view(conv.shape[0], -1)
        # ).view(x.shape[0], self.measure_size) + 1e-8
        prob = torch.ones(x.shape[0], coord.shape[1], device=coord.device, dtype=torch.float32)
        prob = prob / (prob.sum(dim=1, keepdim=True))

        return ProbabilityMeasure(prob, coord)
Exemple #4
0
    def forward(self, mask: torch.Tensor, border=0.5):
        if len(mask.shape) == 4:
            assert mask.shape[1] == 1
            mask = mask[:, 0, :, :]

        gmm = GaussianMixture(n_components=self.n_components,
                              n_features=2).cuda()

        messs = self.from_mask(mask, border)
        gmm.fit(messs.coord.permute(1, 2, 0)[:, None, :, :], n_iter=150)
        probs = gmm.pi[0].permute(2, 0, 1).squeeze()
        means = gmm.mu[0].permute(2, 0, 1)
        if probs.ndim == 1:
            probs = probs[None, ]

        return ProbabilityMeasure(probs, means)
Exemple #5
0
    def forward(self, mask: Tensor, border=0.5):
        if len(mask.shape) == 4:
            assert mask.shape[1] == 1
            mask = mask[:, 0, :, :]

        mask = mask.cpu()

        probabilymeasurelist = [
            self._from_one_mask(mask[i], border) for i in range(len(mask))
        ]

        # means = []

        def compute(coord, n_components):

            if coord.shape[0] < n_components:
                return coord[np.newaxis, :, :], np.ones(
                    (1, coord.shape[0])) / coord.shape[0]

            gmm = mixture.GaussianMixture(n_components=n_components,
                                          covariance_type='spherical',
                                          max_iter=100,
                                          tol=1e-4)
            gmm.fit(coord)
            return gmm.means_[np.newaxis, :, :], gmm.weights_[np.newaxis, :]

        means_and_probs = [
            compute(mes.coord[0].numpy(), self.n_components)
            for mes in probabilymeasurelist
        ]

        means = [
            torch.from_numpy(mm[0]).type(torch.float32)
            for mm in means_and_probs
        ]
        probs = [
            torch.from_numpy(mm[1]).type(torch.float32)
            for mm in means_and_probs
        ]
        # for mes in probabilymeasurelist:
        #     gp = torch.from_numpy(self.gaussian_mixture.fit(mes.coord[0].numpy()).means_[np.newaxis, :, :])
        #     means.append(gp)

        return ProbabilityMeasure(torch.cat(probs), torch.cat(means))
Exemple #6
0
    def forward(pred: ProbabilityMeasure, targets: ProbabilityMeasure):

        with torch.no_grad():
            P = compute_ot_matrix_par(pred.centered().coord.cpu().numpy(),
                                      targets.centered().coord.cpu().numpy())
            P = torch.from_numpy(P).type_as(pred.coord).cuda()

        xs = pred.centered().coord
        xsT = xs.transpose(1, 2)
        xt = targets.centered().coord

        a: Tensor = pred.probability + 1e-8
        a /= a.sum(dim=1, keepdim=True)
        a = a.reshape(a.shape[0], -1, 1)

        A = torch.inverse(xsT.bmm(a * xs)).bmm(xsT.bmm(P.bmm(xt)))

        T = targets.mean() - pred.mean()

        return A.type_as(pred.coord), T.detach()
Exemple #7
0
    def forward(pred: ProbabilityMeasure,
                targets: ProbabilityMeasure,
                iters: int = 200):
        lambd = 0.002

        with torch.no_grad():
            P = SOT(iters, lambd).forward(pred.centered(), targets.centered())

        xs = pred.centered().coord
        xsT = xs.transpose(1, 2)
        xt = targets.centered().coord

        a = pred.probability + 1e-8
        a /= a.sum(dim=1, keepdim=True)
        a = a.reshape(a.shape[0], -1, 1)

        A = torch.inverse(xsT.bmm(a * xs)).bmm(xsT.bmm(P.bmm(xt)))

        T = targets.mean() - pred.mean()

        return A.type_as(pred.coord), T.detach()
Exemple #8
0
 def apply_to_keypoint(self, kp, **params):
     x, y, a, s = kp
     return ProbabilityMeasure(
         params["prob"], torch.cat([y[..., None], x[..., None]], dim=-1))
Exemple #9
0
 def apply_to_mask(self, img: ProbabilityMeasure, **params):
     return img.toImage(self.size)
Exemple #10
0
def clusterization(images: torch.Tensor, size=256, padding=70):
    imgs = images.cpu().numpy().squeeze()
    pattern = generate_binary_structure(2, 2)
    coord_result, prob_result = [], []

    # print("img sum:", images.sum(dim=[1,2,3]).max())
    # t1 = time.time()

    # for sample in range(imgs.shape[0]):
    def compute(sample):
        x, y = np.where((imgs[sample] > 1e-6))
        measure_mask = np.zeros((2, size, size))
        measure_mask[0, x, y] = 1
        measure_mask[1, x, y] = imgs[sample, x, y]
        labeled_array, num_features = label(measure_mask[0], structure=pattern)
        # if num_features > 75:
        #     print(num_features)

        x_coords, y_coords, prob_value = [], [], []
        sample_centroids_coords, sample_probs_value = [], []

        for i in range(1, num_features + 1):
            x_clust, y_clust = np.where(labeled_array == i)
            x_coords.append(np.average(x_clust) / size)
            y_coords.append(np.average(y_clust) / size)
            prob_value.append(np.sum(measure_mask[1, x_clust, y_clust]))
            assert (measure_mask[1, x_clust, y_clust].all() != 0)
            # print("PROB_VALUE ", prob_value)

        [x_coords.append(0) for i in range(padding - len(x_coords))]
        [y_coords.append(0) for i in range(padding - len(y_coords))]
        [prob_value.append(0) for i in range(padding - len(prob_value))]

        sample_centroids_coords.append([x_coords, y_coords])
        sample_probs_value.append(prob_value)

        sample_centroids_coords = np.transpose(
            np.array(sample_centroids_coords), axes=(0, 2, 1))
        sample_probs_value = np.array(sample_probs_value)

        # coord_result.append(sample_centroids_coords)
        # assert(sample_probs_value.sum() != 0)
        # assert(sample_probs_value.all() / sample_probs_value.sum() >= 0)
        # prob_result.append(sample_probs_value / sample_probs_value.sum())
        return x_coords, y_coords, sample_probs_value / (
            sample_probs_value.sum() + 1e-8)
        # return sample_centroids_coords,  sample_probs_value / (sample_probs_value.sum() + 1e-8)

    processed_list = Parallel(n_jobs=16)(delayed(compute)(i)
                                         for i in range(imgs.shape[0]))

    for x, y, p in processed_list:
        coord_result.append(
            torch.cat((torch.tensor(y)[:, None], torch.tensor(x)[:, None]),
                      dim=1)[None, ...])
        prob_result.append(p)
    # print(time.time() - t1)

    return ProbabilityMeasure(
        torch.tensor(np.concatenate(prob_result, axis=0)).type(torch.float32),
        torch.cat(coord_result).type(torch.float32)).cuda()
Exemple #11
0
def train(generator, discriminator, encoder, style_encoder, device,
          starting_model_number):

    batch = 32
    Celeba.batch_size = batch

    latent_size = 512
    model = CondStyleGanModel(generator, StyleGANLoss(discriminator),
                              (0.001, 0.0015))

    style_opt = optim.Adam(style_encoder.parameters(),
                           lr=5e-4,
                           betas=(0.5, 0.97))

    g_transforms: albumentations.DualTransform = albumentations.Compose([
        ToNumpy(),
        NumpyBatch(
            albumentations.ElasticTransform(p=0.8,
                                            alpha=150,
                                            alpha_affine=1,
                                            sigma=10)),
        NumpyBatch(albumentations.ShiftScaleRotate(p=0.5, rotate_limit=10)),
        ToTensor(device)
    ])

    R_s = UnoTransformRegularizer.__call__(
        g_transforms, lambda trans_dict, img, ltnt: L1("R_s")
        (ltnt, style_encoder(trans_dict['image'])))

    sample_z = torch.randn(batch, latent_size, device=device)
    test_img = next(LazyLoader.celeba().loader).to(device)
    print(test_img.shape)
    test_cond = encoder(test_img)

    requires_grad(encoder, False)  # REMOVE BEFORE TRAINING

    t_start = time.time()

    for i in range(100000):
        counter.update(i)
        real_img = next(LazyLoader.celeba().loader).to(device)

        img_content = encoder(real_img).detach()

        noise = mixing_noise(batch, latent_size, 0.9, device)
        fake, _ = generator(img_content, noise)

        model.discriminator_train([real_img], [fake.detach()], img_content)

        writable("Generator loss", model.generator_loss)([real_img], [fake], [], img_content)\
            .minimize_step(model.optimizer.opt_min)

        # print("gen train", time.time() - t1)

        if i % 5 == 0 and i > 0:
            noise = mixing_noise(batch, latent_size, 0.9, device)

            img_content = encoder(real_img).detach()
            fake, fake_latent = generator(img_content,
                                          noise,
                                          return_latents=True)

            fake_latent_test = fake_latent[:, [0, 13], :].detach()
            fake_latent_pred = style_encoder(fake)
            fake_content_pred = encoder(fake)

            restored = generator.module.decode(
                img_content[:batch // 2], style_encoder(real_img[:batch // 2]))
            (HMLoss("BCE content gan", 5000)(fake_content_pred, img_content) +
             L1("L1 restored")(restored, real_img[:batch // 2]) * 50 +
             L1("L1 style gan")(fake_latent_pred, fake_latent_test) * 30 +
             R_s(fake.detach(), fake_latent_pred) * 50).minimize_step(
                 model.optimizer.opt_min, style_opt)

        if i % 100 == 0:
            t_100 = time.time()
            print(i, t_100 - t_start)
            t_start = time.time()
            with torch.no_grad():

                fake_img, _ = generator(test_cond, [sample_z])
                coords, p = heatmap_to_measure(test_cond)
                test_mes = ProbabilityMeasure(p, coords)
                iwm = imgs_with_mask(fake_img, test_mes.toImage(256))
                send_images_to_tensorboard(writer, iwm, "FAKE", i)

                iwm = imgs_with_mask(test_img, test_mes.toImage(256))
                send_images_to_tensorboard(writer, iwm, "REAL", i)

                restored = generator.module.decode(test_cond,
                                                   style_encoder(test_img))
                send_images_to_tensorboard(writer, restored, "RESTORED", i)

        if i % 10000 == 0 and i > 0:
            torch.save(
                {
                    'g': generator.state_dict(),
                    'd': discriminator.state_dict(),
                    'style': style_encoder.state_dict()
                    # 'enc': cont_style_encoder.state_dict(),
                },
                f'/trinity/home/n.buzun/PycharmProjects/saved/stylegan2_w300_{str(starting_model_number + i).zfill(6)}.pt',
            )
    #     HMLoss("BCE content gan", 5000)(fake_content_pred, img_content.detach()) +
    #     Loss(nn.L1Loss()(restored, real_img[:W300DatasetLoader.batch_size//2]) * 50) +
    #     Loss(nn.L1Loss()(fake_latent_pred, fake_latent_test) * 25) +
    #     R_s(fake.detach(), fake_latent_pred) * 50
    # ).minimize_step(
    #     model.optimizer.opt_min,
    #     style_opt,
    # )

    # img_content = encoder_HG(real_img)
    # fake, fake_latent = generator(img_content, noise, return_latents=True)
    # fake_content_pred = encoder_HG(fake)
    #
    #
    # disc_influence = model.loss.generator_loss(real=None, fake=[real_img, img_content]) * 2
    # (HMLoss("BCE content gan", 1)(fake_content_pred, img_content.detach()) +
    # disc_influence).minimize_step(enc_opt)

    if i % 50 == 0 and i > 0:
        with torch.no_grad():
            test_loss = test(encoder_HG)
            print(test_loss)
            # tuner.update(test_loss)
            coord, p = heatmap_to_measure(encoder_HG(w300_test_image))
            pred_measure = ProbabilityMeasure(p, coord)
            iwm = imgs_with_mask(w300_test_image, pred_measure.toImage(256))
            send_images_to_tensorboard(writer, iwm, "W300_test_image", i)
            writer.add_scalar("test_loss", test_loss, i)

    # torch.save(enc.state_dict(), f"/home/ibespalov/pomoika/hg2_e{epoch}.pt")
Exemple #13
0
def content_to_measure(content):
    pred_measures: ProbabilityMeasure = ProbabilityMeasure(
            torch.ones(args.batch_size, 70, device=device) / 70,
            content.reshape(args.batch_size, 70, 2)
        )
    return pred_measures
coord = barycenter.coord

opt = optim.Adam(iter([coord]), lr=0.0006)

encoder_HG = HG_softmax2020(num_classes=68, heatmap_size=64)
encoder_HG.load_state_dict(
    torch.load(f"{Paths.default.models()}/hg2_e29.pt", map_location="cpu"))
encoder_HG = encoder_HG.cuda()

for iter in range(3000):

    img = next(LazyLoader.celeba().loader).cuda()
    content = encoder_HG(img)
    coord, p = heatmap_to_measure(content)
    mes = ProbabilityMeasure(p, coord)

    barycenter_cat = fabric.cat([barycenter] * batch_size)

    loss = Samples_Loss()(barycenter_cat, mes)

    opt.zero_grad()
    loss.to_tensor().backward()
    opt.step()

    barycenter.probability.data = barycenter.probability.relu().data
    barycenter.probability.data /= barycenter.probability.sum(dim=1,
                                                              keepdim=True)

    if iter % 100 == 0:
        print(iter, loss.item())