def _generate(self, template_src): with torch.no_grad(): args = parse_args( f'--cfg scae/config/mnist.yaml --debug'.split(' ')) args.pcae.num_caps = self.num_caps args.im_channels = 1 pcae_decoder = TemplateImageDecoder(args).cuda() if template_src is not None: import wandb from pytorch_lightning.utilities.cloud_io import load as pl_load best_model = wandb.restore('last.ckpt', run_path=template_src, replace=True) pcae_decoder.templates = torch.nn.Parameter( pl_load(best_model.name)['state_dict'] ['decoder.templates'].contiguous()) templates = pcae_decoder._template_nonlin(pcae_decoder.templates) valid_part_poses = [] valid_presences = [] while len(valid_part_poses) < MNISTObjects.NUM_CLASSES: presences_shape = (MNISTObjects.NUM_CLASSES, self.num_caps) presences = Bernoulli(.99).sample( presences_shape).float().cuda() part_poses = self.rand_poses( (MNISTObjects.NUM_CLASSES, self.num_caps), size_ratio=args.pcae.decoder.template_size[0] / args.pcae.decoder.output_size[0] / 2) part_poses = math_utils.geometric_transform(part_poses, similarity=True, inverse=True, as_matrix=True) temp_poses = part_poses[..., :2, :] temp_poses = temp_poses.reshape(*temp_poses.shape[:-2], 6) transformed_templates = self.transform_templates( templates, temp_poses) metric = self.overlap_metric(transformed_templates, presences) metric = metric * (presences.bool().unsqueeze(-1) | presences.bool().unsqueeze(-2)).float() for i in range(MNISTObjects.NUM_CLASSES): if ((metric[i] == 0) | ((10 < metric[i]) & (metric[i] < 20))).all()\ and (metric[i] > 0).any(): valid_part_poses.append(part_poses[i]) valid_presences.append(presences[i]) part_poses = torch.stack( valid_part_poses[:MNISTObjects.NUM_CLASSES]) presences = torch.stack(valid_presences[:MNISTObjects.NUM_CLASSES]) # Vis final objects # temp_poses = part_poses[..., :2, :] # temp_poses = temp_poses.reshape(*temp_poses.shape[:-2], 6) # transformed_templates = self.transform_templates(templates, temp_poses) # plot_image_tensor((transformed_templates.T * presences.T).T.max(dim=1)[0]) # Tensor of shape (batch_size, self._n_caps, 6) object_poses = self.rand_poses((MNISTObjects.NUM_SAMPLES, 1), size_ratio=6) object_poses = math_utils.geometric_transform(object_poses, similarity=True, inverse=True, as_matrix=True) jitter_poses = self.rand_jitter_poses( (MNISTObjects.NUM_SAMPLES, self.num_caps)) jitter_poses = math_utils.geometric_transform(jitter_poses, similarity=True, inverse=True, as_matrix=True) poses = jitter_poses\ @ part_poses.repeat((MNISTObjects.NUM_SAMPLES // MNISTObjects.NUM_CLASSES, 1, 1, 1))\ @ object_poses.expand((MNISTObjects.NUM_SAMPLES, self.num_caps, -1, -1)) poses = poses[..., :2, :] poses = poses.reshape(*poses.shape[:-2], 6) presences = presences.repeat( (MNISTObjects.NUM_SAMPLES // MNISTObjects.NUM_CLASSES, 1)) if self.template_mixing == 'pdf': rec = pcae_decoder(poses, presences) images = rec.pdf.mean() elif self.template_mixing == 'max': transformed_templates = self.transform_templates( templates, poses) # templates = templates.repeat((MNISTObjects.NUM_SAMPLES // MNISTObjects.NUM_CLASSES, 1)) images = (transformed_templates.T * presences.T).T.max(dim=1)[0] else: raise ValueError( f'Invalid template_mixing value {self.template_mixing}') self.data = EasyDict(images=images, templates=pcae_decoder.templates, jitter_poses=jitter_poses, caps_poses=part_poses, sample_poses=object_poses)
import torch from torch.distributions import Bernoulli # zz = torch.zeros((3,4,5)) # oo = torch.ones(zz.size()) # oo = torch.ones_like(zz) # print(oo) # # # test scatter bs = 3 hw = 4 c = 3 rads = torch.randint(0, hw * hw, (bs, )).long() print(rads) rads = torch.randint(0, hw * hw, (1, )).long().repeat(bs) print(rads) # print(rads) # z = torch.zeros(bs, hw*hw).scatter_(1, rads, 1) # print(z) # print(z.reshape((bs,c,hw,hw))) # repeat along axis index = Bernoulli(0.4).sample((c, )).repeat(bs).reshape(bs, -1) print(index) # print(index) index = torch.unsqueeze(index, 2) print(index) # ar = index.repeat(1,hw,1) # print(ar) br = index.repeat(1, 1, hw) print(br)