Exemple #1
0
def plot_rec(x, epoch):
    # get fixed content vector from last ground truth frame
    h_c = netEC(x[opt.n_past - 1])
    if type(h_c) is tuple:
        vec_h_c = h_c[0].detach()
    else:
        vec_h_c = h_c.detach()

    lstm.hidden = lstm.init_hidden()
    gen_seq = []
    gen_seq.append(x[0])
    for i in range(1, opt.n_past + opt.n_future):
        h_p = netEP(x[i - 1]).detach()
        h_pred = lstm(torch.cat([h_p, vec_h_c], 1))
        if i < opt.n_past:
            gen_seq.append(x[i])
        else:
            pred_x = netD([h_c, h_pred]).detach()
            gen_seq.append(pred_x)

    to_plot = []
    nrow = 10
    for i in range(nrow):
        # ground truth
        row = []
        for t in range(opt.n_past + opt.n_future):
            row.append(x[t][i])
        to_plot.append(row)
        # gen
        row = []
        for t in range(opt.n_past + opt.n_future):
            row.append(gen_seq[t][i])
        to_plot.append(row)
    fname = '%s/gen/rec_%d.png' % (opt.log_dir, epoch)
    utils.save_tensors_image(fname, to_plot)
Exemple #2
0
def plot_rec(x, y, s, epoch):
    y = y.view(opt.batch_size, opt.nclass, 1, 1)

    if s == 0:
        _, z, _ = netE[s]((x[s], y))
        rec = netD[s](torch.cat([z, y], 1))
    else:
        residual = x[s] - x[s - 1]
        _, z, _ = netE[s]((torch.cat([x[s], residual], 1), y))
        rec = netD[s]([x[s - 1], torch.cat([z, y], 1)])

    to_plot = []
    nrow = opt.nclass
    ncol = int(opt.batch_size / nrow)
    for i in range(nrow):
        row = []
        for j in range(ncol):
            row.append(x[s][i * ncol + j])
            #if s > 0:
            #    row.append(residual[i*ncol+j])
            row.append(rec[i * ncol + j])
        to_plot.append(row)

    fname = '%s/rec/%d_%d.png' % (opt.log_dir, epoch, opt.image_width /
                                  (2**(opt.nlevels - s - 1)))
    utils.save_tensors_image(fname, to_plot)
Exemple #3
0
def plot_rec(x, epoch):
    # get fixed content vector from last ground truth frame
    h_c = netEC(x[opt.n_past-1])
    if type(h_c) is tuple:
        vec_h_c = h_c[0].detach()
    else:
        vec_h_c = h_c.detach()

    lstm.hidden = lstm.init_hidden()
    gen_seq = []
    gen_seq.append(x[0])
    for i in range(1, opt.n_past+opt.n_future):
        h_p = netEP(x[i-1]).detach()
        h_pred = lstm(torch.cat([h_p, vec_h_c], 1))
        if i < opt.n_past:
            gen_seq.append(x[i])
        else:
            pred_x = netD([h_c, h_pred]).detach()
            gen_seq.append(pred_x)

    to_plot = []
    nrow = 10
    for i in range(nrow):
        # ground truth
        row = []
        for t in range(opt.n_past+opt.n_future):
            row.append(x[t][i])
        to_plot.append(row)
        # gen
        row = []
        for t in range(opt.n_past+opt.n_future):
            row.append(gen_seq[t][i])
        to_plot.append(row)
    fname = '%s/gen/rec_%d.png' % (opt.log_dir, epoch)
    utils.save_tensors_image(fname, to_plot)
def plot_rec(x, epoch):
    frame_predictor.hidden = frame_predictor.init_hidden()
    posterior.hidden = posterior.init_hidden()
    gen_seq = []
    gen_seq.append(x[0])
    x_in = x[0]
    for i in range(1, opt.n_past+opt.n_future):
        #print(len(encoder(x[i])))
        try:
            h_target = encoder(x[i])[0]
        except:
            continue
        if i < opt.n_past:	
            h, skip = encoder(x[i-1])
        x_pred = decoder([h_target.detach(), skip])
        gen_seq.append(x_pred)
   
    to_plot = []
    nrow = min(opt.batch_size, 10)
    for i in range(nrow):
        row = []
        for t in range(opt.n_past+opt.n_future):
            row.append(gen_seq[t][i]) 
        to_plot.append(row)
    fname = '%s/gen/rec_%d.png' % (opt.log_dir, epoch) 
    utils.save_tensors_image(fname, to_plot)
Exemple #5
0
def plot_rec(x, epoch):
    frame_predictor.hidden = frame_predictor.init_hidden()
    posterior.hidden = posterior.init_hidden()
    gen_seq = []
    gen_seq.append(x[0])
    x_in = x[0]
    for i in range(1, opt.n_past + opt.n_future):
        h = encoder(x[i - 1])
        h_target = encoder(x[i])
        if opt.last_frame_skip or i < opt.n_past:
            h, skip = h
        else:
            h, _ = h
        h_target, _ = h_target
        h = h.detach()
        h_target = h_target.detach()
        z_t, _, _ = posterior(h_target)
        if i < opt.n_past:
            frame_predictor(torch.cat([h, z_t], 1))
            gen_seq.append(x[i])
        else:
            h_pred = frame_predictor(torch.cat([h, z_t], 1))
            x_pred = decoder([h_pred, skip]).detach()
            gen_seq.append(x_pred)

    to_plot = []
    nrow = min(opt.batch_size, 10)
    for i in range(nrow):
        row = []
        for t in range(opt.n_past + opt.n_future):
            row.append(gen_seq[t][i])
        to_plot.append(row)
    fname = '%s/gen_1/rec_%d.png' % (opt.log_dir, epoch)
    utils.save_tensors_image(fname, to_plot)
Exemple #6
0
def plot(x, epoch):
    nsample = 5 
    gen_seq = [[] for i in range(nsample)]
    gt_seq = [x[i] for i in range(len(x))]

    h_seq = [encoder(x[i]) for i in range(opt.n_past)]
    for s in range(nsample):
        frame_predictor.hidden = frame_predictor.init_hidden()
        posterior.hidden = posterior.init_hidden()
        gen_seq[s].append(x[0])
        x_in = x[0]
        for i in range(1, opt.n_eval):
            ## When i > opt.n_past, generated frame should be
            ## put back into encoder to generate content vector
            if opt.last_frame_skip or i <= opt.n_past:
                h, skip = h_seq[i-1]
                h = h.detach()
            else:
                h, _ = encoder(x_in)
                h = h.detach()
                
            if i < opt.n_past:
                z_t, _, _ = posterior(h_seq[i-1][0])
                frame_predictor(torch.cat([h, z_t], 1))
                x_in = x[i]
                gen_seq[s].append(x_in)
            else:
                z_t = torch.cuda.FloatTensor(opt.batch_size, opt.z_dim).normal_()
                h = frame_predictor(torch.cat([h, z_t], 1)).detach()
                x_in = decoder([h, skip]).detach()
                gen_seq[s].append(x_in)

    to_plot = []
    gifs = [ [] for t in range(opt.n_eval) ]
    nrow = min(opt.batch_size, 10)
    for i in range(nrow):
        # ground truth sequence
        row = [] 
        for t in range(opt.n_eval):
            row.append(gt_seq[t][i])
        to_plot.append(row)

        for s in range(nsample):
            row = []
            for t in range(opt.n_eval):
                row.append(gen_seq[s][t][i]) 
            to_plot.append(row)
        for t in range(opt.n_eval):
            row = []
            row.append(gt_seq[t][i])
            for s in range(nsample):
                row.append(gen_seq[s][t][i])
            gifs[t].append(row)

    fname = '%s/gen/sample_%d.png' % (opt.log_dir, epoch) 
    utils.save_tensors_image(fname, to_plot)

    fname = '%s/gen/sample_%d.gif' % (opt.log_dir, epoch) 
    utils.save_gif(fname, gifs)
Exemple #7
0
def plot_gen(epoch):
    nrow = opt.nclass
    ncol = int(opt.batch_size / nrow)
    y = torch.Tensor(opt.batch_size, opt.nclass, 1, 1).cuda().zero_()
    for i in range(nrow):
        for j in range(ncol):
            y[i * ncol + j][i] = 1
    scales = []
    residuals = []
    for s in range(opt.nlevels):
        if s == 0:
            z = make_plot_z(
            )  #torch.cuda.FloatTensor(opt.batch_size, opt.z_dim, 1, 1).normal_()
            gen = netD[s](torch.cat([z, y], 1)).detach()
            residual = 0
        else:
            z = make_plot_z(
            )  #torch.cuda.FloatTensor(opt.batch_size, opt.z_dim, 1, 1).normal_()
            gen = netD[s]([scales[s - 1], torch.cat([z, y], 1)]).detach()
            #gen = nn.Sigmoid()(scales[s-1] + residual)
            #residual.data = residual.data.mul(0.5).add(0.5)
        scales.append(gen)
        #residuals.append(residual)

    to_plot = []
    for i in range(nrow):
        row = []
        for j in range(ncol):
            row.append(scales[-1][i * ncol + j])
        to_plot.append(row)

    fname = '%s/gen/%d.png' % (opt.log_dir, epoch)
    utils.save_tensors_image(fname, to_plot)

    to_plot = []
    nrow = 6
    ncol = 3
    for i in range(nrow):
        row = []
        for j in range(ncol):
            row.append(scales[0][i * ncol + j])
            for s in range(1, opt.nlevels):
                #row.append(residuals[s][i*ncol+j])
                row.append(scales[s][i * ncol + j])
        to_plot.append(row)

    fname = '%s/gen/scales_%d.png' % (opt.log_dir, epoch)
    utils.save_tensors_image(fname, to_plot)
Exemple #8
0
def plot(x, epoch):
    _, _, yp = x
    nrow = opt.nclass 
    ncol = int(opt.batch_size/nrow) 

    gen = netG(torch.cat([z_fixed, yp.view(opt.batch_size, opt.nclass, 1, 1)], 1))

    to_plot = []
    for i in range(nrow):
        row = []
        for j in range(ncol):
            row.append(gen[i*ncol+j])
        to_plot.append(row)

    fname = '%s/gen/yp_%d.png' % (opt.log_dir, epoch) 
    utils.save_tensors_image(fname, to_plot)
Exemple #9
0
def plot_rec(x, epoch):
      x_c = x[0]
      x_p = x[np.random.randint(1, opt.max_step)]

      h_c = netEC(x_c)
      h_p = netEP(x_p)
      rec = netD([h_c, h_p])

      x_c, x_p, rec = x_c.data, x_p.data, rec.data
      fname = '%s/rec/%d.png' % (opt.log_dir, epoch) 
      to_plot = []
      row_sz = 5
      nplot = 20
      for i in range(0, nplot-row_sz, row_sz):
          row = [[xc, xp, xr] for xc, xp, xr in zip(x_c[i:i+row_sz], x_p[i:i+row_sz], rec[i:i+row_sz])]
          to_plot.append(list(itertools.chain(*row)))
      utils.save_tensors_image(fname, to_plot)
Exemple #10
0
def plot_rec(x, epoch):
    x_c = x[0]
    x_p = x[np.random.randint(1, opt.max_step)]

    h_c = netEC(x_c)
    h_p = netEP(x_p)
    rec = netD([h_c, h_p])

    x_c, x_p, rec = x_c.data, x_p.data, rec.data
    fname = '%s/rec/%d.png' % (opt.log_dir, epoch)
    to_plot = []
    row_sz = 5
    nplot = 20
    for i in range(0, nplot - row_sz, row_sz):
        row = [[xc, xp, xr] for xc, xp, xr in zip(x_c[i:i + row_sz], x_p[i:i + row_sz], rec[i:i + row_sz])]
        to_plot.append(list(itertools.chain(*row)))
    utils.save_tensors_image(fname, to_plot)
Exemple #11
0
def plot_gen(x, epoch):
    x, y, yp = x

    nrow = opt.nclass
    ncol = int(opt.batch_size / nrow)

    if opt.all_labels:
        y_onehot = yp.view(opt.batch_size, opt.nclass, 1, 1)

        # different class per row
        y_onehot = torch.Tensor(opt.batch_size, opt.nclass, 1,
                                1).cuda().zero_()
        for i in range(nrow):
            for j in range(ncol):
                y_onehot[i * ncol + j][i] = 1
        gen = model.decode(z_fixed, y_onehot)

        to_plot = []
        for i in range(nrow):
            row = []
            for j in range(ncol):
                row.append(gen[i * ncol + j])
            to_plot.append(row)

        fname = '%s/gen/p_%d.png' % (opt.log_dir, epoch)
        utils.save_tensors_image(fname, to_plot)

    y_onehot = y_onehot.view(opt.batch_size, opt.nclass, 1, 1)

    # different class per row
    y_onehot = torch.Tensor(opt.batch_size, opt.nclass, 1, 1).cuda().zero_()
    for i in range(nrow):
        for j in range(ncol):
            y_onehot[i * ncol + j][i] = 1
    gen = model.decode(z_fixed, y_onehot)

    to_plot = []
    for i in range(nrow):
        row = []
        for j in range(ncol):
            row.append(gen[i * ncol + j])
        to_plot.append(row)

    fname = '%s/gen/%d.png' % (opt.log_dir, epoch)
    utils.save_tensors_image(fname, to_plot)
Exemple #12
0
def plot_rec(x, y, epoch):
    frame_predictor.hidden = frame_predictor.init_hidden()
    posterior_mu.hidden = posterior_mu.init_hidden()
    gen_seq = []
    gen_seq.append(x[0])
    x_in = x[0]
    h_match_prev = [encoder(y[m][0])[0].detach() for m in range(5)]
    for i in range(1, opt.n_past+opt.n_future):
        h = encoder(x[i-1])
        h_target = encoder(x[i])
        h_match = [encoder(y[m][i])[0].detach() for m in range(5)]
        if opt.last_frame_skip or i < opt.n_past:	
            h, skip = h
        else:
            h, _ = h
        h_target, _ = h_target
        h = h.detach()
        h_target = h_target.detach()
        
        mu = posterior_mu(h_target)
        ref_var = torch.mean(torch.cat([var_encoder(h_match[m] - h_match_prev[m] + h).unsqueeze(1) for m in range(5)], 1), 1)
        z_t = reparameterize(mu, ref_var)
        
        if i < opt.n_past:
            frame_predictor(torch.cat([h_match[m] - h_match_prev[m] for m in range(5)] + [h, z_t], 1)) 
            gen_seq.append(x[i])
        else:
            h_pred = frame_predictor(torch.cat([h_match[m] - h_match_prev[m] for m in range(5)] + [h, z_t], 1))
            x_pred = decoder([h_pred, skip]).detach()
            gen_seq.append(x_pred)

        h_match_prev = h_match

    to_plot = []
    nrow = min(opt.batch_size, 10)
    for i in range(nrow):
        row = []
        for t in range(opt.n_past+opt.n_future):
            row.append(gen_seq[t][i]) 
        to_plot.append(row)
    fname = '%s/gen/rec_%d.png' % (opt.log_dir, epoch) 
    utils.save_tensors_image(fname, to_plot)
Exemple #13
0
def plot_1hot(epoch):
    nrow = opt.nclass 
    ncol = int(opt.batch_size/nrow) 

    # different class per row
    y_onehot = torch.Tensor(opt.batch_size, opt.nclass, 1, 1).cuda().zero_()
    for i in range(nrow):
        for j in range(ncol):
            y_onehot[i*ncol+j][i] = 1
    gen = netG(torch.cat([z_fixed, y_onehot], 1))

    to_plot = []
    for i in range(nrow):
        row = []
        for j in range(ncol):
            row.append(gen[i*ncol+j])
        to_plot.append(row)

    fname = '%s/gen/%d.png' % (opt.log_dir, epoch) 
    utils.save_tensors_image(fname, to_plot)
Exemple #14
0
def plot_rec(model, x, epoch):
    model.init_states(x[0])
    gen_seq = [x[0]]
    for i in range(1, opt.n_past + opt.n_future):
        if i < opt.n_past:
            hs_rec, feats, zs, mus, logvars = model.reconstruction(x[i])
            model.skips = feats
            gen_seq.append(x[i])
        else:
            hs_rec, feats, zs, mus, logvars = model.reconstruction(x[i])
            x_rec = model.decoding(hs_rec)
            gen_seq.append(x_rec)

    to_plot = []
    nrow = min(opt.batch_size, 10)
    for i in range(nrow):
        row = []
        for t in range(opt.n_past + opt.n_future):
            row.append(gen_seq[t][i])
        to_plot.append(row)
    fname = '%s/gen/rec_%d.png' % (checkpoint_dir, epoch)
    utils.save_tensors_image(fname, to_plot)
Exemple #15
0
def plot_ind(x, epoch):
    x_c = x[0]

    h_c = netEC(x_c)
    nrow = 10
    row_sz = opt.max_step
    to_plot = []
    row = [xi[0].data for xi in x]
    zeros = torch.zeros(opt.channels, opt.image_width, opt.image_width)
    to_plot.append([zeros] + row)
    for i in range(nrow):
        to_plot.append([x[0][i].data])

    for j in range(0, row_sz):
        h_p = netEP(x[j]).data
        for i in range(nrow):
            h_p[i] = h_p[0]
        rec = netD([h_c, Variable(h_p)])
        for i in range(nrow):
            to_plot[i + 1].append(rec[i].data.clone())

    fname = '%s/analogy/%d.png' % (opt.log_dir, epoch)
    utils.save_tensors_image(fname, to_plot)
Exemple #16
0
def plot_rec(x, epoch):
    x, y, yp = x

    # convert the integer y into a one_hot representation for decoder
    if opt.all_labels:
        y_onehot = yp
    else:
        y_onehot = torch.Tensor(opt.batch_size, opt.nclass).cuda().zero_()
        y_onehot.scatter_(1, y.data.view(opt.batch_size, 1).long(), 1)
    y_onehot = y_onehot.view(opt.batch_size, opt.nclass, 1, 1)

    rec, _, _ = model((x, y_onehot))

    to_plot = []
    nrow, ncol = 8, 8
    for i in range(nrow):
        row = []
        for j in range(ncol):
            row.append(x[i * ncol + j])
            row.append(rec[i * ncol + j])
        to_plot.append(row)

    utils.save_tensors_image('%s/rec/%d.png' % (opt.log_dir, epoch), to_plot)
Exemple #17
0
def plot_analogy(x, epoch):
    x_c = x[0]

    h_c = netEC(x_c)
    nrow = 10
    row_sz = opt.max_step 
    to_plot = []
    row = [xi[0].data for xi in x]
    zeros = torch.zeros(opt.channels, opt.image_width, opt.image_width)
    to_plot.append([zeros] + row)
    for i in range(nrow):
        to_plot.append([x[0][i].data])

    for j in range(0, row_sz):
        h_p = netEP(x[j]).data
        for i in range(nrow):
            h_p[i] = h_p[0]
        rec = netD([h_c, Variable(h_p)])
        for i in range(nrow):
            to_plot[i+1].append(rec[i].data.clone())

    fname = '%s/analogy/%d.png' % (opt.log_dir, epoch) 
    utils.save_tensors_image(fname, to_plot)
Exemple #18
0
def plot_gen(x, epoch):
    # get fixed content vector from last ground truth frame
    h_c = netEC(x[opt.n_past-1])
    if type(h_c) is tuple:
        vec_h_c = h_c[0].detach()
    else:
        vec_h_c = h_c.detach()

    lstm.hidden = lstm.init_hidden()
    gen_seq = []
    h_p = netEP(x[0]).detach()
    gen_seq.append(x[0])
    for i in range(1, opt.n_past+opt.n_future):
        if i < opt.n_past:
            lstm(torch.cat([h_p, vec_h_c], 1))
            h_p =netEP(x[i]).detach()
            gen_seq.append(x[i])
        else:
            # print('h_p shape: ', h_p.shape)
            # print('vec_h_c shape: ', vec_h_c.shape)
            h_p = h_p.view([-1, h_p.shape[1], 1, 1])
            h_p = lstm(torch.cat([h_p, vec_h_c], 1))
            # print('h_p', h_p.size())
            # print('h_c', h_c.shape)
            pred_x = netD([h_c, h_p]).detach()
            # print('pred_x', pred_x)
            gen_seq.append(pred_x)

    to_plot = []
    nrow = 10
    for i in range(nrow):
        row = []
        for t in range(opt.n_past+opt.n_future):
            row.append(gen_seq[t][i])
        to_plot.append(row)
    fname = '%s/gen/gen_%d.png' % (opt.log_dir, epoch)
    utils.save_tensors_image(fname, to_plot)
Exemple #19
0
def plot(x, epoch):
    h_c = netEC(x[opt.n_past-1])
    lstm.hidden = lstm.init_hidden()
    gen_seq = []
    for i in range(opt.n_past):
        h_p = netEP(x[i]).detach()
        lstm(h_p)
        gen_seq.append(x[i])

    for i in range(opt.n_past, opt.n_past+opt.n_future):
        h_p = netEP(x[i]).detach()
        lstm(h_p)
        pred_x = netD([h_c, h_p])
        gen_seq.append(pred_x)

    to_plot = []
    nrow = 10
    for i in range(nrow):
        row = []
        for t in range(opt.n_past+opt.n_future):
            row.append(gen_seq[t][i]) 
        to_plot.append(row)
    fname = '%s/gen/%d.png' % (opt.log_dir, epoch) 
    utils.save_tensors_image(fname, to_plot)
Exemple #20
0
def plot(x, y, epoch):
    nsample = 20 
    gen_seq = [[] for i in range(nsample)]
    gt_seq = [x[i] for i in range(len(x))]

    for s in range(nsample):
        frame_predictor.hidden = frame_predictor.init_hidden()
        posterior_mu.hidden = posterior_mu.init_hidden()
        prior_mu.hidden = prior_mu.init_hidden()
        gen_seq[s].append(x[0])
        x_in = x[0]
        h_match_prev = [encoder(y[m][0])[0].detach() for m in range(5)]
        for i in range(1, opt.n_eval):
            h_match = [encoder(y[m][i])[0].detach() for m in range(5)]
            h = encoder(x_in)
            if opt.last_frame_skip or i < opt.n_past:	
                h, skip = h
            else:
                h, _ = h
            h = h.detach()
            if i < opt.n_past:
                h_target = encoder(x[i])
                h_target = h_target[0].detach()
                mu = posterior_mu(h_target)
                mu_p = prior_mu(torch.cat([h_match[m] - h_match_prev[m] for m in range(5)] + [h], -1)) 
                ref_var = torch.mean(torch.cat([var_encoder(h_match[m] - h_match_prev[m] + h).unsqueeze(1) for m in range(5)], 1), 1)
                z_t = reparameterize(mu, ref_var)
                frame_predictor(torch.cat([h_match[m] - h_match_prev[m] for m in range(5)] + [h, z_t], 1))
                x_in = x[i]
                gen_seq[s].append(x_in)
            else:
                mu_p = prior_mu(torch.cat([h_match[m] - h_match_prev[m] for m in range(5)] + [h], -1))
                ref_var = torch.mean(torch.cat([var_encoder(h_match[m] - h_match_prev[m] + h).unsqueeze(1) for m in range(5)], 1), 1)
                z_t = reparameterize(mu_p, ref_var)
                h = frame_predictor(torch.cat([h_match[m] - h_match_prev[m] for m in range(5)] + [h, z_t], 1)).detach()
                x_in = decoder([h, skip]).detach()
                gen_seq[s].append(x_in)
            
            h_match_prev = h_match

    to_plot = []
    gifs = [ [] for t in range(opt.n_eval) ]
    nrow = min(opt.batch_size, 10)
    for i in range(nrow):
        # ground truth sequence
        row = [] 
        for t in range(opt.n_eval):
            row.append(gt_seq[t][i])
        to_plot.append(row)

        # best sequence
        min_mse = 1e7
        for s in range(nsample):
            mse = 0
            for t in range(opt.n_eval):
                mse +=  torch.sum( (gt_seq[t][i].data.cpu() - gen_seq[s][t][i].data.cpu())**2 )
            if mse < min_mse:
                min_mse = mse
                min_idx = s

        s_list = [min_idx, 
                  np.random.randint(nsample), 
                  np.random.randint(nsample), 
                  np.random.randint(nsample), 
                  np.random.randint(nsample)]
        for ss in range(len(s_list)):
            s = s_list[ss]
            row = []
            for t in range(opt.n_eval):
                row.append(gen_seq[s][t][i]) 
            to_plot.append(row)
        for t in range(opt.n_eval):
            row = []
            row.append(gt_seq[t][i])
            for ss in range(len(s_list)):
                s = s_list[ss]
                row.append(gen_seq[s][t][i])
            gifs[t].append(row)

    fname = '%s/gen/sample_%d.png' % (opt.log_dir, epoch) 
    utils.save_tensors_image(fname, to_plot)

    fname = '%s/gen/sample_%d.gif' % (opt.log_dir, epoch) 
    utils.save_gif(fname, gifs)
Exemple #21
0
def plot(x, epoch):
    nsample = 20
    gen_seq = [[] for i in range(nsample)]
    gt_seq = [x[i] for i in range(len(x))]

    for s in range(nsample):
        frame_predictor.hidden = frame_predictor.init_hidden()
        posterior.hidden = posterior.init_hidden()
        prior.hidden = prior.init_hidden()
        gen_seq[s].append(x[0])
        x_in = x[0]
        for i in range(1, opt.n_eval):
            h = encoder(x_in)
            if opt.last_frame_skip or i < opt.n_past:
                h, skip = h
            else:
                h, _ = h
            h = h.detach()
            if i < opt.n_past:
                h_target = encoder(x[i])
                h_target = h_target[0].detach()
                z_t, _, _ = posterior(h_target)
                prior(h)
                frame_predictor(torch.cat([h, z_t], 1))
                x_in = x[i]
                gen_seq[s].append(x_in)
            else:
                z_t, _, _ = prior(h)
                h = frame_predictor(torch.cat([h, z_t], 1)).detach()
                x_in = decoder([h, skip]).detach()
                gen_seq[s].append(x_in)

    to_plot = []
    gifs = [[] for t in range(opt.n_eval)]
    nrow = min(opt.batch_size, 10)
    for i in range(nrow):
        # ground truth sequence
        row = []
        for t in range(opt.n_eval):
            row.append(gt_seq[t][i])
        to_plot.append(row)

        # best sequence
        min_mse = 1e7
        for s in range(nsample):
            mse = 0
            for t in range(opt.n_eval):
                mse += torch.sum(
                    (gt_seq[t][i].data.cpu() - gen_seq[s][t][i].data.cpu())**2)
            if mse < min_mse:
                min_mse = mse
                min_idx = s

        s_list = [
            min_idx,
            np.random.randint(nsample),
            np.random.randint(nsample),
            np.random.randint(nsample),
            np.random.randint(nsample)
        ]
        for ss in range(len(s_list)):
            s = s_list[ss]
            row = []
            for t in range(opt.n_eval):
                row.append(gen_seq[s][t][i])
            to_plot.append(row)
        for t in range(opt.n_eval):
            row = []
            row.append(gt_seq[t][i])
            for ss in range(len(s_list)):
                s = s_list[ss]
                row.append(gen_seq[s][t][i])
            gifs[t].append(row)

    fname = '%s/gen_1/sample_%d.png' % (opt.log_dir, epoch)
    utils.save_tensors_image(fname, to_plot)

    fname = '%s/gen_1/sample_%d.gif' % (opt.log_dir, epoch)
    utils.save_gif(fname, gifs)
Exemple #22
0
def plot(x, epoch):
    nsample = 20
    gen_seq = [[] for i in range(nsample)]
    gt_seq = [x[i] for i in range(len(x))]

    for s in range(nsample):
        """
        frame_predictor.init_hidden()
        posterior.init_hidden()
        prior.init_hidden()
        """
        gen_seq[s].append(x[0])
        x_in = x[0]
        for i in range(1, opt.max_step):
            x_in, _, _, _, _ = frame_generator(x_in, i - 1)
            gen_seq[s].append(x_in)

    to_plot = []
    gifs = [[] for t in range(opt.n_eval)]
    nrow = min(opt.batch_size, 10)
    for i in range(nrow):
        # ground truth sequence
        row = []
        for t in range(opt.n_eval):
            row.append(gt_seq[t][i])
        to_plot.append(row)

        # best sequence
        min_mse = 1e7
        for s in range(nsample):
            mse = 0
            for t in range(opt.n_eval):
                mse += torch.sum(
                    (gt_seq[t][i].data.cpu() - gen_seq[s][t][i].data.cpu())**2)
            if mse < min_mse:
                min_mse = mse
                min_idx = s

        s_list = [
            min_idx,
            np.random.randint(nsample),
            np.random.randint(nsample),
            np.random.randint(nsample),
            np.random.randint(nsample)
        ]
        for ss in range(len(s_list)):
            s = s_list[ss]
            row = []
            for t in range(opt.n_eval):
                row.append(gen_seq[s][t][i])
            to_plot.append(row)
        for t in range(opt.n_eval):
            row = []
            row.append(gt_seq[t][i])
            for ss in range(len(s_list)):
                s = s_list[ss]
                row.append(gen_seq[s][t][i])
            gifs[t].append(row)

    fname = '%s/gen/sample_%d.png' % (opt.log_dir, epoch)
    utils.save_tensors_image(fname, to_plot)

    fname = '%s/gen/sample_%d.gif' % (opt.log_dir, epoch)
    utils.save_gif(fname, gifs)
Exemple #23
0
def plot(model, x, epoch):
    nsample = 20
    gen_seq = [[x[0]] for i in range(nsample)]
    gt_seq = [x[i] for i in range(len(x))]

    for s in range(nsample):
        ## initialization
        model.init_states(x[0])
        ## prediction
        for i in range(1, opt.n_eval):
            if i < opt.n_past:
                hs_rec, feats, zs, mus, logvars = model.reconstruction(x[i])
                model.skips = feats
                gen_seq[s].append(x[i])
            else:
                x_pred = model.inference()
                gen_seq[s].append(x_pred)

    to_plot = []
    gifs = [[] for t in range(opt.n_eval)]
    nrow = min(opt.batch_size, 10)
    for i in range(nrow):
        # ground truth sequence
        row = []
        for t in range(opt.n_eval):
            row.append(gt_seq[t][i])
        to_plot.append(row)

        # best sequence
        min_mse = 1e7
        for s in range(nsample):
            mse = 0
            for t in range(opt.n_eval):
                mse += torch.sum(
                    (gt_seq[t][i].data.cpu() - gen_seq[s][t][i].data.cpu())**2)
            if mse < min_mse:
                min_mse = mse
                min_idx = s

        s_list = [
            min_idx,
            np.random.randint(nsample),
            np.random.randint(nsample),
            np.random.randint(nsample),
            np.random.randint(nsample)
        ]
        for ss in range(len(s_list)):
            s = s_list[ss]
            row = []
            for t in range(opt.n_eval):
                row.append(gen_seq[s][t][i])
            to_plot.append(row)
        for t in range(opt.n_eval):
            row = []
            row.append(gt_seq[t][i])
            for ss in range(len(s_list)):
                s = s_list[ss]
                row.append(gen_seq[s][t][i])
            gifs[t].append(row)

    fname = '%s/gen/sample_%d.png' % (checkpoint_dir, epoch)
    utils.save_tensors_image(fname, to_plot)

    fname = '%s/gen/sample_%d.gif' % (checkpoint_dir, epoch)
    utils.save_gif(fname, gifs)
Exemple #24
0
def make_gifs(x, idx, name):
    # get approx posterior sample
    frame_predictor.hidden = frame_predictor.init_hidden()
    posterior.hidden = posterior.init_hidden()
    posterior_gen = []
    posterior_gen.append(x[0])
    x_in = x[0]
    for i in range(1, opt.n_eval):
        h = encoder(x_in)
        h_target = encoder(x[i])[0].detach()
        if opt.last_frame_skip or i < opt.n_past:	
            h, skip = h
        else:
            h, _ = h
        h = h.detach()
        _, z_t, _= posterior(h_target) # take the mean
        if i < opt.n_past:
            frame_predictor(torch.cat([h, z_t], 1)) 
            posterior_gen.append(x[i])
            x_in = x[i]
        else:
            h_pred = frame_predictor(torch.cat([h, z_t], 1)).detach()
            x_in = decoder([h_pred, skip]).detach()
            posterior_gen.append(x_in)
  

    nsample = opt.nsample
    ssim = np.zeros((opt.batch_size, nsample, opt.n_future))
    psnr = np.zeros((opt.batch_size, nsample, opt.n_future))
    var_np = np.zeros((opt.batch_size, nsample, opt.n_future, opt.z_dim))
    progress = progressbar.ProgressBar(max_value=nsample).start()
    all_gen = []
    for s in range(nsample):
        progress.update(s+1)
        gen_seq = []
        gt_seq = []
        frame_predictor.hidden = frame_predictor.init_hidden()
        posterior.hidden = posterior.init_hidden()
        prior.hidden = prior.init_hidden()
        x_in = x[0]
        all_gen.append([])
        all_gen[s].append(x_in)
        for i in range(1, opt.n_eval):
            h = encoder(x_in)
            if opt.last_frame_skip or i < opt.n_past:	
                h, skip = h
            else:
                h, _ = h
            h = h.detach()
            if i < opt.n_past:
                h_target = encoder(x[i])[0].detach()
                z_t, _, _ = posterior(h_target)
                prior(h)
                frame_predictor(torch.cat([h, z_t], 1))
                x_in = x[i]
                all_gen[s].append(x_in)
            else:
                z_t, mu, logvar = prior(h)
                h = frame_predictor(torch.cat([h, z_t], 1)).detach()
                x_in = decoder([h, skip]).detach()
                gen_seq.append(x_in.data.cpu().numpy())
                gt_seq.append(x[i].data.cpu().numpy())
                all_gen[s].append(x_in)

                var = torch.exp(logvar)  # BxC
                var_np[:,s, i - opt.n_past,:] = var.data.cpu().numpy()

        _, ssim[:, s, :], psnr[:, s, :] = utils.eval_seq(gt_seq, gen_seq)

    progress.finish()
    utils.clear_progressbar()
    
    best_ssim = np.zeros((opt.batch_size, opt.n_future))
    best_psnr = np.zeros((opt.batch_size, opt.n_future))

    best_ssim_var = np.zeros((opt.batch_size, opt.n_future, opt.z_dim))
    best_psnr_var = np.zeros((opt.batch_size, opt.n_future, opt.z_dim))
    ###### ssim ######
    for i in range(opt.batch_size):
        gifs = [ [] for t in range(opt.n_eval) ]
        text = [ [] for t in range(opt.n_eval) ]
        
        mean_ssim = np.mean(ssim[i], 1)
        ordered = np.argsort(mean_ssim)
        best_ssim[i,:] = ssim[i,ordered[-1],:]
        # best ssim var
        best_ssim_var[i,:,:] = var_np[i, ordered[-1], :, :]

        mean_psnr = np.mean(psnr[i], 1)
        ordered_p = np.argsort(mean_psnr)
        best_psnr[i,:] = psnr[i, ordered_p[-1],:]
        # best psnr var
        best_psnr_var[i,:,:] = var_np[i, ordered_p[-1], :, :]

        rand_sidx = [np.random.randint(nsample) for s in range(3)]
        for t in range(opt.n_eval):
            # gt 
            gifs[t].append(add_border(x[t][i], 'green'))
            text[t].append('Ground\ntruth')
            #posterior 
            if t < opt.n_past:
                color = 'green'
            else:
                color = 'red'
            gifs[t].append(add_border(posterior_gen[t][i], color))
            text[t].append('Approx.\nposterior')
            # best 
            if t < opt.n_past:
                color = 'green'
            else:
                color = 'red'
            sidx = ordered[-1]
            gifs[t].append(add_border(all_gen[sidx][t][i], color))
            text[t].append('Best SSIM')
            # random 3
            for s in range(len(rand_sidx)):
                gifs[t].append(add_border(all_gen[rand_sidx[s]][t][i], color))
                text[t].append('Random\nsample %d' % (s+1))

        #fname = '%s/%s_%d.gif' % (opt.log_dir, name, idx+i) 
        fname = '%s/quality_results/%s_%d.gif' % (opt.log_dir, name, idx+i) 
        utils.save_gif_with_text(fname, gifs, text)
    
        # -- generate samples
        to_plot = []
        gts = []
        best_s = []
        best_p = []
        rand_samples = [[] for s in range(len(rand_sidx))]
        for t in range(opt.n_eval):
            # gt
            gts.append(x[t][i])
            best_s.append(all_gen[ordered[-1]][t][i])
            best_p.append(all_gen[ordered_p[-1]][t][i])

            # sample
            for s in range(len(rand_sidx)):
                rand_samples[s].append(all_gen[rand_sidx[s]][t][i])

        to_plot.append(gts)
        to_plot.append(best_s)
        to_plot.append(best_p)
        for s in range(len(rand_sidx)):
            to_plot.append(rand_samples[s])
        fname = '%s/quality_results/%s_%d.png' % (opt.log_dir, name, idx+i)
        utils.save_tensors_image(fname, to_plot)
    
    return best_ssim, best_psnr, best_ssim_var, best_psnr_var
Exemple #25
0
def make_gifs(x, idx, name):
    # sample from approx posterior
    posterior_gen = []
    posterior_gen.append(x[0])
    x_in = x[0]
    
    # rec
    hsvg_net.init_states(x[0])
    for i in range(1, opt.n_eval):
        if i < opt.n_past:
            hs_rec, feats, zs, mus, logvars = hsvg_net.reconstruction(x[i])
            hsvg_net.skips = feats
            posterior_gen.append(x[i])
        else:
            hs_rec, feats, zs, mus, logvars = hsvg_net.reconstruction(x[i])
            x_rec = hsvg_net.decoding(hs_rec)
            posterior_gen.append(x_rec)
    
    # sample from prior
    nsample = opt.nsample
    ssim = np.zeros((opt.batch_size, nsample, opt.n_future))
    psnr = np.zeros((opt.batch_size, nsample, opt.n_future))
    # variance
    var_np = np.zeros((opt.batch_size, nsample, opt.n_future, opt.z_dim))
    all_gen = []
    gt_seq = [x[i].data.cpu().numpy() for i in range(opt.n_past, opt.n_eval)]

    for s in tqdm(range(nsample)):        
        gen_seq = []
        x_in = x[0]
        all_gen.append([])
        all_gen[s].append(x_in)
        
        hsvg_net.init_states(x[0])
        for i in range(1, opt.n_eval):
            if i < opt.n_past:
                hs_rec, feats, zs, mus, logvars = hsvg_net.reconstruction(x[i])
                hsvg_net.skips = feats
                all_gen[s].append(x[i])
            else:
                x_pred = hsvg_net.inference()
                gen_seq.append(x_pred.data.cpu().numpy())
                all_gen[s].append(x_pred)
                
                logvar = torch.cat(hsvg_net.logvars_prior, -1)
                var = torch.exp(logvar)  # BxC
                var_np[:,s, i - opt.n_past,:] = var.data.cpu().numpy()

        _, ssim[:, s, :], psnr[:, s, :] = utils.eval_seq(gt_seq, gen_seq)
    
    best_ssim = np.zeros((opt.batch_size, opt.n_future))
    best_psnr = np.zeros((opt.batch_size, opt.n_future))

    best_ssim_var = np.zeros((opt.batch_size, opt.n_future, opt.z_dim))
    best_psnr_var = np.zeros((opt.batch_size, opt.n_future, opt.z_dim))
    ###### ssim ######
    for i in range(opt.batch_size):
        gifs = [ [] for t in range(opt.n_eval) ]
        text = [ [] for t in range(opt.n_eval) ]

        mean_ssim = np.mean(ssim[i], 1)
        ordered = np.argsort(mean_ssim)
        best_ssim[i,:] = ssim[i,ordered[-1],:]
        # best ssim var
        best_ssim_var[i,:,:] = var_np[i, ordered[-1], :, :]

        mean_psnr = np.mean(psnr[i], 1)
        ordered_p = np.argsort(mean_psnr)
        best_psnr[i,:] = psnr[i, ordered_p[-1],:]
        # best psnr var
        best_psnr_var[i,:,:] = var_np[i, ordered_p[-1], :, :]

        rand_sidx = [np.random.randint(nsample) for s in range(3)]

        # -- generate gifs
        for t in range(opt.n_eval):
            # gt 
            gifs[t].append(add_border(x[t][i], 'green'))
            text[t].append('Ground\ntruth')
            #posterior 
            if t < opt.n_past:
                color = 'green'
            else:
                color = 'red'
            gifs[t].append(add_border(posterior_gen[t][i], color))
            text[t].append('Approx.\nposterior')
            # best 
            if t < opt.n_past:
                color = 'green'
            else:
                color = 'red'
            sidx = ordered[-1]
            gifs[t].append(add_border(all_gen[sidx][t][i], color))
            text[t].append('Best SSIM')
            # random 3
            for s in range(len(rand_sidx)):
                gifs[t].append(add_border(all_gen[rand_sidx[s]][t][i], color))
                text[t].append('Random\nsample %d' % (s+1))

        fname = '%s/quality_results/%s_%d.gif' % (opt.log_dir, name, idx+i) 
        utils.save_gif_with_text(fname, gifs, text)

        # -- generate samples
        to_plot = []
        gts = []
        best_s = []
        best_p = []
        rand_samples = [[] for s in range(len(rand_sidx))]
        for t in range(opt.n_eval):
            # gt
            gts.append(x[t][i])
            best_s.append(all_gen[ordered[-1]][t][i])
            best_p.append(all_gen[ordered_p[-1]][t][i])

            # sample
            for s in range(len(rand_sidx)):
                rand_samples[s].append(all_gen[rand_sidx[s]][t][i])

        to_plot.append(gts)
        to_plot.append(best_s)
        to_plot.append(best_p)
        for s in range(len(rand_sidx)):
            to_plot.append(rand_samples[s])
        fname = '%s/quality_results/%s_%d.png' % (opt.log_dir, name, idx+i)
        utils.save_tensors_image(fname, to_plot)

    return best_ssim, best_psnr, best_ssim_var, best_psnr_var
Exemple #26
0
def make_gifs(x, idx, name):
    # get approx posterior sample
    posterior_gen = []
    posterior_gen.append(x[0])
    x_in = x[0]
    # ------------ calculate the content posterior
    xs = []
    for i in range(0, opt.n_past):
        xs.append(x[i])
    #if True:
    random.shuffle(xs)
    #xc = torch.cat(xs, 1)
    mu_c, logvar_c, skip = cont_encoder(torch.cat(xs, 1))
    mu_c = mu_c.detach()

    for i in range(1, opt.n_eval):
        h_target = pose_encoder(x[i]).detach()
        mu_t_p, logvar_t_p = posterior_pose(torch.cat([h_target, mu_c], 1),
                                            time_step=i - 1)
        z_t_p = utils.reparameterize(mu_t_p, logvar_t_p)
        if i < opt.n_past:
            frame_predictor(torch.cat([z_t_p, mu_c], 1), time_step=i - 1)
            posterior_gen.append(x[i])
            x_in = x[i]
        else:
            h_pred = frame_predictor(torch.cat([z_t_p, mu_c], 1),
                                     time_step=i - 1).detach()
            x_in = decoder([h_pred, skip]).detach()
            posterior_gen.append(x_in)

    nsample = opt.nsample
    ssim = np.zeros((opt.batch_size, nsample, opt.n_future))
    psnr = np.zeros((opt.batch_size, nsample, opt.n_future))

    #ccm_pred = np.zeros((opt.batch_size, nsample, opt.n_eval-1))
    #ccm_gt = np.zeros((opt.batch_size, opt.n_eval-1))

    progress = progressbar.ProgressBar(maxval=nsample).start()
    all_gen = []
    '''for i in range(1, opt.n_eval):
        out_gt = discriminator(torch.cat([x[0], x[i]],dim=1))
        ccm_i_gt = out_gt.mean().data.cpu().numpy()
        print('time step %d, mean out gt: %.4f'%(i,ccm_i_gt))
        ccm_gt[:,i-1] = out_gt.squeeze().data.cpu().numpy()'''

    hs = []
    for i in range(0, opt.n_past):
        hs.append(pose_encoder(x[i]).detach())

    for s in range(nsample):
        progress.update(s + 1)
        gen_seq = []
        gt_seq = []
        x_in = x[0]
        all_gen.append([])
        all_gen[s].append(x_in)

        h = pose_encoder(x[0]).detach()
        for i in range(1, opt.n_eval):
            h_target = pose_encoder(x[i]).detach()

            if i < opt.n_past:
                mu_t_p, logvar_t_p = posterior_pose(torch.cat([h_target, mu_c],
                                                              1),
                                                    time_step=i - 1)
                z_t_p = utils.reparameterize(mu_t_p, logvar_t_p)
                prior(torch.cat([h, mu_c], 1), time_step=i - 1)
                frame_predictor(torch.cat([z_t_p, mu_c], 1), time_step=i - 1)
                x_in = x[i]
                all_gen[s].append(x_in)
                h = h_target
            else:
                mu_t_pp, logvar_t_pp = prior(torch.cat([h, mu_c], 1),
                                             time_step=i - 1)
                z_t = utils.reparameterize(mu_t_pp, logvar_t_pp)
                h_pred = frame_predictor(torch.cat([z_t, mu_c], 1),
                                         time_step=i - 1).detach()
                x_in = decoder([h_pred, skip]).detach()
                gen_seq.append(x_in.data.cpu().numpy())
                gt_seq.append(x[i].data.cpu().numpy())
                all_gen[s].append(x_in)
                h = pose_encoder(x_in).detach()

            #out_pred = discriminator(torch.cat([x[0],x_in],dim=1))
            #ccm_i_pred = out_pred.mean().data.cpu().numpy()
            #print('time step %d, mean out pred: %.4f'%(i,ccm_i_pred))
            #ccm_pred[:, s, i-1] = out_pred.squeeze().data.cpu().numpy()

        _, ssim[:, s, :], psnr[:, s, :] = utils.eval_seq(gt_seq, gen_seq)

    progress.finish()
    utils.clear_progressbar()

    best_ssim = np.zeros((opt.batch_size, opt.n_future))
    best_psnr = np.zeros((opt.batch_size, opt.n_future))
    ###### ssim ######
    for i in range(opt.batch_size):
        gifs = [[] for t in range(opt.n_eval)]
        text = [[] for t in range(opt.n_eval)]

        mean_ssim = np.mean(ssim[i], 1)
        ordered = np.argsort(mean_ssim)
        best_ssim[i, :] = ssim[i, ordered[-1], :]

        mean_psnr = np.mean(psnr[i], 1)
        ordered_p = np.argsort(mean_psnr)
        best_psnr[i, :] = psnr[i, ordered_p[-1], :]

        rand_sidx = [np.random.randint(nsample) for s in range(3)]

        # -- generate gifs
        for t in range(opt.n_eval):
            # gt
            gifs[t].append(add_border(x[t][i], 'green'))
            text[t].append('Ground\ntruth')
            #posterior
            if t < opt.n_past:
                color = 'green'
            else:
                color = 'red'
            gifs[t].append(add_border(posterior_gen[t][i], color))
            text[t].append('Approx.\nposterior')
            # best
            if t < opt.n_past:
                color = 'green'
            else:
                color = 'red'
            sidx = ordered[-1]
            gifs[t].append(add_border(all_gen[sidx][t][i], color))
            text[t].append('Best SSIM')
            # random 3
            for s in range(len(rand_sidx)):
                gifs[t].append(add_border(all_gen[rand_sidx[s]][t][i], color))
                text[t].append('Random\nsample %d' % (s + 1))

        fname = '%s/samples/%s_%d.gif' % (opt.log_dir, name, idx + i)
        utils.save_gif_with_text(fname, gifs, text)

        # -- generate samples
        to_plot = []
        gts = []
        best_p = []
        rand_samples = [[] for s in range(len(rand_sidx))]
        for t in range(opt.n_eval):
            # gt
            gts.append(x[t][i])
            best_p.append(all_gen[ordered_p[-1]][t][i])

            # sample
            for s in range(len(rand_sidx)):
                rand_samples[s].append(all_gen[rand_sidx[s]][t][i])

        to_plot.append(gts)
        to_plot.append(best_p)
        for s in range(len(rand_sidx)):
            to_plot.append(rand_samples[s])
        fname = '%s/samples/%s_%d.png' % (opt.log_dir, name, idx + i)
        utils.save_tensors_image(fname, to_plot)

    return best_ssim, best_psnr  #, ccm_pred, ccm_gt