def sample_autoregressive_batch_last_state(data, episode_number, episode_reward, name): with torch.no_grad(): states, actions, rewards, next_states, terminals, reset, relative_indexes = data states = reshape_input(states).to(DEVICE) targets = (2 * reshape_input(next_states[:, -1:]) - 1).to(DEVICE) actions = actions.to(DEVICE) # if args.teacher_force: name += '_tf' bs = states.shape[0] #vqvae_model.scl print('generating %s images' % (bs)) np_targets = deepcopy(targets.cpu().numpy()) #output = np.zeros((targets.shape[2], targets.shape[3])) total_reward = 0 for bi in range(bs): # sample one at a time due to memory constraints iname = os.path.join( output_savepath, '%s_E%05d_R%03d_%05d.png' % (name, int(episode_number), int(episode_reward), bi)) if not os.path.exists(iname): total_reward += rewards[bi].item() y = targets[bi:bi + 1] * 0.0 y[0, 0, 0, 0] = targets[bi, 0, 0, 0] title = 'step:%05d action:%d reward:%s %s/%s' % ( bi, actions[bi].item(), int( rewards[bi]), total_reward, int(episode_reward)) print("making", title) for i in range(y.shape[1]): for j in range(y.shape[2]): for k in range(y.shape[3]): x_d, z_e_x, z_q_x, latents = vqvae_model( states[bi:bi + 1], y=y) yhat = sample_from_discretized_mix_logistic( x_d, largs.nr_logistic_mix) if not args.teacher_force: y[0, 0, j, k] = 2 * (yhat[0, 0, j, k] / 255.0) - 1 else: y[0, 0, j, k] = targets[bi, 0, j, k] np_canvas = yhat[0, 0].cpu().numpy() f, ax = plt.subplots(1, 2) ax[0].imshow(np_targets[bi, 0]) ax[0].set_title('true') ax[1].imshow(np_canvas) ax[1].set_title('est') plt.suptitle(title) plt.savefig(iname) print('saving', iname) search_path = iname[:-10:] + '*.png' gif_path = iname[:-10:] + '.gif' cmd = 'convert %s %s' % (search_path, gif_path) print('creating gif', gif_path) os.system(cmd)
def sample_batch(data, episode_number, episode_reward, name): with torch.no_grad(): states, actions, rewards, next_states, terminals, reset, relative_indexes = data x = (2*reshape_input(states[:,-1:])-1).to(DEVICE) for i in range(states.shape[0]): x_d, z_e_x, z_q_x, latents = vqvae_model(x[i:i+1]) loss_1 = discretized_mix_logistic_loss(x_d, x[i:i+1], nr_mix=largs.nr_logistic_mix, DEVICE=DEVICE) yhat = sample_from_discretized_mix_logistic(x_d, largs.nr_logistic_mix) yhat = (((yhat+1)/2.0)*255.0).cpu().numpy().astype(np.int) true = (states[i:i+1,-1:]*255.0).cpu().numpy().astype(np.int) f,ax = plt.subplots(1,2) iname = os.path.join(output_savepath, '%s_E%05d_R%03d_%05d.png'%(name, int(episode_number), int(episode_reward), i)) print("writing", os.path.split(iname)[1]) title = 'step %s/%s action %s reward %s' %(i, states.shape[0], actions[i].item(), rewards[i].item()) ax[0].imshow(true[0,0]) ax[0].set_title('true') ax[1].imshow(yhat[0,0]) ax[1].set_title('est') plt.suptitle(title) plt.savefig(iname) print('saving', iname) search_path = iname[:-10:] + '*.png' gif_path = iname[:-10:] + '.gif' cmd = 'convert %s %s' %(search_path, gif_path) print('creating gif', gif_path) os.system(cmd)
def generate_forward_datasets(): with torch.no_grad(): #for dname, data_loader in {'valid':valid_data_loader, 'train':train_data_loader}.items(): for dname, data_loader in {'valid': valid_data_loader}.items(): rmax = data_loader.relative_indexes.max() new = True st = 1 if args.debug: st = 260 en = 0 keep_going = True while keep_going: en = min(st + args.batch_size, rmax - 1) print("all", st, en) fdata = data_loader.get_data(np.arange(st, en, dtype=np.int)) # use reward as endpoint if args.debug: fterminals = list(fdata[2]) else: fterminals = list(fdata[5]) # end at end of episode if 1 in fterminals: en = st + list(fterminals).index(1) + 1 data = data_loader.get_data(np.arange(st, en, dtype=np.int)) else: data = fdata print("NO END") print('generating from %s to %s of %s' % (st, en, rmax)) states, actions, rewards, values, next_states, terminals, reset, relative_indexes = data assert np.sum(terminals[:-1]) == 0 prev_relative_indexes = relative_indexes - 1 prev_data = data_loader.get_data(prev_relative_indexes) pstates, pactions, prewards, pvalues, pnext_states, pterminals, preset, prelative_indexes = prev_data ps = (2 * reshape_input(torch.FloatTensor(pstates)) - 1).to(DEVICE) s = (2 * reshape_input(torch.FloatTensor(states)) - 1).to(DEVICE) ns = (2 * reshape_input(torch.FloatTensor(next_states)) - 1).to(DEVICE) for xx in range(s.shape[0]): try: assert ps[xx, -1].sum() == s[xx, -2].sum() == ns[xx, -3].sum() except: print("assert broke", xx) embed() px_d, zp_e_x, pz_q_x, platents, prev_pred_actions, prev_pred_rewards = vqvae_model( ps) x_d, z_e_x, z_q_x, latents, pred_actions, pred_rewards = vqvae_model( s) nx_d, nz_e_x, nz_q_x, nlatents, next_pred_actions, next_pred_rewards = vqvae_model( ns) if new: all_prev_latents = platents.cpu() all_latents = latents.cpu() all_next_latents = nlatents.cpu() if args.debug: all_prev_states = pstates all_states = states all_next_states = next_states all_next_pred_actions = next_pred_actions.cpu().numpy() all_next_pred_rewards = next_pred_rewards.cpu().numpy() all_pred_actions = pred_actions.cpu().numpy() all_pred_rewards = pred_rewards.cpu().numpy() all_prev_actions = pactions all_prev_rewards = prewards all_prev_values = pvalues all_rewards = rewards all_values = values all_actions = actions all_rel_inds = relative_indexes new = False else: all_prev_latents = np.concatenate( (all_prev_latents, platents.cpu().numpy()), axis=0) all_latents = np.concatenate( (all_latents, latents.cpu().numpy()), axis=0) all_next_latents = np.concatenate( (all_next_latents, nlatents.cpu().numpy()), axis=0) if args.debug: all_prev_states = np.concatenate( (all_prev_states, pstates), axis=0) all_states = np.concatenate((all_states, states), axis=0) all_next_states = np.concatenate( (all_next_states, next_states), axis=0) all_next_pred_rewards = np.concatenate( (all_next_pred_rewards, next_pred_rewards.cpu().numpy()), axis=0) all_pred_rewards = np.concatenate( (all_pred_rewards, pred_rewards.cpu().numpy()), axis=0) all_next_pred_actions = np.concatenate( (all_next_pred_actions, next_pred_actions.cpu().numpy()), axis=0) all_pred_actions = np.concatenate( (all_pred_actions, pred_actions.cpu().numpy()), axis=0) all_prev_rewards = np.concatenate( (all_prev_rewards, prewards)) all_prev_values = np.concatenate( (all_prev_values, pvalues)) all_prev_actions = np.concatenate( (all_prev_actions, pactions)) all_rewards = np.concatenate((all_rewards, rewards)) all_values = np.concatenate((all_values, values)) all_actions = np.concatenate((all_actions, actions)) all_rel_inds = np.concatenate( (all_rel_inds, relative_indexes)) if 1 in fterminals: # skip ahead one so that prev state is correct st = en + 1 else: st = en if en > rmax - 2: keep_going = False forward_dir = args.model_loadname.replace( '.pt', '_%s_forward_imgs' % dname) if not os.path.exists(forward_dir): os.makedirs(forward_dir) forward_filename = args.model_loadname.replace( '.pt', '_%s_forward.npz' % dname) if args.debug: forward_filename = forward_filename.replace( '.npz', 'debug.npz') np.savez(forward_filename, relative_indexes=all_rel_inds, prev_latents=all_prev_latents, latents=all_latents, next_latents=all_next_latents, rewards=all_rewards, values=all_values, actions=all_actions, prev_rewards=all_prev_rewards, prev_values=all_prev_values, prev_actions=all_prev_actions, prev_states=all_prev_states, states=all_states, next_states=all_next_states, num_k=largs.num_k) else: print('saving', forward_filename) np.savez(forward_filename, relative_indexes=all_rel_inds, prev_latents=all_prev_latents, latents=all_latents, next_latents=all_next_latents, rewards=all_rewards, values=all_values, actions=all_actions, prev_rewards=all_prev_rewards, prev_values=all_prev_values, prev_actions=all_prev_actions, num_k=largs.num_k) if args.debug: for i in range(all_prev_latents.shape[0]): f, ax = plt.subplots(3, 3) ax[0, 0].imshow(all_prev_states[i, -1]) # one ahead in action/reward because vq is predicting transition ax[0, 0].set_title('%04d A%sPA%s' % (i, all_prev_actions[i], np.argmax(all_pred_actions[i]))) ax[0, 1].imshow(all_states[i, -1]) ax[0, 1].set_title('%04d A%sPA%s' % (i + 1, all_actions[i], np.argmax(all_next_pred_actions[i]))) ax[0, 2].imshow(all_next_states[i, -1]) ax[0, 2].set_title('%04d' % (i + 2)) ax[1, 0].imshow(all_prev_latents[i], vmin=0, vmax=512) ax[1, 0].set_title('%04d R%sRA%s' % (i, all_prev_rewards[i], np.argmax(all_pred_rewards[i]))) ax[1, 1].imshow(all_latents[i], vmin=0, vmax=512) ax[1, 1].set_title('%04d R%sRA%s' % (i + 1, all_rewards[i], np.argmax(all_next_pred_rewards[i]))) ax[1, 2].imshow(all_next_latents[i], vmin=0, vmax=512) ax[1, 2].set_title('%04d' % (i + 2)) s_mask = (all_latents[i] - all_prev_latents[i]) == 0 diffnl = deepcopy(all_latents[i]) diffnl[s_mask] *= 0 ax[2, 1].imshow(diffnl, vmin=0, vmax=512) ax[2, 1].set_title('diff %s-%s' % (i, i + 1)) s1_mask = (all_next_latents[i] - all_latents[i]) == 0 diffs1nl = deepcopy(all_next_latents[i]) diffs1nl[s1_mask] *= 0 ax[2, 2].imshow(diffs1nl, vmin=0, vmax=512) ax[2, 2].set_title('diff %s-%s' % (i + 1, i + 2)) pname = os.path.join(forward_dir, '%s_frame%05d.png' % (dname, i + 1)) plt.savefig(pname) plt.close() if not i % 10: print('plotting', i, pname) cmd = 'convert %s %s' % (os.path.join( forward_dir, '%s*.png' % dname), os.path.join(forward_dir, '%s.gif' % dname)) print("!!!! creating gif") print(cmd) os.system(cmd)
#sample_batch(valid_episode_batch, episode_index, episode_reward, 'valid') # Can work with any model, but it assumes that the model has a # feature method, and a classifier method, # as in the VGG models in torchvision. grad_cam = GradCam(model = vqvae_model, target_layer_names =['10'], use_cuda=args.use_cuda) # If None, returns the map for the highest scoring category. # Otherwise, targets the requested index. target_index = None states, actions, rewards, values, pred_states, terminals, reset, relative_indexes = valid_episode_batch input_data = (2*reshape_input(states[:1])-1).to(DEVICE) mask = grad_cam(input_data, target_index) img = input_data[0,-1].cpu().numpy() img = (img+1)/2.0 cimg = cv2.cvtColor(img,cv2.COLOR_GRAY2RGB) show_cam_on_image(cimg, mask) #utils.save_image(torch.from_numpy(cam_gb), 'cam_gb.jpg') #gb_model = GuidedBackpropReLUModel(model = models.vgg19(pretrained=True), use_cuda=args.use_cuda) #gb = gb_model(input, index=target_index) #utils.save_image(torch.from_numpy(gb), 'gb.jpg') #cam_mask = np.zeros(gb.shape) #for i in range(0, gb.shape[0]): # cam_mask[i, :, :] = mask
def sample_episode(data, episode_number, episode_reward, name): # rollout for number of steps and decode with vqvae decoder states, actions, rewards, values, next_states, terminals, reset, relative_indexes = data params = (episode_number, episode_reward, name, actions, rewards) pred_actions = [] bs = min(states.shape[0], args.rollout_length) snp = reshape_input(deepcopy(states)) s = (2 * reshape_input(torch.FloatTensor(states)) - 1) nsnp = reshape_input(next_states) # make channels for actions which is the size of the latents gen_method = [] actions = torch.LongTensor(actions).to(DEVICE) elen = actions.shape[0] #print("setting all actions to one") #actions[args.lead_in:] = 1 #actions[40:] = 0 #actions[args.lead_in:]=torch.LongTensor(np.random.randint(min(data_loader.action_space), # max(data_loader.action_space), # actions[args.lead_in:].shape[0])).to(DEVICE) channel_actions = torch.zeros( (elen, forward_info['num_actions'], forward_info['hsize'], forward_info['hsize'])) for a in range(forward_info['num_actions']): channel_actions[actions == a, a] = 1.0 all_tf_pred_latents = np.zeros((args.rollout_length, 10, 10)) all_real_latents = torch.zeros( (args.rollout_length, 10, 10)).to(DEVICE).float() # first pred index is zeros - since we cant predict it all_pred_latents = torch.zeros( (args.rollout_length, 10, 10)).to(DEVICE).float() assert args.lead_in >= 2 # at beginning # 0th real action is actually for 1 # all_real_latents represents i # tf_pred_latents is constructing i, given latents from latents of i-1,i-2 # from tf_pred_latents[i], you can find obs[i] for i in range(bs): x_d, z_e_x, z_q_x, real_latents, pred_actions, pred_signals = vqvae_model( s[i:i + 1]) # for the ith index all_real_latents[i] = real_latents.float() gmethod = 'NOT' if i >= 2: # get a teacher force result - where past real latents are used to # predict this state assert (np.abs(all_real_latents[i - 2]).sum() > 0) assert (np.abs(all_real_latents[i - 1]).sum() > 0) tf_state_input = torch.cat( (channel_actions[i][None, :], all_real_latents[i - 2][None, None], all_real_latents[i - 1][None, None]), dim=1) #tf_pred_next_latents, tf_pred_prev_actions, tf_pred_rewards = conv_forward_model(tf_state_input) tf_pred_next_latents = conv_forward_model(tf_state_input) # prediction for the i + 1 index tf_pred_next_latents = torch.argmax(tf_pred_next_latents, dim=1) # THIS is pred s+1 so the indexes are different all_tf_pred_latents[i] = deepcopy( tf_pred_next_latents[0].cpu().numpy()) #print('tf', i, all_tf_pred_latents[i].sum()) if i > args.lead_in: # use last prediction print('i', i) assert (np.abs(all_pred_latents[i - 2]).sum() > 0) assert (np.abs(all_pred_latents[i - 1]).sum() > 0) state_input = torch.cat( (channel_actions[i][None, :], all_pred_latents[i - 2][None, None].float(), all_pred_latents[i - 1][None, None].float()), dim=1) gmethod = 'SLF' # use teacher forced version if we are in "lead in" else: gmethod = 'FTF' state_input = torch.cat( (channel_actions[i][None, :], all_real_latents[i - 2][None, None].float(), all_real_latents[i - 1][None, None].float()), dim=1) out_pred_next_latents = conv_forward_model(state_input) # take argmax over channels axis pred_next_latents = torch.argmax(out_pred_next_latents, dim=1) # replace true with this all_pred_latents[i] = pred_next_latents[0].float() #print('pred', i, all_pred_latents[i].sum()) else: print("feeding beginning of preds with real", i) all_pred_latents[i] = real_latents.float() gen_method.append(gmethod) all_pred_latents = all_pred_latents.cpu().numpy() all_real_latents = all_real_latents.cpu().numpy() plot_reconstructions(snp, nsnp, all_real_latents, all_pred_latents, all_tf_pred_latents, params, gen_method) plot_latents(all_real_latents, all_pred_latents, all_tf_pred_latents, params, gen_method)