def valid_vqvae(train_cnt, do_plot=False): vqvae_model.eval() #states, actions, rewards, next_states, terminals, is_new_epoch, relative_indexes = valid_data_loader.get_unique_minibatch() states, actions, rewards, values, pred_states, terminals, is_new_epoch, relative_indexes = valid_data_loader.get_framediff_minibatch( ) # because we have 4 layers in vqvae, need to be divisible by 2, 4 times states = (2 * reshape_input(states) - 1).to(DEVICE) rec = (2 * reshape_input(pred_states[:, 0][:, None]) - 1).to(DEVICE) diff = (2 * reshape_input(pred_states[:, 1][:, None]) - 1).to(DEVICE) actions = actions.to(DEVICE) values = values.to(DEVICE) x_d, z_e_x, z_q_x, latents, pred_actions, pred_values = vqvae_model(states) # (args.nr_logistic_mix/2)*3 is needed for each reconstruction z_q_x.retain_grad() rec_est = x_d[:, :nmix] diff_est = x_d[:, nmix:] loss_rec = discretized_mix_logistic_loss(rec_est, rec, nr_mix=args.nr_logistic_mix, DEVICE=DEVICE) loss_diff = discretized_mix_logistic_loss(diff_est, diff, nr_mix=args.nr_logistic_mix, DEVICE=DEVICE) loss_act = F.nll_loss(pred_actions, actions) loss_act.backward(retain_graph=True) loss_values = args.ralpha * F.mse_loss(pred_values, values) loss_values.backward(retain_graph=True) loss_2 = F.mse_loss(z_q_x, z_e_x.detach()) loss_3 = args.beta * F.mse_loss(z_e_x, z_q_x.detach()) bs, yc, yh, yw = x_d.shape yhat = sample_from_discretized_mix_logistic(rec_est, args.nr_logistic_mix) if do_plot: print('writing img') n_imgs = 8 n = min(states.shape[0], n_imgs) gold = (rec.to('cpu') + 1) / 2.0 bs, _, h, w = gold.shape # sample from discretized should be between 0 and 255 print("yhat sample", yhat[:, 0].min().item(), yhat[:, 0].max().item()) yimg = ((yhat + 1.0) / 2.0).to('cpu') print("yhat img", yhat.min().item(), yhat.max().item()) print("gold img", gold.min().item(), gold.max().item()) comparison = torch.cat( [gold.view(bs, 1, h, w)[:n], yimg.view(bs, 1, h, w)[:n]]) img_name = model_base_filepath + "_%010d_valid_reconstruction.png" % train_cnt save_image(comparison, img_name, nrow=n) bs = float(states.shape[0]) loss_list = [ loss_values.item() / bs, loss_act.item() / bs, loss_rec.item() / bs, loss_diff.item() / bs, loss_2.item() / bs, loss_3.item() / bs ] return loss_list
def valid_vqvae(train_cnt, vqvae_model, info, valid_data_loader, do_plot=True): vqvae_model.eval() states, actions, rewards, values, pred_states, terminals, is_new_epoch, relative_indexes = valid_data_loader.get_framediff_minibatch( ) states = (2 * reshape_input(torch.FloatTensor(states)) - 1).to( info['DEVICE']) rec = (2 * reshape_input(torch.FloatTensor(pred_states)[:, 0][:, None]) - 1).to(info['DEVICE']) actions = torch.LongTensor(actions).to(info['DEVICE']) #rewards = torch.LongTensor(rewards).to(DEVICE) # dont normalize diff diff = (reshape_input(torch.FloatTensor(pred_states)[:, 1][:, None])).to( info['DEVICE']) x_d, z_e_x, z_q_x, latents, pred_actions = vqvae_model(states) z_q_x.retain_grad() rec_est = x_d[:, :info['nmix']] diff_est = x_d[:, info['nmix']:] loss_rec = info['ALPHA_REC'] * discretized_mix_logistic_loss( rec_est, rec, info['NR_LOGISTIC_MIX'], DEVICE=info['DEVICE']) loss_diff = discretized_mix_logistic_loss(diff_est, diff, nr_mix=info['NR_LOGISTIC_MIX'], DEVICE=info['DEVICE']) loss_act = info['ALPHA_ACT'] * F.nll_loss( pred_actions, actions, weight=info['actions_weight']) loss_act.backward(retain_graph=True) loss_2 = F.mse_loss(z_q_x, z_e_x.detach()) loss_3 = info['BETA'] * F.mse_loss(z_e_x, z_q_x.detach()) bs, yc, yh, yw = x_d.shape yhat = sample_from_discretized_mix_logistic(rec_est, info['NR_LOGISTIC_MIX']) if do_plot: n_imgs = 8 n = min(states.shape[0], n_imgs) gold = (rec.to('cpu') + 1) / 2.0 bs, _, h, w = gold.shape # sample from discretized should be between 0 and 255 print("yhat sample", yhat[:, 0].min().item(), yhat[:, 0].max().item()) yimg = ((yhat + 1.0) / 2.0).to('cpu') print("yhat img", yhat.min().item(), yhat.max().item()) print("gold img", gold.min().item(), gold.max().item()) comparison = torch.cat( [gold.view(bs, 1, h, w)[:n], yimg.view(bs, 1, h, w)[:n]]) img_name = info[ 'vq_model_base_filepath'] + "_%010d_valid_reconstruction.png" % train_cnt save_image(comparison, img_name, nrow=n) bs = float(states.shape[0]) loss_list = [ loss_act.item() / bs, loss_rec.item() / bs, loss_diff.item() / bs, loss_2.item() / bs, loss_3.item() / bs ] return loss_list
def train_vqvae(train_cnt): st = time.time() #for batch_idx, (data, label, data_index) in enumerate(train_loader): batches = 0 while train_cnt < args.num_examples_to_train: vqvae_model.train() opt.zero_grad() #states, actions, rewards, next_states, terminals, is_new_epoch, relative_indexes = train_data_loader.get_unique_minibatch() states, actions, rewards, values, pred_states, terminals, is_new_epoch, relative_indexes = train_data_loader.get_framediff_minibatch( ) # because we have 4 layers in vqvae, need to be divisible by 2, 4 times states = (2 * reshape_input(states) - 1).to(DEVICE) rec = (2 * reshape_input(pred_states[:, 0][:, None]) - 1).to(DEVICE) # dont normalize diff diff = (reshape_input(pred_states[:, 1][:, None])).to(DEVICE) x_d, z_e_x, z_q_x, latents = vqvae_model(states) # (args.nr_logistic_mix/2)*3 is needed for each reconstruction z_q_x.retain_grad() rec_est = x_d[:, :nmix] diff_est = x_d[:, nmix:] loss_rec = discretized_mix_logistic_loss(rec_est, rec, nr_mix=args.nr_logistic_mix, DEVICE=DEVICE) loss_diff = discretized_mix_logistic_loss(diff_est, diff, nr_mix=args.nr_logistic_mix, DEVICE=DEVICE) loss_2 = F.mse_loss(z_q_x, z_e_x.detach()) loss_3 = args.beta * F.mse_loss(z_e_x, z_q_x.detach()) loss_rec.backward(retain_graph=True) loss_diff.backward(retain_graph=True) vqvae_model.embedding.zero_grad() z_e_x.backward(z_q_x.grad, retain_graph=True) loss_2.backward(retain_graph=True) loss_3.backward() parameters = list(vqvae_model.parameters()) clip_grad_value_(parameters, 10) opt.step() bs = float(x_d.shape[0]) loss_list = [ loss_rec.item() / bs, loss_diff.item() / bs, loss_2.item() / bs, loss_3.item() / bs ] if batches > 5: handle_checkpointing(train_cnt, loss_list) train_cnt += len(states) batches += 1 if not batches % 1000: print("finished %s epoch after %s seconds at cnt %s" % (batches, time.time() - st, train_cnt)) return train_cnt
def sample_batch(data, episode_number, episode_reward, name): with torch.no_grad(): states, actions, rewards, next_states, terminals, reset, relative_indexes = data x = (2*reshape_input(states[:,-1:])-1).to(DEVICE) for i in range(states.shape[0]): x_d, z_e_x, z_q_x, latents = vqvae_model(x[i:i+1]) loss_1 = discretized_mix_logistic_loss(x_d, x[i:i+1], nr_mix=largs.nr_logistic_mix, DEVICE=DEVICE) yhat = sample_from_discretized_mix_logistic(x_d, largs.nr_logistic_mix) yhat = (((yhat+1)/2.0)*255.0).cpu().numpy().astype(np.int) true = (states[i:i+1,-1:]*255.0).cpu().numpy().astype(np.int) f,ax = plt.subplots(1,2) iname = os.path.join(output_savepath, '%s_E%05d_R%03d_%05d.png'%(name, int(episode_number), int(episode_reward), i)) print("writing", os.path.split(iname)[1]) title = 'step %s/%s action %s reward %s' %(i, states.shape[0], actions[i].item(), rewards[i].item()) ax[0].imshow(true[0,0]) ax[0].set_title('true') ax[1].imshow(yhat[0,0]) ax[1].set_title('est') plt.suptitle(title) plt.savefig(iname) print('saving', iname) search_path = iname[:-10:] + '*.png' gif_path = iname[:-10:] + '.gif' cmd = 'convert %s %s' %(search_path, gif_path) print('creating gif', gif_path) os.system(cmd)
def valid_vqvae(train_cnt, do_plot=False): vqvae_model.eval() opt.zero_grad() states, actions, rewards, next_states, terminals, is_new_epoch, relative_indexes = valid_data_loader.get_unique_minibatch( ) # because we have 4 layers in vqvae, need to be divisible by 2, 4 times states = reshape_input(states).to(DEVICE) # only predict future observation - normalize targets = (2 * states[:, -1:] - 1).to(DEVICE) #actions = actions.to(DEVICE) x_d, z_e_x, z_q_x, latents = vqvae_model(states, targets) loss_1 = discretized_mix_logistic_loss(x_d, targets, nr_mix=args.nr_logistic_mix, DEVICE=DEVICE) loss_2 = F.mse_loss(z_q_x, z_e_x.detach()) loss_3 = args.beta * F.mse_loss(z_e_x, z_q_x.detach()) #loss_1, loss_2, loss_3 = get_vqvae_loss(x_d, targets, z_e_x, z_q_x, nr_logistic_mix=args.nr_logistic_mix, beta=args.beta, device=DEVICE) bs, yc, yh, yw = x_d.shape yhat = sample_from_discretized_mix_logistic(x_d, args.nr_logistic_mix) if do_plot: print('writing img') n_imgs = 8 n = min(states.shape[0], n_imgs) gold = states[:, -1:] bs, _, h, w = gold.shape comparison = torch.cat([ gold.to('cpu').view(bs, 1, h, w)[:n], yhat.to('cpu').view(bs, 1, h, w)[:n] ]) img_name = model_base_filepath + "_%010d_valid_reconstruction.png" % train_cnt save_image(comparison, img_name, nrow=n) bs = float(states.shape[0]) return loss_1.item() / bs, loss_2.item() / bs, loss_3.item() / bs
def valid_vqvae(train_cnt, do_plot=False): vqvae_model.eval() states, actions, rewards, next_states, terminals, is_new_epoch, relative_indexes = valid_data_loader.get_unique_minibatch( ) # because we have 4 layers in vqvae, need to be divisible by 2, 4 times states = (2 * reshape_input(states[:, -1:]) - 1).to(DEVICE) x_d, z_e_x, z_q_x, latents = vqvae_model(states) z_q_x.retain_grad() loss_1 = discretized_mix_logistic_loss(x_d, states, nr_mix=args.nr_logistic_mix, DEVICE=DEVICE) loss_2 = F.mse_loss(z_q_x, z_e_x.detach()) loss_3 = args.beta * F.mse_loss(z_e_x, z_q_x.detach()) bs, yc, yh, yw = x_d.shape yhat = sample_from_discretized_mix_logistic(x_d, args.nr_logistic_mix) if do_plot: print('writing img') n_imgs = 8 n = min(states.shape[0], n_imgs) gold = (states.to('cpu') + 1) / 2.0 bs, _, h, w = gold.shape # sample from discretized should be between 0 and 255 print("yhat sample", yhat.min(), yhat.max()) yimg = ((yhat + 1.0) / 2.0).to('cpu') print("yhat img", yhat.min().item(), yhat.max().item()) print("gold img", gold.min().item(), gold.max().item()) comparison = torch.cat( [gold.view(bs, 1, h, w)[:n], yimg.view(bs, 1, h, w)[:n]]) img_name = model_base_filepath + "_%010d_valid_reconstruction.png" % train_cnt save_image(comparison, img_name, nrow=n) bs = float(states.shape[0]) return loss_1.item() / bs, loss_2.item() / bs, loss_3.item() / bs
def forward_pass(x, y): x = Variable(x, requires_grad=False).to(DEVICE) y = Variable(y, requires_grad=False).to(DEVICE) x_d, z_e_x, z_q_x, latents = vmodel(x) # with bigger model - latents is 64, 6, 6 z_q_x.retain_grad() #loss_1 = F.binary_cross_entropy(x_d, x) # going into dml - x should be bt 0 and 1 loss_1 = discretized_mix_logistic_loss(x_d, 2 * y - 1, DEVICE=DEVICE) loss_2 = F.mse_loss(z_q_x, z_e_x.detach()) loss_3 = .25 * F.mse_loss(z_e_x, z_q_x.detach()) return loss_1, loss_2, loss_3, x_d, z_e_x, z_q_x, latents
def find_rec_losses(alpha, nr, nmix, x_d, true, DEVICE): rec_losses = [] rec_ests = [] # get reconstruction losses for each channel for i in range(true.shape[1]): st = i * nmix en = st + nmix pred_x_d = x_d[:, st:en] rec_ests.append(pred_x_d.detach()) rloss = alpha * discretized_mix_logistic_loss( pred_x_d, true[:, i][:, None], nr_mix=nr, DEVICE=DEVICE) rec_losses.append(rloss) return rec_losses, rec_ests
def train_vqvae(train_cnt): st = time.time() #for batch_idx, (data, label, data_index) in enumerate(train_loader): batches = 0 while train_cnt < args.num_examples_to_train: vqenc.train() pcnn_decoder.train() opt.zero_grad() states, actions, rewards, next_states, terminals, is_new_epoch, relative_indexes = train_data_loader.get_unique_minibatch( ) # because we have 4 layers in vqvae, need to be divisible by 2, 4 times states = reshape_input(states).to(DEVICE) # only predict future observation - normalize targets = (2 * states[:, -1:] - 1).to(DEVICE) #actions = actions.to(DEVICE) x_d, z_e_x, z_q_x, latents = vqvae_model(states, targets) #z_e_x, z_q_x, latents = vqenc(states) #float_condition = latents.view(latents.shape[0], latents.shape[1]*latents.shape[2]).float() #x_d = pcnn_decoder(targets, class_condition=actions, float_condition=float_condition) z_q_x.retain_grad() vqvae_model.spatial_condition.retain_grad() loss_1 = discretized_mix_logistic_loss(x_d, targets, nr_mix=args.nr_logistic_mix, DEVICE=DEVICE) loss_2 = F.mse_loss(z_q_x, z_e_x.detach()) loss_3 = args.beta * F.mse_loss(z_e_x, z_q_x.detach()) #loss_1, loss_2, loss_3 = get_vqvae_loss(x_d, targets, z_e_x, z_q_x, nr_logistic_mix=args.nr_logistic_mix, beta=args.beta, device=DEVICE) loss_1.backward(retain_graph=True) #vqvae_model.encoder.embedding.zero_grad() #z_e_x.backward(z_q_x.grad, retain_graph=True) z_e_x.backward(vqvae_model.spatial_condition.grad, retain_graph=True) loss_2.backward(retain_graph=True) loss_3.backward() parameters = list(vqvae_model.parameters()) clip_grad_value_(parameters, 10) opt.step() bs = float(x_d.shape[0]) handle_checkpointing(train_cnt, loss_1.item() / bs, loss_2.item() / bs, loss_3.item() / bs) train_cnt += len(states) batches += 1 if not batches % 1000: print("finished %s epoch after %s seconds at cnt %s" % (batches, time.time() - st, train_cnt)) return train_cnt
def train_acn(train_cnt): #test_acn(0,True) vae_model.train() prior_model.train() train_loss = 0 init_cnt = train_cnt st = time.time() train_buffer.reset_unique() #for batch_idx, (data, _, data_index) in enumerate(train_loader): while train_buffer.unique_available: # batch = train_buffer.get_unique_minibatch(args.batch_size) batch_idx = batch[-1] states, actions, rewards, next_states = make_state( batch[:-1], DEVICE, 255.) data = next_states[:, -1:] opt.zero_grad() z, u_q, s_q = vae_model(data) # add the predicted codes to the input prior_model.codes[batch_idx] = u_q.detach().cpu().numpy() prior_model.fit_knn(prior_model.codes) u_p, s_p = prior_model(u_q) kl = kl_loss_function(u_q, s_q, u_p, s_p) # decoder changed output of pcnn to number of channels needed for dml yhat_batch = vae_model.decoder(pcnn_decoder(x=data, float_condition=z)) # input should be bt -1 and 1 rec_loss = discretized_mix_logistic_loss(yhat_batch, data, nr_mix=nr_logistic_mix, DEVICE=DEVICE) #yhat = sample_from_discretized_mix_logistic(yhat_batch, nr_logistic_mix) loss = kl + rec_loss loss.backward() train_loss += loss.item() opt.step() # add batch size because it hasn't been added to train cnt yet avg_train_loss = train_loss / float((train_cnt + data.shape[0]) - init_cnt) if train_cnt > 50000: handle_checkpointing(train_cnt, avg_train_loss) train_cnt += len(data) print("finished epoch after %s seconds at cnt %s" % (time.time() - st, train_cnt)) return train_cnt
def train_vqvae(train_cnt): st = time.time() #for batch_idx, (data, label, data_index) in enumerate(train_loader): batches = 0 while train_cnt < args.num_examples_to_train: vqvae_model.train() opt.zero_grad() states, actions, rewards, next_states, terminals, is_new_epoch, relative_indexes = train_data_loader.get_unique_minibatch( ) # because we have 4 layers in vqvae, need to be divisible by 2, 4 times states = (2 * reshape_input(states[:, -1:]) - 1).to(DEVICE) x_d, z_e_x, z_q_x, latents = vqvae_model(states) z_q_x.retain_grad() loss_1 = discretized_mix_logistic_loss(x_d, states, nr_mix=args.nr_logistic_mix, DEVICE=DEVICE) loss_2 = F.mse_loss(z_q_x, z_e_x.detach()) loss_3 = args.beta * F.mse_loss(z_e_x, z_q_x.detach()) loss_1.backward(retain_graph=True) vqvae_model.embedding.zero_grad() z_e_x.backward(z_q_x.grad, retain_graph=True) loss_2.backward(retain_graph=True) loss_3.backward() parameters = list(vqvae_model.parameters()) clip_grad_value_(parameters, 10) opt.step() bs = float(x_d.shape[0]) handle_checkpointing(train_cnt, loss_1.item() / bs, loss_2.item() / bs, loss_3.item() / bs) train_cnt += len(states) batches += 1 if not batches % 1000: print("finished %s epoch after %s seconds at cnt %s" % (batches, time.time() - st, train_cnt)) return train_cnt
def train_vqvae(train_cnt, vqvae_model, opt, info, train_data_loader, valid_data_loader): st = time.time() #for batch_idx, (data, label, data_index) in enumerate(train_loader): batches = 0 while train_cnt < info['VQ_NUM_EXAMPLES_TO_TRAIN']: vqvae_model.train() opt.zero_grad() states, actions, rewards, values, pred_states, terminals, is_new_epoch, relative_indexes = train_data_loader.get_framediff_minibatch( ) # because we have 4 layers in vqvae, need to be divisible by 2, 4 times states = (2 * reshape_input(torch.FloatTensor(states)) - 1).to( info['DEVICE']) rec = ( 2 * reshape_input(torch.FloatTensor(pred_states)[:, 0][:, None]) - 1).to(info['DEVICE']) actions = torch.LongTensor(actions).to(info['DEVICE']) rewards = torch.LongTensor(rewards).to(info['DEVICE']) # dont normalize diff diff = (reshape_input( torch.FloatTensor(pred_states)[:, 1][:, None])).to(info['DEVICE']) x_d, z_e_x, z_q_x, latents, pred_actions, pred_rewards = vqvae_model( states) z_q_x.retain_grad() rec_est = x_d[:, :info['nmix']] diff_est = x_d[:, info['nmix']:] loss_rec = info['ALPHA_REC'] * discretized_mix_logistic_loss( rec_est, rec, nr_mix=info['NR_LOGISTIC_MIX'], DEVICE=info['DEVICE']) loss_diff = discretized_mix_logistic_loss( diff_est, diff, nr_mix=info['NR_LOGISTIC_MIX'], DEVICE=info['DEVICE']) loss_act = info['ALPHA_ACT'] * F.nll_loss( pred_actions, actions, weight=info['actions_weight']) loss_rewards = info['ALPHA_REW'] * F.nll_loss( pred_rewards, rewards, weight=info['rewards_weight']) loss_2 = F.mse_loss(z_q_x, z_e_x.detach()) loss_act.backward(retain_graph=True) loss_rec.backward(retain_graph=True) loss_diff.backward(retain_graph=True) loss_3 = info['BETA'] * F.mse_loss(z_e_x, z_q_x.detach()) vqvae_model.embedding.zero_grad() z_e_x.backward(z_q_x.grad, retain_graph=True) loss_2.backward(retain_graph=True) loss_3.backward() parameters = list(vqvae_model.parameters()) clip_grad_value_(parameters, 10) opt.step() bs = float(x_d.shape[0]) avg_train_losses = [ loss_rewards.item() / bs, loss_act.item() / bs, loss_rec.item() / bs, loss_diff.item() / bs, loss_2.item() / bs, loss_3.item() / bs ] if batches > info['VQ_MIN_BATCHES_BEFORE_SAVE']: if ((train_cnt - info['vq_last_save']) >= info['VQ_SAVE_EVERY']): info['vq_last_save'] = train_cnt info['vq_save_times'].append(time.time()) avg_valid_losses = valid_vqvae(train_cnt, vqvae_model, info, valid_data_loader) handle_plot_ckpt(train_cnt, info, avg_train_losses, avg_valid_losses) filename = info[ 'vq_model_base_filepath'] + "_%010dex.pt" % train_cnt print("SAVING MODEL:%s" % filename) print("Saving model at cnt:%s cnt since last saved:%s" % (train_cnt, train_cnt - info['vq_last_save'])) state = { 'vqvae_state_dict': vqvae_model.state_dict(), 'vq_optimizer': opt.state_dict(), 'vq_embedding': vqvae_model.embedding, 'vq_info': info, } save_checkpoint(state, filename=filename) train_cnt += len(states) batches += 1 if not batches % 1000: print("finished %s epoch after %s seconds at cnt %s" % (batches, time.time() - st, train_cnt)) return train_cnt
def test_acn(train_cnt, do_plot): vae_model.eval() prior_model.eval() test_loss = 0 print('starting test', train_cnt) st = time.time() seen = 0 with torch.no_grad(): valid_buffer.reset_unique() for i in range(10): if valid_buffer.unique_available: batch = valid_buffer.get_unique_minibatch(args.batch_size) batch_idx = batch[-1] states, actions, rewards, next_states = make_state( batch[:-1], DEVICE, 255.) data = next_states[:, -1:] # yhat_batch is bt 0-1 z, u_q, s_q = vae_model(data) u_p, s_p = prior_model(u_q) kl = kl_loss_function(u_q, s_q, u_p, s_p) yhat_batch = vae_model.decoder( pcnn_decoder(x=data, float_condition=z)) rec_loss = discretized_mix_logistic_loss( yhat_batch, data, nr_mix=nr_logistic_mix, DEVICE=DEVICE) loss = kl + rec_loss test_loss += loss.item() seen += data.shape[0] if i == 0: if do_plot: print('writing img') n = min(data.size(0), 8) bs = data.shape[0] yhat = sample_from_discretized_mix_logistic( yhat_batch, nr_logistic_mix, only_mean=True) # sampled yhat_batch is bt 0-1 #yimg = yhat_batch yimg = ((yhat + 1.0) / 2.0) # yimg is bt 0.78 and 0.57 - print('data', data.max(), data.min()) ## gold is bt 0 and .57 gold = (data + 1) / 2.0 #gold = data print('bef', yhat_batch.max(), yhat_batch.min()) #print('sam', yhat.max(), yhat.min()) print('yimg', yimg.max(), yimg.min()) print('gold', gold.max(), gold.min()) bs, _, h, w = data.shape # data should be between 0 and 1 to be plotted with # save_image assert (yimg.min() >= 0) assert (yimg.max() <= 1) comparison = torch.cat([ gold.view(bs, 1, h, w)[:n], yimg.view(bs, 1, h, w)[:n] ]) img_name = vae_base_filepath + "_%010d_valid_reconstruction.png" % train_cnt save_image(comparison.cpu(), img_name, nrow=n) print('finished writing img', img_name) test_loss /= seen print('====> Test set loss: {:.4f}'.format(test_loss)) print('finished test', time.time() - st) return test_loss