def sample(_, decoder, args: Options, trace: ViewMem): if trace.memory is None: num_items = get_integer((1, 8)) else: num_items = trace.memory[1] z = torch.randn(num_items, args.dim_z).to(trace.device) gms = decoder(z) vs, splits = gm_utils.hierarchical_gm_sample(gms, trace.points_in_sample, False) vs = vs.cpu().numpy() splits = [s for s in splits] palette = create_palettes(splits) im, points = view([vs_ for vs_ in vs], splits, palette) return True, (sample, num_items), im, (points, splits, palette)
def interpolate(encoder, decoder, args: Options, trace: ViewMem): if trace.memory is None: msg = '\tPlease choose number of interpolation: from 8 to 20\n\t' num_interpulate = get_integer((7, 21), msg) else: num_interpulate = trace.memory[1] z = get_z_by_id(encoder, args, 2, trace.last_idx, trace) gms = decoder.interpulate(z, num_interpulate) vs, splits = gm_utils.hierarchical_gm_sample(gms, trace.points_in_sample) spread = [vs[i].cpu().numpy() for i in range(vs.shape[0])] splits = [s for s in splits] palette = create_palettes(splits) im, points = view(spread, splits, palette) #, save_path=f'{cp_folder}/interpulate_{idx[0].item()}_{idx[1].item()}.png') return True, (interpolate, (num_interpulate,)), im, (points, splits, palette)
def hgmms(encoder, decoder, args: Options, trace: ViewMem): z = get_z_by_id(encoder, args, 1, None, trace) gms = decoder(z) num_gms = len(gms) vs = [] splits = [] for i in range(num_gms): gms_ = [gms[i] for i in range(i+1)] vs_, splits_ = gm_utils.hierarchical_gm_sample(gms_, trace.points_in_sample) vs.append(vs_.squeeze(0).cpu().numpy()) splits.append(splits_.squeeze(0).cpu().numpy()) palette = create_palettes(splits) im, points = view(vs, splits, palette) return True, (hgmms, ), im, (points, splits, palette)
def before_plot(self, data, eval_size=-1): y, encoder_inp = self.arrange_data(data, eval_size) z, _, _ = self.encoder(encoder_inp) gms = self.decoder(z) vs, splits = gm_utils.hierarchical_gm_sample( gms, self.opt.partial_samples[0], self.opt.flatten_sigma) # splits_inp = np.array([0, encoder_inp.shape[1]]) # splits_y = np.array([0, y[1]]) # transform back if self.opt.recon or len(self.opt.transforms) > 0: transforms = data[3:] for i in range(len(transforms)): transform = transforms[-(i + 1)][:vs.shape[0]].to(self.device) if transform.dim() == 2: t = lambda x: x - transform[:, None, :] else: t = lambda x: torch.einsum('bnd,brd->bnr', x, transform) vs, vs_in, y = list(map(t, [vs, encoder_inp, y])) out = list( map(lambda x: x.data.cpu().numpy(), [vs, encoder_inp, y, splits])) return out