def recover_z_shape(t_z): z1 = squeeze(t_z.l1, reverse=True) z2 = torch.cat([z1, t_z.l2], dim=1) z2 = squeeze(z2, reverse=True) z3 = torch.cat([z2, t_z.l3], dim=1) z3 = squeeze(z3, reverse=True) return z3
def sample_100(context, glow, nn_theta, temperature=0.5): for net in glow: net.eval() for net in nn_theta: net.eval() b_s = context.size(0) # generate two frames torchvision.utils.save_image(context[:, 0, ...].squeeze(), 'samples_100/context.png') for m in range(100): print(m) context_frame = context[:, 0, ...] for n in range(2): t0_zi, _, _ = flow_forward(context_frame, glow) mu_l1, logsigma_l1 = nn_theta.l1(t0_zi.l1) g1 = Normal(loc=mu_l1, scale=temperature * torch.exp(logsigma_l1)) z1_sample = g1.sample() sldj = torch.zeros(b_s, device=device) # Inverse L1 h1, sldj = glow.l1(z1_sample, sldj, reverse=True) h1 = squeeze(h1, reverse=True) # Sample z2 mu_l2, logsigma_l2 = nn_theta.l2(t0_zi.l2, h1) g2 = Normal(loc=mu_l2, scale=temperature * torch.exp(logsigma_l2)) z2_sample = g2.sample() h12 = torch.cat([h1, z2_sample], dim=1) h12, sldj = glow.l2(h12, sldj, reverse=True) h12 = squeeze(h12, reverse=True) # Sample z3 mu_l3, logsigma_l3 = nn_theta.l3(t0_zi.l3, h12) g3 = Normal(loc=mu_l3, scale=temperature * torch.exp(logsigma_l3)) z3_sample = g3.sample() x_t = torch.cat([h12, z3_sample], dim=1) x_t, sldj = glow.l3(x_t, sldj, reverse=True) x_t = squeeze(x_t, reverse=True) x_t = torch.sigmoid(x_t) if not os.path.exists('samples_100/'): os.mkdir('samples_100/') torchvision.utils.save_image( x_t, 'samples_100/sample{}_{}.png'.format(m, n + 1)) assert context_frame.shape == x_t.shape context_frame = x_t.clone()
def flow_inverse_smovement(context, glow, nn_theta, epoch): for net in glow: net.eval() for net in nn_theta: net.eval() # pre-process the context frame b_s = context.size(0) context_frame = context[:, 0, ...] t0_zi, _, _ = flow_forward(context_frame, glow) mu_l1, logsigma_l1 = nn_theta.l1(t0_zi.l1) g1 = Normal(loc=mu_l1, scale=torch.exp(logsigma_l1)) z1_sample = g1.sample() print("z1", z1_sample.shape) sldj = torch.zeros(b_s, device=device) # Inverse L1 h1, sldj = glow.l1(z1_sample, sldj, reverse=True) h1 = squeeze(h1, reverse=True) # Sample z2 mu_l2, logsigma_l2 = nn_theta.l2(t0_zi.l2, h1) g2 = Normal(loc=mu_l2, scale=torch.exp(logsigma_l2)) z2_sample = g2.sample() h12 = torch.cat([h1, z2_sample], dim=1) h12, sldj = glow.l2(h12, sldj, reverse=True) h12 = squeeze(h12, reverse=True) # Sample z3 mu_l3, logsigma_l3 = nn_theta.l3(t0_zi.l3, h12) g3 = Normal(loc=mu_l3, scale=torch.exp(logsigma_l3)) z3_sample = g3.sample() x_t = torch.cat([h12, z3_sample], dim=1) x_t, sldj = glow.l3(x_t, sldj, reverse=True) x_t = squeeze(x_t, reverse=True) x_t = torch.sigmoid(x_t) torchvision.utils.save_image(x_t, 'samples/sample{}.png'.format(epoch)) torchvision.utils.save_image(context[:, 0, ...].squeeze(), 'samples/context{}.png'.format(epoch)) torchvision.utils.save_image(context[:, 1, ...].squeeze(), 'samples/gt{}.png'.format(epoch))
def flow_forward(x, flow): if x.min() < 0 or x.max() > 1: raise ValueError('Expected x in [0, 1], got min/max {}/{}'.format( x.min(), x.max())) # pre-process x, sldj = pre_process(x) # L3 x3 = squeeze(x, reverse=False) x3, sldj = flow.l3(x3, sldj, reverse=False) x3, x_split3 = x3.chunk(2, dim=1) # L2 x2 = squeeze(x3, reverse=False) x2, sldj = flow.l2(x2, sldj, reverse=False) x2, x_split2 = x2.chunk(2, dim=1) # L1 x1 = squeeze(x2, reverse=False) x1, sldj = flow.l1(x1, sldj) partition_out = Z_splits(l3=x_split3, l2=x_split2, l1=x1) partition_h = Z_splits(l3=x3, l2=x2, l1=None) return partition_out, partition_h, sldj