def recover_z_shape(t_z):
    z1 = squeeze(t_z.l1, reverse=True)
    z2 = torch.cat([z1, t_z.l2], dim=1)
    z2 = squeeze(z2, reverse=True)
    z3 = torch.cat([z2, t_z.l3], dim=1)
    z3 = squeeze(z3, reverse=True)
    return z3
Exemple #2
0
def sample_100(context, glow, nn_theta, temperature=0.5):
    for net in glow:
        net.eval()
    for net in nn_theta:
        net.eval()

    b_s = context.size(0)
    # generate two frames
    torchvision.utils.save_image(context[:, 0, ...].squeeze(),
                                 'samples_100/context.png')

    for m in range(100):
        print(m)
        context_frame = context[:, 0, ...]
        for n in range(2):
            t0_zi, _, _ = flow_forward(context_frame, glow)

            mu_l1, logsigma_l1 = nn_theta.l1(t0_zi.l1)
            g1 = Normal(loc=mu_l1, scale=temperature * torch.exp(logsigma_l1))
            z1_sample = g1.sample()
            sldj = torch.zeros(b_s, device=device)

            # Inverse L1
            h1, sldj = glow.l1(z1_sample, sldj, reverse=True)
            h1 = squeeze(h1, reverse=True)

            # Sample z2
            mu_l2, logsigma_l2 = nn_theta.l2(t0_zi.l2, h1)
            g2 = Normal(loc=mu_l2, scale=temperature * torch.exp(logsigma_l2))
            z2_sample = g2.sample()
            h12 = torch.cat([h1, z2_sample], dim=1)
            h12, sldj = glow.l2(h12, sldj, reverse=True)
            h12 = squeeze(h12, reverse=True)

            # Sample z3
            mu_l3, logsigma_l3 = nn_theta.l3(t0_zi.l3, h12)
            g3 = Normal(loc=mu_l3, scale=temperature * torch.exp(logsigma_l3))
            z3_sample = g3.sample()

            x_t = torch.cat([h12, z3_sample], dim=1)
            x_t, sldj = glow.l3(x_t, sldj, reverse=True)
            x_t = squeeze(x_t, reverse=True)

            x_t = torch.sigmoid(x_t)

            if not os.path.exists('samples_100/'):
                os.mkdir('samples_100/')

            torchvision.utils.save_image(
                x_t, 'samples_100/sample{}_{}.png'.format(m, n + 1))

            assert context_frame.shape == x_t.shape
            context_frame = x_t.clone()
def flow_inverse_smovement(context, glow, nn_theta, epoch):
    for net in glow:
        net.eval()
    for net in nn_theta:
        net.eval()

    # pre-process the context frame
    b_s = context.size(0)
    context_frame = context[:, 0, ...]
    t0_zi, _, _ = flow_forward(context_frame, glow)

    mu_l1, logsigma_l1 = nn_theta.l1(t0_zi.l1)
    g1 = Normal(loc=mu_l1, scale=torch.exp(logsigma_l1))
    z1_sample = g1.sample()
    print("z1", z1_sample.shape)
    sldj = torch.zeros(b_s, device=device)

    # Inverse L1
    h1, sldj = glow.l1(z1_sample, sldj, reverse=True)
    h1 = squeeze(h1, reverse=True)

    # Sample z2
    mu_l2, logsigma_l2 = nn_theta.l2(t0_zi.l2, h1)
    g2 = Normal(loc=mu_l2, scale=torch.exp(logsigma_l2))
    z2_sample = g2.sample()
    h12 = torch.cat([h1, z2_sample], dim=1)
    h12, sldj = glow.l2(h12, sldj, reverse=True)
    h12 = squeeze(h12, reverse=True)

    # Sample z3
    mu_l3, logsigma_l3 = nn_theta.l3(t0_zi.l3, h12)
    g3 = Normal(loc=mu_l3, scale=torch.exp(logsigma_l3))
    z3_sample = g3.sample()

    x_t = torch.cat([h12, z3_sample], dim=1)
    x_t, sldj = glow.l3(x_t, sldj, reverse=True)
    x_t = squeeze(x_t, reverse=True)

    x_t = torch.sigmoid(x_t)

    torchvision.utils.save_image(x_t, 'samples/sample{}.png'.format(epoch))
    torchvision.utils.save_image(context[:, 0, ...].squeeze(),
                                 'samples/context{}.png'.format(epoch))
    torchvision.utils.save_image(context[:, 1, ...].squeeze(),
                                 'samples/gt{}.png'.format(epoch))
Exemple #4
0
def flow_forward(x, flow):
    if x.min() < 0 or x.max() > 1:
        raise ValueError('Expected x in [0, 1], got min/max {}/{}'.format(
            x.min(), x.max()))

    # pre-process
    x, sldj = pre_process(x)
    # L3
    x3 = squeeze(x, reverse=False)
    x3, sldj = flow.l3(x3, sldj, reverse=False)
    x3, x_split3 = x3.chunk(2, dim=1)
    # L2
    x2 = squeeze(x3, reverse=False)
    x2, sldj = flow.l2(x2, sldj, reverse=False)
    x2, x_split2 = x2.chunk(2, dim=1)
    # L1
    x1 = squeeze(x2, reverse=False)
    x1, sldj = flow.l1(x1, sldj)

    partition_out = Z_splits(l3=x_split3, l2=x_split2, l1=x1)
    partition_h = Z_splits(l3=x3, l2=x2, l1=None)

    return partition_out, partition_h, sldj