def forward(self, obs, action): """ :param obs: torch Variable, [batch_size, sequence length, obs dim] :param action: torch Variable, [batch_size, sequence length, action dim] :return: torch Variable, [batch_size, sequence length, 1] """ assert len(obs.size()) == 3 inputs = torch.cat((obs, action), dim=2) batch_size, subsequence_length = obs.size()[:2] cx = Variable( ptu.FloatTensor(1, batch_size, self.hidden_size) ) cx.data.fill_(0) hx = Variable( ptu.FloatTensor(1, batch_size, self.hidden_size) ) hx.data.fill_(0) rnn_outputs, _ = self.lstm(inputs, (hx, cx)) rnn_outputs.contiguous() rnn_outputs_flat = rnn_outputs.view(-1, self.hidden_size) obs_flat = obs.view(-1, self.obs_dim) action_flat = action.view(-1, self.action_dim) h = torch.cat((rnn_outputs_flat, obs_flat), dim=1) h = F.relu(self.fc1(h)) h = torch.cat((h, action_flat), dim=1) h = F.relu(self.fc2(h)) outputs_flat = self.last_fc(h) return outputs_flat.view(batch_size, subsequence_length, 1)
def forward(self, obs, memory, action, write): """ :param obs: torch Variable, [batch_size, sequence length, obs dim] :param memory: torch Variable, [batch_size, sequence length, memory dim] :param action: torch Variable, [batch_size, sequence length, action dim] :param write: torch Variable, [batch_size, sequence length, memory dim] :return: torch Variable, [batch_size, sequence length, 1] """ rnn_inputs = torch.cat((obs, memory, action, write), dim=2) batch_size, subsequence_length, _ = obs.size() cx = Variable( ptu.FloatTensor(1, batch_size, self.hidden_size) ) cx.data.fill_(0) hx = Variable( ptu.FloatTensor(1, batch_size, self.hidden_size) ) hx.data.fill_(0) state = (hx, cx) rnn_outputs, _ = self.rnn(rnn_inputs, state) rnn_outputs.contiguous() rnn_outputs_flat = rnn_outputs.view( batch_size * subsequence_length, self.fc1.in_features, ) outputs_flat = self.output_activation(self.last_fc(rnn_outputs_flat)) return outputs_flat.view(batch_size, subsequence_length, 1)
def forward( self, obs, deterministic=False, return_log_prob=False, return_entropy=False, return_log_prob_of_mean=False, ): obs, taus = split_tau(obs) h = obs batch_size = h.size()[0] y_binary = ptu.FloatTensor(batch_size, self.max_tau + 1) y_binary.zero_() t = taus.data.long() t = torch.clamp(t, min=0) y_binary.scatter_(1, t, 1) h = torch.cat(( obs, ptu.Variable(y_binary), ), dim=1) return super().forward( obs=h, deterministic=deterministic, return_log_prob=return_log_prob, return_entropy=return_entropy, return_log_prob_of_mean=return_log_prob_of_mean, )
def forward(self, flat_obs, actions=None): obs, taus = split_tau(flat_obs) if actions is not None: h = torch.cat((obs, actions), dim=1) else: h = obs batch_size = h.size()[0] y_binary = ptu.FloatTensor(batch_size, self.max_tau + 1) y_binary.zero_() t = taus.data.long() t = torch.clamp(t, min=0) y_binary.scatter_(1, t, 1) if actions is not None: h = torch.cat(( obs, ptu.Variable(y_binary), actions ), dim=1) else: h = torch.cat(( obs, ptu.Variable(y_binary), ), dim=1) for i, fc in enumerate(self.fcs): h = self.hidden_activation(fc(h)) return - torch.abs(self.last_fc(h))
def dump_latent_plots(vae_env, epoch): from railrl.core import logger import os.path as osp from torchvision.utils import save_image if getattr(vae_env, "get_states_sweep", None) is None: return nx, ny = (vae_env.vis_granularity, vae_env.vis_granularity) states_sweep = vae_env.get_states_sweep(nx, ny) sweep_latents_mu, sweep_latents_logvar = vae_env.encode_states(states_sweep, clip_std=False) sweep_latents_std = np.exp(0.5*sweep_latents_logvar) sweep_latents_sample = vae_env.reparameterize(sweep_latents_mu, sweep_latents_logvar, noisy=True) images_mu_sc, images_std_sc, images_sample_sc = [], [], [] imsize = 84 for i in range(sweep_latents_mu.shape[1]): image_mu_sc = vae_env.transform_image(vae_env.get_image_plt( sweep_latents_mu[:,i].reshape((nx, ny)), vmin=-2.0, vmax=2.0, draw_state=False, imsize=imsize)) images_mu_sc.append(image_mu_sc) image_std_sc = vae_env.transform_image(vae_env.get_image_plt( sweep_latents_std[:,i].reshape((nx, ny)), vmin=0.0, vmax=2.0, draw_state=False, imsize=imsize)) images_std_sc.append(image_std_sc) image_sample_sc = vae_env.transform_image(vae_env.get_image_plt( sweep_latents_sample[:,i].reshape((nx, ny)), vmin=-3.0, vmax=3.0, draw_state=False, imsize=imsize)) images_sample_sc.append(image_sample_sc) images = images_mu_sc + images_std_sc + images_sample_sc images = np.array(images) if vae_env.representation_size > 16: nrow = 16 else: nrow = vae_env.representation_size if epoch is not None: save_dir = osp.join(logger.get_snapshot_dir(), 'z_%d.png' % epoch) else: save_dir = osp.join(logger.get_snapshot_dir(), 'z.png') save_image( ptu.FloatTensor( ptu.from_numpy( images.reshape( (vae_env.representation_size*3, -1, imsize, imsize) ))), save_dir, nrow=nrow, )
def forward( self, flat_obs, return_preactivations=False ): obs, taus = split_tau(flat_obs) h = obs batch_size = h.size()[0] y_binary = ptu.FloatTensor(batch_size, self.max_tau + 1) y_binary.zero_() t = taus.data.long() t = torch.clamp(t, min=0) y_binary.scatter_(1, t, 1) h = torch.cat(( obs, ptu.Variable(y_binary), ), dim=1) return super().forward( h, return_preactivations=return_preactivations, )
def dump_latent_histogram(vae_env, epoch, noisy=False, reproj=False, use_true_prior=None, draw_dots=False): from railrl.core import logger import os.path as osp from torchvision.utils import save_image images = vae_env.get_image_latent_histogram( noisy=noisy, reproj=reproj, draw_dots=draw_dots, use_true_prior=use_true_prior ) if noisy: prefix = 'h' elif reproj: prefix = 'h_r' else: prefix = 'h_mu' if epoch is None: save_dir = osp.join(logger.get_snapshot_dir(), prefix + '.png') else: save_dir = osp.join(logger.get_snapshot_dir(), prefix + '_%d.png' % epoch) save_image( ptu.FloatTensor(ptu.from_numpy(images)), save_dir, nrow=int(np.sqrt(images.shape[0])), )
def eval_model_np(state, action): state = ptu.Variable(ptu.FloatTensor([[state]]), requires_grad=False) action = ptu.Variable(ptu.FloatTensor([[action]]), requires_grad=False) a, v = model(state, action) q = a + v return ptu.get_numpy(q)[0]
def dump_video( env, policy, filename, rollout_function, qf=None, vf=None, rows=3, columns=6, pad_length=0, pad_color=255, do_timer=True, imsize=84, epoch=None, vis_list=None, vis_blacklist=None, ): if vis_list is None: vis_list = [ 'image_desired_goal', 'image_observation', 'reconstr_image_observation', 'reconstr_image_reproj_observation', 'image_desired_subgoal', 'image_desired_subgoal_reproj', 'image_plt', 'image_latent_histogram_2d', 'image_latent_histogram_mu_2d', 'image_v_latent', 'image_v', 'image_v_noisy_state_and_goal', 'image_v_noisy_state', 'image_v_noisy_goal', 'image_rew', 'image_rew_euclidean', 'image_rew_mahalanobis', 'image_rew_logp', 'image_rew_kl', 'image_rew_kl_rev', ] if vis_blacklist is not None: vis_list = [x for x in vis_list if x not in vis_blacklist] num_channels = 1 if env.grayscale else 3 frames = [] N = rows * columns subgoal_images = [] for i in range(N): start = time.time() path = rollout_function( env, policy, qf=qf, vf=vf, animated=False, epoch=epoch, rollout_num=i, ) if 'image_desired_subgoals_reproj' in path['full_observations'][1]: image_ob = path['full_observations'][1][ 'image_observation'].reshape((-1, 3, imsize, imsize)) if 'image_desired_goal_annotated' in path['full_observations'][1]: image_goal = path['full_observations'][1][ 'image_desired_goal_annotated'].reshape( (-1, 3, imsize, imsize)) else: image_goal = path['full_observations'][1][ 'image_desired_goal'].reshape((-1, 3, imsize, imsize)) image_sg = path['full_observations'][1][ 'image_desired_subgoals_reproj'] image_sg = image_sg.reshape((-1, 3, imsize, imsize)) image = np.concatenate((image_ob, image_sg, image_goal)) subgoal_images.append(image) mini_frames = [] for d in path['full_observations'][1:]: get_image_kwargs = dict( pad_length=pad_length, pad_color=pad_color, imsize=imsize, ) get_image_sweeps = [d.get(key, None) for key in vis_list] img = get_image( *get_image_sweeps, **get_image_kwargs, ) mini_frames.append(img) horizon = len(mini_frames) frames += mini_frames if do_timer: print(i, time.time() - start) if len(subgoal_images) != 0: from railrl.core import logger import os.path as osp logdir = logger.get_snapshot_dir() filename_subgoals = osp.join(logdir, 'sg_{epoch}.png'.format(epoch=epoch)) nrow = subgoal_images[0].shape[0] subgoal_images = np.concatenate(subgoal_images) save_image( ptu.FloatTensor(ptu.from_numpy(subgoal_images)), filename_subgoals, nrow=nrow, ) frames = np.array(frames, dtype=np.uint8).reshape( (N, horizon, -1, imsize + 2 * pad_length, num_channels)) f1 = [] for k1 in range(columns): f2 = [] for k2 in range(rows): k = k1 * rows + k2 f2.append(frames[k:k + 1, :, :, :, :].reshape( (horizon, -1, imsize + 2 * pad_length, num_channels))) f1.append(np.concatenate(f2, axis=1)) outputdata = np.concatenate(f1, axis=2) skvideo.io.vwrite(filename, outputdata) print("Saved video to ", filename)
mu_stds = np.std(np.vstack(latent_mus), axis=0) plt.bar(np.arange(len(mu_stds)), mu_stds) plt.title("X-axis puck sweep") plt.xlabel("latent dim") plt.ylabel("Mean std") plt.show() sigma_stds = np.mean(np.vstack(latent_sigmas), axis=0) plt.bar(np.arange(len(sigma_stds)), sigma_stds) plt.title("X-axis puck sweep") plt.xlabel("latent dim") plt.ylabel("Sigma std") plt.show() imgs = np.array(imgs) imgs = ptu.FloatTensor(imgs) save_image(imgs, 'x-puck-sweep.png') # ------------------ X axis - arm goals = [] for x in np.arange(env.hand_low[0], env.hand_high[0], 0.01): new_hand_xyz = hand_xyz.copy() new_hand_xyz[0] = x goals.append(np.hstack((new_hand_xyz, obj_xyz))) imgs, latent_mus, latent_sigmas = get_info(goals) mu_stds = np.std(np.vstack(latent_mus), axis=0) plt.bar(np.arange(len(mu_stds)), mu_stds) plt.title("X-axis arm sweep") plt.xlabel("latent dim") plt.ylabel("Mean std")
def contains_points_pytorch(self, points): less_op = points < ptu.Variable( ptu.FloatTensor([self.max_x, self.max_y])) greater_op = points > ptu.Variable( ptu.FloatTensor([self.min_x, self.min_y])) return less_op * greater_op
# return np.array(joint_angles), np.array(jacobians) return ( np.hstack((np.array(states), np.array(actions))), np.array(next_states) - np.array(states) ) train_x_np, train_y_np = generate_data(N_PATHS, PATH_LENGTH) test_x_np, test_y_np = generate_data(N_PATHS_TEST, PATH_LENGTH) train_x = ptu.np_to_var(train_x_np) train_y = ptu.np_to_var(train_y_np) test_x = ptu.np_to_var(test_x_np) test_y = ptu.np_to_var(test_y_np) train_dataset = TensorDataset( ptu.FloatTensor(train_x_np), ptu.FloatTensor(train_y_np), ) in_dim = train_x_np[0].size out_dim = train_y_np[0].size dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True) def train_network(net, title): train_losses = [] test_losses = [] times = [] optimizer = Adam(net.parameters(), lr=1e-3) criterion = nn.MSELoss()
def get_new_state(self, batch_size): cx = Variable(ptu.FloatTensor(1, batch_size, self.hidden_size)) cx.data.fill_(0) hx = Variable(ptu.FloatTensor(1, batch_size, self.hidden_size)) hx.data.fill_(0) return hx, cx
def forward(self, obs, initial_memory): """ :param obs: torch Variable, [batch_size, sequence length, obs dim] :param initial_memory: torch Variable, [batch_size, memory dim] :return: (actions, writes) tuple actions: [batch_size, sequence length, action dim] writes: [batch_size, sequence length, memory dim] """ assert len(obs.size()) == 3 assert len(initial_memory.size()) == 2 batch_size, subsequence_length = obs.size()[:2] subtraj_writes = Variable(ptu.FloatTensor(batch_size, subsequence_length, self.memory_dim), requires_grad=False) subtraj_actions = Variable(ptu.FloatTensor(batch_size, subsequence_length, self.action_dim), requires_grad=False) if self.feed_action_to_memory: if self.num_splits_for_rnn_internally > 1: state = torch.split( initial_memory, self.memory_dim // self.num_splits_for_rnn_internally, dim=1, ) for i in range(subsequence_length): current_obs = obs[:, i, :] augmented_state = torch.cat((current_obs, ) + state, dim=1) action = self.forward_action(augmented_state) rnn_input = torch.cat([current_obs, action], dim=1) state = self.rnn_cell(rnn_input, state) subtraj_writes[:, i, :] = torch.cat(state, dim=1) subtraj_actions[:, i, :] = action else: state = initial_memory for i in range(subsequence_length): current_obs = obs[:, i, :] augmented_state = torch.cat([current_obs, state], dim=1) action = self.forward_action(augmented_state) rnn_input = torch.cat([current_obs, action], dim=1) state = self.rnn_cell(rnn_input, state) subtraj_writes[:, i, :] = state subtraj_actions[:, i, :] = action return subtraj_actions, subtraj_writes """ Create the new writes. """ if self.num_splits_for_rnn_internally > 1: state = torch.split( initial_memory, self.memory_dim // self.num_splits_for_rnn_internally, dim=1, ) for i in range(subsequence_length): state = self.rnn_cell(obs[:, i, :], state) subtraj_writes[:, i, :] = torch.cat(state, dim=1) else: state = initial_memory for i in range(subsequence_length): state = self.rnn_cell(obs[:, i, :], state) subtraj_writes[:, i, :] = state # The reason that using a LSTM doesn't work is that this gives you only # the FINAL hx and cx, not all of them :( """ Create the new subtrajectory memories with the initial memories and the new writes. """ expanded_init_memory = initial_memory.unsqueeze(1) if subsequence_length > 1: memories = torch.cat( ( expanded_init_memory, subtraj_writes[:, :-1, :], ), dim=1, ) else: memories = expanded_init_memory """ Use new memories to create env actions. """ all_subtraj_inputs = torch.cat([obs, memories], dim=2) for i in range(subsequence_length): augmented_state = all_subtraj_inputs[:, i, :] action = self.forward_action(augmented_state) subtraj_actions[:, i, :] = action return subtraj_actions, subtraj_writes