Пример #1
0
    def __init__(
        self,
        venv: vec_env.VecEnv,
        expert_data: Union[Iterable[Mapping], types.Transitions],
        expert_batch_size: int,
        gen_algo: on_policy_algorithm.OnPolicyAlgorithm,
        *,
        # FIXME(sam) pass in discrim net directly; don't ask for kwargs indirectly
        discrim_kwargs: Optional[Mapping] = None,
        **kwargs,
    ):
        """Generative Adversarial Imitation Learning.

        Most parameters are described in and passed to `AdversarialTrainer.__init__`.
        Additional parameters that `GAIL` adds on top of its superclass initializer are
        as follows:

        Args:
            discrim_kwargs: Optional keyword arguments to use while constructing the
                DiscrimNetGAIL.

        """
        discrim_kwargs = discrim_kwargs or {}
        discrim = discrim_nets.DiscrimNetGAIL(
            venv.observation_space, venv.action_space, **discrim_kwargs
        )
        super().__init__(
            venv, gen_algo, discrim, expert_data, expert_batch_size, **kwargs
        )
Пример #2
0
    def __init__(
        self,
        venv: vec_env.VecEnv,
        expert_data: Union[Iterable[Mapping], types.Transitions],
        expert_batch_size: int,
        gen_algo: on_policy_algorithm.OnPolicyAlgorithm,
        *,
        # FIXME(sam) pass in discrim net directly; don't ask for kwargs indirectly
        discrim_kwargs: Optional[Mapping] = None,
        **kwargs,
    ):
        """Generative Adversarial Imitation Learning.

        Most parameters are described in and passed to `AdversarialTrainer.__init__`.
        Additional parameters that `GAIL` adds on top of its superclass initializer are
        as follows:

        Args:
            discrim_kwargs: Optional keyword arguments to use while constructing the
                DiscrimNetGAIL.

        """
        if discrim_kwargs.get("discriminator_hid_sizes", {}):
            print(discrim_kwargs.get("discriminator_hid_sizes", {}))
            discrim_net = ActObsMLP(
                action_space=venv.action_space,
                observation_space=venv.observation_space,
                hid_sizes=discrim_kwargs.get("discriminator_hid_sizes", {}),
            )
            discrim = discrim_nets.DiscrimNetGAIL(venv.observation_space,
                                                  venv.action_space,
                                                  discrim_net=discrim_net)
            #discrim=th.load('/home/leonor/Desktop/HIRL_dissertation/hierarchical/output/train_adversarial/GripperReachOpenUR5Sim-v0/20210529_163416_e39a59_servidor_100_demos/checkpoints/final/discrim.pt', map_location=th.device('cpu'))
        else:
            discrim_kwargs = discrim_kwargs or {}
            discrim = discrim_nets.DiscrimNetGAIL(venv.observation_space,
                                                  venv.action_space,
                                                  **discrim_kwargs)
        super().__init__(venv, gen_algo, discrim, expert_data,
                         expert_batch_size, **kwargs)
Пример #3
0
    def __init__(
        self,
        venv: vec_env.VecEnv,
        expert_data: Union[Iterable[Mapping], types.Transitions],
        expert_batch_size: int,
        gen_algo: on_policy_algorithm.OnPolicyAlgorithm,
        *,
        # FIXME(sam) pass in discrim net directly; don't ask for kwargs indirectly
        discrim_kwargs: Optional[Mapping] = None,
        policy,
        _init_setup_model,
        **kwargs,
    ):
        """Generative Adversarial Imitation Learning.

        Most parameters are described in and passed to `AdversarialTrainer.__init__`.
        Additional parameters that `GAIL` adds on top of its superclass initializer are
        as follows:

        Args:
            discrim_kwargs: Optional keyword arguments to use while constructing the
                DiscrimNetGAIL.

        """
        env_name = 'ReachObjectUR5Sim-v0'
        num_vec = 1
        max_episode_steps = 1000
        venv = util.make_vec_env(
            env_name,
            num_vec,
            #seed=_seed,
            #parallel=parallel,
            #log_dir=log_dir,
            max_episode_steps=max_episode_steps,
        )
        discrim_kwargs = discrim_kwargs or {}
        discrim = discrim_nets.DiscrimNetGAIL(venv.observation_space,
                                              venv.action_space,
                                              **discrim_kwargs)
        super().__init__(
            venv=venv,
            gen_algo=gen_algo,
            discrim=discrim,
            expert_data=expert_data,
            expert_batch_size=expert_batch_size,  #**kwargs
        )
        self.verbose = False
        self.tensorboard_log = None
        self.use_sde = False
Пример #4
0
def _setup_gail_provide_discriminator(venv):
    discriminator = discrim_nets.ActObsMLP(venv.action_space,
                                           venv.observation_space,
                                           hid_sizes=(4, 4, 4))
    return discrim_nets.DiscrimNetGAIL(venv.observation_space,
                                       venv.action_space, discriminator)
Пример #5
0
def _setup_gail(venv):
    return discrim_nets.DiscrimNetGAIL(venv.observation_space,
                                       venv.action_space)