Example #1
0
    def init(self, policy, args):
        if args.bco_alpha != 0:
            # Adjust the amount of online experience based on args.
            args.num_env_steps = args.bco_alpha_size * args.bco_alpha
            args.num_steps = args.bco_alpha_size // args.num_processes
            print(f"Adjusted # steps to {args.num_steps}")
            print(f"Adjusted # env interactions to {args.num_env_steps}")

        super().init(policy, args)

        if self._arg('lr_env_steps') is None and args.bco_alpha != 0:
            # Adjust the learning rate decay steps based on the number of BC
            # updates performed.
            bc_updates = super().get_num_updates()
            bco_full_updates = self.get_num_updates() + 1
            # We perform a full BC update for each "BCO update". "BCO updates"
            # come from the initial training and for each update from online
            # experience determined according to alpha.
            self.lr_updates = bc_updates * bco_full_updates
            print(f"Adjusted # lr updates to {self.lr_updates}")

        get_state_enc = partial(self.policy.get_base_net_fn,
                rutils.get_obs_shape(self.policy.obs_space))
        self.inv_func = InvFunc(get_state_enc,
                                rutils.get_obs_shape(self.policy.obs_space),
                                rutils.get_ac_dim(self.policy.action_space))
        self.inv_func = self.inv_func.to(self.args.device)
        self.inv_opt = optim.Adam(self.inv_func.parameters(),
                                  lr=self.args.bco_inv_lr)
Example #2
0
    def init(self, obs_space, action_space, args):
        super().init(obs_space, action_space, args)

        obs_shape = rutils.get_obs_shape(obs_space, args.policy_ob_key)

        self.critic = self.get_critic_fn(obs_shape, self._get_base_out_shape(),
                action_space)

        log_std_bounds = [float(x) for x in self.args.log_std_bounds.split(',')]

        self.actor = self.get_actor_fn(rutils.get_obs_shape(obs_space,
            args.policy_ob_key), self._get_base_out_shape(), action_space,
            log_std_bounds)

        self.ac_low_bound = torch.tensor(self.action_space.low).to(args.device)
        self.ac_high_bound = torch.tensor(self.action_space.high).to(args.device)
Example #3
0
 def get_env_settings(self, args):
     settings = super().get_env_settings(args)
     settings.include_info_keys.extend([
         ('final_obs',
          lambda env: rutils.get_obs_shape(env.observation_space))
     ])
     return settings
Example #4
0
    def _get_obs(self, obs_dict):
        obs = obs_dict['image']

        obs = obs.reshape(-1, 3)
        obs = np.array(list(map(lambda x: NODE_TO_ONE_HOT[tuple(x)], obs)))
        obs = obs.reshape(*rutils.get_obs_shape(self.observation_space))
        return obs
Example #5
0
 def init(self, obs_space, action_space, args):
     super().init(obs_space, action_space, args)
     self.actor = self.get_actor_fn(
         rutils.get_obs_shape(obs_space, args.policy_ob_key),
         self._get_base_out_shape())
     self.dist = self.get_dist_fn(self.actor.output_shape,
                                  self.action_space)
Example #6
0
    def init(self, obs_space, action_space, args):
        super().init(obs_space, action_space, args)

        obs_shape = rutils.get_obs_shape(obs_space, args.policy_ob_key)

        self.critic = self.get_critic_fn(obs_shape, self._get_base_out_shape(),
                                         action_space)
        self.critic_head = self.get_critic_head_fn(self.critic.output_shape[0])
Example #7
0
    def _create_discrim(self):
        ob_shape = rutils.get_obs_shape(self.policy.obs_space)
        ac_dim = rutils.get_ac_dim(self.action_space)
        base_net = self.policy.get_base_net_fn(ob_shape)
        discrim, dhidden_dim = self.get_discrim()
        discrim_head = InjectNet(base_net.net, discrim,
                                 base_net.output_shape[0], dhidden_dim, ac_dim,
                                 self.args.action_input)

        return discrim_head.to(self.args.device)
Example #8
0
    def init(self, obs_space, action_space, args):
        super().init(obs_space, action_space, args)
        if 'recurrent' in inspect.getargspec(self.get_base_net_fn).args:
            self.get_base_net_fn = partial(self.get_base_net_fn,
                    recurrent=self.args.recurrent_policy)
        if self.use_goal:
            use_obs_shape = rutils.get_obs_shape(obs_space, args.policy_ob_key)
            if len(use_obs_shape) != 1:
                raise ValueError(('Goal conditioning only ',
                    'works with flat state representation'))
            use_obs_shape = (use_obs_shape[0] + obs_space['desired_goal'].shape[0],)
        else:
            use_obs_shape = rutils.get_obs_shape(obs_space, args.policy_ob_key)

        self.base_net = self.get_base_net_fn(use_obs_shape)
        base_out_dim = self.base_net.output_shape[0]
        for k in self.fuse_states:
            if len(obs_space.spaces[k].shape) != 1:
                raise ValueError('Can only fuse 1D states')
            base_out_dim += obs_space.spaces[k].shape[0]
        self.base_out_shape = (base_out_dim,)
Example #9
0
 def _get_sampler(self, storage):
     obs = storage.get_def_obs_seq()
     ob_shape = rutils.get_obs_shape(self.policy.obs_space)
     self.agent_obs_pairs = {
         'state': obs[:-1].view(-1, *ob_shape),
         'next_state': obs[1:].view(-1, *ob_shape),
         'mask': storage.masks[:-1].view(-1, 1),
     }
     failure_sampler = BatchSampler(SubsetRandomSampler(
         range(self.args.num_steps)),
                                    self.args.traj_batch_size,
                                    drop_last=True)
     return self.expert_train_loader, failure_sampler
Example #10
0
 def _create_discrim(self):
     new_shape = list(rutils.get_obs_shape(self.policy.obs_space))
     new_shape[0] *= 2
     base_net = self.policy.get_base_net_fn(new_shape)
     return DoubleStateDiscrim(base_net).to(self.args.device)
Example #11
0
 def _infer_inv_accuracy(self, trans_sampler, dataset):
     total_count = 0
     num_correct = 0
     with torch.no_grad():
         for trans_idx in trans_sampler:
             use_state_0, use_state_1, true_action = select_batch(trans_idx,
                                                                  dataset, self.args.device, rutils.get_obs_shape(self.policy.obs_space))
             pred_action = self.inv_func(use_state_0, use_state_1)
             pred_class = torch.argmax(pred_action, dim=-1)
             num_correct += (pred_class ==
                             true_action.view(-1)).float().sum()
             total_count += float(use_state_0.shape[0])
     return 100.0 * (num_correct / total_count)
Example #12
0
    def _train_inv_func(self, trans_sampler, dataset):
        infer_ac_losses = []
        for i in tqdm(range(self.args.bco_inv_epochs)):
            for trans_idx in trans_sampler:
                use_state_0, use_state_1, true_action = select_batch(trans_idx,
                                                                     dataset, self.args.device, rutils.get_obs_shape(self.policy.obs_space))
                pred_action = self.inv_func(use_state_0, use_state_1)
                loss = autils.compute_ac_loss(pred_action, true_action,
                                              self.policy.action_space)
                infer_ac_losses.append(loss.item())

                self.inv_opt.zero_grad()
                loss.backward()
                self.inv_opt.step()
        return infer_ac_losses