def call_plot(model_dict, data_dict, info): from acn_utils import tsne_plot from acn_utils import pca_plot # always be in eval mode model_dict = set_model_mode(model_dict, 'valid') with torch.no_grad(): for phase in ['valid', 'train']: data_loader = data_dict[phase] data_loader.reset_unique() batch = data_loader.get_unique_minibatch(84) states, actions, rewards, next_states, _, _, batch_indexes, index_indexes = batch fp_out = forward_pass(model_dict, states, actions, rewards, next_states, index_indexes, phase, info) model_dict, states, actions, rewards, target, u_q, u_p, s_p, rec_dml, pcnn_dml, z_e_x, z_q_x, latents = fp_out bs = states.shape[0] u_q_flat = u_q.view(bs, info['code_length']) X = u_q_flat.cpu().numpy() color = index_indexes images = target[:,0].cpu().numpy() if args.tsne: param_name = '_tsne_%s_P%s.html'%(phase, info['perplexity']) html_path = info['model_loadpath'].replace('.pt', param_name) tsne_plot(X=X, images=images, color=color, perplexity=info['perplexity'], html_out_path=html_path, serve=False) if args.pca: param_name = '_pca_%s.html'%(phase) html_path = info['model_loadpath'].replace('.pt', param_name) pca_plot(X=X, images=images, color=color, html_out_path=html_path, serve=False) break
def call_plot(model_dict, data_dict, info, sample, tsne, pca): from acn_utils import tsne_plot from acn_utils import pca_plot # always be in eval mode - so we dont swap neighbors model_dict = set_model_mode(model_dict, 'valid') with torch.no_grad(): for phase in ['valid', 'train']: data_loader = data_dict[phase] data_loader.reset_unique() batch = data_loader.get_unique_minibatch(84) states, actions, rewards, next_states, _, _, batch_indexes, index_indexes = batch fp_out = forward_pass(model_dict, states, actions, rewards, next_states, index_indexes, phase, info) model_dict, states, actions, rewards, target, u_q, u_p, s_p, rec_dml, z_e_x, z_q_x, latents = fp_out rec_yhat = sample_from_discretized_mix_logistic( rec_dml, info['nr_logistic_mix'], only_mean=info['sample_mean'], sampling_temperature=info['sampling_temperature']) f, ax = plt.subplots(10, 3) ax[0, 0].set_title('prev') ax[0, 1].set_title('true') ax[0, 2].set_title('pred') for i in range(10): ax[i, 0].matshow(states[i, -1]) ax[i, 1].matshow(target[i, -1]) ax[i, 2].matshow(rec_yhat[i, -1]) ax[i, 0].axis('off') ax[i, 1].axis('off') ax[i, 2].axis('off') plt.subplots_adjust(wspace=0, hspace=0) plt.tight_layout() plt_path = info['model_loadpath'].replace('.pt', '_%s_plt.png' % phase) print('plotting', plt_path) plt.savefig(plt_path) bs = states.shape[0] u_q_flat = u_q.view(bs, info['code_length']) X = u_q_flat.cpu().numpy() color = index_indexes images = target[:, 0].cpu().numpy() if tsne: param_name = '_tsne_%s_P%s.html' % (phase, info['perplexity']) html_path = info['model_loadpath'].replace('.pt', param_name) tsne_plot(X=X, images=images, color=color, perplexity=info['perplexity'], html_out_path=html_path, serve=False) if pca: param_name = '_pca_%s.html' % (phase) html_path = info['model_loadpath'].replace('.pt', param_name) pca_plot(X=X, images=images, color=color, html_out_path=html_path, serve=False)
def load_uvdeconv_representation_model(representation_model_path): # model trained with this file: # ../models/train_atari_uvdeconv_tacn_midtwgradloss.py # will output a acn flat float representation and a vq discrete # representation - which to use? rep_info = {'device': device, 'args': args} rep_model_dict, _, rep_info, train_cnt, epoch_cnt, rescale, rescale_inv = create_models( rep_info, representation_model_path, load_data=False) rep_model_dict = set_model_mode(rep_model_dict, 'valid') return rep_model_dict, rep_info, prepare_uv_state_latents, rescale, rescale_inv
def sample(model_dict, data_dict, info): from skvideo.io import vwrite model_dict = set_model_mode(model_dict, 'valid') output_savepath = args.model_loadpath.replace('.pt', '') bs = 10 with torch.no_grad(): for phase in ['train', 'valid']: with torch.no_grad(): data_loader = data_dict[phase] data_loader.reset_unique() batch = data_loader.get_unique_minibatch(bs) states, actions, rewards, next_states, _, _, batch_indexes, index_indexes = batch st_can = '_zc' iname = output_savepath + st_can + '_st%s'%args.sampling_temperature + '_sample_%s.png'%phase fp_out = forward_pass(model_dict, states, actions, rewards, next_states, index_indexes, phase, info) prep_batch = make_atari_channel_action_reward_diff_state(states, actions, rewards, next_states, info['device'], info['num_actions'], info['num_rewards']) states, action_cond, reward_cond, target = prep_batch last = states[:,-1:] rlast = rescale_inv(last) model_dict, states, actions, rewards, target, u_q, u_p, s_p, rec_dml, pcnn_dml, z_e_x, z_q_x, latents = fp_out # teacher forced version pcnn_yhat = sample_from_discretized_mix_logistic(pcnn_dml, info['nr_logistic_mix'], only_mean=args.sample_mean, sampling_temperature=args.sampling_temperature) rec_yhat = sample_from_discretized_mix_logistic(rec_dml, info['nr_logistic_mix'], only_mean=args.sample_mean, sampling_temperature=args.sampling_temperature) # create blank canvas for autoregressive sampling nprlast = rlast.detach().cpu().numpy() np_target = target.detach().cpu().numpy() #np_rec_yhat = nprlast+rec_yhat.detach().cpu().numpy() #canvas = deconv_yhat_batch print('using zero output as sample canvas') canvas = torch.zeros_like(target) for i in range(canvas.shape[1]): for j in range(canvas.shape[2]): print('sampling row: %s'%j) for k in range(canvas.shape[3]): #output = model_dict['pcnn_decoder_model'](x=canvas, spatial_condition=rec_dml) output = model_dict['pcnn_decoder_model'](x=canvas, spatial_condition=rec_dml) output = sample_from_discretized_mix_logistic(output.detach(), info['nr_logistic_mix'], only_mean=args.sample_mean, sampling_temperature=args.sampling_temperature) canvas[:,i,j,k] = rescale(rlast[:,i,j,k]-rescale_inv(output[:,i,j,k])) if target[0,i,j,k] != -1: print(j,k,output[0,i,j,k], canvas[0,i,j,k], target[0,i,j,k]) f,ax = plt.subplots(bs, 3, sharex=True, sharey=True, figsize=(3,bs)) #np_output = output.detach().cpu().numpy() #np_output = nprlast+rescale_inv(output.detach().cpu().numpy()) np_pcnn_yhat = pcnn_yhat.detach().cpu().numpy() true_diff = target.detach().cpu().numpy() pred_diff = output.detach().cpu().numpy() #pred_diff = canvas.detach().cpu().numpy() np_pred = nprlast+rescale_inv(pred_diff) np_true = nprlast+rescale_inv(true_diff) np_tf = nprlast+rescale_inv(np_pcnn_yhat) ma_true = true_diff * (np.abs(true_diff) > 0) ma_pred = pred_diff * (np.abs(pred_diff) > 0) ma_tf = np_pcnn_yhat * (np.abs(np_pcnn_yhat) > 0) for idx in range(bs): ax[idx,0].matshow(np_true[idx,0], cmap=plt.cm.gray) ax[idx,0].imshow(ma_true[idx,0], alpha=0.5, cmap=plt.cm.Reds_r) ax[idx,1].matshow(np_tf[idx,0], cmap=plt.cm.gray) ax[idx,1].imshow(ma_tf[idx,0], alpha=0.5, cmap=plt.cm.Reds_r) ax[idx,2].matshow(np_pred[idx,0], cmap=plt.cm.gray) ax[idx,2].imshow(ma_pred[idx,0], alpha=0.5, cmap=plt.cm.Reds_r) ax[idx,0].axis('off') ax[idx,1].axis('off') ax[idx,2].axis('off') ax[0,0].set_title('target') ax[0,1].set_title('tf') ax[0,2].set_title('sam') print('plotting %s'%iname) plt.savefig(iname) plt.close()
def run(train_cnt, model_dict, data_dict, phase, info): st = time.time() loss_dict = {'running': 0, 'kl':0, 'pcnn_%s'%info['rec_loss_type']:0, 'vq':0, 'commit':0, 'loss':0, } print('starting', phase, 'cuda', torch.cuda.memory_allocated(device=None)) data_loader = data_dict[phase] data_loader.reset_unique() num_batches = len(data_loader.unique_indexes)//info['batch_size'] idx = 0 set_model_mode(model_dict, phase) torch.set_grad_enabled(phase=='train') while data_loader.unique_available: for key in model_dict.keys(): model_dict[key].zero_grad() batch = data_loader.get_unique_minibatch(info['batch_size']) states, actions, rewards, next_states, _, _, batch_indexes, index_indexes = batch fp_out = forward_pass(model_dict, states, actions, rewards, next_states, index_indexes, phase, info) model_dict, states, actions, rewards, target, u_q, u_p, s_p, rec_dml, pcnn_dml, z_e_x, z_q_x, latents = fp_out bs,c,h,w = states.shape if idx == 0: log_ones = torch.zeros(bs, info['code_length']).to(info['device']) if bs != log_ones.shape[0]: log_ones = torch.zeros(bs, info['code_length']).to(info['device']) kl = kl_loss_function(u_q.view(bs, info['code_length']), log_ones, u_p.view(bs, info['code_length']), s_p.view(bs, info['code_length']), reduction=info['reduction']) # no loss on deconv rec pcnn_loss = discretized_mix_logistic_loss(pcnn_dml, target, nr_mix=info['nr_logistic_mix'], reduction=info['reduction']) vq_loss = F.mse_loss(z_q_x, z_e_x.detach(), reduction=info['reduction']) commit_loss = F.mse_loss(z_e_x, z_q_x.detach(), reduction=info['reduction']) commit_loss *= info['vq_commitment_beta'] loss = kl+pcnn_loss+commit_loss+vq_loss loss_dict['running']+=bs loss_dict['loss']+=loss.detach().cpu().item() loss_dict['kl']+= kl.detach().cpu().item() loss_dict['vq']+= vq_loss.detach().cpu().item() loss_dict['commit']+= commit_loss.detach().cpu().item() loss_dict['pcnn_%s'%info['rec_loss_type']]+=pcnn_loss.detach().cpu().item() if phase == 'train': model_dict = clip_parameters(model_dict) loss.backward() model_dict['opt'].step() train_cnt+=bs if idx == num_batches-3: # store example near end for plotting pcnn_yhat = sample_from_discretized_mix_logistic(pcnn_dml, info['nr_logistic_mix'], only_mean=info['sample_mean'], sampling_temperature=info['sampling_temperature']) rec_yhat = sample_from_discretized_mix_logistic(rec_dml, info['nr_logistic_mix'], only_mean=info['sample_mean'], sampling_temperature=info['sampling_temperature']) example = {'prev_frame':rescale_inv(states[:,-1:].detach().cpu()), 'target':rescale_inv(target.detach().cpu()), 'deconv_yhat':rescale_inv(rec_yhat.detach().cpu()), 'pcnn_yhat':rescale_inv(pcnn_yhat.detach().cpu()), } if not idx % 10: print(train_cnt, idx, account_losses(loss_dict)) print(phase, 'cuda', torch.cuda.memory_allocated(device=None)) idx+=1 loss_avg = account_losses(loss_dict) print("finished %s after %s secs at cnt %s"%(phase, time.time()-st, train_cnt, )) print(loss_avg) print('end', phase, 'cuda', torch.cuda.memory_allocated(device=None)) del states; del target; del actions; del rewards torch.cuda.empty_cache() print('after delete end', phase, 'cuda', torch.cuda.memory_allocated(device=None)) #return model_dict, data_dict, loss_avg, example return loss_avg, example
def sample(model_dict, data_dict, info): model_dict = set_model_mode(model_dict, 'valid') output_savepath = args.model_loadpath.replace('.pt', '') bs = 10 with torch.no_grad(): for phase in ['valid', 'train']: with torch.no_grad(): data_loader = data_dict[phase] data_loader.reset_unique() batch = data_loader.get_unique_minibatch(bs) states, actions, rewards, next_states, _, _, batch_indexes, index_indexes = batch fp_out = forward_pass(model_dict, states, actions, rewards, next_states, index_indexes, phase, info) model_dict, states, actions, rewards, target, u_q, u_p, s_p, rec_dml, pcnn_dml, z_e_x, z_q_x, latents = fp_out # teacher forced version pcnn_yhat = sample_from_discretized_mix_logistic( pcnn_dml, info['nr_logistic_mix'], only_mean=args.sample_mean, sampling_temperature=args.sampling_temperature) rec_yhat = sample_from_discretized_mix_logistic( rec_dml, info['nr_logistic_mix'], only_mean=args.sample_mean, sampling_temperature=args.sampling_temperature) # create blank canvas for autoregressive sampling last = states[:, -1:] np_last = deepcopy(last.detach().cpu().numpy()) np_target = deepcopy(target.detach().cpu().numpy()) np_rec_yhat = rec_yhat.detach().cpu().numpy() np_pcnn_yhat = pcnn_yhat.detach().cpu().numpy() if args.teacher_force_prev: st_can = '_lf' canvas = last else: st_can = '_zc' canvas = torch.zeros_like(target) iname = output_savepath + st_can + '_st%s' % args.sampling_temperature + '_sample_%s.png' % phase print('using zero output as sample canvas') for i in range(canvas.shape[1]): for j in range(canvas.shape[2]): print('sampling row: %s' % j) for k in range(canvas.shape[3]): output = model_dict['pcnn_decoder_model']( x=canvas, spatial_condition=rec_dml) output_o = sample_from_discretized_mix_logistic( output.detach(), info['nr_logistic_mix'], only_mean=args.sample_mean, sampling_temperature=args.sampling_temperature) canvas[:, i, j, k] = output_o[:, i, j, k] output_o = sample_from_discretized_mix_logistic( output.detach(), info['nr_logistic_mix'], only_mean=args.sample_mean, sampling_temperature=args.sampling_temperature) f, ax = plt.subplots(bs, 5, sharex=True, sharey=True, figsize=(3, bs)) np_output = output_o.detach().cpu().numpy() for idx in range(bs): ax[idx, 0].matshow(np_last[idx, 0], cmap=plt.cm.gray) ax[idx, 1].matshow(np_target[idx, 0], cmap=plt.cm.gray) ax[idx, 2].matshow(np_rec_yhat[idx, 0], cmap=plt.cm.gray) ax[idx, 3].matshow(np_pcnn_yhat[idx, 0], cmap=plt.cm.gray) ax[idx, 4].matshow(np_output[idx, 0], cmap=plt.cm.gray) ax[idx, 0].set_title('last') ax[idx, 1].set_title('true') ax[idx, 2].set_title('conv') ax[idx, 3].set_title('tf') ax[idx, 4].set_title('sam') ax[idx, 0].axis('off') ax[idx, 1].axis('off') ax[idx, 2].axis('off') ax[idx, 3].axis('off') ax[idx, 4].axis('off') print('plotting %s' % iname) plt.savefig(iname) plt.close()