예제 #1
0
def test_endless_iter():
    x = range(2)
    it = util.endless_iter(x)
    assert next(it) == 0
    assert next(it) == 1
    assert next(it) == 0
예제 #2
0
    def __init__(
        self,
        venv: vec_env.VecEnv,
        gen_algo: on_policy_algorithm.OnPolicyAlgorithm,
        discrim: discrim_nets.DiscrimNet,
        expert_data: Union[Iterable[Mapping], types.Transitions],
        expert_batch_size: int,
        n_disc_updates_per_round: int = 2,
        *,
        log_dir: str = "output/",
        normalize_obs: bool = True,
        normalize_reward: bool = True,
        disc_opt_cls: Type[th.optim.Optimizer] = th.optim.Adam,
        disc_opt_kwargs: Optional[Mapping] = None,
        gen_replay_buffer_capacity: Optional[int] = None,
        init_tensorboard: bool = False,
        init_tensorboard_graph: bool = False,
        debug_use_ground_truth: bool = False,
    ):
        """Builds AdversarialTrainer.

        Args:
            venv: The vectorized environment to train in.
            gen_algo: The generator RL algorithm that is trained to maximize
                discriminator confusion. The generator batch size
                `self.gen_batch_size` is inferred from `gen_algo.n_steps`.
            discrim: The discriminator network. This will be moved to the same
                device as `gen_algo`.
            expert_data: Either a `torch.utils.data.DataLoader`-like object or an
                instance of `Transitions` which is automatically converted into a
                shuffled version of the former type.

                If the argument passed is a `DataLoader`, then it must yield batches of
                expert data via its `__iter__` method. Each batch is a dictionary whose
                keys "obs", "acts", "next_obs", and "dones", correspond to Tensor or
                NumPy array values each with batch dimension equal to
                `expert_batch_size`. If any batch dimension doesn't equal
                `expert_batch_size` then a `ValueError` is raised.

                If the argument is a `Transitions` instance, then `len(expert_data)`
                must be at least `expert_batch_size`.
            expert_batch_size: The number of samples in each batch yielded from
                the expert data loader. The discriminator batch size is twice this
                number because each discriminator batch contains a generator sample for
                every expert sample.
            n_discrim_updates_per_round: The number of discriminator updates after each
                round of generator updates in AdversarialTrainer.learn().
            log_dir: Directory to store TensorBoard logs, plots, etc. in.
            normalize_obs: Whether to normalize observations with `VecNormalize`.
            normalize_reward: Whether to normalize rewards with `VecNormalize`.
            disc_opt_cls: The optimizer for discriminator training.
            disc_opt_kwargs: Parameters for discriminator training.
            gen_replay_buffer_capacity: The capacity of the
                generator replay buffer (the number of obs-action-obs samples from
                the generator that can be stored).

                By default this is equal to `self.gen_batch_size`, meaning that we
                sample only from the most recent batch of generator samples.
            init_tensorboard: If True, makes various discriminator
                TensorBoard summaries.
            init_tensorboard_graph: If both this and `init_tensorboard` are True,
                then write a Tensorboard graph summary to disk.
            debug_use_ground_truth: If True, use the ground truth reward for
                `self.train_env`.
                This disables the reward wrapping that would normally replace
                the environment reward with the learned reward. This is useful for
                sanity checking that the policy training is functional.
        """

        assert (
            logger.is_configured()
        ), "Requires call to imitation.util.logger.configure"
        self._global_step = 0
        self._disc_step = 0
        self.n_disc_updates_per_round = n_disc_updates_per_round

        if expert_batch_size <= 0:
            raise ValueError(f"expert_batch_size={expert_batch_size} must be positive.")

        self.expert_batch_size = expert_batch_size
        if isinstance(expert_data, types.Transitions):
            if len(expert_data) < expert_batch_size:
                raise ValueError(
                    "Provided Transitions instance as `expert_data` argument but "
                    "len(expert_data) < expert_batch_size. "
                    f"({len(expert_data)} < {expert_batch_size})."
                )

            self.expert_data_loader = th_data.DataLoader(
                expert_data,
                batch_size=expert_batch_size,
                collate_fn=types.transitions_collate_fn,
                shuffle=True,
                drop_last=True,
            )
        else:
            self.expert_data_loader = expert_data
        self._endless_expert_iterator = util.endless_iter(self.expert_data_loader)

        self.debug_use_ground_truth = debug_use_ground_truth
        self.venv = venv
        self.gen_algo = gen_algo
        self._log_dir = log_dir

        # Create graph for optimising/recording stats on discriminator
        self.discrim = discrim.to(self.gen_algo.device)
        self._disc_opt_cls = disc_opt_cls
        self._disc_opt_kwargs = disc_opt_kwargs or {}
        self._init_tensorboard = init_tensorboard
        self._init_tensorboard_graph = init_tensorboard_graph
        self._disc_opt = self._disc_opt_cls(
            self.discrim.parameters(), **self._disc_opt_kwargs
        )

        if self._init_tensorboard:
            logging.info("building summary directory at " + self._log_dir)
            summary_dir = os.path.join(self._log_dir, "summary")
            os.makedirs(summary_dir, exist_ok=True)
            self._summary_writer = thboard.SummaryWriter(summary_dir)

        self.venv_buffering = wrappers.BufferingWrapper(self.venv)
        self.venv_norm_obs = vec_env.VecNormalize(
            self.venv_buffering,
            norm_reward=False,
            norm_obs=normalize_obs,
        )

        if debug_use_ground_truth:
            # Would use an identity reward fn here, but RewardFns can't see rewards.
            self.venv_wrapped = self.venv_norm_obs
            self.gen_callback = None
        else:
            self.venv_wrapped = reward_wrapper.RewardVecEnvWrapper(
                self.venv_norm_obs, self.discrim.predict_reward_train
            )
            self.gen_callback = self.venv_wrapped.make_log_callback()
        self.venv_train = vec_env.VecNormalize(
            self.venv_wrapped, norm_obs=False, norm_reward=normalize_reward
        )

        self.gen_algo.set_env(self.venv_train)

        if gen_replay_buffer_capacity is None:
            gen_replay_buffer_capacity = self.gen_batch_size
        self._gen_replay_buffer = buffer.ReplayBuffer(
            gen_replay_buffer_capacity, self.venv
        )
예제 #3
0
def test_endless_iter_error():
    x = []
    with pytest.raises(ValueError, match="no elements"):
        util.endless_iter(x)