def init(self, policy, args): if args.bco_alpha != 0: # Adjust the amount of online experience based on args. args.num_env_steps = args.bco_alpha_size * args.bco_alpha args.num_steps = args.bco_alpha_size // args.num_processes print(f"Adjusted # steps to {args.num_steps}") print(f"Adjusted # env interactions to {args.num_env_steps}") super().init(policy, args) if self._arg('lr_env_steps') is None and args.bco_alpha != 0: # Adjust the learning rate decay steps based on the number of BC # updates performed. bc_updates = super().get_num_updates() bco_full_updates = self.get_num_updates() + 1 # We perform a full BC update for each "BCO update". "BCO updates" # come from the initial training and for each update from online # experience determined according to alpha. self.lr_updates = bc_updates * bco_full_updates print(f"Adjusted # lr updates to {self.lr_updates}") get_state_enc = partial(self.policy.get_base_net_fn, rutils.get_obs_shape(self.policy.obs_space)) self.inv_func = InvFunc(get_state_enc, rutils.get_obs_shape(self.policy.obs_space), rutils.get_ac_dim(self.policy.action_space)) self.inv_func = self.inv_func.to(self.args.device) self.inv_opt = optim.Adam(self.inv_func.parameters(), lr=self.args.bco_inv_lr)
def init(self, obs_space, action_space, args): super().init(obs_space, action_space, args) obs_shape = rutils.get_obs_shape(obs_space, args.policy_ob_key) self.critic = self.get_critic_fn(obs_shape, self._get_base_out_shape(), action_space) log_std_bounds = [float(x) for x in self.args.log_std_bounds.split(',')] self.actor = self.get_actor_fn(rutils.get_obs_shape(obs_space, args.policy_ob_key), self._get_base_out_shape(), action_space, log_std_bounds) self.ac_low_bound = torch.tensor(self.action_space.low).to(args.device) self.ac_high_bound = torch.tensor(self.action_space.high).to(args.device)
def get_env_settings(self, args): settings = super().get_env_settings(args) settings.include_info_keys.extend([ ('final_obs', lambda env: rutils.get_obs_shape(env.observation_space)) ]) return settings
def _get_obs(self, obs_dict): obs = obs_dict['image'] obs = obs.reshape(-1, 3) obs = np.array(list(map(lambda x: NODE_TO_ONE_HOT[tuple(x)], obs))) obs = obs.reshape(*rutils.get_obs_shape(self.observation_space)) return obs
def init(self, obs_space, action_space, args): super().init(obs_space, action_space, args) self.actor = self.get_actor_fn( rutils.get_obs_shape(obs_space, args.policy_ob_key), self._get_base_out_shape()) self.dist = self.get_dist_fn(self.actor.output_shape, self.action_space)
def init(self, obs_space, action_space, args): super().init(obs_space, action_space, args) obs_shape = rutils.get_obs_shape(obs_space, args.policy_ob_key) self.critic = self.get_critic_fn(obs_shape, self._get_base_out_shape(), action_space) self.critic_head = self.get_critic_head_fn(self.critic.output_shape[0])
def _create_discrim(self): ob_shape = rutils.get_obs_shape(self.policy.obs_space) ac_dim = rutils.get_ac_dim(self.action_space) base_net = self.policy.get_base_net_fn(ob_shape) discrim, dhidden_dim = self.get_discrim() discrim_head = InjectNet(base_net.net, discrim, base_net.output_shape[0], dhidden_dim, ac_dim, self.args.action_input) return discrim_head.to(self.args.device)
def init(self, obs_space, action_space, args): super().init(obs_space, action_space, args) if 'recurrent' in inspect.getargspec(self.get_base_net_fn).args: self.get_base_net_fn = partial(self.get_base_net_fn, recurrent=self.args.recurrent_policy) if self.use_goal: use_obs_shape = rutils.get_obs_shape(obs_space, args.policy_ob_key) if len(use_obs_shape) != 1: raise ValueError(('Goal conditioning only ', 'works with flat state representation')) use_obs_shape = (use_obs_shape[0] + obs_space['desired_goal'].shape[0],) else: use_obs_shape = rutils.get_obs_shape(obs_space, args.policy_ob_key) self.base_net = self.get_base_net_fn(use_obs_shape) base_out_dim = self.base_net.output_shape[0] for k in self.fuse_states: if len(obs_space.spaces[k].shape) != 1: raise ValueError('Can only fuse 1D states') base_out_dim += obs_space.spaces[k].shape[0] self.base_out_shape = (base_out_dim,)
def _get_sampler(self, storage): obs = storage.get_def_obs_seq() ob_shape = rutils.get_obs_shape(self.policy.obs_space) self.agent_obs_pairs = { 'state': obs[:-1].view(-1, *ob_shape), 'next_state': obs[1:].view(-1, *ob_shape), 'mask': storage.masks[:-1].view(-1, 1), } failure_sampler = BatchSampler(SubsetRandomSampler( range(self.args.num_steps)), self.args.traj_batch_size, drop_last=True) return self.expert_train_loader, failure_sampler
def _create_discrim(self): new_shape = list(rutils.get_obs_shape(self.policy.obs_space)) new_shape[0] *= 2 base_net = self.policy.get_base_net_fn(new_shape) return DoubleStateDiscrim(base_net).to(self.args.device)
def _infer_inv_accuracy(self, trans_sampler, dataset): total_count = 0 num_correct = 0 with torch.no_grad(): for trans_idx in trans_sampler: use_state_0, use_state_1, true_action = select_batch(trans_idx, dataset, self.args.device, rutils.get_obs_shape(self.policy.obs_space)) pred_action = self.inv_func(use_state_0, use_state_1) pred_class = torch.argmax(pred_action, dim=-1) num_correct += (pred_class == true_action.view(-1)).float().sum() total_count += float(use_state_0.shape[0]) return 100.0 * (num_correct / total_count)
def _train_inv_func(self, trans_sampler, dataset): infer_ac_losses = [] for i in tqdm(range(self.args.bco_inv_epochs)): for trans_idx in trans_sampler: use_state_0, use_state_1, true_action = select_batch(trans_idx, dataset, self.args.device, rutils.get_obs_shape(self.policy.obs_space)) pred_action = self.inv_func(use_state_0, use_state_1) loss = autils.compute_ac_loss(pred_action, true_action, self.policy.action_space) infer_ac_losses.append(loss.item()) self.inv_opt.zero_grad() loss.backward() self.inv_opt.step() return infer_ac_losses