def train(self, data_loader): print('Training...') with torch.autograd.set_detect_anomaly(True): self.epoch += 1 self.G.train() self.D.train() record_G = utils.Record() record_D = utils.Record() start_time = time.time() progress = progressbar.ProgressBar(maxval=len(data_loader)).start() for i, (trace, image) in enumerate(data_loader): progress.update(i + 1) trace = trace.cuda() image = image.cuda() self.D.zero_grad() # update D with real images real_output = self.D(image) err_D_real = self.loss(real_output, self.real_label) D_x = real_output.data.mean() # update D with reconstructed images fake_input, *_ = self.trace2image(trace) fake_refine = self.G(fake_input) fake_output = self.D(fake_refine.detach()) err_D_fake = self.loss(fake_output, self.fake_label) D_G_z = fake_output.data.mean() err_D = err_D_fake + err_D_real err_D.backward() self.optimizerD.step() self.G.zero_grad() # update G fake_output = self.D(fake_refine) err_G = self.loss(fake_output, self.real_label) err_G.backward() self.optimizerG.step() record_D.add(err_D.item()) record_G.add(err_G.item()) progress.finish() utils.clear_progressbar() print('----------------------------------------') print('Epoch: %d' % self.epoch) print('Costs time: %.2f s' % (time.time() - start_time)) print('Loss of G: %f' % (record_G.mean())) print('Loss of D: %f' % (record_D.mean())) print('D(x): %f, D(G(z)): %f' % (D_x, D_G_z)) print('----------------------------------------') utils.save_image(image.data, ('%s/image/test/target_%03d.jpg' % (self.args['gan_dir'], self.epoch))) utils.save_image(trace2image.data, ('%s/image/test/tr2im_%03d.jpg' % (self.args['gan_dir'], self.epoch))) utils.save_image(image2image.data, ('%s/image/test/im2im_%03d.jpg' % (self.args['gan_dir'], self.epoch)))
def train(self, data_loader): print('Training...') with torch.autograd.set_detect_anomaly(True): self.epoch += 1 self.set_train() record_trace = utils.Record() record_image = utils.Record() record_inter = utils.Record() record_kld = utils.Record() start_time = time.time() progress = progressbar.ProgressBar(maxval=len(data_loader)).start() for i, (trace, image) in enumerate(data_loader): progress.update(i + 1) trace = trace.cuda() image = image.cuda() self.zero_grad() trace_embed = self.TraceEncoder(trace) image_embed = self.ImageEncoder(image) trace_mu, trace_logvar = trace_embed, trace_embed image_mu, image_logvar = image_embed, image_embed trace_z = utils.reparameterize(trace_mu, trace_logvar) image_z = utils.reparameterize(image_mu, image_logvar) trace2image, trace_inter = self.Decoder(trace_z) image2image, image_inter = self.Decoder(image_z) err_trace = self.l1(trace2image, image) err_image = self.l1(image2image, image) #err_inter = self.l2(trace_inter, image_inter) err_kld = self.kld(image_mu, image_logvar, trace_mu, trace_logvar) #(err_trace + err_image + err_inter + self.args['beta'] * err_kld).backward() (err_trace + err_image + self.args['beta'] * err_kld).backward() self.optimizer.step() record_trace.add(err_trace) record_image.add(err_image) #record_inter.add(err_inter) record_kld.add(err_kld) progress.finish() utils.clear_progressbar() print('----------------------------------------') print('Epoch: %d' % self.epoch) print('Costs time: %.2fs' % (time.time() - start_time)) print('Loss of Trace to Image: %f' % (record_trace.mean())) print('Loss of Image to Image: %f' % (record_image.mean())) print('Loss of KL-Divergence: %f' % (record_kld.mean())) print('----------------------------------------') utils.save_image(image.data, ('%s/image/train/target_%03d.jpg' % (self.args['vae_dir'], self.epoch))) utils.save_image(trace2image.data, ('%s/image/train/tr2im_%03d.jpg' % (self.args['vae_dir'], self.epoch))) utils.save_image(image2image.data, ('%s/image/train/im2im_%03d.jpg' % (self.args['vae_dir'], self.epoch)))
def make_gifs(x, idx, name, frame_predictor, posterior, prior, encoder, decoder): nsample = opt.nsample progress = progressbar.ProgressBar(max_value=nsample).start() all_gen = [] for s in range(nsample): progress.update(s + 1) gen_seq = [] gt_seq = [] frame_predictor.hidden = frame_predictor.init_hidden() posterior.hidden = posterior.init_hidden() prior.hidden = prior.init_hidden() x_in = x[0] all_gen.append([]) all_gen[s].append(x_in) for i in range(1, opt.n_eval): h = encoder(x_in) if opt.last_frame_skip or i < opt.n_past: h, skip = h else: h, _ = h h = h.detach() if i < opt.n_past: h_target = encoder(x[i])[0].detach() z_t, _, _ = posterior(h_target) prior(h) frame_predictor(torch.cat([h, z_t], 1)) x_in = x[i] all_gen[s].append(x_in) else: z_t, _, _ = prior(h) h = frame_predictor(torch.cat([h, z_t], 1)).detach() x_in = decoder([h, skip]).detach() gen_seq.append(x_in.data.cpu().numpy()) gt_seq.append(x[i].data.cpu().numpy()) all_gen[s].append(x_in) # prep np.array to be plotted TRU = np.zeros( [opt.n_eval, opt.batch_size, 1, opt.image_width, opt.image_width]) GEN = np.zeros([ nsample, opt.n_eval, opt.batch_size, 1, opt.image_width, opt.image_width ]) for i in range(opt.n_eval): TRU[i, :, :, :, :] = inv_scaler(x[i].cpu().numpy()) for k in range(nsample): GEN[k, i, :, :, :, :] = inv_scaler(all_gen[k][i].cpu().numpy()) # plot print(" ground truth max:", np.max(TRU), " gen max:", np.max(GEN)) for j in range(opt.batch_size): plot_rainfall(TRU[:, j, 0, :, :], GEN[:, :, j, 0, :, :], opt.log_dir, name + "_sample" + str(j), nsample) progress.finish() utils.clear_progressbar()
def run(): if not os.path.exists(CSV_DIRECTORY): os.makedirs(CSV_DIRECTORY) number_of_samples = count_xml_dumps() print('Creating stations.csv…') create_stations() print('Creating bikes.csv…') create_bikes(number_of_samples) clear_progressbar() print('Creating bike_positions.csv and bike_movements.csv…') create_bike_positions_and_movement(number_of_samples) clear_progressbar()
def test(self, data_loader): print('Testing...') with torch.no_grad(): self.G.eval() self.D.eval() record_G = utils.Record() record_D = utils.Record() start_time = time.time() progress = progressbar.ProgressBar(maxval=len(data_loader)).start() for i, (trace, image) in enumerate(data_loader): progress.update(i + 1) trace = trace.cuda() image = image.cuda() real_output = self.D(image) err_D_real = self.loss(real_output, self.real_label) D_x = real_output.data.mean() fake_input, *_ = self.trace2image(trace) fake_refine = self.G(fake_input) fake_output = self.D(fake_refine.detach()) err_D_fake = self.loss(fake_output, self.fake_label) D_G_z = fake_output.data.mean() err_D = err_D_fake + err_D_real fake_output = self.D(fake_refine) err_G = self.loss(fake_output, self.real_label) record_D.add(err_D.item()) record_G.add(err_G.item()) progress.finish() utils.clear_progressbar() print('----------------------------------------') print('Test at Epoch %d' % self.epoch) print('Costs time: %.2f s' % (time.time() - start_time)) print('Loss of G: %f' % (record_G.mean())) print('Loss of D: %f' % (record_D.mean())) print('D(x): %f, D(G(z)): %f' % (D_x, D_G_z)) print('----------------------------------------') utils.save_image(image.data, ('%s/image/test/target_%03d.jpg' % (self.args['gan_dir'], self.epoch))) utils.save_image(trace2image.data, ('%s/image/test/tr2im_%03d.jpg' % (self.args['gan_dir'], self.epoch))) utils.save_image(image2image.data, ('%s/image/test/im2im_%03d.jpg' % (self.args['gan_dir'], self.epoch)))
def render_hourly(session): date = START_DATE total_steps = (END_DATE - START_DATE) / STEP i = 0 while date < END_DATE: i += 1 print_progressbar(i / total_steps) graph = nx.MultiDiGraph() start = date end = (date + STEP) result = session.run( """ MATCH (a:Station)-[r:BIKE_MOVED]->(b:Station) WHERE {start} <= r.timestamp_start < {end} RETURN a, r, b""", { 'start': start.timestamp(), 'end': end.timestamp() }) for record in result: station_a = record['a']['name'].replace('/', ' /\n') station_b = record['b']['name'].replace('/', ' /\n') bike_id = record['r']['bike_id'] start_time = datetime.fromtimestamp( record['r']['timestamp_start']).strftime('%H\:%M') end_time = datetime.fromtimestamp( record['r']['timestamp_end']).strftime('%H\:%M') label = f'{start_time} -\n{end_time}' color = 'red' if record['r']['transporter'] else '#aaaaaa' penwidth = 2 if record['r']['transporter'] else 1 graph.add_edge(station_a, station_b, label=label, color=color, penwidth=penwidth) # graph.add_edge(station_a, station_b, label=bike_id) filename = f"{start.strftime('%Y-%m-%d_%H_%M')} - {end.strftime('%Y-%m-%d_%H_%M')}.dot" write_dot(graph, os.path.join(OUTPUT_DIRECTORY, filename)) date = end clear_progressbar()
encoder.train() decoder.train() epoch_mse = 0 epoch_kld = 0 progress = progressbar.ProgressBar(max_value=opt.epoch_size).start() for i in range(opt.epoch_size): progress.update(i + 1) x = next(training_batch_generator) # train frame_predictor mse, kld = train(x) epoch_mse += mse epoch_kld += kld progress.finish() utils.clear_progressbar() print('[%02d] mse loss: %.5f | kld loss: %.5f (%d)' % (epoch, epoch_mse / opt.epoch_size, epoch_kld / opt.epoch_size, epoch * opt.epoch_size * opt.batch_size)) # plot some stuff frame_predictor.eval() #encoder.eval() #decoder.eval() posterior.eval() prior.eval() x = next(testing_batch_generator) plot(x, epoch) plot_rec(x, epoch)
def make_gifs(x, idx, name): # get approx posterior sample frame_predictor.hidden = frame_predictor.init_hidden() posterior.hidden = posterior.init_hidden() posterior_gen = [] posterior_gen.append(x[0]) x_in = x[0] for i in range(1, opt.n_eval): h = encoder(x_in) h_target = encoder(x[i])[0].detach() if opt.last_frame_skip or i < opt.n_past: h, skip = h else: h, _ = h h = h.detach() _, z_t, _ = posterior(h_target) # take the mean if i < opt.n_past: frame_predictor(torch.cat([h, z_t], 1)) posterior_gen.append(x[i]) x_in = x[i] else: h_pred = frame_predictor(torch.cat([h, z_t], 1)).detach() x_in = decoder([h_pred, skip]).detach() posterior_gen.append(x_in) nsample = opt.nsample ssim = np.zeros((opt.batch_size, nsample, opt.n_future)) psnr = np.zeros((opt.batch_size, nsample, opt.n_future)) progress = progressbar.ProgressBar(max_value=nsample).start() all_gen = [] for s in range(nsample): progress.update(s + 1) gen_seq = [] gt_seq = [] frame_predictor.hidden = frame_predictor.init_hidden() posterior.hidden = posterior.init_hidden() x_in = x[0] all_gen.append([]) all_gen[s].append(x_in) for i in range(1, opt.n_eval): h = encoder(x_in) if opt.last_frame_skip or i < opt.n_past: h, skip = h else: h, _ = h h = h.detach() if i < opt.n_past: h_target = encoder(x[i])[0].detach() _, z_t, _ = posterior(h_target) else: z_t = torch.cuda.FloatTensor(opt.batch_size, opt.z_dim).normal_() if i < opt.n_past: frame_predictor(torch.cat([h, z_t], 1)) x_in = x[i] all_gen[s].append(x_in) else: h = frame_predictor(torch.cat([h, z_t], 1)).detach() x_in = decoder([h, skip]).detach() gen_seq.append(x_in.data.cpu().numpy()) gt_seq.append(x[i].data.cpu().numpy()) all_gen[s].append(x_in) _, ssim[:, s, :], psnr[:, s, :] = utils.eval_seq(gt_seq, gen_seq) progress.finish() utils.clear_progressbar() ###### ssim ###### for i in range(opt.batch_size): gifs = [[] for t in range(opt.n_eval)] text = [[] for t in range(opt.n_eval)] mean_ssim = np.mean(ssim[i], 1) ordered = np.argsort(mean_ssim) rand_sidx = [np.random.randint(nsample) for s in range(3)] for t in range(opt.n_eval): # gt gifs[t].append(add_border(x[t][i], 'green')) text[t].append('Ground\ntruth') #posterior if t < opt.n_past: color = 'green' else: color = 'red' gifs[t].append(add_border(posterior_gen[t][i], color)) text[t].append('Approx.\nposterior') # best if t < opt.n_past: color = 'green' else: color = 'red' sidx = ordered[-1] gifs[t].append(add_border(all_gen[sidx][t][i], color)) text[t].append('Best SSIM') # random 3 for s in range(len(rand_sidx)): gifs[t].append(add_border(all_gen[rand_sidx[s]][t][i], color)) text[t].append('Random\nsample %d' % (s + 1)) fname = '%s/%s_%d.gif' % (opt.log_dir, name, idx + i) utils.save_gif_with_text(fname, gifs, text)
def make_gifs(x, idx, name): # get approx posterior sample frame_predictor.hidden = frame_predictor.init_hidden() posterior.hidden = posterior.init_hidden() posterior_gen = [] posterior_gen.append(x[0]) x_in = x[0] for i in range(1, opt.n_eval): h = encoder(x_in) h_target = encoder(x[i])[0].detach() if opt.last_frame_skip or i < opt.n_past: h, skip = h else: h, _ = h h = h.detach() _, z_t, _= posterior(h_target) # take the mean if i < opt.n_past: frame_predictor(torch.cat([h, z_t], 1)) posterior_gen.append(x[i]) x_in = x[i] else: h_pred = frame_predictor(torch.cat([h, z_t], 1)).detach() x_in = decoder([h_pred, skip]).detach() posterior_gen.append(x_in) nsample = opt.nsample ssim = np.zeros((opt.batch_size, nsample, opt.n_future)) psnr = np.zeros((opt.batch_size, nsample, opt.n_future)) var_np = np.zeros((opt.batch_size, nsample, opt.n_future, opt.z_dim)) progress = progressbar.ProgressBar(max_value=nsample).start() all_gen = [] for s in range(nsample): progress.update(s+1) gen_seq = [] gt_seq = [] frame_predictor.hidden = frame_predictor.init_hidden() posterior.hidden = posterior.init_hidden() prior.hidden = prior.init_hidden() x_in = x[0] all_gen.append([]) all_gen[s].append(x_in) for i in range(1, opt.n_eval): h = encoder(x_in) if opt.last_frame_skip or i < opt.n_past: h, skip = h else: h, _ = h h = h.detach() if i < opt.n_past: h_target = encoder(x[i])[0].detach() z_t, _, _ = posterior(h_target) prior(h) frame_predictor(torch.cat([h, z_t], 1)) x_in = x[i] all_gen[s].append(x_in) else: z_t, mu, logvar = prior(h) h = frame_predictor(torch.cat([h, z_t], 1)).detach() x_in = decoder([h, skip]).detach() gen_seq.append(x_in.data.cpu().numpy()) gt_seq.append(x[i].data.cpu().numpy()) all_gen[s].append(x_in) var = torch.exp(logvar) # BxC var_np[:,s, i - opt.n_past,:] = var.data.cpu().numpy() _, ssim[:, s, :], psnr[:, s, :] = utils.eval_seq(gt_seq, gen_seq) progress.finish() utils.clear_progressbar() best_ssim = np.zeros((opt.batch_size, opt.n_future)) best_psnr = np.zeros((opt.batch_size, opt.n_future)) best_ssim_var = np.zeros((opt.batch_size, opt.n_future, opt.z_dim)) best_psnr_var = np.zeros((opt.batch_size, opt.n_future, opt.z_dim)) ###### ssim ###### for i in range(opt.batch_size): gifs = [ [] for t in range(opt.n_eval) ] text = [ [] for t in range(opt.n_eval) ] mean_ssim = np.mean(ssim[i], 1) ordered = np.argsort(mean_ssim) best_ssim[i,:] = ssim[i,ordered[-1],:] # best ssim var best_ssim_var[i,:,:] = var_np[i, ordered[-1], :, :] mean_psnr = np.mean(psnr[i], 1) ordered_p = np.argsort(mean_psnr) best_psnr[i,:] = psnr[i, ordered_p[-1],:] # best psnr var best_psnr_var[i,:,:] = var_np[i, ordered_p[-1], :, :] rand_sidx = [np.random.randint(nsample) for s in range(3)] for t in range(opt.n_eval): # gt gifs[t].append(add_border(x[t][i], 'green')) text[t].append('Ground\ntruth') #posterior if t < opt.n_past: color = 'green' else: color = 'red' gifs[t].append(add_border(posterior_gen[t][i], color)) text[t].append('Approx.\nposterior') # best if t < opt.n_past: color = 'green' else: color = 'red' sidx = ordered[-1] gifs[t].append(add_border(all_gen[sidx][t][i], color)) text[t].append('Best SSIM') # random 3 for s in range(len(rand_sidx)): gifs[t].append(add_border(all_gen[rand_sidx[s]][t][i], color)) text[t].append('Random\nsample %d' % (s+1)) #fname = '%s/%s_%d.gif' % (opt.log_dir, name, idx+i) fname = '%s/quality_results/%s_%d.gif' % (opt.log_dir, name, idx+i) utils.save_gif_with_text(fname, gifs, text) # -- generate samples to_plot = [] gts = [] best_s = [] best_p = [] rand_samples = [[] for s in range(len(rand_sidx))] for t in range(opt.n_eval): # gt gts.append(x[t][i]) best_s.append(all_gen[ordered[-1]][t][i]) best_p.append(all_gen[ordered_p[-1]][t][i]) # sample for s in range(len(rand_sidx)): rand_samples[s].append(all_gen[rand_sidx[s]][t][i]) to_plot.append(gts) to_plot.append(best_s) to_plot.append(best_p) for s in range(len(rand_sidx)): to_plot.append(rand_samples[s]) fname = '%s/quality_results/%s_%d.png' % (opt.log_dir, name, idx+i) utils.save_tensors_image(fname, to_plot) return best_ssim, best_psnr, best_ssim_var, best_psnr_var
# --------- training loop ------------------------------------ for epoch in range(opt.niter): lstm.train() epoch_loss = 0 progress = progressbar.ProgressBar(max_value=opt.epoch_size).start() for i in range(opt.epoch_size): progress.update(i+1) x = next(training_batch_generator) # train lstm loss = train(x) epoch_loss += loss progress.finish() utils.clear_progressbar() lstm.eval() # plot some stuff x = next(testing_batch_generator) plot_gen(x, epoch) plot_rec(x, epoch) print('[%02d] mse loss: %.6f (%d)' % (epoch, epoch_loss/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) # save the model torch.save({ 'lstm': lstm, 'opt': opt}, '%s/model.pth' % opt.log_dir) print('here')
def make_gifs(x, idx, name): # get approx posterior sample posterior_gen = [] posterior_gen.append(x[0]) x_in = x[0] # ------------ calculate the content posterior xs = [] for i in range(0, opt.n_past): xs.append(x[i]) #if True: random.shuffle(xs) #xc = torch.cat(xs, 1) mu_c, logvar_c, skip = cont_encoder(torch.cat(xs, 1)) mu_c = mu_c.detach() for i in range(1, opt.n_eval): h_target = pose_encoder(x[i]).detach() mu_t_p, logvar_t_p = posterior_pose(torch.cat([h_target, mu_c], 1), time_step=i - 1) z_t_p = utils.reparameterize(mu_t_p, logvar_t_p) if i < opt.n_past: frame_predictor(torch.cat([z_t_p, mu_c], 1), time_step=i - 1) posterior_gen.append(x[i]) x_in = x[i] else: h_pred = frame_predictor(torch.cat([z_t_p, mu_c], 1), time_step=i - 1).detach() x_in = decoder([h_pred, skip]).detach() posterior_gen.append(x_in) nsample = opt.nsample ssim = np.zeros((opt.batch_size, nsample, opt.n_future)) psnr = np.zeros((opt.batch_size, nsample, opt.n_future)) #ccm_pred = np.zeros((opt.batch_size, nsample, opt.n_eval-1)) #ccm_gt = np.zeros((opt.batch_size, opt.n_eval-1)) progress = progressbar.ProgressBar(maxval=nsample).start() all_gen = [] '''for i in range(1, opt.n_eval): out_gt = discriminator(torch.cat([x[0], x[i]],dim=1)) ccm_i_gt = out_gt.mean().data.cpu().numpy() print('time step %d, mean out gt: %.4f'%(i,ccm_i_gt)) ccm_gt[:,i-1] = out_gt.squeeze().data.cpu().numpy()''' hs = [] for i in range(0, opt.n_past): hs.append(pose_encoder(x[i]).detach()) for s in range(nsample): progress.update(s + 1) gen_seq = [] gt_seq = [] x_in = x[0] all_gen.append([]) all_gen[s].append(x_in) h = pose_encoder(x[0]).detach() for i in range(1, opt.n_eval): h_target = pose_encoder(x[i]).detach() if i < opt.n_past: mu_t_p, logvar_t_p = posterior_pose(torch.cat([h_target, mu_c], 1), time_step=i - 1) z_t_p = utils.reparameterize(mu_t_p, logvar_t_p) prior(torch.cat([h, mu_c], 1), time_step=i - 1) frame_predictor(torch.cat([z_t_p, mu_c], 1), time_step=i - 1) x_in = x[i] all_gen[s].append(x_in) h = h_target else: mu_t_pp, logvar_t_pp = prior(torch.cat([h, mu_c], 1), time_step=i - 1) z_t = utils.reparameterize(mu_t_pp, logvar_t_pp) h_pred = frame_predictor(torch.cat([z_t, mu_c], 1), time_step=i - 1).detach() x_in = decoder([h_pred, skip]).detach() gen_seq.append(x_in.data.cpu().numpy()) gt_seq.append(x[i].data.cpu().numpy()) all_gen[s].append(x_in) h = pose_encoder(x_in).detach() #out_pred = discriminator(torch.cat([x[0],x_in],dim=1)) #ccm_i_pred = out_pred.mean().data.cpu().numpy() #print('time step %d, mean out pred: %.4f'%(i,ccm_i_pred)) #ccm_pred[:, s, i-1] = out_pred.squeeze().data.cpu().numpy() _, ssim[:, s, :], psnr[:, s, :] = utils.eval_seq(gt_seq, gen_seq) progress.finish() utils.clear_progressbar() best_ssim = np.zeros((opt.batch_size, opt.n_future)) best_psnr = np.zeros((opt.batch_size, opt.n_future)) ###### ssim ###### for i in range(opt.batch_size): gifs = [[] for t in range(opt.n_eval)] text = [[] for t in range(opt.n_eval)] mean_ssim = np.mean(ssim[i], 1) ordered = np.argsort(mean_ssim) best_ssim[i, :] = ssim[i, ordered[-1], :] mean_psnr = np.mean(psnr[i], 1) ordered_p = np.argsort(mean_psnr) best_psnr[i, :] = psnr[i, ordered_p[-1], :] rand_sidx = [np.random.randint(nsample) for s in range(3)] # -- generate gifs for t in range(opt.n_eval): # gt gifs[t].append(add_border(x[t][i], 'green')) text[t].append('Ground\ntruth') #posterior if t < opt.n_past: color = 'green' else: color = 'red' gifs[t].append(add_border(posterior_gen[t][i], color)) text[t].append('Approx.\nposterior') # best if t < opt.n_past: color = 'green' else: color = 'red' sidx = ordered[-1] gifs[t].append(add_border(all_gen[sidx][t][i], color)) text[t].append('Best SSIM') # random 3 for s in range(len(rand_sidx)): gifs[t].append(add_border(all_gen[rand_sidx[s]][t][i], color)) text[t].append('Random\nsample %d' % (s + 1)) fname = '%s/samples/%s_%d.gif' % (opt.log_dir, name, idx + i) utils.save_gif_with_text(fname, gifs, text) # -- generate samples to_plot = [] gts = [] best_p = [] rand_samples = [[] for s in range(len(rand_sidx))] for t in range(opt.n_eval): # gt gts.append(x[t][i]) best_p.append(all_gen[ordered_p[-1]][t][i]) # sample for s in range(len(rand_sidx)): rand_samples[s].append(all_gen[rand_sidx[s]][t][i]) to_plot.append(gts) to_plot.append(best_p) for s in range(len(rand_sidx)): to_plot.append(rand_samples[s]) fname = '%s/samples/%s_%d.png' % (opt.log_dir, name, idx + i) utils.save_tensors_image(fname, to_plot) return best_ssim, best_psnr #, ccm_pred, ccm_gt
def make_gifs(x, idx, names): all_gt = x.copy() if opt.use_action: actions = x[1] x = x[0] # get approx posterior sample frame_predictor.hidden = frame_predictor.init_hidden() posterior.hidden = posterior.init_hidden() posterior_gen = [] posterior_gen.append(x[0]) x_in = x[0] for i in range(1, opt.n_eval): h = encoder(x_in) h_target = encoder(x[i])[0].detach() if opt.last_frame_skip or i < opt.n_past: h, skip = h else: h, _ = h h = h.detach() _, z_t, _ = posterior(h_target) # take the mean if i < opt.n_past: if not opt.use_action: frame_predictor(torch.cat([h, z_t], 1)) else: frame_predictor( torch.cat([h, z_t, actions[i - 1].repeat(1, opt.a_dim)], 1)) posterior_gen.append(x[i]) x_in = x[i] else: if not opt.use_action: h_pred = frame_predictor(torch.cat([h, z_t], 1)).detach() else: h_pred = frame_predictor( torch.cat([h, z_t, actions[i - 1].repeat(1, opt.a_dim)], 1)).detach() x_in = decoder([h_pred, skip]).detach() posterior_gen.append(x_in) nsample = opt.nsample ssim = np.zeros((opt.batch_size, nsample, opt.n_future)) psnr = np.zeros((opt.batch_size, nsample, opt.n_future)) progress = progressbar.ProgressBar(maxval=nsample).start() all_gen = [] for s in range(nsample): progress.update(s + 1) gen_seq = [] gt_seq = [] frame_predictor.hidden = frame_predictor.init_hidden() posterior.hidden = posterior.init_hidden() prior.hidden = prior.init_hidden() x_in = x[0] all_gen.append([]) all_gen[s].append(x_in) for i in range(1, opt.n_eval): h = encoder(x_in) if opt.last_frame_skip or i < opt.n_past: h, skip = h else: h, _ = h h = h.detach() if i < opt.n_past: h_target = encoder(x[i])[0].detach() z_t, _, _ = posterior(h_target) prior(h) if not opt.use_action: frame_predictor(torch.cat([h, z_t], 1)) else: frame_predictor( torch.cat( [h, z_t, actions[i - 1].repeat(1, opt.a_dim)], 1)) x_in = x[i] all_gen[s].append(x_in) else: z_t, _, _ = prior(h) if not opt.use_action: h = frame_predictor(torch.cat([h, z_t], 1)).detach() else: h = frame_predictor( torch.cat( [h, z_t, actions[i - 1].repeat(1, opt.a_dim)], 1)).detach() x_in = decoder([h, skip]).detach() gen_seq.append(x_in.data.cpu().numpy()) gt_seq.append(x[i].data.cpu().numpy()) all_gen[s].append(x_in) _, ssim[:, s, :], psnr[:, s, :] = utils.eval_seq(gt_seq, gen_seq) progress.finish() utils.clear_progressbar() ###### ssim ###### for i in range(opt.batch_size): for s in range(nsample): imgs = [] for t in range(opt.n_eval): img = all_gen[s][t][i].cpu().transpose(0, 1).transpose( 1, 2).clamp(0, 1).numpy() img = (img * 255).astype(np.uint8) imgs.append(img) fname = '%s/%s_%d.gif' % (opt.log_dir, names[i], s) utils.save_gif_IROS_2019(fname, imgs) # save ground truth imgs_gt = [] for t in range(opt.n_eval): img = all_gt[t][i].cpu().transpose(0, 1).transpose(1, 2).clamp( 0, 1).numpy() img = (img * 255).astype(np.uint8) imgs_gt.append(img) fname = '%s/%s_gt.gif' % (opt.log_dir, names[i]) utils.save_gif_IROS_2019(fname, imgs_gt)