Example #1
0
    def _generate(self, template_src):
        with torch.no_grad():
            args = parse_args(
                f'--cfg scae/config/mnist.yaml --debug'.split(' '))
            args.pcae.num_caps = self.num_caps
            args.im_channels = 1

            pcae_decoder = TemplateImageDecoder(args).cuda()
            if template_src is not None:
                import wandb
                from pytorch_lightning.utilities.cloud_io import load as pl_load
                best_model = wandb.restore('last.ckpt',
                                           run_path=template_src,
                                           replace=True)
                pcae_decoder.templates = torch.nn.Parameter(
                    pl_load(best_model.name)['state_dict']
                    ['decoder.templates'].contiguous())
            templates = pcae_decoder._template_nonlin(pcae_decoder.templates)

            valid_part_poses = []
            valid_presences = []
            while len(valid_part_poses) < MNISTObjects.NUM_CLASSES:
                presences_shape = (MNISTObjects.NUM_CLASSES, self.num_caps)
                presences = Bernoulli(.99).sample(
                    presences_shape).float().cuda()

                part_poses = self.rand_poses(
                    (MNISTObjects.NUM_CLASSES, self.num_caps),
                    size_ratio=args.pcae.decoder.template_size[0] /
                    args.pcae.decoder.output_size[0] / 2)
                part_poses = math_utils.geometric_transform(part_poses,
                                                            similarity=True,
                                                            inverse=True,
                                                            as_matrix=True)

                temp_poses = part_poses[..., :2, :]
                temp_poses = temp_poses.reshape(*temp_poses.shape[:-2], 6)

                transformed_templates = self.transform_templates(
                    templates, temp_poses)
                metric = self.overlap_metric(transformed_templates, presences)
                metric = metric * (presences.bool().unsqueeze(-1)
                                   | presences.bool().unsqueeze(-2)).float()

                for i in range(MNISTObjects.NUM_CLASSES):
                    if ((metric[i] == 0) | ((10 < metric[i]) & (metric[i] < 20))).all()\
                            and (metric[i] > 0).any():
                        valid_part_poses.append(part_poses[i])
                        valid_presences.append(presences[i])

            part_poses = torch.stack(
                valid_part_poses[:MNISTObjects.NUM_CLASSES])
            presences = torch.stack(valid_presences[:MNISTObjects.NUM_CLASSES])

            # Vis final objects
            # temp_poses = part_poses[..., :2, :]
            # temp_poses = temp_poses.reshape(*temp_poses.shape[:-2], 6)
            # transformed_templates = self.transform_templates(templates, temp_poses)
            # plot_image_tensor((transformed_templates.T * presences.T).T.max(dim=1)[0])

            # Tensor of shape (batch_size, self._n_caps, 6)
            object_poses = self.rand_poses((MNISTObjects.NUM_SAMPLES, 1),
                                           size_ratio=6)
            object_poses = math_utils.geometric_transform(object_poses,
                                                          similarity=True,
                                                          inverse=True,
                                                          as_matrix=True)

            jitter_poses = self.rand_jitter_poses(
                (MNISTObjects.NUM_SAMPLES, self.num_caps))
            jitter_poses = math_utils.geometric_transform(jitter_poses,
                                                          similarity=True,
                                                          inverse=True,
                                                          as_matrix=True)

            poses = jitter_poses\
                    @ part_poses.repeat((MNISTObjects.NUM_SAMPLES // MNISTObjects.NUM_CLASSES, 1, 1, 1))\
                    @ object_poses.expand((MNISTObjects.NUM_SAMPLES, self.num_caps, -1, -1))
            poses = poses[..., :2, :]
            poses = poses.reshape(*poses.shape[:-2], 6)

            presences = presences.repeat(
                (MNISTObjects.NUM_SAMPLES // MNISTObjects.NUM_CLASSES, 1))

            if self.template_mixing == 'pdf':
                rec = pcae_decoder(poses, presences)
                images = rec.pdf.mean()
            elif self.template_mixing == 'max':
                transformed_templates = self.transform_templates(
                    templates, poses)
                # templates = templates.repeat((MNISTObjects.NUM_SAMPLES // MNISTObjects.NUM_CLASSES, 1))
                images = (transformed_templates.T *
                          presences.T).T.max(dim=1)[0]
            else:
                raise ValueError(
                    f'Invalid template_mixing value {self.template_mixing}')

            self.data = EasyDict(images=images,
                                 templates=pcae_decoder.templates,
                                 jitter_poses=jitter_poses,
                                 caps_poses=part_poses,
                                 sample_poses=object_poses)
import torch
from torch.distributions import Bernoulli

# zz = torch.zeros((3,4,5))
# oo = torch.ones(zz.size())
# oo = torch.ones_like(zz)
# print(oo)
#
# # test scatter
bs = 3
hw = 4
c = 3
rads = torch.randint(0, hw * hw, (bs, )).long()
print(rads)
rads = torch.randint(0, hw * hw, (1, )).long().repeat(bs)
print(rads)
# print(rads)
# z = torch.zeros(bs, hw*hw).scatter_(1, rads, 1)
# print(z)
# print(z.reshape((bs,c,hw,hw)))

# repeat along axis
index = Bernoulli(0.4).sample((c, )).repeat(bs).reshape(bs, -1)
print(index)
# print(index)
index = torch.unsqueeze(index, 2)
print(index)
# ar = index.repeat(1,hw,1)
# print(ar)
br = index.repeat(1, 1, hw)
print(br)