def train_pd_match_sd(self, dataset, bs, itr, outer_itr): sampler = self.sampler expert_traj, _ = dataset.sample(bs) # sample from dataset to initialize trajectory from x, actdata = self.vae.splitobs(FloatTensor(expert_traj)) z = Variable(torch.randn((bs, self.latent_dim))) pd_traj, sd_traj = self.forward(sampler, x, z) sd_traj_obs = get_numpy(sd_traj.mle) traj_3d_shape = (bs, -1, self.obs_dim) pd_traj_obs = np_to_var(pd_traj['obs'][:, 1:]) se = sd_traj.reshape(traj_3d_shape).log_likelihood(pd_traj_obs) mse_sd_pd = self.compute_traj_mse(pd_traj_obs, sd_traj.mle, traj_3d_shape) pd_traj['rewards'] = get_numpy(se) self.policy_algo.process_samples(0, pd_traj, augment_obs=get_numpy(z)) self.policy_algo.optimize_policy(0, pd_traj) traj_sets = [sd_traj_obs, pd_traj['obs'][:, 1:]] pd_traj['stats']['mse_sd_pd'] = get_numpy(mse_sd_pd.mean()).item() pd_traj['stats']['ll'] = np.mean(get_numpy(se)) return pd_traj['stats']
def plot_compare(self, dataset, itr): trajs, _ = dataset.sample_hard(self.plot_size) x, actdata = self.vae.splitobs(np_to_var(trajs)) latent_dist = self.vae.encode(x) latent = latent_dist.sample(deterministic=True) #import pdb; pdb.set_trace() pd_traj, sd_traj = self.forward(self.plot_sampler, FloatTensor(trajs), latent) # sample for plottin traj_sets = [ get_numpy(x)[:, self.step_dim:], get_numpy(sd_traj.mle), pd_traj['obs'][:, 1:] ] traj_sets = [ x.reshape((self.plot_size, self.max_path_length, -1)) for x in traj_sets ] traj_names = ['expert', 'sd', 'pd'] plot_traj_sets([dataset.process(x) for x in traj_sets], traj_names, itr, env_id=dataset.env_id) #dataset.plot_pd_compare([x[0] for x in traj_sets], traj_names, itr) for traj_no in range(5): dataset.plot_pd_compare([x[traj_no, ...] for x in traj_sets], traj_names, itr, name='Full_State_%d' % traj_no, save_dir='pd_match_expert') self.zero_grad()
def process_samples(self, itr, sd, augment_obs=None): # rewards is (bs, path_len) # actions is (bs, path_len, acton_dim) if augment_obs is not None: sd['obs'] = np.concatenate([sd['obs'], np.repeat(np.expand_dims(augment_obs, 1), sd['obs'].shape[1], 1)], -1) sd['values'] = get_numpy(self.baseline.predict(sd['obs'].astype(np.float32))) sd['rewards'] = sd['rewards'].astype(np.float32) path_len = sd['rewards'].shape[-1] returns = np.zeros((sd['rewards'].shape[0], sd['rewards'].shape[1] + 1)) returns[:, -2] = sd['values'][:, -1] rewards = sd['rewards'] if self.use_gae: gae = 0 for step in reversed(range(rewards.shape[1])): mask = 1 if step != rewards.shape[1] - 1 else 0 delta = rewards[:, step] + self.discount * sd['values'][:, step + 1] * mask - sd['values'][:, step] gae = delta + self.discount * self.gae_tau * mask * gae returns[:, step] = gae + sd['values'][:, step] #import pdb; pdb.set_trace() else: for step in reversed(range(rewards.shape[1])): returns[:, step] = returns[:, step + 1] * self.discount + rewards[:, step] sd['returns'] = np.cumsum(sd['rewards'][:, ::-1], axis=-1)[:, ::-1] sd['discount_returns'] = returns[:, :-1] sd['actions'] = sd['actions'].detach() sd['log_prob'] = sd['action_dist'].log_likelihood(sd['actions']) sd['entropy'] = sd['action_dist'].entropy() # logger.log('Fitting Baseline') #sd['adv'] = sd['returns'] - sd['values'] if self.fit_baseline: self.baseline.fit(sd['obs'][:, :-1], sd['discount_returns']) if hasattr(self.policy, 'obs_filter') and self.policy.obs_filter is not None: self.policy.obs_filter.update(DoubleTensor(sd['obs'][:, :-1].reshape((-1, self.policy.obs_filter.shape)))) if sd['values'].shape[1] > 1: sd['discount_adv'] = sd['discount_returns'] - sd['values'][:, :-1] else: sd['discount_adv'] = sd['discount_returns'] - sd['values'] if self.center_adv: sd['discount_adv'] = (sd['discount_adv'] - sd['discount_adv'].mean()) / ( sd['discount_adv'].std() + 1e-5 ) sd['stats'] = OrderedDict([ ('Mean Return', sd['returns'][:, 0].mean()), ('Min Return', sd['returns'][:, 0].min()), ('Max Return', sd['returns'][:, 0].max()), ('Var Return', sd['returns'][:, 0].var()), ('Entropy', get_numpy(sd['entropy'].mean())[0]), #('Policy loss', -(get_numpy(sd['log_prob']) * sd['discount_adv']).sum(-1).mean()), ])
def optimize_policy(self, itr, samples_data): prev_param = get_numpy(self._target.get_params_flat()) self.policy.zero_grad() loss_before = self.loss(samples_data) loss_before.backward() flat_g = self.policy.get_params_flat() loss_before = get_numpy(loss_before).item() Hx = self._hvp_approach.build_eval(samples_data) descent_direction = krylov.cg(Hx, flat_g, cg_iters=self._cg_iters) initial_step_size = np.sqrt( 2.0 * self._max_constraint_val * (1. / (descent_direction.dot(Hx(descent_direction)) + 1e-8))) if np.isnan(initial_step_size): initial_step_size = 1. flat_descent_step = initial_step_size * descent_direction logger.log("descent direction computed") n_iter = 0 for n_iter, ratio in enumerate(self._backtrack_ratio**np.arange( self._max_backtracks)): cur_step = ratio * flat_descent_step cur_param = prev_param - cur_step self._target.set_params_flat(from_numpy(cur_param)) loss, constraint_val = self.compute_loss_terms(samples_data) if self._debug_nan and np.isnan(constraint_val): import ipdb ipdb.set_trace() if loss < loss_before and constraint_val <= self._max_constraint_val: break if (np.isnan(loss) or np.isnan(constraint_val) or loss >= loss_before or constraint_val >= self._max_constraint_val ) and not self._accept_violation: logger.log("Line search condition violated. Rejecting the step!") if np.isnan(loss): logger.log("Violated because loss is NaN") if np.isnan(constraint_val): logger.log("Violated because constraint %s is NaN" % self._constraint_name) if loss >= loss_before: logger.log("Violated because loss not improving") if constraint_val >= self._max_constraint_val: logger.log("Violated because constraint %s is violated" % self._constraint_name) self._target.set_param_values(prev_param, trainable=True) logger.log("backtrack iters: %d" % n_iter) logger.log("computing loss after") logger.log("optimization finished")
def get_opt_output(self, sd, penalty): self.policy.zero_grad() penalized_loss, surr_loss, mean_kl = self.compute_loss_terms( sd, penalty) penalized_loss.backward() params = self.policy.get_params() grads = [p.grad for p in params] flat_grad = torch.cat([g.view(-1) for g in grads]) #import pdb; pdb.set_trace() return [ get_numpy(penalized_loss.double())[0], get_numpy(flat_grad.double()) ]
def rollout(env, policy, max_path_length, add_input=None, volatile=False, reset_args = None): sd = dict(obs=[], rewards=[], actions=[], action_dist_lst=[]) obs = env.reset(reset_args) for s in range(max_path_length): policy_input = Variable(from_numpy(np.array([obs])).float(), volatile=volatile) if add_input is not None: policy_input = torch.cat([policy_input, add_input], -1) if s == 0: policy.reset(1) if policy.recurrent(): policy_input = policy_input.unsqueeze(0) action_dist = policy.forward(policy_input) action = action_dist.sample() x = env.step(get_numpy(action)) next_obs = x[0] sd['obs'].append(obs) sd['rewards'].append(x[1]) sd['actions'].append(action) obs = next_obs sd['obs'].append(obs) sd['obs'] = np.array(sd['obs']) # (bs, max_path_length, obs_dim) sd['rewards'] = np.array(sd['rewards']) # (bs, max_path_length) sd['actions'] = torch.stack(sd['actions'], 1) return sd
def rollout(self, max_path_length, add_input=None, volatile=False): sd = dict(obs=[], rewards=[], actions=[], action_dist_lst=[]) obs = self.envs.reset() self.policy.reset(len(obs)) for s in range(max_path_length): policy_input = Variable(from_numpy(np.stack(obs)).float(), volatile=volatile) if add_input is not None: policy_input = torch.cat([policy_input, add_input], -1) action_dist = self.policy.forward(policy_input) action = action_dist.sample() if self.random_action_p > 0: flip = np.random.binomial(1, self.random_action_p, size=len(obs)) if flip.sum() > 0: random_act = np.random.randint(0, int(self.env.action_space.flat_dim), size=flip.sum()) action[from_numpy(flip).byte()] = from_numpy(random_act) next_obs, rewards, done, info = self.envs.step(get_numpy(action)) sd['obs'].append(obs) sd['rewards'].append(rewards) sd['actions'].append(action) sd['action_dist_lst'].append(action_dist) obs = next_obs # Append last obs sd['obs'].append(obs) sd['obs'] = np.stack(sd['obs'], 1) # (bs, max_path_length, obs_dim) sd['rewards'] = np.stack(sd['rewards'], 1) # (bs, max_path_length) sd['actions'] = torch.stack(sd['actions'], 1) sd['action_dist'] = sd['action_dist_lst'][0].combine(sd['action_dist_lst'], torch.stack, axis=1) return sd
def plot_random(self, dataset, itr, sample_size=5): y_dist, latent = self.sample(dataset, sample_size) traj_sets = [dataset.unnormalize(get_numpy(y_dist.mle))] traj_names = ['sampled'] plot_traj_sets([dataset.process(traj_set) for traj_set in traj_sets], traj_names, itr, figname='sampled', env_id=dataset.env_id)
def sample_pd(self, sampler, latent, trajs): num_traj = latent.size()[0] print('sampling', num_traj, self.max_path_length) sd = sampler.obtain_samples(num_traj * self.max_path_length, self.max_path_length, latent, reset_args=get_numpy(trajs)) return sd
def test(self, dataset): data = FloatTensor(dataset.train_data) x, actdata = self.splitobs(data) y = x[:, self.step_dim:] z_dist = self.encode(Variable(x)) z = z_dist.sample() y_dist = self.decode(x, z) log_likelihood = torch.pow(y_dist.mle - Variable(y), 2).mean(-1).mean(0) return get_numpy(log_likelihood).item()
def sample(self, deterministic=False): if deterministic: return self.prob_3d.max(-1)[1].unsqueeze(-1) else: cat_size = self.probs_3d.size()[-1] onehot = np.zeros((self.bs * self.path_len, cat_size)) idx = torch.multinomial(self.prob.view(-1, cat_size), 1) onehot[np.arange(self.bs * self.path_len), get_numpy(idx.squeeze())] = 1 return np_to_var(onehot.reshape(self.probs_3d.size()))
def plot_interp(self, dataset, itr): x = dataset.sample(2)[0] x1 = np_to_var(np.expand_dims(x[0, ...], 0)) x2 = np_to_var(np.expand_dims(x[1, ...], 0)) l1 = get_numpy(self.encode(x1).sample(deterministic=True)) l2 = get_numpy(self.encode(x2).sample(deterministic=True)) num_interp = 7 latents = np.zeros((num_interp, self.latent_dim)) for i in range(self.latent_dim): latents[:, i] = np.interp(np.linspace(0, 1, num_interp), [0, 1], [l1[0, i], l2[0, i]]) traj = dataset.unnormalize(get_numpy(self.decode(x1.repeat(num_interp, 1).data, np_to_var(latents)).mle)) traj_sets = [traj[i, ...] for i in range(num_interp)] traj_names = range(num_interp) dataset.plot_pd_compare(traj_sets, traj_names, itr, save_dir='interp')
def rollout(self, max_path_length, add_input=None, reset_args=None, volatile=False): sd = dict(obs=[], rewards=[], actions=[], action_dist_lst=[], states=[]) obs = self.reset_envs(reset_args) self.policy.reset(obs.shape[0]) for s in range(max_path_length): state = self.policy.get_state().data if self.policy.recurrent() else None if self.ego: obs_ego = obs.copy() obs_ego[:, self.egoidx] -= reset_args[:, self.egoidx] policy_input = Variable(from_numpy(obs_ego).float(), volatile=volatile) else: policy_input = Variable(from_numpy(obs).float(), volatile=volatile) if add_input is not None: policy_input = torch.cat([policy_input, add_input], -1) action_dist = self.policy.forward(policy_input) action = action_dist.sample() if self.random_action_p > 0: flip = np.random.binomial(1, self.random_action_p, size=len(obs)) if flip.sum() > 0: random_act = np.random.randint(0, self.policy.output_dim, size=flip.sum()) action[from_numpy(flip).byte()] = from_numpy(random_act) next_obs, rewards, done, info = self.step_envs(get_numpy(action)) #env_step = self.step_envs(get_numpy(action)) #next_obs = [x[0] for x in env_step] sd['obs'].append(obs) sd['rewards'].append(rewards) sd['actions'].append(action) sd['action_dist_lst'].append(action_dist) sd['states'].append(state) obs = next_obs # Append last obs sd['obs'].append(obs) sd['obs'] = np.stack(sd['obs'], 1) # (bs, max_path_length, obs_dim) #import pdb; pdb.set_trace() sd['states'] = torch.stack(sd['states'], 2) if self.policy.recurrent() else None sd['rewards'] = np.stack(sd['rewards'], 1) # (bs, max_path_length) sd['actions'] = torch.stack(sd['actions'], 1) sd['action_dist'] = sd['action_dist_lst'][0].combine(sd['action_dist_lst'], torch.stack, axis=1) return sd
def plot_compare(self, dataset, itr, save_dir='trajs'): x = FloatTensor(dataset.sample_hard(5)[0]) x, actdata = self.splitobs(x) target = x[:, self.step_dim:] y_dist = self.decode(x, self.encode(Variable(x)).sample()) traj_sets = [dataset.unnormalize(get_numpy(traj_set)) for traj_set in [target, y_dist.mle]] traj_names = ['expert', 'sd'] plot_traj_sets([dataset.process(traj_set) for traj_set in traj_sets], traj_names, itr, env_id=dataset.env_id) for traj_no in range(5): dataset.plot_pd_compare([x[traj_no, ...] for x in traj_sets], traj_names, itr, name='Full_State_%d' % traj_no, save_dir=save_dir)
def fit(self, obs_np, returns_np): self.network.apply(xavier_init) bs, path_len, obs_dim = obs_np.shape obs = from_numpy(obs_np.reshape(-1, obs_dim).astype(np.float32)) returns = from_numpy(returns_np.reshape(-1).astype(np.float32)) dataloader = DataLoader(TensorDataset(obs, returns), batch_size=self.batch_size, shuffle=True) for epoch in range(self.max_epochs): for x, y in dataloader: self.optimizer.zero_grad() x = Variable(x) y = Variable(y).float().view(-1, 1) loss = (self.network(x) - y).pow(2).mean() loss.backward() self.optimizer.step() print('loss %f' % get_numpy(loss).item())
def rollout_meta(self, latents, cur_obs, reward_fn, rstate): nbatch = latents.shape[1] state = cur_obs #np.array([cur_obs] * nbatch) trajs = [] for lat in latents: latent_v = np_to_var(lat) state_v = from_numpy(state).float() sd_traj = self.vae.decode(state_v, latent_v) self.vae.decoder.zero_grad() decoded_traj = get_numpy(sd_traj.mle).reshape( (nbatch, -1, cur_obs.shape[1])) state = decoded_traj[:, -1] trajs.append(decoded_traj) combo_traj = np.concatenate(trajs, axis=1) rewards, rstate = self.eval_rewards(combo_traj, reward_fn, rstate, discount=True) return rewards, combo_traj
def train_epoch(self, dataset, epoch=0, train=True, max_steps=1e99): full_stats = dict([('MSE',0), ('Total Loss', 0), ('LL', 0), ('KL Loss', 0), ('BC Loss', 0)]) n_batch = 0 self.optimizer.zero_grad() for loss, stats in self.loss_generator(dataset): if train: loss.backward() self.optimizer.step() for k in stats.keys(): full_stats[k] += get_numpy(stats[k]).item() n_batch += 1 if n_batch >= max_steps: break self.optimizer.zero_grad() for k in full_stats.keys(): full_stats[k] /= n_batch return full_stats
def train_vae_joint(self, dataset, other_dataset, test_dataset, outer_itr, itr): all_stats = {} stat_list = [(' T', defaultdict(list)), (' N', defaultdict(list))] data_gen = [self.vae.loss_generator(dataset) ] #, self.vae.loss_generator(other_dataset)] for i in range(self.vae_train_steps): losses = [] self.vae.optimizer.zero_grad() for (_, stats), gen in zip(stat_list, data_gen): loss, stat_var = next(gen, (None, None)) if loss is not None: losses.append(loss) for k, v in stat_var.items(): stats[k].append(get_numpy(v).item()) if len(losses) > 0: total_loss = sum(losses) / len(losses) total_loss.backward() self.vae.optimizer.step() for _, stats in stat_list: for k, v in stats.items(): stats[k] = np.mean(v) if test_dataset is not None: test = self.vae.train_epoch(test_dataset, itr, train=False, max_steps=self.vae_train_steps // 10) stat_list.append((' V', test)) stat_list.append( (' PD', self.train_pd_match_sd(dataset, 20, outer_itr, outer_itr))) for prefix, stat in stat_list: for k, v in stat.items(): all_stats[k + prefix] = v return all_stats
def forward(self, z, initial_input=None): # z is (bs, latent_dim) bs = z.size()[0] self.recurrent_network.init_hidden(bs) if initial_input is None: initial_input = self.init_input(bs) z = z.unsqueeze(0) # (1, bs, latent_dim) x = initial_input.unsqueeze(0) # (1, bs, sum(cat_sizes)) probs, argmaxs = [], [] for s in range(self.path_len): x = torch.cat([x, z], -1) prob, argmax = self.step(x) probs.append(prob) argmaxs.append(argmax) x = prob.unsqueeze(0) probs = torch.stack(probs, 1) # (bs, path_len, sum(cat_sizes)) argmaxs = torch.stack(argmaxs, 1) # (bs, path_len, len(cat_sizes)) onehot = np_to_var(np.eye(self.output_dim)[get_numpy(argmaxs)]) dist = RecurrentCategorical(probs, self.path_len, onehot) return dist
def rollout(policy, env, max_path_length, add_input=None, plot=False): obs = env.reset() sd = dict(obs=[], rewards=[], actions=[], action_dist_lst=[]) for s in range(max_path_length): if add_input is not None: policy_input = torch.cat([np_to_var(obs), add_input], -1).view(1, -1) else: policy_input = np_to_var(obs).unsqueeze(0) action_dist = policy.forward(policy_input) action = action_dist.sample() next_obs, reward, done, info = env.step(get_numpy(action.squeeze())) sd['obs'].append(obs) sd['rewards'].append(reward) sd['actions'].append(action) sd['action_dist_lst'].append(action_dist) obs = next_obs if plot: env.render() return sd
def forward(self, z, initial_input=None): # z is (bs, latent_dim) # Initial input is initial obs (bs, obs_dim) bs = z.size()[0] self.recurrent_network.init_hidden(bs) if initial_input is None: initial_input = self.init_input(bs) z = z.unsqueeze(0) # (1, bs, latent_dim) initial_input = initial_input.unsqueeze(0) # (1, bs, obs_dim) means, log_vars, probs, onehots = [], [], [], [] x = initial_input for s in range(self.path_len): x = torch.cat([x, z], -1) mean, log_var, prob = self.step(x) onehot = np.zeros(prob.size()[1:]) onehot[np.arange(0, bs), get_numpy(torch.max(prob.squeeze(0), -1)[1]).astype(np.int32)] = 1 onehot = np_to_var(onehot).unsqueeze(0) x = torch.cat([mean, onehot], -1) #x = Variable(torch.randn(mean.size())) * torch.exp(log_var) + mean means.append(mean.squeeze(dim=0)) log_vars.append(log_var.squeeze(dim=0)) probs.append(prob.squeeze(dim=0)) onehots.append(onehot.squeeze(dim=0)) means = torch.stack(means, 1).view(bs, -1) log_vars = torch.stack(log_vars, 1).view(bs, -1) probs = torch.stack(probs, 1).view(bs, -1) onehots = torch.stack(onehots, 1).view(bs, -1) gauss_dist = Normal(means, log_var=log_vars) cat_dist = RecurrentCategorical(probs, self.path_len, onehots) return Mixed(gauss_dist, cat_dist, self.path_len)
def train_explorer(self, dataset, test_dataset, dummy_dataset, itr): bs = self.batch_size # load fixed initial state and goals from config init_state = self.block_config[0] goals = np.array(self.block_config[1]) # functions for computing the reward and initializing the reward state (rstate) # rstate is used to keep track of things such as which goal you are currently on reward_fn, init_rstate = self.reward_fn # total actual reward collected by MPC agent so far total_mpc_rew = np.zeros(self.mpc_batch) # keep track of states visited by MPC to initialize the explorer from all_inits = [] # current state of mpc batche cur_state = np.array([init_state] * self.mpc_batch) # initialize the reward state for the mpc batch rstate = init_rstate(self.mpc_batch) # for visualization purposes mpc_preds = [] mpc_actual = [] mpc_span = [] rstates = [] # Perform MPC over max_horizon for T in range(self.max_horizon): print(T) # for goal visulization rstates.append(rstate) # rollout imaginary trajectories using state decoder rollouts = self.mpc(cur_state, min(self.plan_horizon, self.max_horizon - T), self.mpc_explore, self.mpc_explore_batch, reward_fn, rstate) # get first latent of best trajectory for each batch np_latents = rollouts[2][:, 0] # rollout the first latent in simulator mpc_traj = self.sampler_mpc.obtain_samples(self.mpc_batch * self.max_path_length, self.max_path_length, np_to_var(np_latents), reset_args=cur_state) # update reward and reward state based on trajectory from simulator mpc_rew, rstate = self.eval_rewards(mpc_traj['obs'], reward_fn, rstate) # for logging and visualization purposes futures = rollouts[0] + total_mpc_rew total_mpc_rew += mpc_rew mpc_preds.append(rollouts[1][0]) mpc_span.append(rollouts[3]) mpc_stats = { 'mean futures': np.mean(futures), 'std futures': np.std(futures), 'mean actual': np.mean(total_mpc_rew), 'std actual': np.std(total_mpc_rew), } mpc_actual.append(mpc_traj['obs'][0]) with logger.prefix('itr #%d mpc step #%d | ' % (itr, T)): self.vae.print_diagnostics(mpc_stats) record_tabular(mpc_stats, 'mpc_stats.csv') # add current state to list of states explorer can initialize from all_inits.append(cur_state) # update current state to current state of simulator cur_state = mpc_traj['obs'][:, -1] # for visualization for idx, (actual, pred, rs, span) in enumerate( zip(mpc_actual, mpc_preds, rstates, mpc_span)): dataset.plot_pd_compare( [actual, pred, span[:100], span[:100, :dataset.path_len]], ['actual', 'pred', 'imagined', 'singlestep'], itr, save_dir='mpc_match', name='Pred' + str(idx), goals=goals, goalidx=rs[0]) # compute reward at final state, for some tasks that care about final state reward final_reward, _ = reward_fn(cur_state, rstate) print(total_mpc_rew) print(final_reward) # randomly select states for explorer to explore start_states = np.concatenate(all_inits, axis=0) start_states = start_states[np.random.choice( start_states.shape[0], self.rand_per_mpc_step, replace=self.rand_per_mpc_step > start_states.shape[0])] # run the explorer from those states explore_len = ((self.max_path_length + 1) * self.mpc_explore_len) - 1 self.policy_ex_algo.max_path_length = explore_len ex_trajs = self.sampler_ex.obtain_samples(start_states.shape[0] * explore_len, explore_len, None, reset_args=start_states) # Now concat actions taken by explorer with observations for adding to the dataset trajs = ex_trajs['obs'] obs = trajs[:, -1] if hasattr(self.action_space, 'shape') and len(self.action_space.shape) > 0: acts = get_numpy(ex_trajs['actions']) else: # convert discrete actions into onehot act_idx = get_numpy(ex_trajs['actions']) acts = np.zeros( (trajs.shape[0], trajs.shape[1] - 1, dataset.action_dim)) acts_reshape = acts.reshape((-1, dataset.action_dim)) acts_reshape[range(acts_reshape.shape[0]), act_idx.reshape(-1)] = 1.0 # concat actions with obs acts = np.concatenate((acts, acts[:, -1:, :]), 1) trajacts = np.concatenate((ex_trajs['obs'], acts), axis=-1) trajacts = trajacts.reshape( (-1, self.max_path_length + 1, trajacts.shape[-1])) # compute train/val split ntrain = min(int(0.9 * trajacts.shape[0]), dataset.buffer_size // self.add_frac) if dataset.n < dataset.batch_size and ntrain < dataset.batch_size: ntrain = dataset.batch_size nvalid = min(trajacts.shape[0] - ntrain, test_dataset.buffer_size // self.add_frac) if test_dataset.n < test_dataset.batch_size and nvalid < test_dataset.batch_size: nvalid = test_dataset.batch_size print("Adding ", ntrain, ", Valid: ", nvalid) dataset.add_samples(trajacts[:ntrain].reshape((ntrain, -1))) test_dataset.add_samples(trajacts[-nvalid:].reshape((nvalid, -1))) # dummy dataset stores only data from this iteration dummy_dataset.clear() dummy_dataset.add_samples(trajacts[:-nvalid].reshape( (trajacts.shape[0] - nvalid, -1))) # compute negative ELBO on trajectories of explorer neg_elbos = [] cur_batch = from_numpy(trajacts).float() for i in range(0, trajacts.shape[0], self.batch_size): mse, neg_ll, kl, bcloss, z_dist = self.vae.forward_batch( cur_batch[i:i + self.batch_size]) neg_elbo = (get_numpy(neg_ll) + get_numpy(kl)) neg_elbos.append(neg_elbo) # reward the explorer rewards = np.zeros_like(ex_trajs['rewards']) neg_elbos = np.concatenate(neg_elbos, axis=0) neg_elbos = neg_elbos.reshape((rewards.shape[0], -1)) # just not on the first iteration, since VAE hasnt fitted yet if itr != 1: rewidx = list( range(self.max_path_length, explore_len, self.max_path_length + 1)) + [explore_len - 1] for i in range(rewards.shape[0]): rewards[i, rewidx] = neg_elbos[i] # add in true reward to explorer if desired if self.true_reward_scale != 0: rstate = init_rstate(rewards.shape[0]) for oidx in range(rewards.shape[1]): r, rstate = reward_fn(ex_trajs['obs'][:, oidx], rstate) rewards[:, oidx] += r * self.true_reward_scale ex_trajs['rewards'] = rewards # train explorer using PPO with neg elbo self.policy_ex_algo.process_samples( 0, ex_trajs) #, augment_obs=get_numpy(z)) if itr != 1: self.policy_ex_algo.optimize_policy(0, ex_trajs) ex_trajs['stats']['MPC Actual'] = np.mean(total_mpc_rew) ex_trajs['stats']['Final Reward'] = np.mean(final_reward) # reset explorer if necessary if ex_trajs['stats']['Entropy'] < self.reset_ent: if hasattr(self.policy_ex, "prob_network"): self.policy_ex.prob_network.apply(xavier_init) else: self.policy_ex.apply(xavier_init) self.policy_ex.log_var_network.params_var.data = self.policy_ex.log_var_network.param_init # for visualization purposes colors = ['purple', 'magenta', 'green', 'black', 'yellow', 'black'] fig, ax = plt.subplots(3, 2, figsize=(10, 10)) for i in range(6): if i * 2 + 1 < obs.shape[1]: axx = ax[i // 2][i % 2] if i == 5: axx.scatter(obs[:, -3], obs[:, -2], color=colors[i], s=10) else: axx.scatter(obs[:, i * 2], obs[:, i * 2 + 1], color=colors[i], s=10) axx.set_xlim(-3, 3) axx.set_ylim(-3, 3) path = logger.get_snapshot_dir() + '/final_dist' if not os.path.exists(path): os.makedirs(path) plt.savefig('%s/%d.png' % (path, itr)) np.save(path + "/" + str(itr), obs) return ex_trajs['stats']
def optimize_policy(self, itr, samples_data, add_input_fn=None, add_input_input=None, add_loss_fn=None, print=True): advantages = from_numpy(samples_data['discount_adv'].astype( np.float32)) n_traj = samples_data['obs'].shape[0] n_obs = n_traj * self.max_path_length #add_input_obs = from_numpy(samples_data['obs'][:, :, :self.obs_dim].astype(np.float32)).view(n_traj, -1) if add_input_fn is not None: obs = from_numpy(samples_data['obs'] [:, :self.max_path_length, :self.obs_dim].astype( np.float32)).view(n_obs, -1) else: obs = from_numpy( samples_data['obs'][:, :self.max_path_length, :].astype( np.float32)).view(n_obs, -1) #obs = from_numpy(samples_data['obs'][:, :self.max_path_length, :].astype(np.float32)).view(n_obs, -1) actions = samples_data['actions'].view(n_obs, -1).data returns = from_numpy(samples_data['discount_returns'].copy()).view( -1, 1).float() old_action_log_probs = samples_data['log_prob'].view(n_obs, -1).data states = samples_data['states'].view( samples_data['states'].size()[0], n_obs, -1) if self.policy.recurrent() else None for epoch_itr in range(self.epoch): sampler = BatchSampler(SubsetRandomSampler(range(n_obs)), self.ppo_batch_size, drop_last=False) for indices in sampler: indices = LongTensor(indices) obs_batch = Variable(obs[indices]) actions_batch = actions[indices] return_batch = returns[indices] old_action_log_probs_batch = old_action_log_probs[indices] if states is not None: self.policy.set_state(Variable(states[:, indices])) if add_input_fn is not None: add_input_dist = add_input_fn(Variable(add_input_input)) add_input = add_input_dist.sample() add_input_rep = torch.unsqueeze(add_input, 1).repeat( 1, self.max_path_length, 1).view(n_obs, -1) #add_input_batch = add_input[indices/add_input.size()[0]] add_input_batch = add_input_rep[indices] obs_batch = torch.cat([obs_batch, add_input_batch], -1) values = self.baseline.forward(obs_batch.detach()) action_dist = self.policy.forward(obs_batch) action_log_probs = action_dist.log_likelihood( Variable(actions_batch)).unsqueeze(-1) dist_entropy = action_dist.entropy().mean() ratio = torch.exp(action_log_probs - Variable(old_action_log_probs_batch)) adv_targ = Variable(advantages.view(-1, 1)[indices]) surr1 = ratio * adv_targ surr2 = torch.clamp(ratio, 1.0 - self.clip_param, 1.0 + self.clip_param) * adv_targ action_loss = -torch.min( surr1, surr2).mean() # PPO's pessimistic surrogate (L^CLIP) value_loss = (Variable(return_batch) - values).pow(2).mean() self.optimizer.zero_grad() total_loss = (value_loss + action_loss - dist_entropy * self.entropy_bonus) if add_loss_fn is not None: total_loss += add_loss_fn(add_input_dist, add_input, add_input_input) total_loss.backward() self.optimizer.step() if print: stats = { 'total loss': get_numpy(total_loss)[0], 'action loss': get_numpy(action_loss)[0], 'value loss': get_numpy(value_loss)[0], 'entropy': get_numpy(dist_entropy)[0] } with logger.prefix('Train PPO itr %d epoch itr %d | ' % (itr, epoch_itr)): self.print_diagnostics(stats) return total_loss
def optimize_policy(self, itr, samples_data): try_penalty = float( np.clip(self._penalty, self._min_penalty, self._max_penalty)) penalty_scale_factor = None def gen_f_opt(penalty): def f(flat_params): self.policy.set_params_flat(from_numpy(flat_params)) return self.get_opt_output(samples_data, penalty) return f cur_params = get_numpy(self.policy.get_params_flat().double()) opt_params = cur_params # Save views of objs for efficiency samples_data['obs_flat_var'] = np_to_var(samples_data['obs_flat']) samples_data['action_dist_flat'] = samples_data['action_dist'].detach( ).reshape((-1, samples_data['action_dist'].dim)) samples_data['actions_flat'] = samples_data['actions'].view( -1, self.action_dim) samples_data['discount_adv_var'] = np_to_var( samples_data['discount_adv']) for penalty_itr in range(self._max_penalty_itr): logger.log('trying penalty=%.3f...' % try_penalty) itr_opt_params, _, _ = scipy.optimize.fmin_l_bfgs_b( func=gen_f_opt(try_penalty), x0=cur_params, maxiter=self._max_opt_itr) _, try_loss, try_constraint_val = self.compute_loss_terms( samples_data, try_penalty) try_loss = get_numpy(try_loss)[0] try_constraint_val = get_numpy(try_constraint_val)[0] logger.log('penalty %f => loss %f, %s %f' % (try_penalty, try_loss, self._constraint_name, try_constraint_val)) if try_constraint_val < self._max_constraint_val or \ (penalty_itr == self._max_penalty_itr - 1 and opt_params is None): opt_params = itr_opt_params if not self._adapt_penalty: break # Decide scale factor on the first iteration, or if constraint violation yields numerical error if penalty_scale_factor is None or np.isnan(try_constraint_val): # Increase penalty if constraint violated, or if constraint term is NAN if try_constraint_val > self._max_constraint_val or np.isnan( try_constraint_val): penalty_scale_factor = self._increase_penalty_factor else: # Otherwise (i.e. constraint satisfied), shrink penalty penalty_scale_factor = self._decrease_penalty_factor opt_params = itr_opt_params else: if penalty_scale_factor > 1 and \ try_constraint_val <= self._max_constraint_val: break elif penalty_scale_factor < 1 and \ try_constraint_val >= self._max_constraint_val: break try_penalty *= penalty_scale_factor try_penalty = float( np.clip(try_penalty, self._min_penalty, self._max_penalty)) self._penalty = try_penalty self.policy.set_params_flat(from_numpy(opt_params))
def compute_mle(self): onehot = np.zeros(self.prob.size()) onehot[np.arange(0, self.bs), get_numpy(torch.max(self.prob, -1)[1]).astype(np.int32)] = 1 return np_to_var(onehot)