def pred(config, mode='cifar10'): if mode == 'cifar10': obs = (3, 32, 32) sample_batch_size = 25 model = PixelCNN(nr_resnet=config.nr_resnet, nr_filters=config.nr_filters, input_channels=obs[0], nr_logistic_mix=config.nr_logistic_mix).cuda() if config.load_params: load_part_of_model(model, config.load_params) print('model parameters loaded') sample_op = lambda x: sample_from_discretized_mix_logistic(x, config.nr_logistic_mix) rescaling_inv = lambda x: .5 * x + .5 def sample(model): model.train(False) data = torch.zeros(sample_batch_size, obs[0], obs[1], obs[2]) data = data.cuda() for i in range(obs[1]): for j in range(obs[2]): with torch.no_grad(): data_v = data out = model(data_v, sample=True) out_sample = sample_op(out) data[:, :, i, j] = out_sample.data[:, :, i, j] return data print('sampling...') sample_t = sample(model) sample_t = rescaling_inv(sample_t) save_image(sample_t, 'images/sample.png', nrow=5, padding=0)
def generate_results(model, data_loader, nr_logistic_mix, do_use_cuda): start_time = time.time() for batch_idx, (data, img_names) in enumerate(data_loader): if do_use_cuda: x = Variable(data, requires_grad=False).cuda() else: x = Variable(data, requires_grad=False) x_d, z_e_x, z_q_x, latents = model(x) # z_e_x is output of encoder # z_q_x is input into decoder # latents is code book x_tilde = sample_from_discretized_mix_logistic(x_d, nr_logistic_mix) nx_tilde = x_tilde.cpu().data.numpy() nx_tilde = (0.5 * nx_tilde + 0.5) * 255 nx_tilde = nx_tilde.astype(np.uint8) embed() #vae_input = z_e_x.contiguous().view(z_e_x.shape[0],-1) #vqvae_rec_images = x_tilde.cpu().data.numpy() nz_q_x = z_q_x.contiguous().view(z_q_x.shape[0], -1).cpu().data.numpy() nlatents = latents.cpu().data.numpy() for ind, img_name in enumerate(img_names): #gen_img_name = img_name.replace('.png', 'vqvae_gen.png') #imwrite(gen_img_name, nx_tilde[ind][0]) #latents_name = img_name.replace('.png', 'vqvae_latents.npy') z_q_x_name = img_name.replace('.png', 'vqvae_z_q_x.npy') np.save(z_q_x_name, nz_q_x[ind]) if not batch_idx % 10: print 'Generate batch_idx: {} Time: {}'.format( batch_idx, time.time() - start_time)
def generate_results(data_loader): start_time = time.time() for batch_idx, (data, img_names) in enumerate(data_loader): if use_cuda: x = Variable(data, requires_grad=False).cuda() else: x = Variable(data, requires_grad=False) dec = vae(x) decr = dec.contiguous().view(dec.shape[0], 32, 10, 10) udec = (decr * z_q_x_std) + z_q_x_mean # TODO - knearest neighbors # going to use vae.z_mean, and vae.z_std # look at slides from laurent and vincent - cifar summer school # to prune unused dimensions # now we have mu and sigma that we will run knn lookup within our # training set. then once we've mapped to the original frame that made # the mu/sigma - we know frame and mu and sigma x_d = qmodel.decoder(udec) x_tilde = sample_from_discretized_mix_logistic(x_d, nr_logistic_mix) nx_tilde = x_tilde.cpu().data.numpy() nx_tilde = (0.5 * nx_tilde + 0.5) * 255 nx_tilde = nx_tilde.astype(np.uint8) for ind, img_name in enumerate(img_names): print('int', ind) gen_img_name = img_name.replace('.npy', 'vv_gen.png') imwrite(gen_img_name, nx_tilde[ind][0]) # z_q_x_name = img_name.replace('.png', 'vqvae_z_q_x.npy') # np.save(z_q_x_name, nz_q_x[ind]) if not batch_idx % 10: print 'Generate batch_idx: {} Time: {}'.format( batch_idx, time.time() - start_time)
def generate_reconstruction(base_path, data_loader, nr_logistic_mix, do_use_cuda): if not os.path.exists(base_path): os.makedirs(base_path) for i, (mu, sigma) in enumerate(zip(mus, sigmas)): if 100 < i < 200: if do_use_cuda: tmu = Variable(torch.FloatTensor(mu), requires_grad=False).cuda() tsigma = Variable(torch.FloatTensor(sigma), requires_grad=False).cuda() else: tmu = Variable(torch.FloatTensor(mu), requires_grad=False) tsigma = Variable(torch.FloatTensor(sigma), requires_grad=False) bs = 20 base = Variable(torch.from_numpy(np.zeros((bs, 800))).float(), requires_grad=False) for s in range(bs): base_noise = Variable(torch.from_numpy( rdn.normal(0, 1, size=tsigma.size())).float(), requires_grad=False) base[s] = tmu + tsigma * base_noise base[:, worst_inds] = 0.0 z = base.contiguous().view(bs, 32, 5, 5) x_d = vae.decoder(z) x_tilde = sample_from_discretized_mix_logistic( x_d, nr_logistic_mix) nx_tilde = x_tilde.cpu().data.numpy() inx_tilde = ((0.5 * nx_tilde + 0.5) * 255).astype(np.uint8) mean_tilde = np.mean(inx_tilde, axis=0)[0].astype(np.uint8) max_tilde = np.max(inx_tilde, axis=0)[0].astype(np.uint8) mean_img_name = os.path.join(base_path, 'gmean_%05d.png' % (i)) a_img_name = os.path.join(base_path, 'gadapt_%05d.png' % (i)) max_img_name = os.path.join(base_path, 'gmax_%05d.png' % (i)) imwrite(mean_img_name, mean_tilde) imwrite(max_img_name, max_tilde) nonzero = np.count_nonzero(inx_tilde, axis=0)[0] adapt_tilde = max_tilde # must have 3 instances to go into adapt adapt_tilde[nonzero < 3] = 0 imwrite(a_img_name, adapt_tilde)
def test(epoch, test_loader, do_use_cuda, save_img_path=None): test_loss = [] for batch_idx, (data, _) in enumerate(test_loader): start_time = time.time() if do_use_cuda: x = Variable(data, requires_grad=False).cuda() else: x = Variable(data, requires_grad=False) x_d, z_e_x, z_q_x, latents = vmodel(x) loss_1 = discretized_mix_logistic_loss(x_d, 2 * x - 1, use_cuda=do_use_cuda) loss_2 = F.mse_loss(z_q_x, z_e_x.detach()) loss_3 = .25 * F.mse_loss(z_e_x, z_q_x.detach()) test_loss.append(to_scalar([loss_1, loss_2, loss_3])) test_loss_mean = np.asarray(test_loss).mean(0) if save_img_path is not None: x_tilde = sample_from_discretized_mix_logistic(x_d, nr_logistic_mix) idx = 0 x_cat = torch.cat([x[idx], x_tilde[idx]], 0) images = x_cat.cpu().data pred = (((np.array(x_tilde.cpu().data)[0, 0] + 1.0) / 2.0) * float(max_pixel - min_pixel)) + min_pixel # input x is between 0 and 1 real = (np.array(x.cpu().data)[0, 0] * float(max_pixel - min_pixel)) + min_pixel f, ax = plt.subplots(1, 3, figsize=(10, 3)) ax[0].imshow(real, vmin=0, vmax=max_pixel) ax[0].set_title("original") ax[1].imshow(pred, vmin=0, vmax=max_pixel) ax[1].set_title("pred epoch %s test loss %s" % (epoch, np.mean(test_loss_mean))) ax[2].imshow((pred - real)**2, cmap='gray') ax[2].set_title("error") f.tight_layout() plt.savefig(save_img_path) plt.close() print("saving example image") print("rsync -avhp [email protected]://%s" % os.path.abspath(save_img_path)) return test_loss_mean
def test(x, model, nr_logistic_mix, do_use_cuda=False, save_img_path=None): x_d, z_e_x, z_q_x, latents = model(x) x_tilde = sample_from_discretized_mix_logistic(x_d, nr_logistic_mix) loss_1 = discretized_mix_logistic_loss(x_d, 2 * x - 1, use_cuda=do_use_cuda) loss_2 = F.mse_loss(z_q_x, z_e_x.detach()) loss_3 = .25 * F.mse_loss(z_e_x, z_q_x.detach()) test_loss = to_scalar([loss_1, loss_2, loss_3]) if save_img_path is not None: idx = np.random.randint(0, len(test_data)) x_cat = torch.cat([x[idx], x_tilde[idx]], 0) images = x_cat.cpu().data oo = 0.5 * np.array(x_tilde.cpu().data)[0, 0] + 0.5 ii = np.array(x.cpu().data)[0, 0] imwrite(save_img_path, oo) imwrite(save_img_path.replace('.png', 'orig.png'), ii) return test_loss
def generate_results(data_loader, nr_logistic_mix, do_use_cuda): start_time = time.time() for batch_idx, (data, img_names) in enumerate(data_loader): if do_use_cuda: x = Variable(data, requires_grad=False).cuda() else: x = Variable(data, requires_grad=False) x_d = vae(x.contiguous().view(x.shape[0], -1)) #x_tilde = x_d.contiguous().view(x_d.shape[0], 1, dsize, dsize) x_di = x_d.contiguous().view(x_d.shape[0], probs_size, dsize, dsize) x_tilde = sample_from_discretized_mix_logistic(x_di, nr_logistic_mix) nx_tilde = x_tilde.cpu().data.numpy() inx_tilde = ((nx_tilde) * 255).astype(np.uint8) for ind, img_name in enumerate(img_names): gen_img_name = img_name.replace('.png', 'pixelvae_gen.png') imwrite(gen_img_name, inx_tilde[ind][0]) if not batch_idx % 10: print 'Generate batch_idx: {} Time: {}'.format( batch_idx, time.time() - start_time)
def generate_results(base_path, data_loader, nr_logistic_mix, do_use_cuda): if not os.path.exists(base_path): os.makedirs(base_path) start_time = time.time() data_mu = np.empty((0, 800)) data_sigma = np.empty((0, 800)) limit = 4000 for batch_idx, (data, img_names) in enumerate(data_loader): if batch_idx * 32 > limit: continue else: if do_use_cuda: x = Variable(data, requires_grad=False).cuda() else: x = Variable(data, requires_grad=False) x_d = vae(x) x_tilde = sample_from_discretized_mix_logistic( x_d, nr_logistic_mix) nx_tilde = x_tilde.cpu().data.numpy() inx_tilde = ((0.5 * nx_tilde + 0.5) * 255).astype(np.uint8) # vae.z_mean is batch_sizex800 # vae.z_sigma is batch_sizex800 zmean = vae.z_mean.cpu().data.numpy() zsigma = vae.z_sigma.cpu().data.numpy() data_mu = np.vstack((data_mu, zmean)) data_sigma = np.vstack((data_sigma, zsigma)) for ind, img_path in enumerate(img_names): img_name = os.path.split(img_path)[1] gen_img_name = img_name.replace('.png', 'conv_vae_gen.png') #gen_latent_name = img_name.replace('.png', 'conv_vae_latents.npz') imwrite(os.path.join(base_path, gen_img_name), inx_tilde[ind][0]) #np.savez(gen_latent_name, zmean=zmean[ind], zsigma=zsigma[ind]) if not batch_idx % 10: print 'Generate batch_idx: {} Time: {}'.format( batch_idx, time.time() - start_time) np.savez(os.path.join(base_path, 'mu_conv_vae.npz'), data_mu) np.savez(os.path.join(base_path, 'sigma_conv_vae.npz'), data_sigma)
def generate(frame_num, gen_latents, orig_img_path, save_img_path): z_q_x = vmodel.embedding(gen_latents.view(gen_latents.size(0), -1)) z_q_x = z_q_x.view(gen_latents.shape[0], 6, 6, -1).permute(0, 3, 1, 2) x_d = vmodel.decoder(z_q_x) if save_img_path is not None: x_tilde = sample_from_discretized_mix_logistic(x_d, nr_logistic_mix) pred = (((np.array(x_tilde.cpu().data)[0, 0] + 1.0) / 2.0) * float(max_pixel - min_pixel)) + min_pixel # input x is between 0 and 1 real = imread(orig_img_path) f, ax = plt.subplots(1, 3, figsize=(10, 3)) ax[0].imshow(real, vmin=0, vmax=max_pixel) ax[0].set_title("original frame %s" % frame_num) ax[1].imshow(pred, vmin=0, vmax=max_pixel) ax[1].set_title("pred") ax[2].imshow((pred - real)**2, cmap='gray') ax[2].set_title("error") f.tight_layout() plt.savefig(save_img_path) plt.close() print("saving example image") print("rsync -avhp [email protected]://%s" % os.path.abspath(save_img_path))
def train(config, mode='cifar10'): model_name = 'pcnn_lr:{:.5f}_nr-resnet{}_nr-filters{}'.format(config.lr, config.nr_resnet, config.nr_filters) try: os.makedirs('models') os.makedirs('images') # print('mkdir:', config.outfile) except OSError: pass seed = np.random.randint(0, 10000) print("Random Seed: ", seed) torch.manual_seed(seed) np.random.seed(seed) torch.cuda.manual_seed_all(seed) cudnn.benchmark = True trainset, train_loader, testset, test_loader, classes = load_data(mode=mode, batch_size=config.batch_size) if mode == 'cifar10' or mode == 'faces': obs = (3, 32, 32) loss_op = lambda real, fake: discretized_mix_logistic_loss(real, fake, config.nr_logistic_mix) sample_op = lambda x: sample_from_discretized_mix_logistic(x, config.nr_logistic_mix) elif mode == 'mnist': obs = (1, 28, 28) loss_op = lambda real, fake: discretized_mix_logistic_loss_1d(real, fake, config.nr_logistic_mix) sample_op = lambda x: sample_from_discretized_mix_logistic_1d(x, config.nr_logistic_mix) sample_batch_size = 25 rescaling_inv = lambda x: .5 * x + .5 model = PixelCNN(nr_resnet=config.nr_resnet, nr_filters=config.nr_filters, input_channels=obs[0], nr_logistic_mix=config.nr_logistic_mix).cuda() optimizer = torch.optim.Adam(model.parameters(), lr=config.lr) scheduler = lr_scheduler.StepLR(optimizer, step_size=1, gamma=config.lr_decay) if config.load_params: load_part_of_model(model, config.load_params) print('model parameters loaded') def sample(model): model.train(False) data = torch.zeros(sample_batch_size, obs[0], obs[1], obs[2]) data = data.cuda() with tqdm(total=obs[1] * obs[2]) as pbar: for i in range(obs[1]): for j in range(obs[2]): with torch.no_grad(): data_v = data out = model(data_v, sample=True) out_sample = sample_op(out) data[:, :, i, j] = out_sample.data[:, :, i, j] pbar.update(1) return data print('starting training') for epoch in range(config.max_epochs): model.train() torch.cuda.synchronize() train_loss = 0. time_ = time.time() with tqdm(total=len(train_loader)) as pbar: for batch_idx, (data, label) in enumerate(train_loader): data = data.requires_grad_(True).cuda() output = model(data) loss = loss_op(data, output) optimizer.zero_grad() loss.backward() optimizer.step() train_loss += loss.item() pbar.update(1) deno = batch_idx * config.batch_size * np.prod(obs) print('train loss : %s' % (train_loss / deno), end='\t') # decrease learning rate scheduler.step() model.eval() test_loss = 0. with tqdm(total=len(test_loader)) as pbar: for batch_idx, (data, _) in enumerate(test_loader): data = data.requires_grad_(False).cuda() output = model(data) loss = loss_op(data, output) test_loss += loss.item() del loss, output pbar.update(1) deno = batch_idx * config.batch_size * np.prod(obs) print('test loss : {:.4f}, time : {:.4f}'.format((test_loss / deno), (time.time() - time_))) torch.cuda.synchronize() if (epoch + 1) % config.save_interval == 0: torch.save(model.state_dict(), 'models/{}_{}.pth'.format(model_name, epoch)) print('sampling...') sample_t = sample(model) sample_t = rescaling_inv(sample_t) save_image(sample_t, 'images/{}_{}.png'.format(model_name, epoch), nrow=5, padding=0)
def call_plot(model_dict, data_dict, info, sample, tsne, pca): from utils import tsne_plot from utils import pca_plot from sklearn.cluster import KMeans # always be in eval mode - so we dont swap neighbors model_dict = set_model_mode(model_dict, 'valid') srandom_state = np.random.RandomState(1234) with torch.no_grad(): for phase in ['train', 'valid']: batch_index = srandom_state.randint(0, len(data_dict[phase].dataset), info['batch_size']) print(batch_index) data = torch.stack([ data_dict[phase].dataset.indexed_dataset[index][0] for index in batch_index ]) label = torch.stack([ data_dict[phase].dataset.indexed_dataset[index][1] for index in batch_index ]) batch_index = torch.LongTensor(batch_index) data = torch.FloatTensor(data) fp_out = forward_pass(model_dict, data, label, batch_index, 'valid', info) if info['vq_decoder']: model_dict, data, target, rec_dml, u_q, u_p, s_p, z_e_x, z_q_x, latents = fp_out else: model_dict, data, target, rec_dml, u_q, u_p, s_p = fp_out bs, c, h, w = data.shape rec_yhat = sample_from_discretized_mix_logistic( rec_dml, info['nr_logistic_mix'], only_mean=info['sample_mean'], sampling_temperature=info['sampling_temperature']) data = data.detach().cpu().numpy() rec = rec_yhat.detach().cpu().numpy() u_q_flat = u_q.view(bs, info['code_length']) # choose limited number to plot n = min([20, bs]) n_neighbors = args.num_k if sample: all_neighbor_distances, all_neighbor_indexes = model_dict[ 'prior_model'].kneighbors(u_q_flat, n_neighbors=n_neighbors) all_neighbor_indexes = all_neighbor_indexes.cpu().numpy() all_neighbor_distances = all_neighbor_distances.cpu().numpy() n_cols = 2 + n_neighbors tbatch_index = batch_index.cpu().numpy() np_label = label.cpu().numpy() for i in np.arange(0, n): # plot each base image plt_path = info['model_loadpath'].replace( '.pt', '_batch_rec_neighbors_%s_%06d_plt.png' % (phase, tbatch_index[i])) # bi 5136 neighbor_indexes = all_neighbor_indexes[i] code = u_q[i].view( (1, model_dict['acn_model'].bottleneck_channels, model_dict['acn_model'].eo, model_dict['acn_model'].eo)).cpu().numpy() f, ax = plt.subplots(4, n_cols) ax[0, 0].set_title('L%sI%s' % (np_label[i], tbatch_index[i])) ax[0, 0].set_ylabel('true') ax[0, 0].matshow(data[i, 0]) ax[1, 0].set_ylabel('rec') ax[1, 0].matshow(rec[i, 0]) ax[2, 0].matshow(code[0, 0]) ax[3, 0].matshow(code[0, 1]) neighbor_data = torch.stack([ data_dict['train'].dataset.indexed_dataset[index][0] for index in neighbor_indexes ]) neighbor_label = torch.stack([ data_dict['train'].dataset.indexed_dataset[index][1] for index in neighbor_indexes ]) # u_q_flat neighbor_codes_flat = model_dict['prior_model'].codes[ neighbor_indexes] neighbor_codes = neighbor_codes_flat.view( n_neighbors, model_dict['acn_model'].bottleneck_channels, model_dict['acn_model'].eo, model_dict['acn_model'].eo) if info['vq_decoder']: neighbor_rec_dml, _, _, _ = model_dict[ 'acn_model'].decode( neighbor_codes.to(info['device'])) else: neighbor_rec_dml = model_dict['acn_model'].decode( neighbor_codes.to(info['device'])) neighbor_data = neighbor_data.cpu().numpy() neighbor_rec_yhat = sample_from_discretized_mix_logistic( neighbor_rec_dml, info['nr_logistic_mix'], only_mean=info['sample_mean'], sampling_temperature=info['sampling_temperature']).cpu( ).numpy() for ni in range(n_neighbors): nindex = all_neighbor_indexes[i, ni].item() nlabel = neighbor_label[ni].cpu().numpy() ncode = neighbor_codes[ni].cpu().numpy() ax[0, ni + 2].set_title('L%sI%s' % (nlabel, nindex)) ax[0, ni + 2].matshow(neighbor_data[ni, 0]) ax[1, ni + 2].matshow(neighbor_rec_yhat[ni, 0]) ax[2, ni + 2].matshow(ncode[0]) ax[3, ni + 2].matshow(ncode[1]) ax[2, 0].set_ylabel('lc0') ax[3, 0].set_ylabel('lc1') [ax[xx, 0].set_xticks([]) for xx in range(4)] [ax[xx, 0].set_yticks([]) for xx in range(4)] for xx in range(4): [ax[xx, col].axis('off') for col in range(1, n_cols)] plt.subplots_adjust(wspace=0, hspace=0) plt.tight_layout() print('plotting', plt_path) plt.savefig(plt_path) plt.close() X = u_q_flat.cpu().numpy() #km = KMeans(n_clusters=10) #y = km.fit_predict(X) # color points based on clustering, label, or index color = label.cpu().numpy() #y #batch_indexes if tsne: param_name = '_tsne_%s_P%s.html' % (phase, info['perplexity']) html_path = info['model_loadpath'].replace('.pt', param_name) if not os.path.exists(html_path): tsne_plot(X=X, images=data[:, 0], color=color, perplexity=info['perplexity'], html_out_path=html_path, serve=False) if pca: param_name = '_pca_%s.html' % (phase) html_path = info['model_loadpath'].replace('.pt', param_name) if not os.path.exists(html_path): pca_plot(X=X, images=data[:, 0], color=color, html_out_path=html_path, serve=False)
def run(train_cnt, model_dict, data_dict, phase, info): st = time.time() loss_dict = { 'running': 0, 'kl': 0, 'rec_%s' % info['rec_loss_type']: 0, 'loss': 0, } if info['vq_decoder']: loss_dict['vq'] = 0 loss_dict['commit'] = 0 dataset = data_dict[phase] num_batches = len(dataset) // info['batch_size'] print(phase, 'num batches', num_batches) set_model_mode(model_dict, phase) torch.set_grad_enabled(phase == 'train') batch_cnt = 0 data_loader = data_dict[phase] num_batches = len(data_loader) for idx, (data, label, batch_index) in enumerate(data_loader): for key in model_dict.keys(): model_dict[key].zero_grad() fp_out = forward_pass(model_dict, data, label, batch_index, phase, info) if info['vq_decoder']: model_dict, data, target, rec_dml, u_q, u_p, s_p, z_e_x, z_q_x, latents = fp_out else: model_dict, data, target, rec_dml, u_q, u_p, s_p = fp_out bs, c, h, w = data.shape if batch_cnt == 0: log_ones = torch.zeros(bs, info['code_length']).to(info['device']) if bs != log_ones.shape[0]: log_ones = torch.zeros(bs, info['code_length']).to(info['device']) kl = kl_loss_function(u_q.view(bs, info['code_length']), log_ones, u_p.view(bs, info['code_length']), s_p.view(bs, info['code_length']), reduction=info['reduction']) rec_loss = discretized_mix_logistic_loss( rec_dml, target, nr_mix=info['nr_logistic_mix'], reduction=info['reduction']) if info['vq_decoder']: vq_loss = F.mse_loss(z_q_x, z_e_x.detach(), reduction=info['reduction']) commit_loss = F.mse_loss(z_e_x, z_q_x.detach(), reduction=info['reduction']) commit_loss *= info['vq_commitment_beta'] loss_dict['vq'] += vq_loss.detach().cpu().item() loss_dict['commit'] += commit_loss.detach().cpu().item() loss = kl + rec_loss + commit_loss + vq_loss else: loss = kl + rec_loss loss_dict['running'] += bs loss_dict['rec_%s' % info['rec_loss_type']] += rec_loss.detach().cpu().item() loss_dict['loss'] += loss.detach().cpu().item() loss_dict['kl'] += kl.detach().cpu().item() loss_dict['running'] += bs loss_dict['loss'] += loss.detach().cpu().item() loss_dict['kl'] += kl.detach().cpu().item() loss_dict['rec_%s' % info['rec_loss_type']] += rec_loss.detach().cpu().item() if phase == 'train': model_dict = clip_parameters(model_dict) loss.backward() model_dict['opt'].step() train_cnt += bs if batch_cnt == num_batches - 1: # store example near end for plotting rec_yhat = sample_from_discretized_mix_logistic( rec_dml, info['nr_logistic_mix'], only_mean=info['sample_mean'], sampling_temperature=info['sampling_temperature']) example = { 'target': data.detach().cpu().numpy(), 'rec': rec_yhat.detach().cpu().numpy(), } if not batch_cnt % 100: print(train_cnt, batch_cnt, account_losses(loss_dict)) print(phase, 'cuda', torch.cuda.memory_allocated(device=None)) batch_cnt += 1 loss_avg = account_losses(loss_dict) torch.cuda.empty_cache() print("finished %s after %s secs at cnt %s" % ( phase, time.time() - st, train_cnt, )) del data del target return loss_avg, example
return x_out if __name__ == '__main__': # img = torch.zeros(8, 3, 32, 32).float().uniform_(-1, 1).cuda() # # img = torch.zeros(8, 3, 32, 32).float().cuda() # model = PixelCNN(nr_resnet=3, nr_filters=100, input_channels=img.size(1)).cuda() # out = model(img) # # loss = discretized_mix_logistic_loss(img, out) # print('loss : %s' % loss.item()) img = torch.zeros(1, 3, 32, 32).float().cuda() model = PixelCNN(nr_resnet=5, nr_filters=160, input_channels=img.size(1), nr_logistic_mix=10).cuda() load_part_of_model(model, 'models/pcnn_lr_0.00020_nr-resnet5_nr-filters160_58.pth') sample_op = lambda x: sample_from_discretized_mix_logistic(x, 10) from tqdm import tqdm with tqdm(total=32*32) as pbar: for i in range(32): for j in range(32): with torch.no_grad(): data_v = img out = model(data_v, sample=True) out_sample = sample_op(out) img[:, :, i, j] = out_sample.data[:, :, i, j] pbar.update(1) from torchvision.utils import save_image save_image(img.data.cpu(), '1.jpg', nrow=1)
def generate_imgs(dataloader, output_filepath, true_img_path, data_type, transform): if not os.path.exists(output_filepath): os.makedirs(output_filepath) for batch_idx, (data_mu_diff_scaled, data_mu_diff, data_mu_orig, data_sigma_diff_scaled, data_sigma_diff, data_sigma_orig, name) in enumerate(dataloader): # data_mu_orig will be one longer than the diff versions batch_size = data_mu_diff_scaled.shape[0] # predict one less time step than availble (first is input) n_timesteps = data_mu_diff_scaled.shape[1] vae_input_size = 800 ####################### # get rnn details ####################### rnn_data = data_mu_diff_scaled.permute(1, 0, 2) seq = Variable(torch.FloatTensor(rnn_data), requires_grad=False) h1_tm1 = Variable(torch.FloatTensor(np.zeros( (batch_size, hidden_size))), requires_grad=False) c1_tm1 = Variable(torch.FloatTensor(np.zeros( (batch_size, hidden_size))), requires_grad=False) h2_tm1 = Variable(torch.FloatTensor(np.zeros( (batch_size, hidden_size))), requires_grad=False) c2_tm1 = Variable(torch.FloatTensor(np.zeros( (batch_size, hidden_size))), requires_grad=False) if use_cuda: mus_vae = mus_vae.cuda() seq = seq.cuda() out_mu = out_mu.cuda() h1_tm1 = h1_tm1.cuda() c1_tm1 = c1_tm1.cuda() h2_tm1 = h2_tm1.cuda() c2_tm1 = c2_tm1.cuda() # get time offsets correct x = seq[:-1] # put initial step in rnn_outputs = [seq[0]] gt_outputs = [seq[0]] nrnn_outputs = [seq[0].cpu().data.numpy()] ngt_outputs = [seq[0].cpu().data.numpy()] for i in range(len(x)): # number of frames to start with #if i < 4: output, h1_tm1, c1_tm1, h2_tm1, c2_tm1 = rnn( x[i], h1_tm1, c1_tm1, h2_tm1, c2_tm1) #else: # output, h1_tm1, c1_tm1, h2_tm1, c2_tm1 = rnn(output, h1_tm1, c1_tm1, h2_tm1, c2_tm1) nrnn_outputs += [output.cpu().data.numpy()] rnn_outputs += [output] # put ground truth in to check pipeline ngt_outputs += [seq[i + 1].cpu().data.numpy()] gt_outputs += [seq[i + 1]] print(output.sum().data[0], seq[i + 1].sum().data[0]) # vae data shoud be batch,timestep(example),features # 0th frame is the same here gt_rnn_pred = torch.stack(gt_outputs, 0) rnn_pred = torch.stack(rnn_outputs, 0) # 0th frame is the same here rnn_mu_diff_scaled = rnn_pred.permute(1, 0, 2).data.numpy() gt_rnn_mu_diff_scaled = gt_rnn_pred.permute(1, 0, 2).data.numpy() nrnn_mu_diff_scaled = np.swapaxes(np.array(nrnn_outputs), 0, 1) ngt_rnn_mu_diff_scaled = np.swapaxes(np.array(ngt_outputs), 0, 1) # only use relevant mus orig_mu_placeholder = Variable(torch.FloatTensor( np.zeros((n_timesteps, vae_input_size))), requires_grad=False) diff_mu_placeholder = Variable(torch.FloatTensor( np.zeros((n_timesteps, vae_input_size))), requires_grad=False) diff_mu_unscaled_placeholder = Variable(torch.FloatTensor( np.zeros((n_timesteps, vae_input_size))), requires_grad=False) diff_mu_unscaled_rnn_placeholder = Variable(torch.FloatTensor( np.zeros((n_timesteps, vae_input_size))), requires_grad=False) gt_diff_mu_unscaled_rnn_placeholder = Variable(torch.FloatTensor( np.zeros((n_timesteps, vae_input_size))), requires_grad=False) if transform == "std": print("removing standard deviation transform") # convert to numpy so broadcasting works rnn_mu_diff_unscaled = torch.FloatTensor((rnn_mu_diff_scaled * mu_diff_std) + mu_diff_mean[None]) gt_rnn_mu_diff_unscaled = torch.FloatTensor( (gt_rnn_mu_diff_scaled * mu_diff_std) + mu_diff_mean[None]) data_mu_diff_unscaled = torch.FloatTensor( (data_mu_diff_scaled.numpy() * mu_diff_std) + mu_diff_mean[None]) else: print("no transform") rnn_mu_diff_unscaled = rnn_mu_diff_scaled gt_rnn_mu_diff_unscaled = gt_rnn_mu_diff_scaled data_mu_diff_unscaled = data_mu_diff_scaled # go through each distinct episode (should be length of 167) for e in range(batch_size): basename = os.path.split(name[e])[1].replace('.npz', '') if not e: print("starting %s" % basename) basepath = os.path.join(output_filepath, basename) # reconstruct rnn vae # now the size going through the decoder is 169x32x5x5 # original data is one longer since there was no diff applied ep_mu_orig = data_mu_orig[e, 1:] ep_mu_diff = data_mu_diff[e] ep_mu_diff_unscaled = data_mu_diff_unscaled[e] ep_mu_diff_unscaled_rnn = rnn_mu_diff_unscaled[e] gt_ep_mu_diff_unscaled_rnn = gt_rnn_mu_diff_unscaled[e] primer_frame = data_mu_orig[e, 0, :] # need to reconstruct from original # get the first frame from the original dataset to add diffs to # data_mu_orig will be one frame longer # unscale the scaled version ep_mu_diff[0] += primer_frame ep_mu_diff_unscaled[0] += primer_frame ep_mu_diff_unscaled_rnn[0] += primer_frame gt_ep_mu_diff_unscaled_rnn[0] += primer_frame print("before diff add") for diff_frame in range(1, n_timesteps): #print("adding diff to %s" %diff_frame) ep_mu_diff[diff_frame] += ep_mu_diff[diff_frame - 1] ep_mu_diff_unscaled[diff_frame] += ep_mu_diff_unscaled[ diff_frame - 1] ep_mu_diff_unscaled_rnn[diff_frame] += ep_mu_diff_unscaled_rnn[ diff_frame - 1] gt_ep_mu_diff_unscaled_rnn[ diff_frame] += gt_ep_mu_diff_unscaled_rnn[diff_frame - 1] rnn_mu_img = ep_mu_diff_unscaled_rnn.numpy() gt_rnn_mu_img = gt_ep_mu_diff_unscaled_rnn.numpy() ff, axf = plt.subplots(1, 2, figsize=(5, 10)) axf[0].imshow(gt_rnn_mu_img, origin='lower') axf[0].set_title("gt_rnn_mu") axf[1].imshow(rnn_mu_img, origin='lower') axf[1].set_title("rnn_mu") ff.tight_layout() fimg_name = basepath + '_rnn_mu_plot.png' fimg_name = fimg_name.replace('_frame_%05d' % 0, '') print("plotted %s" % fimg_name) plt.savefig(fimg_name) plt.close() orig_mu_placeholder[:, best_inds] = Variable( torch.FloatTensor(ep_mu_orig)) diff_mu_placeholder[:, best_inds] = Variable( torch.FloatTensor(ep_mu_diff)) diff_mu_unscaled_placeholder[:, best_inds] = Variable( torch.FloatTensor(ep_mu_diff_unscaled)) diff_mu_unscaled_rnn_placeholder[:, best_inds] = Variable( torch.FloatTensor(ep_mu_diff_unscaled_rnn)) gt_diff_mu_unscaled_rnn_placeholder[:, best_inds] = Variable( torch.FloatTensor(gt_ep_mu_diff_unscaled_rnn)) #for i in range(1,diff_mu_unscaled_rnn_placeholder.shape[0]): # diff_mu_unscaled_rnn_placeholder[i] = gt_diff_mu_unscaled_rnn_placeholder[i] # add a placeholder here if you want to process it mu_types = OrderedDict([ ('orig', orig_mu_placeholder), # ('diff',diff_mu_placeholder), # ('diff_unscaled',diff_mu_unscaled_placeholder), ('gtrnn', gt_diff_mu_unscaled_rnn_placeholder), ('rnn', diff_mu_unscaled_rnn_placeholder), ]) mu_reconstructed = OrderedDict() # get reconstructed image for each type for xx, mu_output_name in enumerate(mu_types.keys()): mu_output = mu_types[mu_output_name] cuts = get_cuts(mu_output.shape[0], 1) print(mu_output_name, mu_output.sum().data[0], mu_output[0].sum().data[0]) x_tildes = [] for (s, e) in cuts: mu_batch = mu_output[s:e] # only put part of the episdoe through x_d = vae.decoder(mu_batch.contiguous().view( mu_batch.shape[0], 32, 5, 5)) x_tilde = sample_from_discretized_mix_logistic( x_d, nr_logistic_mix, deterministic=True) x_tildes.append(x_tilde.cpu().data.numpy()) nx_tilde = np.array(x_tildes)[:, 0, 0] inx_tilde = ((0.5 * nx_tilde + 0.5) * 255).astype(np.uint8) mu_reconstructed[mu_output_name] = inx_tilde for frame_num in range(n_timesteps): true_img_name = os.path.join( true_img_path, basename.replace('_conv_vae', '.png')).replace('frame_%05d' % 0, 'frame_%05d' % frame_num) true_img = imread(true_img_name) print("true img %s" % true_img_name) num_imgs = len(mu_reconstructed.keys()) + 1 f, ax = plt.subplots(1, num_imgs, figsize=(3 * num_imgs, 3)) ax[0].imshow(true_img, origin='lower') ax[0].set_title('true frame %04d' % frame_num) for ii, mu_output_name in enumerate(mu_reconstructed.keys()): ax[ii + 1].imshow( mu_reconstructed[mu_output_name][frame_num], origin='lower') ax[ii + 1].set_title(mu_output_name) f.tight_layout() img_name = basepath + '_rnn_plot.png' img_name = img_name.replace('frame_%05d' % 0, 'frame_%05d' % frame_num) print("plotted %s" % img_name) plt.savefig(img_name) plt.close()