def _do_training(self): batch, weights, idxes = self.get_batch() self.t += 1 rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] weights = torch.autograd.Variable(ptu.from_numpy( weights.reshape(weights.shape + (1,)) ), requires_grad=False) """ Critic operations. """ next_actions = self.target_policy(next_obs) noise = torch.normal( torch.zeros_like(next_actions), self.target_policy_noise, ) noise = torch.clamp( noise, -self.target_policy_noise_clip, self.target_policy_noise_clip ) noisy_next_actions = next_actions + noise target_q1_values = self.target_qf1(next_obs, noisy_next_actions) target_q2_values = self.target_qf2(next_obs, noisy_next_actions) target_q_values = torch.min(target_q1_values, target_q2_values) q_target = rewards + (1. - terminals) * self.discount * target_q_values q_target = q_target.detach() q1_pred = self.qf1(obs, actions) bellman_errors_1 = (q1_pred - q_target) ** 2 # IS bellman_errors_1 = weights * bellman_errors_1 qf1_loss = bellman_errors_1.mean() q2_pred = self.qf2(obs, actions) bellman_errors_2 = (q2_pred - q_target) ** 2 # IS bellman_errors_2 = weights * bellman_errors_2 qf2_loss = bellman_errors_2.mean() """ Update Networks """ self.qf1_optimizer.zero_grad() qf1_loss.backward() self.qf1_optimizer.step() self.qf2_optimizer.zero_grad() qf2_loss.backward() self.qf2_optimizer.step() # Update priorities EPS = 10 ** -6 td_error_1 = np.abs(ptu.get_numpy(q1_pred) - q_target) td_error_2 = np.abs(ptu.get_numpy(q2_pred) - q_target) new_priorities = (td_error_1 + td_error_2) / 2 + EPS self.replay_buffer.update_priorities(idxes, new_priorities) policy_actions = policy_loss = None if self._n_train_steps_total % self.policy_and_target_update_period == 0: policy_actions = self.policy(obs) q_output = self.qf1(obs, policy_actions) policy_loss = - q_output.mean() self.policy_optimizer.zero_grad() policy_loss.backward() self.policy_optimizer.step() ptu.soft_update_from_to(self.policy, self.target_policy, self.tau) ptu.soft_update_from_to(self.qf1, self.target_qf1, self.tau) ptu.soft_update_from_to(self.qf2, self.target_qf2, self.tau) """ Save some statistics for eval using just one batch. """ if self.need_to_update_eval_statistics: self.need_to_update_eval_statistics = False if policy_loss is None: policy_actions = self.policy(obs) q_output = self.qf1(obs, policy_actions) policy_loss = - q_output.mean() self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss)) self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss)) self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy( policy_loss )) self.eval_statistics.update(create_stats_ordered_dict( 'Q1 Predictions', ptu.get_numpy(q1_pred), )) self.eval_statistics.update(create_stats_ordered_dict( 'Q2 Predictions', ptu.get_numpy(q2_pred), )) self.eval_statistics.update(create_stats_ordered_dict( 'Q Targets', ptu.get_numpy(q_target), )) self.eval_statistics.update(create_stats_ordered_dict( 'Bellman Errors 1', ptu.get_numpy(bellman_errors_1), )) self.eval_statistics.update(create_stats_ordered_dict( 'Bellman Errors 2', ptu.get_numpy(bellman_errors_2), )) self.eval_statistics.update(create_stats_ordered_dict( 'Policy Action', ptu.get_numpy(policy_actions), ))
def _elem_or_tuple_to_variable(elem_or_tuple): if isinstance(elem_or_tuple, tuple): return tuple(_elem_or_tuple_to_variable(e) for e in elem_or_tuple) return ptu.from_numpy(elem_or_tuple).float()
def encode_np(self, imgs, cond): return ptu.get_numpy( self.encode(ptu.from_numpy(imgs), ptu.from_numpy(cond)))
def torch_ify(np_array_or_other): if isinstance(np_array_or_other, np.ndarray): return ptu.from_numpy(np_array_or_other) else: return np_array_or_other
def v_function(obs): action = policy.get_actions(obs) obs, action = ptu.from_numpy(obs), ptu.from_numpy(action) return qf1(obs, action)
def get_train_dict(self, batch): rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] batch_size = obs.size()[0] """ Policy operations. """ policy_actions = self.policy(obs) z_output = self.qf(obs, policy_actions) # BATCH_SIZE x NUM_BINS q_output = (z_output * self.create_atom_values(batch_size)).sum(1) policy_loss = -q_output.mean() """ Critic operations. """ next_actions = self.target_policy(next_obs) target_qf_histogram = self.target_qf( next_obs, next_actions, ) # z_target = ptu.Variable(torch.zeros(self.batch_size, self.num_bins)) rewards_batch = rewards.repeat(1, self.num_bins) terminals_batch = terminals.repeat(1, self.num_bins) projected_returns = (self.reward_scale * rewards_batch + (1. - terminals_batch) * self.discount * self.create_atom_values(batch_size)) projected_returns = torch.clamp(projected_returns, self.returns_min, self.returns_max) bin_values = (projected_returns - self.returns_min) / self.bin_width lower_bin_indices = torch.floor(bin_values) upper_bin_indices = torch.ceil(bin_values) lower_bin_deltas = target_qf_histogram * (upper_bin_indices - bin_values) upper_bin_deltas = target_qf_histogram * (bin_values - lower_bin_indices) z_target_np = np.zeros((batch_size, self.num_bins)) lower_deltas_np = lower_bin_deltas.data.numpy() upper_deltas_np = upper_bin_deltas.data.numpy() lower_idxs_np = lower_bin_indices.data.numpy().astype(int) upper_idxs_np = upper_bin_indices.data.numpy().astype(int) for batch_i in range(self.batch_size): for bin_i in range(self.num_bins): z_target_np[batch_i, bin_i] += (lower_deltas_np[batch_i, lower_idxs_np[batch_i, bin_i]]) z_target_np[batch_i, bin_i] += (upper_deltas_np[batch_i, upper_idxs_np[batch_i, bin_i]]) z_target = ptu.Variable(ptu.from_numpy(z_target_np).float()) # for j in range(self.num_bins): # import ipdb; ipdb.set_trace() # atom_value = self.atom_values_batch[:, j:j+1] # projected_returns = self.reward_scale * rewards + (1. - terminals) * self.discount * ( # atom_value # ) # bin_values = (projected_returns - self.returns_min) / self.bin_width # lower_bin_indices = torch.floor(bin_values) # upper_bin_indices = torch.ceil(bin_values) # lower_bin_deltas = target_qf_histogram[:, j:j+1] * ( # upper_bin_indices - bin_values # ) # upper_bin_deltas = target_qf_histogram[:, j:j+1] * ( # bin_values - lower_bin_indices # ) # new_lower_bin_values = torch.gather( # z_target, 1, lower_bin_indices.long().data # ) + lower_bin_deltas # new_upper_bin_values = torch.gather( # z_target, 1, upper_bin_indices.long().data # ) + upper_bin_deltas # noinspection PyUnresolvedReferences z_pred = self.qf(obs, actions) qf_loss = -(z_target * torch.log(z_pred)).sum(1).mean(0) return OrderedDict([ ('Policy Actions', policy_actions), ('Policy Loss', policy_loss), ('QF Outputs', q_output), ('Z targets', z_target), ('Z predictions', z_pred), ('QF Loss', qf_loss), ])
def _kl_np_to_np(self, np_imgs): torch_input = ptu.from_numpy(normalize_image(np_imgs)) mu, log_var = self.model.encode(torch_input) return ptu.get_numpy( -torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=1))
def visualize_rollout( env, world_model, logdir, max_path_length, low_level_primitives, policy=None, img_size=64, num_rollouts=4, use_raps_obs=False, use_true_actions=True, ): file_path = logdir + "/" os.makedirs(file_path, exist_ok=True) print( f"Generating Imagination Reconstructions No Intermediate Obs: {use_raps_obs} Actual Actions: {use_true_actions}" ) file_suffix = f"imagination_reconstructions_raps_obs_{use_raps_obs}_actual_actions_{use_true_actions}.png" file_path += file_suffix img_shape = (img_size, img_size, 3) reconstructions = np.zeros( (num_rollouts, max_path_length + 1, *img_shape), dtype=np.uint8, ) obs = np.zeros( (num_rollouts, max_path_length + 1, *img_shape), dtype=np.uint8, ) for rollout in range(num_rollouts): for step in range(0, max_path_length + 1): if step == 0: observation = env.reset() new_img = ptu.from_numpy(observation) policy.reset(observation.reshape(1, -1)) reward = 0 if low_level_primitives: policy_o = (None, observation.reshape(1, -1)) else: policy_o = observation.reshape(1, -1) # hack to pass typing checks vis = convert_img_to_save( world_model.get_image_from_obs( torch.from_numpy(observation.reshape(1, -1))).numpy()) add_text(vis, "Ground Truth", (1, 60), 0.25, (0, 255, 0)) else: high_level_action, state = policy.get_action( policy_o, use_raps_obs, use_true_actions) state = state["state"] observation, reward, done, info = env.step( high_level_action[0], ) if low_level_primitives: low_level_obs = np.expand_dims( np.array(info["low_level_obs"]), 0) low_level_action = np.expand_dims( np.array(info["low_level_action"]), 0) policy_o = (low_level_action, low_level_obs) else: policy_o = observation.reshape(1, -1) ( primitive_name, _, _, ) = env.get_primitive_info_from_high_level_action( high_level_action[0]) # hack to pass typing checks vis = convert_img_to_save( world_model.get_image_from_obs( torch.from_numpy(observation.reshape(1, -1))).numpy()) add_text(vis, primitive_name, (1, 60), 0.25, (0, 255, 0)) add_text(vis, f"r: {reward}", (35, 7), 0.3, (0, 0, 0)) obs[rollout, step] = vis if step != 0: new_img = reconstruct_from_state(state, world_model) if step == 1: add_text(new_img, "Reconstruction", (1, 60), 0.25, (0, 255, 0)) reconstructions[rollout, step - 1] = new_img reward_pred = (world_model.reward( world_model.get_features( state)).detach().cpu().numpy().item()) discount_pred = (world_model.get_dist( world_model.pred_discount(world_model.get_features(state)), std=None, normal=False, ).mean.detach().cpu().numpy().item(), )[0] print( f"Rollout {rollout} Step {step - 1} Predicted Reward {reward_pred}" ) print(f"Rollout {rollout} Step {step - 1} Reward {prev_r}") print( f"Rollout {rollout} Step {step - 1} Predicted Discount {discount_pred}" ) print() prev_r = reward _, state = policy.get_action(policy_o) state = state["state"] new_img = reconstruct_from_state(state, world_model) reconstructions[rollout, max_path_length] = new_img reward_pred = (world_model.reward( world_model.get_features(state)).detach().cpu().numpy().item()) discount_pred = (world_model.get_dist( world_model.pred_discount(world_model.get_features(state)), std=None, normal=False, ).mean.detach().cpu().numpy().item(), )[0] print(f"Rollout {rollout} Final Predicted Reward {reward_pred}") print(f"Rollout {rollout} Final Reward {reward}") print(f"Rollout {rollout} Final Predicted Discount {discount_pred}") print() im = np.zeros( (img_size * 2 * num_rollouts, (max_path_length + 1) * img_size, 3), dtype=np.uint8, ) for rollout in range(num_rollouts): for step in range(max_path_length + 1): im[img_size * 2 * rollout:img_size * 2 * rollout + img_size, img_size * step:img_size * (step + 1), ] = obs[rollout, step] im[img_size * 2 * rollout + img_size:img_size * 2 * (rollout + 1), img_size * step:img_size * (step + 1), ] = reconstructions[rollout, step] cv2.imwrite(file_path, im) print(f"Saved Rollout Visualization to {file_path}") print()
def get_action(self, obs, deterministic=False): ''' sample action from the policy, conditioned on the task embedding ''' z = self.z obs = ptu.from_numpy(obs[None]) in_ = torch.cat([obs, z], dim=1) return self.policy.get_action(in_, deterministic=deterministic)
def get_action(self, obs): obs = np.expand_dims(obs, axis=0) obs = Variable(ptu.from_numpy(obs).float(), requires_grad=False) action, _, _ = self.__call__(obs, None) action = action.squeeze(0) return ptu.get_numpy(action), {}
def post_epoch_video_func( algorithm, epoch, policy, img_size=256, mode="eval", ): if epoch == -1 or epoch % 100 == 0: print("Generating Video: ") env = algorithm.eval_env file_path = osp.join(logger.get_snapshot_dir(), mode + "_" + str(epoch) + "_video.avi") img_array1 = [] path_length = 0 observation = env.reset() policy.reset(observation) obs = np.zeros( (4, algorithm.max_path_length, env.observation_space.shape[0]), dtype=np.uint8, ) actions = np.zeros( (4, algorithm.max_path_length, env.action_space.shape[0])) while path_length < algorithm.max_path_length: action, agent_info = policy.get_action(observation, ) observation, reward, done, info = env.step( action, render_every_step=True, render_mode="rgb_array", render_im_shape=(img_size, img_size), ) img_array1.extend(env.envs[0].img_array) obs[0, path_length] = observation actions[0, path_length] = action path_length += 1 img_array2 = [] path_length = 0 observation = env.reset() policy.reset(observation) while path_length < algorithm.max_path_length: action, agent_info = policy.get_action(observation, ) observation, reward, done, info = env.step( action, render_every_step=True, render_mode="rgb_array", render_im_shape=(img_size, img_size), ) img_array2.extend(env.envs[0].img_array) obs[1, path_length] = o actions[1, path_length] = action path_length += 1 img_array3 = [] path_length = 0 observation = env.reset() policy.reset(observation) while path_length < algorithm.max_path_length: action, agent_info = policy.get_action(observationo, ) observation, r, d, i = env.step( action, render_every_step=True, render_mode="rgb_array", render_im_shape=(img_size, img_size), ) img_array3.extend(env.envs[0].img_array) obs[2, path_length] = observation actions[2, path_length] = action path_length += 1 img_array4 = [] path_length = 0 observation = env.reset() policy.reset(observation) while path_length < algorithm.max_path_length: action, agent_info = policy.get_action(observation, ) observation, reward, done, info = env.step( action, render_every_step=True, render_mode="rgb_array", render_im_shape=(img_size, img_size), ) img_array4.extend(env.envs[0].img_array) obs[3, path_length] = observation actions[3, path_length] = action path_length += 1 fourcc = cv2.VideoWriter_fourcc(*"DIVX") out = cv2.VideoWriter(file_path, fourcc, 100.0, (img_size * 2, img_size * 2)) max_len = max(len(img_array1), len(img_array2), len(img_array3), len(img_array4)) gif_clip = [] for i in range(max_len): if i >= len(img_array1): im1 = img_array1[-1] else: im1 = img_array1[i] if i >= len(img_array2): im2 = img_array2[-1] else: im2 = img_array2[i] if i >= len(img_array3): im3 = img_array3[-1] else: im3 = img_array3[i] if i >= len(img_array4): im4 = img_array4[-1] else: im4 = img_array4[i] im12 = np.concatenate((im1, im2), 1) im34 = np.concatenate((im3, im4), 1) im = np.concatenate((im12, im34), 0) out.write(im) gif_clip.append(im) out.release() print("video saved to :", file_path) # gif_file_path = osp.join( # logger.get_snapshot_dir(), mode + "_" + str(epoch) + ".gif" # ) # clip = ImageSequenceClip(list(gif_clip), fps=20) # clip.write_gif(gif_file_path, fps=20) # takes way too much space obs, actions = ptu.from_numpy(obs), ptu.from_numpy(actions) ( post, prior, post_dist, prior_dist, image_dist, reward_dist, pred_discount_dist, embed, ) = algorithm.trainer.world_model(obs.detach(), actions.detach()) if isinstance(image_dist, tuple): image_dist, _ = image_dist image_dist_mean = image_dist.mean.detach() reconstructions = image_dist_mean[:, :3, :, :] reconstructions = (torch.clamp( reconstructions.permute(0, 2, 3, 1).reshape( 4, algorithm.max_path_length, 64, 64, 3) + 0.5, 0, 1, ) * 255.0) reconstructions = ptu.get_numpy(reconstructions).astype(np.uint8) obs_np = ptu.get_numpy(obs[:, :, :64 * 64 * 3].reshape( 4, algorithm.max_path_length, 3, 64, 64).permute(0, 1, 3, 4, 2)).astype(np.uint8) file_path = osp.join(logger.get_snapshot_dir(), mode + "_" + str(epoch) + "_reconstructions.png") im = np.zeros((128 * 4, algorithm.max_path_length * 64, 3), dtype=np.uint8) for i in range(4): for j in range(algorithm.max_path_length): im[128 * i:128 * i + 64, 64 * j:64 * (j + 1)] = obs_np[i, j] im[128 * i + 64:128 * (i + 1), 64 * j:64 * (j + 1)] = reconstructions[i, j] cv2.imwrite(file_path, im) if image_dist_mean.shape[1] == 6: reconstructions = image_dist_mean[:, 3:6, :, :] reconstructions = (torch.clamp( reconstructions.permute(0, 2, 3, 1).reshape( 4, algorithm.max_path_length, 64, 64, 3) + 0.5, 0, 1, )) * 255.0 reconstructions = ptu.get_numpy(reconstructions).astype(np.uint8) file_path = osp.join( logger.get_snapshot_dir(), mode + "_" + str(epoch) + "_reconstructions_wrist_cam.png", ) obs_np = ptu.get_numpy(obs[:, :, 64 * 64 * 3:64 * 64 * 6].reshape( 4, algorithm.max_path_length, 3, 64, 64).permute(0, 1, 3, 4, 2)).astype(np.uint8) im = np.zeros((128 * 4, algorithm.max_path_length * 64, 3), dtype=np.uint8) for i in range(4): for j in range(algorithm.max_path_length): im[128 * i:128 * i + 64, 64 * j:64 * (j + 1)] = obs_np[i, j] im[128 * i + 64:128 * (i + 1), 64 * j:64 * (j + 1)] = reconstructions[i, j] cv2.imwrite(file_path, im)
def forward(self, context=None, mask=None, r=None): if r is None: # hack for now to make things efficient min_context_len = min([ d['observations'].shape[0] for task_trajs in context for d in task_trajs ]) obs = np.array( [[d['observations'][:min_context_len] for d in task_trajs] for task_trajs in context]) next_obs = np.array([[ d['next_observations'][:min_context_len] for d in task_trajs ] for task_trajs in context]) if not self.state_only: acts = np.array( [[d['actions'][:min_context_len] for d in task_trajs] for task_trajs in context]) all_timesteps = np.concatenate([obs, acts, next_obs], axis=-1) else: all_timesteps = np.concatenate([obs, next_obs], axis=-1) # FOR DEBUGGING THE ENCODER # all_timesteps = all_timesteps[:,:,-1:,:] # print(all_timesteps) # print(all_timesteps.shape) # print(all_timesteps.dtype) # print('----') # print(acts.shape) # print(obs.shape) # print(next_obs.shape) # if acts.shape[0] == 10: # print(acts) # print(obs) # print(next_obs) # if acts.shape[0] all_timesteps = Variable(ptu.from_numpy(all_timesteps), requires_grad=False) # N_tasks x N_trajs x Len x Dim N_tasks, N_trajs, Len, Dim = all_timesteps.size( 0), all_timesteps.size(1), all_timesteps.size( 2), all_timesteps.size(3) all_timesteps = all_timesteps.view(-1, Dim) embeddings = self.timestep_encoder(all_timesteps) embeddings = embeddings.view(N_tasks, N_trajs, Len, self.r_dim) if self.use_sum_for_traj_agg: traj_embeddings = torch.sum(embeddings, dim=2) else: traj_embeddings = torch.mean(embeddings, dim=2) # get r if mask is None: r = self.agg(traj_embeddings) else: r = self.agg_masked(traj_embeddings, mask) post_mean, post_log_sig_diag = self.r2z_map(r) return ReparamMultivariateNormalDiag(post_mean, post_log_sig_diag)
def fit(self, data, weights=None): if weights is None: weights = np.ones(len(data)) sum_of_weights = weights.flatten().sum() weights = weights / sum_of_weights all_weights_pt = ptu.from_numpy(weights) indexed_train_data = IndexedData(data) if self.skew_sampling: base_sampler = WeightedRandomSampler(weights, len(weights)) else: base_sampler = RandomSampler(indexed_train_data) train_dataloader = DataLoader( indexed_train_data, sampler=BatchSampler( base_sampler, batch_size=self.batch_size, drop_last=False, ), ) if self.reset_vae_every_epoch: raise NotImplementedError() epoch_stats_list = defaultdict(list) for _ in range(self.num_inner_vae_epochs): for _, indexed_batch in enumerate(train_dataloader): idxs, batch = indexed_batch batch = batch[0].float().to(ptu.device) latents, means, log_vars, stds = ( self.encoder.get_encoding_and_suff_stats(batch)) beta = 1 kl = self.kl_to_prior(means, log_vars, stds) reconstruction_log_prob = self.compute_log_prob( batch, self.decoder, latents) elbo = -kl * beta + reconstruction_log_prob if self.weight_loss: idxs = torch.cat(idxs) batch_weights = all_weights_pt[idxs].unsqueeze(1) loss = -(batch_weights * elbo).sum() else: loss = -elbo.mean() self.encoder_opt.zero_grad() self.decoder_opt.zero_grad() loss.backward() self.encoder_opt.step() self.decoder_opt.step() epoch_stats_list['losses'].append(ptu.get_numpy(loss)) epoch_stats_list['kls'].append(ptu.get_numpy(kl.mean())) epoch_stats_list['log_probs'].append( ptu.get_numpy(reconstruction_log_prob.mean())) epoch_stats_list['latent-mean'].append( ptu.get_numpy(latents.mean())) epoch_stats_list['latent-std'].append( ptu.get_numpy(latents.std())) for k, v in create_stats_ordered_dict( 'weights', ptu.get_numpy(all_weights_pt)).items(): epoch_stats_list[k].append(v) self._epoch_stats = { 'unnormalized weight sum': sum_of_weights, } for k in epoch_stats_list: self._epoch_stats[k] = np.mean(epoch_stats_list[k])
def reconstruct(self, data): latents = self.encoder.encode(ptu.from_numpy(data)) return ptu.get_numpy(self.decoder.reconstruct(latents))
def _decode(self, latents): reconstructions, _ = self.vae.decode(ptu.from_numpy(latents)) decoded = ptu.get_numpy(reconstructions) return decoded
def __init__( self, env, policy, discriminator, policy_optimizer, expert_replay_buffer, disc_optim_batch_size=1024, policy_optim_batch_size=1024, num_update_loops_per_train_call=1000, num_disc_updates_per_loop_iter=1, num_policy_updates_per_loop_iter=1, # initial_only_disc_train_epochs=0, pretrain_disc=False, num_disc_pretrain_iters=1000, disc_lr=1e-3, disc_momentum=0.0, disc_optimizer_class=optim.Adam, use_grad_pen=True, grad_pen_weight=10, plotter=None, render_eval_paths=False, eval_deterministic=False, train_objective='airl', num_disc_input_dims=2, plot_reward_surface=True, use_survival_reward=False, use_ctrl_cost=False, ctrl_cost_weight=0.0, **kwargs): assert disc_lr != 1e-3, 'Just checking that this is being taken from the spec file' if eval_deterministic: eval_policy = MakeDeterministic(policy) else: eval_policy = policy super().__init__(env=env, exploration_policy=policy, eval_policy=eval_policy, expert_replay_buffer=expert_replay_buffer, **kwargs) self.policy_optimizer = policy_optimizer self.discriminator = discriminator self.rewardf_eval_statistics = None self.disc_optimizer = disc_optimizer_class( self.discriminator.parameters(), lr=disc_lr, betas=(disc_momentum, 0.999)) print('\n\nDISC MOMENTUM: %f\n\n' % disc_momentum) self.disc_optim_batch_size = disc_optim_batch_size self.policy_optim_batch_size = policy_optim_batch_size assert train_objective in ['airl', 'fairl', 'gail', 'w1'] self.train_objective = train_objective self.bce = nn.BCEWithLogitsLoss() target_batch_size = self.disc_optim_batch_size self.bce_targets = torch.cat([ torch.ones(target_batch_size, 1), torch.zeros(target_batch_size, 1) ], dim=0) self.bce_targets = Variable(self.bce_targets) if ptu.gpu_enabled(): self.bce.cuda() self.bce_targets = self.bce_targets.cuda() self.use_grad_pen = use_grad_pen self.grad_pen_weight = grad_pen_weight self.num_update_loops_per_train_call = num_update_loops_per_train_call self.num_disc_updates_per_loop_iter = num_disc_updates_per_loop_iter self.num_policy_updates_per_loop_iter = num_policy_updates_per_loop_iter # self.initial_only_disc_train_epochs = initial_only_disc_train_epochs self.pretrain_disc = pretrain_disc self.did_disc_pretraining = False self.num_disc_pretrain_iters = num_disc_pretrain_iters self.cur_epoch = -1 self.plot_reward_surface = plot_reward_surface if plot_reward_surface: d = 6 self._d = d self._d_len = np.arange(-d, d + 0.25, 0.25).shape[0] self.xy_var = [] for i in np.arange(d, -d - 0.25, -0.25): for j in np.arange(-d, d + 0.25, 0.25): self.xy_var.append([float(j), float(i)]) self.xy_var = np.array(self.xy_var) self.xy_var = Variable(ptu.from_numpy(self.xy_var), requires_grad=False) # d = 20 # self._d = d # # self._d_len = np.arange(0.98697072 - 0.3, 0.98697072 + 0.175, 0.02).shape[0] # self._d_len_rows = np.arange(0.74914774 - 0.35, 0.74914774 + 0.45, 0.01).shape[0] # self._d_len_cols = np.arange(0.98697072 - 0.3, 0.98697072 + 0.175, 0.01).shape[0] # self.xy_var = [] # for i in np.arange(0.74914774 + 0.45, 0.74914774 - 0.35, -0.01): # for j in np.arange(0.98697072 - 0.3, 0.98697072 + 0.175, 0.01): # self.xy_var.append([float(j), float(i)]) # self.xy_var = np.array(self.xy_var) # self.xy_var = Variable(ptu.from_numpy(self.xy_var), requires_grad=False) self.num_disc_input_dims = num_disc_input_dims self.use_survival_reward = use_survival_reward self.use_ctrl_cost = use_ctrl_cost self.ctrl_cost_weight = ctrl_cost_weight
def _encode(self, imgs): latent_distribution_params = self.vae.encode(ptu.from_numpy(imgs)) return ptu.get_numpy(latent_distribution_params[0])
def _do_training(self): ''' ''' # train the discriminator (and the encoder) # print('$$$$$$$$$') # print(self.num_disc_updates_per_epoch) for i in range(self.num_disc_updates_per_epoch): self.encoder_optimizer.zero_grad() self.disc_optimizer.zero_grad() context_batch, context_pred_batch, test_pred_batch, policy_test_pred_batch, traj_len = self._get_disc_training_batch( ) # convert it to a pytorch tensor # note that our objective says we should maximize likelihood of # BOTH the context_batch and the test_batch exp_obs_batch = np.concatenate((context_pred_batch['observations'], test_pred_batch['observations']), axis=0) exp_obs_batch = Variable(ptu.from_numpy(exp_obs_batch), requires_grad=False) exp_acts_batch = np.concatenate( (context_pred_batch['actions'], test_pred_batch['actions']), axis=0) exp_acts_batch = Variable(ptu.from_numpy(exp_acts_batch), requires_grad=False) policy_obs_batch = Variable(ptu.from_numpy( policy_test_pred_batch['observations']), requires_grad=False) policy_acts_batch = Variable(ptu.from_numpy( policy_test_pred_batch['actions']), requires_grad=False) post_dist = self.encoder(context_batch) # z = post_dist.sample() # N_tasks x Dim z = post_dist.mean # z_reg_loss = 0.0001 * z.norm(2, dim=1).mean() z_reg_loss = 0.0 # make z's for expert samples context_pred_z = z.repeat( 1, traj_len * self.num_context_trajs_for_training).view( -1, z.size(1)) test_pred_z = z.repeat(1, traj_len * self.num_test_trajs_for_training).view( -1, z.size(1)) z_batch = torch.cat([context_pred_z, test_pred_z], dim=0) positive_obs_batch = torch.cat([exp_obs_batch, z_batch], dim=1) positive_acts_batch = exp_acts_batch # make z's for policy samples z_policy = z_batch negative_obs_batch = torch.cat([policy_obs_batch, z_policy], dim=1) negative_acts_batch = policy_acts_batch # compute the loss for the discriminator obs_batch = torch.cat([positive_obs_batch, negative_obs_batch], dim=0) acts_batch = torch.cat([positive_acts_batch, negative_acts_batch], dim=0) disc_logits = self.disc(obs_batch, acts_batch) disc_preds = (disc_logits > 0).type(torch.FloatTensor) # disc_percent_policy_preds_one = disc_preds[z.size(0):].mean() disc_loss = self.bce(disc_logits, self.bce_targets) accuracy = (disc_preds == self.bce_targets).type( torch.FloatTensor).mean() if self.use_grad_pen: eps = Variable(torch.rand(positive_obs_batch.size(0), 1), requires_grad=True) if ptu.gpu_enabled(): eps = eps.cuda() # old and probably has a bad weird effect on the encoder # difference is that before I was also taking into account norm of grad of disc # wrt the z # interp_obs = eps*positive_obs_batch + (1-eps)*negative_obs_batch # permute the exp_obs_batch (not just within a single traj, but overall) # This is actually a really tricky question how to permute the batches # 1) permute within each of trajectories # z's will be matched, colors won't be matched anyways # 2) permute within trajectories corresponding to a single context set # z's will be matched, colors will be "more unmatched" # 3) just shuffle everything up # Also, the z's need to be handled appropriately interp_obs = eps * exp_obs_batch + (1 - eps) * policy_obs_batch # interp_z = z_batch.detach() # interp_obs = torch.cat([interp_obs, interp_z], dim=1) interp_obs.detach() # interp_obs.requires_grad = True interp_actions = eps * positive_acts_batch + ( 1 - eps) * negative_acts_batch interp_actions.detach() # interp_actions.requires_grad = True gradients = autograd.grad( outputs=self.disc( torch.cat([interp_obs, z_batch.detach()], dim=1), interp_actions).sum(), inputs=[interp_obs, interp_actions], # grad_outputs=torch.ones(exp_specs['batch_size'], 1).cuda(), create_graph=True, retain_graph=True, only_inputs=True) # print(gradients[0].size()) # z_norm = gradients[0][:,-50:].norm(2, dim=1) # print('Z grad norm: %.4f +/- %.4f' % (torch.mean(z_norm), torch.std(z_norm))) # print(gradients[0][:,-50:].size()) # o_norm = gradients[0][:,:-50].norm(2, dim=1) # o_norm = gradients[0].norm(2, dim=1) # print('Obs grad norm: %.4f +/- %.4f' % (torch.mean(o_norm), torch.std(o_norm))) # print(gradients[0].size()) # print(gradients[0][:,:50].norm(2, dim=1)) total_grad = torch.cat([gradients[0], gradients[1]], dim=1) # print(total_grad.size()) gradient_penalty = ((total_grad.norm(2, dim=1) - 1)**2).mean() # another form of grad pen # gradient_penalty = (total_grad.norm(2, dim=1) ** 2).mean() disc_loss = disc_loss + gradient_penalty * self.grad_pen_weight total_reward_loss = z_reg_loss + disc_loss total_reward_loss.backward() self.disc_optimizer.step() self.encoder_optimizer.step() # print(self.disc.fc0.bias[0]) # print(self.encoder.traj_encoder.traj_enc_mlp.fc0.bias[0]) # train the policy # print('--------') # print(self.num_policy_updates_per_epoch) for i in range(self.num_policy_updates_per_epoch): context_batch, policy_batch = self._get_policy_training_batch() policy_batch = np_to_pytorch_batch(policy_batch) post_dist = self.encoder(context_batch) # z = post_dist.sample() # N_tasks x Dim z = post_dist.mean z = z.detach() # repeat z to have the right size z = z.repeat(1, self.policy_batch_size_per_task).view( self.num_tasks_used_per_update * self.policy_batch_size_per_task, -1).detach() # now augment the obs with the latent sample z policy_batch['observations'] = torch.cat( [policy_batch['observations'], z], dim=1) policy_batch['next_observations'] = torch.cat( [policy_batch['next_observations'], z], dim=1) # compute the rewards # If you compute log(D) - log(1-D) then you just get the logits policy_rewards = self.disc(policy_batch['observations'], policy_batch['actions']).detach() policy_batch['rewards'] = policy_rewards # rew_more_than_zero = (rewards > 0).type(torch.FloatTensor).mean() # print(rew_more_than_zero.data[0]) # do a policy update (the zeroing of grads etc. should be handled internally) # print(policy_rewards.size()) self.policy_optimizer.train_step(policy_batch) # print(self.main_policy.fc0.bias[0]) if self.eval_statistics is None: """ Eval should set this to None. This way, these statistics are only computed for one batch. """ self.eval_statistics = OrderedDict() self.eval_statistics['Disc Loss'] = np.mean( ptu.get_numpy(disc_loss)) self.eval_statistics['Disc Acc'] = np.mean(ptu.get_numpy(accuracy)) # self.eval_statistics['Disc Percent Policy Preds 1'] = np.mean(ptu.get_numpy(disc_percent_policy_preds_one)) self.eval_statistics['Disc Rewards Mean'] = np.mean( ptu.get_numpy(policy_rewards)) self.eval_statistics['Disc Rewards Std'] = np.std( ptu.get_numpy(policy_rewards)) self.eval_statistics['Disc Rewards Max'] = np.max( ptu.get_numpy(policy_rewards)) self.eval_statistics['Disc Rewards Min'] = np.min( ptu.get_numpy(policy_rewards)) # self.eval_statistics['Disc Rewards GT Zero'] = np.mean(ptu.get_numpy(rew_more_than_zero)) z_norm = z.norm(2, dim=1).mean() self.eval_statistics['Z Norm'] = np.mean(ptu.get_numpy(z_norm)) if self.policy_optimizer.eval_statistics is not None: self.eval_statistics.update( self.policy_optimizer.eval_statistics)
def __init__( self, input_width, input_height, input_channels, output_size, kernel_sizes, n_channels, strides, paddings, hidden_sizes=None, added_fc_input_size=0, batch_norm_conv=False, batch_norm_fc=False, init_w=1e-4, hidden_init=nn.init.xavier_uniform_, hidden_activation=nn.ReLU(), output_activation=identity, ): if hidden_sizes is None: hidden_sizes = [] assert len(kernel_sizes) == \ len(n_channels) == \ len(strides) == \ len(paddings) super().__init__() self.hidden_sizes = hidden_sizes self.input_width = input_width self.input_height = input_height self.input_channels = input_channels self.output_size = output_size self.output_activation = output_activation self.hidden_activation = hidden_activation self.batch_norm_conv = batch_norm_conv self.batch_norm_fc = batch_norm_fc self.added_fc_input_size = added_fc_input_size self.conv_input_length = self.input_width * self.input_height * self.input_channels self.conv_layers = nn.ModuleList() self.conv_norm_layers = nn.ModuleList() self.fc_layers = nn.ModuleList() self.fc_norm_layers = nn.ModuleList() for out_channels, kernel_size, stride, padding in \ zip(n_channels, kernel_sizes, strides, paddings): conv = nn.Conv2d(input_channels, out_channels, kernel_size, stride=stride, padding=padding) hidden_init(conv.weight) conv.bias.data.fill_(0) conv_layer = conv self.conv_layers.append(conv_layer) input_channels = out_channels xcoords = np.expand_dims(np.linspace(-1, 1, self.input_width), 0).repeat(self.input_height, 0) ycoords = np.repeat(np.linspace(-1, 1, self.input_height), self.input_width).reshape((self.input_height, self.input_width)) self.coords = from_numpy(np.expand_dims(np.stack([xcoords, ycoords], 0), 0))
def test_vae_traj(vae, env_id, save_path=None, save_name=None): pjhome = os.environ['PJHOME'] # optimal traj data_path = osp.join(pjhome, 'data/local/env/{}-optimal-traj.npy'.format(env_id)) if not osp.exists(data_path): return imgs = np.load(data_path) traj_len, batch_size, imlen = imgs.shape imgs = imgs.reshape((-1, imlen)) latents, _ = vae.encode(ptu.from_numpy(imgs)) latents = ptu.get_numpy(latents) latent_distances = np.linalg.norm(latents - latents[-1], axis=1) puck_distances = np.load( osp.join( pjhome, 'data/local/env/{}-optimal-traj-puck-distance.npy'.format(env_id))) fig, axs = plt.subplots(1, 2, figsize=(14, 5)) axs = axs.reshape(-1) ax = axs[0] ax2 = ax.twinx() ax.plot(puck_distances, label='puck distance', color='r') ax2.plot(latent_distances, label='vae distance', color='b') ax.legend(loc='upper right') ax2.legend(loc='center right') # sub-optimal imgs = np.load( osp.join(pjhome, 'data/local/env/{}-local-optimal-traj.npy'.format(env_id))) goal_image = np.load( osp.join( pjhome, 'data/local/env/{}-local-optimal-traj-goal.npy'.format(env_id))) traj_len, batch_size, imlen = imgs.shape puck_distances = np.load( osp.join( pjhome, 'data/local/env/{}-local-optimal-traj-puck-distance.npy'.format( env_id))) imgs = imgs.reshape((-1, imlen)) goal_image = goal_image.reshape((-1, imlen)) latents = vae.encode(ptu.from_numpy(imgs))[0] latent_goal = vae.encode(ptu.from_numpy(goal_image))[0] latents = ptu.get_numpy(latents).reshape(traj_len, -1) latent_goal = ptu.get_numpy(latent_goal).flatten() latent_distances = np.linalg.norm(latents - latent_goal, axis=1) ax = axs[1] ax2 = ax.twinx() ax.plot(puck_distances, label='puck distance', color='r') ax2.plot(latent_distances, label='vae distance', color='b') ax.legend(loc='upper right') ax2.legend(loc='center right') plt.savefig(osp.join(save_path, save_name)) plt.close('all')
def _reconstruction_squared_error_np_to_np(self, np_imgs): torch_input = ptu.from_numpy(normalize_image(np_imgs)) recons, *_ = self.model(torch_input) error = torch_input - recons return ptu.get_numpy((error**2).sum(dim=1))
def _decode(self, latents, vae=None): if vae is None: vae = self.vae_original reconstructions, _ = vae.decode(ptu.from_numpy(latents)) decoded = ptu.get_numpy(reconstructions) return decoded
def set_param_values_np(self, param_values): torch_dict = OrderedDict() for key, tensor in param_values.items(): torch_dict[key] = ptu.from_numpy(tensor) self.load_state_dict(torch_dict)
def _encode(self, imgs, vae=None): if vae is None: vae = self.vae_original latent_distribution_params = vae.encode(ptu.from_numpy(imgs)) return ptu.get_numpy(latent_distribution_params[0])
def v_function(obs): action = policy.get_actions(obs) obs, action = ptu.from_numpy(obs), ptu.from_numpy(action) return qf1(obs, action, return_individual_q_vals=True)
def __init__( self, env, policy, qf1, qf2, target_qf1, target_qf2, bonus_network, beta, use_bonus_critic, use_bonus_policy, use_log, bonus_norm_param, rewards_shift_param, device, discount=0.99, reward_scale=1.0, policy_lr=1e-3, qf_lr=1e-3, alpha_lr=3e-5, optimizer_class=optim.Adam, soft_target_tau=1e-2, target_update_period=1, plotter=None, render_eval_paths=False, use_automatic_entropy_tuning=True, target_entropy=None, ): super().__init__() self.env = env self.policy = policy self.qf1 = qf1 self.qf2 = qf2 self.target_qf1 = target_qf1 self.target_qf2 = target_qf2 self.device = device self.bonus_network = bonus_network self.beta = beta # type of adding bonus to critic or policy self.use_bonus_critic = use_bonus_critic self.use_bonus_policy = use_bonus_policy # use log in the bonus # if use_log : log(bonus) # else 1 - bonus self.use_log = use_log # normlization self.obs_mu, self.obs_std = bonus_norm_param self.normalize = self.obs_mu is not None if self.normalize: print('.......Using normailization in bonus........') self.obs_mu = ptu.from_numpy(self.obs_mu).to(device) self.obs_std = ptu.from_numpy(self.obs_std).to(device) # self.actions_mu = ptu.from_numpy(self.actions_mu).to(device) # self.actions_std = ptu.from_numpy(self.actions_std).to(device) self.rewards_shift_param = rewards_shift_param self.soft_target_tau = soft_target_tau self.target_update_period = target_update_period self.use_automatic_entropy_tuning = use_automatic_entropy_tuning if self.use_automatic_entropy_tuning: if target_entropy: self.target_entropy = target_entropy else: self.target_entropy = -np.prod( self.env.action_space.shape).item( ) # heuristic value from Tuomas self.log_alpha = ptu.zeros(1, requires_grad=True) self.alpha_optimizer = optimizer_class( [self.log_alpha], lr=alpha_lr, ) self.plotter = plotter self.render_eval_paths = render_eval_paths self.qf_criterion = nn.MSELoss() self.vf_criterion = nn.MSELoss() self.policy_optimizer = optimizer_class( self.policy.parameters(), lr=policy_lr, ) self.qf1_optimizer = optimizer_class( self.qf1.parameters(), lr=qf_lr, ) self.qf2_optimizer = optimizer_class( self.qf2.parameters(), lr=qf_lr, ) self.discrete = False self.discount = discount self.reward_scale = reward_scale self.eval_statistics = OrderedDict() self._n_train_steps_total = 0 self._need_to_update_eval_statistics = True self.clip_val = 1e3
def _elem_or_tuple_to_variable(elem_or_tuple): if isinstance(elem_or_tuple, tuple): return tuple(_elem_or_tuple_to_variable(e) for e in elem_or_tuple) return Variable(ptu.from_numpy(elem_or_tuple).float(), requires_grad=False)
def train(self): for t in range(self.num_updates): self.mus = Variable(ptu.from_numpy(self.np_mus), requires_grad=True) self.log_sigmas = Variable(ptu.from_numpy(self.np_log_sigmas), requires_grad=True) # generate rollouts for each task normal = ReparamMultivariateNormalDiag(self.mus, self.log_sigmas) all_samples = [] all_log_probs = [] all_rewards = [] for sample_num in range(self.num_z_samples_per_update): sample = normal.sample() all_samples.append(sample) all_log_probs.append(normal.log_prob(sample)) # evaluate each of the z's for each task rewards = [] for i in range(len(self.obs_task_params)): successes = self.gen_rollout( self.obs_task_params[i], self.task_params[i], sample[i], self.num_trajs_per_z_sample ) rewards.append(np.mean(successes)) all_rewards.append(rewards) all_log_probs = torch.cat(all_log_probs, dim=-1) # num_tasks x num_samples np_all_rewards = np.array(all_rewards).T # num_tasks x num_samples all_rewards = Variable(ptu.from_numpy(np_all_rewards)) all_rewards = all_rewards - torch.mean(all_rewards, dim=-1, keepdim=True) all_rewards = self.reward_scale * all_rewards # compute gradients wrt mus and sigmas pg_loss = -1.0 * torch.sum( torch.mean(all_log_probs * all_rewards, dim=-1) ) grads = autograd.grad( outputs=pg_loss, inputs=[self.mus, self.log_sigmas], only_inputs=True ) # update the mus and sigmas if self.use_nat_grad: print('Nat Grad') mu_grad = grads[0] * (torch.exp(2 * self.log_sigmas)).detach() log_sig_grad = 0.5 * grads[1] else: print('Normal Grad') mu_grad = grads[0] log_sig_grad = grads[1] self.mus = self.mus - self.mu_lr * mu_grad self.log_sigmas = self.log_sigmas - self.log_sig_lr * log_sig_grad self.np_mus = ptu.get_numpy(self.mus) self.np_log_sigmas = ptu.get_numpy(self.log_sigmas) # logging np_all_rewards = np.mean(np_all_rewards, axis=-1) print('\n-----------------------------------------------') # print('Avg Reward: {}'.format(np.mean(np_all_rewards))) # print('Std Reward: {}'.format(np.std(np_all_rewards))) # print('Max Reward: {}'.format(np.max(np_all_rewards))) # print('Min Reward: {}'.format(np.min(np_all_rewards))) print(np_all_rewards) print(self.np_mus) print(np.exp(2*self.np_log_sigmas))
def decode_np(self, latents): reconstructions = self.decode(ptu.from_numpy(latents)) decoded = ptu.get_numpy(reconstructions) return decoded
def __init__( self, discriminator, exp_data, pol_data, disc_optim_batch_size=1024, num_update_loops_per_train_call=1, num_disc_updates_per_loop_iter=1, disc_lr=1e-3, disc_momentum=0.0, disc_optimizer_class=optim.Adam, use_grad_pen=True, grad_pen_weight=10, train_objective='airl', ): assert disc_lr != 1e-3, 'Just checking that this is being taken from the spec file' self.exp_data, self.pol_data = exp_data, pol_data self.discriminator = discriminator self.rewardf_eval_statistics = None self.disc_optimizer = disc_optimizer_class( self.discriminator.parameters(), lr=disc_lr, betas=(disc_momentum, 0.999) ) print('\n\nDISC MOMENTUM: %f\n\n' % disc_momentum) self.disc_optim_batch_size = disc_optim_batch_size assert train_objective in ['airl', 'fairl', 'gail', 'w1'] self.train_objective = train_objective self.bce = nn.BCEWithLogitsLoss() target_batch_size = self.disc_optim_batch_size self.bce_targets = torch.cat( [ torch.zeros(target_batch_size), torch.ones(target_batch_size), 2*torch.ones(target_batch_size), ], dim=0 ).type(torch.LongTensor) self.bce_targets = Variable(self.bce_targets) if ptu.gpu_enabled(): self.bce.cuda() self.bce_targets = self.bce_targets.cuda() self.use_grad_pen = use_grad_pen self.grad_pen_weight = grad_pen_weight self.num_update_loops_per_train_call = num_update_loops_per_train_call self.num_disc_updates_per_loop_iter = num_disc_updates_per_loop_iter d = 5.0 self._d = d self._d_len = np.arange(-d,d+0.25,0.25).shape[0] self.xy_var = [] for i in np.arange(d,-d-0.25,-0.25): for j in np.arange(-d,d+0.25,0.25): self.xy_var.append([float(j),float(i)]) self.xy_var = np.array(self.xy_var) self.xy_var = Variable(ptu.from_numpy(self.xy_var), requires_grad=False)