def plot_rec(x, epoch): # get fixed content vector from last ground truth frame h_c = netEC(x[opt.n_past - 1]) if type(h_c) is tuple: vec_h_c = h_c[0].detach() else: vec_h_c = h_c.detach() lstm.hidden = lstm.init_hidden() gen_seq = [] gen_seq.append(x[0]) for i in range(1, opt.n_past + opt.n_future): h_p = netEP(x[i - 1]).detach() h_pred = lstm(torch.cat([h_p, vec_h_c], 1)) if i < opt.n_past: gen_seq.append(x[i]) else: pred_x = netD([h_c, h_pred]).detach() gen_seq.append(pred_x) to_plot = [] nrow = 10 for i in range(nrow): # ground truth row = [] for t in range(opt.n_past + opt.n_future): row.append(x[t][i]) to_plot.append(row) # gen row = [] for t in range(opt.n_past + opt.n_future): row.append(gen_seq[t][i]) to_plot.append(row) fname = '%s/gen/rec_%d.png' % (opt.log_dir, epoch) utils.save_tensors_image(fname, to_plot)
def plot_rec(x, y, s, epoch): y = y.view(opt.batch_size, opt.nclass, 1, 1) if s == 0: _, z, _ = netE[s]((x[s], y)) rec = netD[s](torch.cat([z, y], 1)) else: residual = x[s] - x[s - 1] _, z, _ = netE[s]((torch.cat([x[s], residual], 1), y)) rec = netD[s]([x[s - 1], torch.cat([z, y], 1)]) to_plot = [] nrow = opt.nclass ncol = int(opt.batch_size / nrow) for i in range(nrow): row = [] for j in range(ncol): row.append(x[s][i * ncol + j]) #if s > 0: # row.append(residual[i*ncol+j]) row.append(rec[i * ncol + j]) to_plot.append(row) fname = '%s/rec/%d_%d.png' % (opt.log_dir, epoch, opt.image_width / (2**(opt.nlevels - s - 1))) utils.save_tensors_image(fname, to_plot)
def plot_rec(x, epoch): # get fixed content vector from last ground truth frame h_c = netEC(x[opt.n_past-1]) if type(h_c) is tuple: vec_h_c = h_c[0].detach() else: vec_h_c = h_c.detach() lstm.hidden = lstm.init_hidden() gen_seq = [] gen_seq.append(x[0]) for i in range(1, opt.n_past+opt.n_future): h_p = netEP(x[i-1]).detach() h_pred = lstm(torch.cat([h_p, vec_h_c], 1)) if i < opt.n_past: gen_seq.append(x[i]) else: pred_x = netD([h_c, h_pred]).detach() gen_seq.append(pred_x) to_plot = [] nrow = 10 for i in range(nrow): # ground truth row = [] for t in range(opt.n_past+opt.n_future): row.append(x[t][i]) to_plot.append(row) # gen row = [] for t in range(opt.n_past+opt.n_future): row.append(gen_seq[t][i]) to_plot.append(row) fname = '%s/gen/rec_%d.png' % (opt.log_dir, epoch) utils.save_tensors_image(fname, to_plot)
def plot_rec(x, epoch): frame_predictor.hidden = frame_predictor.init_hidden() posterior.hidden = posterior.init_hidden() gen_seq = [] gen_seq.append(x[0]) x_in = x[0] for i in range(1, opt.n_past+opt.n_future): #print(len(encoder(x[i]))) try: h_target = encoder(x[i])[0] except: continue if i < opt.n_past: h, skip = encoder(x[i-1]) x_pred = decoder([h_target.detach(), skip]) gen_seq.append(x_pred) to_plot = [] nrow = min(opt.batch_size, 10) for i in range(nrow): row = [] for t in range(opt.n_past+opt.n_future): row.append(gen_seq[t][i]) to_plot.append(row) fname = '%s/gen/rec_%d.png' % (opt.log_dir, epoch) utils.save_tensors_image(fname, to_plot)
def plot_rec(x, epoch): frame_predictor.hidden = frame_predictor.init_hidden() posterior.hidden = posterior.init_hidden() gen_seq = [] gen_seq.append(x[0]) x_in = x[0] for i in range(1, opt.n_past + opt.n_future): h = encoder(x[i - 1]) h_target = encoder(x[i]) if opt.last_frame_skip or i < opt.n_past: h, skip = h else: h, _ = h h_target, _ = h_target h = h.detach() h_target = h_target.detach() z_t, _, _ = posterior(h_target) if i < opt.n_past: frame_predictor(torch.cat([h, z_t], 1)) gen_seq.append(x[i]) else: h_pred = frame_predictor(torch.cat([h, z_t], 1)) x_pred = decoder([h_pred, skip]).detach() gen_seq.append(x_pred) to_plot = [] nrow = min(opt.batch_size, 10) for i in range(nrow): row = [] for t in range(opt.n_past + opt.n_future): row.append(gen_seq[t][i]) to_plot.append(row) fname = '%s/gen_1/rec_%d.png' % (opt.log_dir, epoch) utils.save_tensors_image(fname, to_plot)
def plot(x, epoch): nsample = 5 gen_seq = [[] for i in range(nsample)] gt_seq = [x[i] for i in range(len(x))] h_seq = [encoder(x[i]) for i in range(opt.n_past)] for s in range(nsample): frame_predictor.hidden = frame_predictor.init_hidden() posterior.hidden = posterior.init_hidden() gen_seq[s].append(x[0]) x_in = x[0] for i in range(1, opt.n_eval): ## When i > opt.n_past, generated frame should be ## put back into encoder to generate content vector if opt.last_frame_skip or i <= opt.n_past: h, skip = h_seq[i-1] h = h.detach() else: h, _ = encoder(x_in) h = h.detach() if i < opt.n_past: z_t, _, _ = posterior(h_seq[i-1][0]) frame_predictor(torch.cat([h, z_t], 1)) x_in = x[i] gen_seq[s].append(x_in) else: z_t = torch.cuda.FloatTensor(opt.batch_size, opt.z_dim).normal_() h = frame_predictor(torch.cat([h, z_t], 1)).detach() x_in = decoder([h, skip]).detach() gen_seq[s].append(x_in) to_plot = [] gifs = [ [] for t in range(opt.n_eval) ] nrow = min(opt.batch_size, 10) for i in range(nrow): # ground truth sequence row = [] for t in range(opt.n_eval): row.append(gt_seq[t][i]) to_plot.append(row) for s in range(nsample): row = [] for t in range(opt.n_eval): row.append(gen_seq[s][t][i]) to_plot.append(row) for t in range(opt.n_eval): row = [] row.append(gt_seq[t][i]) for s in range(nsample): row.append(gen_seq[s][t][i]) gifs[t].append(row) fname = '%s/gen/sample_%d.png' % (opt.log_dir, epoch) utils.save_tensors_image(fname, to_plot) fname = '%s/gen/sample_%d.gif' % (opt.log_dir, epoch) utils.save_gif(fname, gifs)
def plot_gen(epoch): nrow = opt.nclass ncol = int(opt.batch_size / nrow) y = torch.Tensor(opt.batch_size, opt.nclass, 1, 1).cuda().zero_() for i in range(nrow): for j in range(ncol): y[i * ncol + j][i] = 1 scales = [] residuals = [] for s in range(opt.nlevels): if s == 0: z = make_plot_z( ) #torch.cuda.FloatTensor(opt.batch_size, opt.z_dim, 1, 1).normal_() gen = netD[s](torch.cat([z, y], 1)).detach() residual = 0 else: z = make_plot_z( ) #torch.cuda.FloatTensor(opt.batch_size, opt.z_dim, 1, 1).normal_() gen = netD[s]([scales[s - 1], torch.cat([z, y], 1)]).detach() #gen = nn.Sigmoid()(scales[s-1] + residual) #residual.data = residual.data.mul(0.5).add(0.5) scales.append(gen) #residuals.append(residual) to_plot = [] for i in range(nrow): row = [] for j in range(ncol): row.append(scales[-1][i * ncol + j]) to_plot.append(row) fname = '%s/gen/%d.png' % (opt.log_dir, epoch) utils.save_tensors_image(fname, to_plot) to_plot = [] nrow = 6 ncol = 3 for i in range(nrow): row = [] for j in range(ncol): row.append(scales[0][i * ncol + j]) for s in range(1, opt.nlevels): #row.append(residuals[s][i*ncol+j]) row.append(scales[s][i * ncol + j]) to_plot.append(row) fname = '%s/gen/scales_%d.png' % (opt.log_dir, epoch) utils.save_tensors_image(fname, to_plot)
def plot(x, epoch): _, _, yp = x nrow = opt.nclass ncol = int(opt.batch_size/nrow) gen = netG(torch.cat([z_fixed, yp.view(opt.batch_size, opt.nclass, 1, 1)], 1)) to_plot = [] for i in range(nrow): row = [] for j in range(ncol): row.append(gen[i*ncol+j]) to_plot.append(row) fname = '%s/gen/yp_%d.png' % (opt.log_dir, epoch) utils.save_tensors_image(fname, to_plot)
def plot_rec(x, epoch): x_c = x[0] x_p = x[np.random.randint(1, opt.max_step)] h_c = netEC(x_c) h_p = netEP(x_p) rec = netD([h_c, h_p]) x_c, x_p, rec = x_c.data, x_p.data, rec.data fname = '%s/rec/%d.png' % (opt.log_dir, epoch) to_plot = [] row_sz = 5 nplot = 20 for i in range(0, nplot-row_sz, row_sz): row = [[xc, xp, xr] for xc, xp, xr in zip(x_c[i:i+row_sz], x_p[i:i+row_sz], rec[i:i+row_sz])] to_plot.append(list(itertools.chain(*row))) utils.save_tensors_image(fname, to_plot)
def plot_rec(x, epoch): x_c = x[0] x_p = x[np.random.randint(1, opt.max_step)] h_c = netEC(x_c) h_p = netEP(x_p) rec = netD([h_c, h_p]) x_c, x_p, rec = x_c.data, x_p.data, rec.data fname = '%s/rec/%d.png' % (opt.log_dir, epoch) to_plot = [] row_sz = 5 nplot = 20 for i in range(0, nplot - row_sz, row_sz): row = [[xc, xp, xr] for xc, xp, xr in zip(x_c[i:i + row_sz], x_p[i:i + row_sz], rec[i:i + row_sz])] to_plot.append(list(itertools.chain(*row))) utils.save_tensors_image(fname, to_plot)
def plot_gen(x, epoch): x, y, yp = x nrow = opt.nclass ncol = int(opt.batch_size / nrow) if opt.all_labels: y_onehot = yp.view(opt.batch_size, opt.nclass, 1, 1) # different class per row y_onehot = torch.Tensor(opt.batch_size, opt.nclass, 1, 1).cuda().zero_() for i in range(nrow): for j in range(ncol): y_onehot[i * ncol + j][i] = 1 gen = model.decode(z_fixed, y_onehot) to_plot = [] for i in range(nrow): row = [] for j in range(ncol): row.append(gen[i * ncol + j]) to_plot.append(row) fname = '%s/gen/p_%d.png' % (opt.log_dir, epoch) utils.save_tensors_image(fname, to_plot) y_onehot = y_onehot.view(opt.batch_size, opt.nclass, 1, 1) # different class per row y_onehot = torch.Tensor(opt.batch_size, opt.nclass, 1, 1).cuda().zero_() for i in range(nrow): for j in range(ncol): y_onehot[i * ncol + j][i] = 1 gen = model.decode(z_fixed, y_onehot) to_plot = [] for i in range(nrow): row = [] for j in range(ncol): row.append(gen[i * ncol + j]) to_plot.append(row) fname = '%s/gen/%d.png' % (opt.log_dir, epoch) utils.save_tensors_image(fname, to_plot)
def plot_rec(x, y, epoch): frame_predictor.hidden = frame_predictor.init_hidden() posterior_mu.hidden = posterior_mu.init_hidden() gen_seq = [] gen_seq.append(x[0]) x_in = x[0] h_match_prev = [encoder(y[m][0])[0].detach() for m in range(5)] for i in range(1, opt.n_past+opt.n_future): h = encoder(x[i-1]) h_target = encoder(x[i]) h_match = [encoder(y[m][i])[0].detach() for m in range(5)] if opt.last_frame_skip or i < opt.n_past: h, skip = h else: h, _ = h h_target, _ = h_target h = h.detach() h_target = h_target.detach() mu = posterior_mu(h_target) ref_var = torch.mean(torch.cat([var_encoder(h_match[m] - h_match_prev[m] + h).unsqueeze(1) for m in range(5)], 1), 1) z_t = reparameterize(mu, ref_var) if i < opt.n_past: frame_predictor(torch.cat([h_match[m] - h_match_prev[m] for m in range(5)] + [h, z_t], 1)) gen_seq.append(x[i]) else: h_pred = frame_predictor(torch.cat([h_match[m] - h_match_prev[m] for m in range(5)] + [h, z_t], 1)) x_pred = decoder([h_pred, skip]).detach() gen_seq.append(x_pred) h_match_prev = h_match to_plot = [] nrow = min(opt.batch_size, 10) for i in range(nrow): row = [] for t in range(opt.n_past+opt.n_future): row.append(gen_seq[t][i]) to_plot.append(row) fname = '%s/gen/rec_%d.png' % (opt.log_dir, epoch) utils.save_tensors_image(fname, to_plot)
def plot_1hot(epoch): nrow = opt.nclass ncol = int(opt.batch_size/nrow) # different class per row y_onehot = torch.Tensor(opt.batch_size, opt.nclass, 1, 1).cuda().zero_() for i in range(nrow): for j in range(ncol): y_onehot[i*ncol+j][i] = 1 gen = netG(torch.cat([z_fixed, y_onehot], 1)) to_plot = [] for i in range(nrow): row = [] for j in range(ncol): row.append(gen[i*ncol+j]) to_plot.append(row) fname = '%s/gen/%d.png' % (opt.log_dir, epoch) utils.save_tensors_image(fname, to_plot)
def plot_rec(model, x, epoch): model.init_states(x[0]) gen_seq = [x[0]] for i in range(1, opt.n_past + opt.n_future): if i < opt.n_past: hs_rec, feats, zs, mus, logvars = model.reconstruction(x[i]) model.skips = feats gen_seq.append(x[i]) else: hs_rec, feats, zs, mus, logvars = model.reconstruction(x[i]) x_rec = model.decoding(hs_rec) gen_seq.append(x_rec) to_plot = [] nrow = min(opt.batch_size, 10) for i in range(nrow): row = [] for t in range(opt.n_past + opt.n_future): row.append(gen_seq[t][i]) to_plot.append(row) fname = '%s/gen/rec_%d.png' % (checkpoint_dir, epoch) utils.save_tensors_image(fname, to_plot)
def plot_ind(x, epoch): x_c = x[0] h_c = netEC(x_c) nrow = 10 row_sz = opt.max_step to_plot = [] row = [xi[0].data for xi in x] zeros = torch.zeros(opt.channels, opt.image_width, opt.image_width) to_plot.append([zeros] + row) for i in range(nrow): to_plot.append([x[0][i].data]) for j in range(0, row_sz): h_p = netEP(x[j]).data for i in range(nrow): h_p[i] = h_p[0] rec = netD([h_c, Variable(h_p)]) for i in range(nrow): to_plot[i + 1].append(rec[i].data.clone()) fname = '%s/analogy/%d.png' % (opt.log_dir, epoch) utils.save_tensors_image(fname, to_plot)
def plot_rec(x, epoch): x, y, yp = x # convert the integer y into a one_hot representation for decoder if opt.all_labels: y_onehot = yp else: y_onehot = torch.Tensor(opt.batch_size, opt.nclass).cuda().zero_() y_onehot.scatter_(1, y.data.view(opt.batch_size, 1).long(), 1) y_onehot = y_onehot.view(opt.batch_size, opt.nclass, 1, 1) rec, _, _ = model((x, y_onehot)) to_plot = [] nrow, ncol = 8, 8 for i in range(nrow): row = [] for j in range(ncol): row.append(x[i * ncol + j]) row.append(rec[i * ncol + j]) to_plot.append(row) utils.save_tensors_image('%s/rec/%d.png' % (opt.log_dir, epoch), to_plot)
def plot_analogy(x, epoch): x_c = x[0] h_c = netEC(x_c) nrow = 10 row_sz = opt.max_step to_plot = [] row = [xi[0].data for xi in x] zeros = torch.zeros(opt.channels, opt.image_width, opt.image_width) to_plot.append([zeros] + row) for i in range(nrow): to_plot.append([x[0][i].data]) for j in range(0, row_sz): h_p = netEP(x[j]).data for i in range(nrow): h_p[i] = h_p[0] rec = netD([h_c, Variable(h_p)]) for i in range(nrow): to_plot[i+1].append(rec[i].data.clone()) fname = '%s/analogy/%d.png' % (opt.log_dir, epoch) utils.save_tensors_image(fname, to_plot)
def plot_gen(x, epoch): # get fixed content vector from last ground truth frame h_c = netEC(x[opt.n_past-1]) if type(h_c) is tuple: vec_h_c = h_c[0].detach() else: vec_h_c = h_c.detach() lstm.hidden = lstm.init_hidden() gen_seq = [] h_p = netEP(x[0]).detach() gen_seq.append(x[0]) for i in range(1, opt.n_past+opt.n_future): if i < opt.n_past: lstm(torch.cat([h_p, vec_h_c], 1)) h_p =netEP(x[i]).detach() gen_seq.append(x[i]) else: # print('h_p shape: ', h_p.shape) # print('vec_h_c shape: ', vec_h_c.shape) h_p = h_p.view([-1, h_p.shape[1], 1, 1]) h_p = lstm(torch.cat([h_p, vec_h_c], 1)) # print('h_p', h_p.size()) # print('h_c', h_c.shape) pred_x = netD([h_c, h_p]).detach() # print('pred_x', pred_x) gen_seq.append(pred_x) to_plot = [] nrow = 10 for i in range(nrow): row = [] for t in range(opt.n_past+opt.n_future): row.append(gen_seq[t][i]) to_plot.append(row) fname = '%s/gen/gen_%d.png' % (opt.log_dir, epoch) utils.save_tensors_image(fname, to_plot)
def plot(x, epoch): h_c = netEC(x[opt.n_past-1]) lstm.hidden = lstm.init_hidden() gen_seq = [] for i in range(opt.n_past): h_p = netEP(x[i]).detach() lstm(h_p) gen_seq.append(x[i]) for i in range(opt.n_past, opt.n_past+opt.n_future): h_p = netEP(x[i]).detach() lstm(h_p) pred_x = netD([h_c, h_p]) gen_seq.append(pred_x) to_plot = [] nrow = 10 for i in range(nrow): row = [] for t in range(opt.n_past+opt.n_future): row.append(gen_seq[t][i]) to_plot.append(row) fname = '%s/gen/%d.png' % (opt.log_dir, epoch) utils.save_tensors_image(fname, to_plot)
def plot(x, y, epoch): nsample = 20 gen_seq = [[] for i in range(nsample)] gt_seq = [x[i] for i in range(len(x))] for s in range(nsample): frame_predictor.hidden = frame_predictor.init_hidden() posterior_mu.hidden = posterior_mu.init_hidden() prior_mu.hidden = prior_mu.init_hidden() gen_seq[s].append(x[0]) x_in = x[0] h_match_prev = [encoder(y[m][0])[0].detach() for m in range(5)] for i in range(1, opt.n_eval): h_match = [encoder(y[m][i])[0].detach() for m in range(5)] h = encoder(x_in) if opt.last_frame_skip or i < opt.n_past: h, skip = h else: h, _ = h h = h.detach() if i < opt.n_past: h_target = encoder(x[i]) h_target = h_target[0].detach() mu = posterior_mu(h_target) mu_p = prior_mu(torch.cat([h_match[m] - h_match_prev[m] for m in range(5)] + [h], -1)) ref_var = torch.mean(torch.cat([var_encoder(h_match[m] - h_match_prev[m] + h).unsqueeze(1) for m in range(5)], 1), 1) z_t = reparameterize(mu, ref_var) frame_predictor(torch.cat([h_match[m] - h_match_prev[m] for m in range(5)] + [h, z_t], 1)) x_in = x[i] gen_seq[s].append(x_in) else: mu_p = prior_mu(torch.cat([h_match[m] - h_match_prev[m] for m in range(5)] + [h], -1)) ref_var = torch.mean(torch.cat([var_encoder(h_match[m] - h_match_prev[m] + h).unsqueeze(1) for m in range(5)], 1), 1) z_t = reparameterize(mu_p, ref_var) h = frame_predictor(torch.cat([h_match[m] - h_match_prev[m] for m in range(5)] + [h, z_t], 1)).detach() x_in = decoder([h, skip]).detach() gen_seq[s].append(x_in) h_match_prev = h_match to_plot = [] gifs = [ [] for t in range(opt.n_eval) ] nrow = min(opt.batch_size, 10) for i in range(nrow): # ground truth sequence row = [] for t in range(opt.n_eval): row.append(gt_seq[t][i]) to_plot.append(row) # best sequence min_mse = 1e7 for s in range(nsample): mse = 0 for t in range(opt.n_eval): mse += torch.sum( (gt_seq[t][i].data.cpu() - gen_seq[s][t][i].data.cpu())**2 ) if mse < min_mse: min_mse = mse min_idx = s s_list = [min_idx, np.random.randint(nsample), np.random.randint(nsample), np.random.randint(nsample), np.random.randint(nsample)] for ss in range(len(s_list)): s = s_list[ss] row = [] for t in range(opt.n_eval): row.append(gen_seq[s][t][i]) to_plot.append(row) for t in range(opt.n_eval): row = [] row.append(gt_seq[t][i]) for ss in range(len(s_list)): s = s_list[ss] row.append(gen_seq[s][t][i]) gifs[t].append(row) fname = '%s/gen/sample_%d.png' % (opt.log_dir, epoch) utils.save_tensors_image(fname, to_plot) fname = '%s/gen/sample_%d.gif' % (opt.log_dir, epoch) utils.save_gif(fname, gifs)
def plot(x, epoch): nsample = 20 gen_seq = [[] for i in range(nsample)] gt_seq = [x[i] for i in range(len(x))] for s in range(nsample): frame_predictor.hidden = frame_predictor.init_hidden() posterior.hidden = posterior.init_hidden() prior.hidden = prior.init_hidden() gen_seq[s].append(x[0]) x_in = x[0] for i in range(1, opt.n_eval): h = encoder(x_in) if opt.last_frame_skip or i < opt.n_past: h, skip = h else: h, _ = h h = h.detach() if i < opt.n_past: h_target = encoder(x[i]) h_target = h_target[0].detach() z_t, _, _ = posterior(h_target) prior(h) frame_predictor(torch.cat([h, z_t], 1)) x_in = x[i] gen_seq[s].append(x_in) else: z_t, _, _ = prior(h) h = frame_predictor(torch.cat([h, z_t], 1)).detach() x_in = decoder([h, skip]).detach() gen_seq[s].append(x_in) to_plot = [] gifs = [[] for t in range(opt.n_eval)] nrow = min(opt.batch_size, 10) for i in range(nrow): # ground truth sequence row = [] for t in range(opt.n_eval): row.append(gt_seq[t][i]) to_plot.append(row) # best sequence min_mse = 1e7 for s in range(nsample): mse = 0 for t in range(opt.n_eval): mse += torch.sum( (gt_seq[t][i].data.cpu() - gen_seq[s][t][i].data.cpu())**2) if mse < min_mse: min_mse = mse min_idx = s s_list = [ min_idx, np.random.randint(nsample), np.random.randint(nsample), np.random.randint(nsample), np.random.randint(nsample) ] for ss in range(len(s_list)): s = s_list[ss] row = [] for t in range(opt.n_eval): row.append(gen_seq[s][t][i]) to_plot.append(row) for t in range(opt.n_eval): row = [] row.append(gt_seq[t][i]) for ss in range(len(s_list)): s = s_list[ss] row.append(gen_seq[s][t][i]) gifs[t].append(row) fname = '%s/gen_1/sample_%d.png' % (opt.log_dir, epoch) utils.save_tensors_image(fname, to_plot) fname = '%s/gen_1/sample_%d.gif' % (opt.log_dir, epoch) utils.save_gif(fname, gifs)
def plot(x, epoch): nsample = 20 gen_seq = [[] for i in range(nsample)] gt_seq = [x[i] for i in range(len(x))] for s in range(nsample): """ frame_predictor.init_hidden() posterior.init_hidden() prior.init_hidden() """ gen_seq[s].append(x[0]) x_in = x[0] for i in range(1, opt.max_step): x_in, _, _, _, _ = frame_generator(x_in, i - 1) gen_seq[s].append(x_in) to_plot = [] gifs = [[] for t in range(opt.n_eval)] nrow = min(opt.batch_size, 10) for i in range(nrow): # ground truth sequence row = [] for t in range(opt.n_eval): row.append(gt_seq[t][i]) to_plot.append(row) # best sequence min_mse = 1e7 for s in range(nsample): mse = 0 for t in range(opt.n_eval): mse += torch.sum( (gt_seq[t][i].data.cpu() - gen_seq[s][t][i].data.cpu())**2) if mse < min_mse: min_mse = mse min_idx = s s_list = [ min_idx, np.random.randint(nsample), np.random.randint(nsample), np.random.randint(nsample), np.random.randint(nsample) ] for ss in range(len(s_list)): s = s_list[ss] row = [] for t in range(opt.n_eval): row.append(gen_seq[s][t][i]) to_plot.append(row) for t in range(opt.n_eval): row = [] row.append(gt_seq[t][i]) for ss in range(len(s_list)): s = s_list[ss] row.append(gen_seq[s][t][i]) gifs[t].append(row) fname = '%s/gen/sample_%d.png' % (opt.log_dir, epoch) utils.save_tensors_image(fname, to_plot) fname = '%s/gen/sample_%d.gif' % (opt.log_dir, epoch) utils.save_gif(fname, gifs)
def plot(model, x, epoch): nsample = 20 gen_seq = [[x[0]] for i in range(nsample)] gt_seq = [x[i] for i in range(len(x))] for s in range(nsample): ## initialization model.init_states(x[0]) ## prediction for i in range(1, opt.n_eval): if i < opt.n_past: hs_rec, feats, zs, mus, logvars = model.reconstruction(x[i]) model.skips = feats gen_seq[s].append(x[i]) else: x_pred = model.inference() gen_seq[s].append(x_pred) to_plot = [] gifs = [[] for t in range(opt.n_eval)] nrow = min(opt.batch_size, 10) for i in range(nrow): # ground truth sequence row = [] for t in range(opt.n_eval): row.append(gt_seq[t][i]) to_plot.append(row) # best sequence min_mse = 1e7 for s in range(nsample): mse = 0 for t in range(opt.n_eval): mse += torch.sum( (gt_seq[t][i].data.cpu() - gen_seq[s][t][i].data.cpu())**2) if mse < min_mse: min_mse = mse min_idx = s s_list = [ min_idx, np.random.randint(nsample), np.random.randint(nsample), np.random.randint(nsample), np.random.randint(nsample) ] for ss in range(len(s_list)): s = s_list[ss] row = [] for t in range(opt.n_eval): row.append(gen_seq[s][t][i]) to_plot.append(row) for t in range(opt.n_eval): row = [] row.append(gt_seq[t][i]) for ss in range(len(s_list)): s = s_list[ss] row.append(gen_seq[s][t][i]) gifs[t].append(row) fname = '%s/gen/sample_%d.png' % (checkpoint_dir, epoch) utils.save_tensors_image(fname, to_plot) fname = '%s/gen/sample_%d.gif' % (checkpoint_dir, epoch) utils.save_gif(fname, gifs)
def make_gifs(x, idx, name): # get approx posterior sample frame_predictor.hidden = frame_predictor.init_hidden() posterior.hidden = posterior.init_hidden() posterior_gen = [] posterior_gen.append(x[0]) x_in = x[0] for i in range(1, opt.n_eval): h = encoder(x_in) h_target = encoder(x[i])[0].detach() if opt.last_frame_skip or i < opt.n_past: h, skip = h else: h, _ = h h = h.detach() _, z_t, _= posterior(h_target) # take the mean if i < opt.n_past: frame_predictor(torch.cat([h, z_t], 1)) posterior_gen.append(x[i]) x_in = x[i] else: h_pred = frame_predictor(torch.cat([h, z_t], 1)).detach() x_in = decoder([h_pred, skip]).detach() posterior_gen.append(x_in) nsample = opt.nsample ssim = np.zeros((opt.batch_size, nsample, opt.n_future)) psnr = np.zeros((opt.batch_size, nsample, opt.n_future)) var_np = np.zeros((opt.batch_size, nsample, opt.n_future, opt.z_dim)) progress = progressbar.ProgressBar(max_value=nsample).start() all_gen = [] for s in range(nsample): progress.update(s+1) gen_seq = [] gt_seq = [] frame_predictor.hidden = frame_predictor.init_hidden() posterior.hidden = posterior.init_hidden() prior.hidden = prior.init_hidden() x_in = x[0] all_gen.append([]) all_gen[s].append(x_in) for i in range(1, opt.n_eval): h = encoder(x_in) if opt.last_frame_skip or i < opt.n_past: h, skip = h else: h, _ = h h = h.detach() if i < opt.n_past: h_target = encoder(x[i])[0].detach() z_t, _, _ = posterior(h_target) prior(h) frame_predictor(torch.cat([h, z_t], 1)) x_in = x[i] all_gen[s].append(x_in) else: z_t, mu, logvar = prior(h) h = frame_predictor(torch.cat([h, z_t], 1)).detach() x_in = decoder([h, skip]).detach() gen_seq.append(x_in.data.cpu().numpy()) gt_seq.append(x[i].data.cpu().numpy()) all_gen[s].append(x_in) var = torch.exp(logvar) # BxC var_np[:,s, i - opt.n_past,:] = var.data.cpu().numpy() _, ssim[:, s, :], psnr[:, s, :] = utils.eval_seq(gt_seq, gen_seq) progress.finish() utils.clear_progressbar() best_ssim = np.zeros((opt.batch_size, opt.n_future)) best_psnr = np.zeros((opt.batch_size, opt.n_future)) best_ssim_var = np.zeros((opt.batch_size, opt.n_future, opt.z_dim)) best_psnr_var = np.zeros((opt.batch_size, opt.n_future, opt.z_dim)) ###### ssim ###### for i in range(opt.batch_size): gifs = [ [] for t in range(opt.n_eval) ] text = [ [] for t in range(opt.n_eval) ] mean_ssim = np.mean(ssim[i], 1) ordered = np.argsort(mean_ssim) best_ssim[i,:] = ssim[i,ordered[-1],:] # best ssim var best_ssim_var[i,:,:] = var_np[i, ordered[-1], :, :] mean_psnr = np.mean(psnr[i], 1) ordered_p = np.argsort(mean_psnr) best_psnr[i,:] = psnr[i, ordered_p[-1],:] # best psnr var best_psnr_var[i,:,:] = var_np[i, ordered_p[-1], :, :] rand_sidx = [np.random.randint(nsample) for s in range(3)] for t in range(opt.n_eval): # gt gifs[t].append(add_border(x[t][i], 'green')) text[t].append('Ground\ntruth') #posterior if t < opt.n_past: color = 'green' else: color = 'red' gifs[t].append(add_border(posterior_gen[t][i], color)) text[t].append('Approx.\nposterior') # best if t < opt.n_past: color = 'green' else: color = 'red' sidx = ordered[-1] gifs[t].append(add_border(all_gen[sidx][t][i], color)) text[t].append('Best SSIM') # random 3 for s in range(len(rand_sidx)): gifs[t].append(add_border(all_gen[rand_sidx[s]][t][i], color)) text[t].append('Random\nsample %d' % (s+1)) #fname = '%s/%s_%d.gif' % (opt.log_dir, name, idx+i) fname = '%s/quality_results/%s_%d.gif' % (opt.log_dir, name, idx+i) utils.save_gif_with_text(fname, gifs, text) # -- generate samples to_plot = [] gts = [] best_s = [] best_p = [] rand_samples = [[] for s in range(len(rand_sidx))] for t in range(opt.n_eval): # gt gts.append(x[t][i]) best_s.append(all_gen[ordered[-1]][t][i]) best_p.append(all_gen[ordered_p[-1]][t][i]) # sample for s in range(len(rand_sidx)): rand_samples[s].append(all_gen[rand_sidx[s]][t][i]) to_plot.append(gts) to_plot.append(best_s) to_plot.append(best_p) for s in range(len(rand_sidx)): to_plot.append(rand_samples[s]) fname = '%s/quality_results/%s_%d.png' % (opt.log_dir, name, idx+i) utils.save_tensors_image(fname, to_plot) return best_ssim, best_psnr, best_ssim_var, best_psnr_var
def make_gifs(x, idx, name): # sample from approx posterior posterior_gen = [] posterior_gen.append(x[0]) x_in = x[0] # rec hsvg_net.init_states(x[0]) for i in range(1, opt.n_eval): if i < opt.n_past: hs_rec, feats, zs, mus, logvars = hsvg_net.reconstruction(x[i]) hsvg_net.skips = feats posterior_gen.append(x[i]) else: hs_rec, feats, zs, mus, logvars = hsvg_net.reconstruction(x[i]) x_rec = hsvg_net.decoding(hs_rec) posterior_gen.append(x_rec) # sample from prior nsample = opt.nsample ssim = np.zeros((opt.batch_size, nsample, opt.n_future)) psnr = np.zeros((opt.batch_size, nsample, opt.n_future)) # variance var_np = np.zeros((opt.batch_size, nsample, opt.n_future, opt.z_dim)) all_gen = [] gt_seq = [x[i].data.cpu().numpy() for i in range(opt.n_past, opt.n_eval)] for s in tqdm(range(nsample)): gen_seq = [] x_in = x[0] all_gen.append([]) all_gen[s].append(x_in) hsvg_net.init_states(x[0]) for i in range(1, opt.n_eval): if i < opt.n_past: hs_rec, feats, zs, mus, logvars = hsvg_net.reconstruction(x[i]) hsvg_net.skips = feats all_gen[s].append(x[i]) else: x_pred = hsvg_net.inference() gen_seq.append(x_pred.data.cpu().numpy()) all_gen[s].append(x_pred) logvar = torch.cat(hsvg_net.logvars_prior, -1) var = torch.exp(logvar) # BxC var_np[:,s, i - opt.n_past,:] = var.data.cpu().numpy() _, ssim[:, s, :], psnr[:, s, :] = utils.eval_seq(gt_seq, gen_seq) best_ssim = np.zeros((opt.batch_size, opt.n_future)) best_psnr = np.zeros((opt.batch_size, opt.n_future)) best_ssim_var = np.zeros((opt.batch_size, opt.n_future, opt.z_dim)) best_psnr_var = np.zeros((opt.batch_size, opt.n_future, opt.z_dim)) ###### ssim ###### for i in range(opt.batch_size): gifs = [ [] for t in range(opt.n_eval) ] text = [ [] for t in range(opt.n_eval) ] mean_ssim = np.mean(ssim[i], 1) ordered = np.argsort(mean_ssim) best_ssim[i,:] = ssim[i,ordered[-1],:] # best ssim var best_ssim_var[i,:,:] = var_np[i, ordered[-1], :, :] mean_psnr = np.mean(psnr[i], 1) ordered_p = np.argsort(mean_psnr) best_psnr[i,:] = psnr[i, ordered_p[-1],:] # best psnr var best_psnr_var[i,:,:] = var_np[i, ordered_p[-1], :, :] rand_sidx = [np.random.randint(nsample) for s in range(3)] # -- generate gifs for t in range(opt.n_eval): # gt gifs[t].append(add_border(x[t][i], 'green')) text[t].append('Ground\ntruth') #posterior if t < opt.n_past: color = 'green' else: color = 'red' gifs[t].append(add_border(posterior_gen[t][i], color)) text[t].append('Approx.\nposterior') # best if t < opt.n_past: color = 'green' else: color = 'red' sidx = ordered[-1] gifs[t].append(add_border(all_gen[sidx][t][i], color)) text[t].append('Best SSIM') # random 3 for s in range(len(rand_sidx)): gifs[t].append(add_border(all_gen[rand_sidx[s]][t][i], color)) text[t].append('Random\nsample %d' % (s+1)) fname = '%s/quality_results/%s_%d.gif' % (opt.log_dir, name, idx+i) utils.save_gif_with_text(fname, gifs, text) # -- generate samples to_plot = [] gts = [] best_s = [] best_p = [] rand_samples = [[] for s in range(len(rand_sidx))] for t in range(opt.n_eval): # gt gts.append(x[t][i]) best_s.append(all_gen[ordered[-1]][t][i]) best_p.append(all_gen[ordered_p[-1]][t][i]) # sample for s in range(len(rand_sidx)): rand_samples[s].append(all_gen[rand_sidx[s]][t][i]) to_plot.append(gts) to_plot.append(best_s) to_plot.append(best_p) for s in range(len(rand_sidx)): to_plot.append(rand_samples[s]) fname = '%s/quality_results/%s_%d.png' % (opt.log_dir, name, idx+i) utils.save_tensors_image(fname, to_plot) return best_ssim, best_psnr, best_ssim_var, best_psnr_var
def make_gifs(x, idx, name): # get approx posterior sample posterior_gen = [] posterior_gen.append(x[0]) x_in = x[0] # ------------ calculate the content posterior xs = [] for i in range(0, opt.n_past): xs.append(x[i]) #if True: random.shuffle(xs) #xc = torch.cat(xs, 1) mu_c, logvar_c, skip = cont_encoder(torch.cat(xs, 1)) mu_c = mu_c.detach() for i in range(1, opt.n_eval): h_target = pose_encoder(x[i]).detach() mu_t_p, logvar_t_p = posterior_pose(torch.cat([h_target, mu_c], 1), time_step=i - 1) z_t_p = utils.reparameterize(mu_t_p, logvar_t_p) if i < opt.n_past: frame_predictor(torch.cat([z_t_p, mu_c], 1), time_step=i - 1) posterior_gen.append(x[i]) x_in = x[i] else: h_pred = frame_predictor(torch.cat([z_t_p, mu_c], 1), time_step=i - 1).detach() x_in = decoder([h_pred, skip]).detach() posterior_gen.append(x_in) nsample = opt.nsample ssim = np.zeros((opt.batch_size, nsample, opt.n_future)) psnr = np.zeros((opt.batch_size, nsample, opt.n_future)) #ccm_pred = np.zeros((opt.batch_size, nsample, opt.n_eval-1)) #ccm_gt = np.zeros((opt.batch_size, opt.n_eval-1)) progress = progressbar.ProgressBar(maxval=nsample).start() all_gen = [] '''for i in range(1, opt.n_eval): out_gt = discriminator(torch.cat([x[0], x[i]],dim=1)) ccm_i_gt = out_gt.mean().data.cpu().numpy() print('time step %d, mean out gt: %.4f'%(i,ccm_i_gt)) ccm_gt[:,i-1] = out_gt.squeeze().data.cpu().numpy()''' hs = [] for i in range(0, opt.n_past): hs.append(pose_encoder(x[i]).detach()) for s in range(nsample): progress.update(s + 1) gen_seq = [] gt_seq = [] x_in = x[0] all_gen.append([]) all_gen[s].append(x_in) h = pose_encoder(x[0]).detach() for i in range(1, opt.n_eval): h_target = pose_encoder(x[i]).detach() if i < opt.n_past: mu_t_p, logvar_t_p = posterior_pose(torch.cat([h_target, mu_c], 1), time_step=i - 1) z_t_p = utils.reparameterize(mu_t_p, logvar_t_p) prior(torch.cat([h, mu_c], 1), time_step=i - 1) frame_predictor(torch.cat([z_t_p, mu_c], 1), time_step=i - 1) x_in = x[i] all_gen[s].append(x_in) h = h_target else: mu_t_pp, logvar_t_pp = prior(torch.cat([h, mu_c], 1), time_step=i - 1) z_t = utils.reparameterize(mu_t_pp, logvar_t_pp) h_pred = frame_predictor(torch.cat([z_t, mu_c], 1), time_step=i - 1).detach() x_in = decoder([h_pred, skip]).detach() gen_seq.append(x_in.data.cpu().numpy()) gt_seq.append(x[i].data.cpu().numpy()) all_gen[s].append(x_in) h = pose_encoder(x_in).detach() #out_pred = discriminator(torch.cat([x[0],x_in],dim=1)) #ccm_i_pred = out_pred.mean().data.cpu().numpy() #print('time step %d, mean out pred: %.4f'%(i,ccm_i_pred)) #ccm_pred[:, s, i-1] = out_pred.squeeze().data.cpu().numpy() _, ssim[:, s, :], psnr[:, s, :] = utils.eval_seq(gt_seq, gen_seq) progress.finish() utils.clear_progressbar() best_ssim = np.zeros((opt.batch_size, opt.n_future)) best_psnr = np.zeros((opt.batch_size, opt.n_future)) ###### ssim ###### for i in range(opt.batch_size): gifs = [[] for t in range(opt.n_eval)] text = [[] for t in range(opt.n_eval)] mean_ssim = np.mean(ssim[i], 1) ordered = np.argsort(mean_ssim) best_ssim[i, :] = ssim[i, ordered[-1], :] mean_psnr = np.mean(psnr[i], 1) ordered_p = np.argsort(mean_psnr) best_psnr[i, :] = psnr[i, ordered_p[-1], :] rand_sidx = [np.random.randint(nsample) for s in range(3)] # -- generate gifs for t in range(opt.n_eval): # gt gifs[t].append(add_border(x[t][i], 'green')) text[t].append('Ground\ntruth') #posterior if t < opt.n_past: color = 'green' else: color = 'red' gifs[t].append(add_border(posterior_gen[t][i], color)) text[t].append('Approx.\nposterior') # best if t < opt.n_past: color = 'green' else: color = 'red' sidx = ordered[-1] gifs[t].append(add_border(all_gen[sidx][t][i], color)) text[t].append('Best SSIM') # random 3 for s in range(len(rand_sidx)): gifs[t].append(add_border(all_gen[rand_sidx[s]][t][i], color)) text[t].append('Random\nsample %d' % (s + 1)) fname = '%s/samples/%s_%d.gif' % (opt.log_dir, name, idx + i) utils.save_gif_with_text(fname, gifs, text) # -- generate samples to_plot = [] gts = [] best_p = [] rand_samples = [[] for s in range(len(rand_sidx))] for t in range(opt.n_eval): # gt gts.append(x[t][i]) best_p.append(all_gen[ordered_p[-1]][t][i]) # sample for s in range(len(rand_sidx)): rand_samples[s].append(all_gen[rand_sidx[s]][t][i]) to_plot.append(gts) to_plot.append(best_p) for s in range(len(rand_sidx)): to_plot.append(rand_samples[s]) fname = '%s/samples/%s_%d.png' % (opt.log_dir, name, idx + i) utils.save_tensors_image(fname, to_plot) return best_ssim, best_psnr #, ccm_pred, ccm_gt