예제 #1
0
def adversarial_learning(
    venv,
    expert=None,
    expert_venv=None,
    expert_trajectories=None,
    state_only=False,
    policy_fn=get_ppo,
    total_timesteps=20000,
    gen_batch_size=200,
    disc_batch_size=100,
    updates_per_batch=2,
    policy_lr=1e-3,
    reward_lr=1e-3,
    is_airl=True,
    **kwargs,
):
    # Set up generator
    gen_policy = policy_fn(venv, learning_rate=policy_lr)
    policy = gen_policy

    # Set up discriminator
    if is_airl:
        rn = BasicShapedRewardNet(
            venv.observation_space,
            venv.action_space,
            theta_units=[32, 32],
            phi_units=[32, 32],
            scale=True,
            state_only=state_only,
        )
        discrim = DiscrimNetAIRL(rn, entropy_weight=1.0)
    else:
        rn = None
        discrim = DiscrimNetGAIL(venv.observation_space, venv.action_space)

    # Set up optimizer
    train_op = tf.train.AdamOptimizer(learning_rate=reward_lr).minimize(
        tf.reduce_mean(discrim.disc_loss))

    # Set up environment reward
    reward_train = functools.partial(
        discrim.reward_train, gen_log_prob_fn=gen_policy.action_probability)
    venv_train = reward_wrapper.RewardVecEnvWrapper(venv, reward_train)
    venv_train_buffering = BufferingWrapper(venv_train)
    gen_policy.set_env(venv_train_buffering)  # possibly redundant

    # Set up replay buffers
    gen_replay_buffer_capacity = 20 * gen_batch_size
    gen_replay_buffer = buffer.ReplayBuffer(gen_replay_buffer_capacity, venv)

    if expert_trajectories is not None:
        expert_transitions = flatten_trajectories(expert_trajectories)
        exp_replay_buffer = buffer.ReplayBuffer.from_data(expert_transitions)
    else:
        exp_replay_buffer = buffer.ReplayBuffer(gen_replay_buffer_capacity,
                                                venv)

    # Start training
    sess = tf.get_default_session()
    sess.run(tf.global_variables_initializer())

    num_epochs = int(np.ceil(total_timesteps / gen_batch_size))

    for epoch in range(num_epochs):
        # Train gen
        gen_policy.learn(total_timesteps=gen_batch_size,
                         reset_num_timesteps=True)
        gen_replay_buffer.store(venv_train_buffering.pop_transitions())

        if expert_trajectories is None:
            exp_replay_buffer.store(
                flatten_trajectories(
                    sample_trajectories(expert_venv,
                                        expert,
                                        n_timesteps=gen_batch_size)))

        # Train disc
        for _ in range(updates_per_batch):
            disc_minibatch_size = disc_batch_size // updates_per_batch
            half_minibatch = disc_minibatch_size // 2

            gen_samples = gen_replay_buffer.sample(half_minibatch)
            expert_samples = exp_replay_buffer.sample(half_minibatch)

            obs = np.concatenate([gen_samples.obs, expert_samples.obs])
            acts = np.concatenate([gen_samples.acts, expert_samples.acts])
            next_obs = np.concatenate(
                [gen_samples.next_obs, expert_samples.next_obs])
            labels = np.concatenate(
                [np.ones(half_minibatch),
                 np.zeros(half_minibatch)])

            log_act_prob = gen_policy.action_probability(obs,
                                                         actions=acts,
                                                         logp=True)
            log_act_prob = log_act_prob.reshape((disc_minibatch_size, ))

            _, logits_v, loss_v = sess.run(
                [
                    train_op,
                    discrim._disc_logits_gen_is_high,
                    discrim._disc_loss,
                ],
                feed_dict={
                    discrim.obs_ph: obs,
                    discrim.act_ph: acts,
                    discrim.next_obs_ph: next_obs,
                    discrim.labels_gen_is_one_ph: labels,
                    discrim.log_policy_act_prob_ph: log_act_prob,
                },
            )

    results = {}
    results["reward_model"] = rn
    results["discrim"] = discrim
    results["policy"] = gen_policy

    return results
예제 #2
0
    def __init__(
        self,
        venv: VecEnv,
        gen_policy: BaseRLModel,
        discrim: discrim_net.DiscrimNet,
        expert_demos: rollout.Transitions,
        *,
        log_dir: str = 'output/',
        disc_batch_size: int = 2048,
        disc_minibatch_size: int = 256,
        disc_opt_cls: tf.train.Optimizer = tf.train.AdamOptimizer,
        disc_opt_kwargs: dict = {},
        gen_replay_buffer_capacity: Optional[int] = None,
        init_tensorboard: bool = False,
        init_tensorboard_graph: bool = False,
        debug_use_ground_truth: bool = False,
    ):
        """Builds Trainer.

    Args:
        venv: The vectorized environment to train in.
        gen_policy: The generator policy that is trained to maximize
          discriminator confusion. The generator batch size
          `self.gen_batch_size` is inferred from `gen_policy.n_batch`.
        discrim: The discriminator network.
          For GAIL, use a DiscrimNetGAIL. For AIRL, use a DiscrimNetAIRL.
        expert_demos: Transitions from an expert dataset.
        log_dir: Directory to store TensorBoard logs, plots, etc. in.
        disc_batch_size: The default number of expert and generator transitions
          samples to feed to the discriminator in each call to
          `self.train_disc()`. (Half of the samples are expert and half of the
          samples are generator).
        disc_minibatch_size: The discriminator minibatch size. Each
          discriminator batch is split into minibatches and an Adam update is
          applied on the gradient resulting form each minibatch. Must evenly
          divide `disc_batch_size`. Must be an even number.
        disc_opt_cls: The optimizer for discriminator training.
        disc_opt_kwargs: Parameters for discriminator training.
        gen_replay_buffer_capacity: The capacity of the
          generator replay buffer (the number of obs-action-obs samples from
          the generator that can be stored).

          By default this is equal to `20 * self.gen_batch_size`.
        init_tensorboard: If True, makes various discriminator
          TensorBoard summaries.
        init_tensorboard_graph: If both this and `init_tensorboard` are True,
          then write a Tensorboard graph summary to disk.
        debug_use_ground_truth: If True, use the ground truth reward for
          `self.train_env`.
          This disables the reward wrapping that would normally replace
          the environment reward with the learned reward. This is useful for
          sanity checking that the policy training is functional.
    """
        assert util.logger.is_configured(), ("Requires call to "
                                             "imitation.util.logger.configure")
        self._sess = tf.get_default_session()
        self._global_step = tf.train.create_global_step()

        assert disc_batch_size % disc_minibatch_size == 0
        assert disc_minibatch_size % 2 == 0, (
            "discriminator minibatch size must be even "
            "(equal split between generator and expert samples)")
        self.disc_batch_size = disc_batch_size
        self.disc_minibatch_size = disc_minibatch_size

        self.debug_use_ground_truth = debug_use_ground_truth

        self.venv = venv
        self._expert_demos = expert_demos
        self._gen_policy = gen_policy

        self._log_dir = log_dir

        # Create graph for optimising/recording stats on discriminator
        self._discrim = discrim
        self._disc_opt_cls = disc_opt_cls
        self._disc_opt_kwargs = disc_opt_kwargs
        self._init_tensorboard = init_tensorboard
        self._init_tensorboard_graph = init_tensorboard_graph
        self._build_graph()
        self._sess.run(tf.global_variables_initializer())

        if debug_use_ground_truth:
            # Would use an identity reward fn here, but RewardFns can't see rewards.
            self.reward_train = self.reward_test = None
            self.venv_train = self.venv_test = self.venv
        else:
            self.reward_train = partial(
                self.discrim.reward_train,
                gen_log_prob_fn=self._gen_policy.action_probability)
            self.reward_test = self.discrim.reward_test
            self.venv_train = reward_wrapper.RewardVecEnvWrapper(
                self.venv, self.reward_train)
            self.venv_test = reward_wrapper.RewardVecEnvWrapper(
                self.venv, self.reward_test)

        self.venv_train_norm = VecNormalize(self.venv_train)
        self.venv_train_norm_buffering = BufferingWrapper(self.venv_train_norm)
        self.gen_policy.set_env(self.venv_train_norm_buffering)

        if gen_replay_buffer_capacity is None:
            gen_replay_buffer_capacity = 20 * self.gen_batch_size
        self._gen_replay_buffer = buffer.ReplayBuffer(
            gen_replay_buffer_capacity, self.venv)
        self._exp_replay_buffer = buffer.ReplayBuffer.from_data(expert_demos)
        if self.disc_batch_size // 2 > len(self._exp_replay_buffer):
            warn(
                "The discriminator batch size is more than twice the number of "
                "expert samples. This means that we will be reusing samples every "
                "discrim batch.")
예제 #3
0
def test_pop(episode_lengths: Sequence[int], n_steps: int,
             extra_pop_timesteps: Sequence[int]):
    """Check pop_transitions() results for BufferWrapper.

  To make things easier to test, we use _CountingEnv where the observation
  is simply the episode timestep. The reward is 10x the timestep. Our action
  is 2.1x the timestep. There is an confusing offset for the observation because
  it has timestep 0 (due to reset()) and the other quantities don't, so here is
  an example of environment outputs and associated actions:

  ```
  episode_length = 5
  obs = [0, 1, 2, 3, 4, 5]  (len=6)
  acts = [0, 2.1, 4.2, ..., 8.4]  (len=5)
  rews = [10, ..., 50]  (len=5)
  ```

  Converted to `Transition`-format, this looks like:
  ```
  episode_length = 5
  obs = [0, 1, 2, 3, 4, 5]  (len=5)
  next_obs = [1, 2, 3, 4, 5]  (len=5)
  acts = [0, 2.1, 4.2, ..., 8.4]  (len=5)
  rews = [10, ..., 50]  (len=5)
  ```

  Args:
    episode_lengths: The number of timesteps before episode end in each dummy
      environment.
    n_steps: Number of times to call `step()` on the dummy environment.
    extra_pop_timesteps: By default, we only call `pop_*()` after `n_steps`
      calls to `step()`. For every unique positive `x` in `extra_pop_timesteps`,
      we also call `pop_*()` after the `x`th call to `step()`. All popped
      samples are concatenated before validating results at the end of this
      test case. All `x` in `extra_pop_timesteps` must be in range(1, n_steps).
      (`x == 0` is not valid because there are no transitions to pop at timestep
      0).
  """
    if not n_steps >= 1:  # pragma: no cover
        raise ValueError(n_steps)
    for t in extra_pop_timesteps:  # pragma: no cover
        if t < 1:
            raise ValueError(t)
        if not 1 <= t < n_steps:
            pytest.skip("pop timesteps out of bounds for this test case")

    def make_env(ep_len):
        return lambda: _CountingEnv(episode_length=ep_len)

    venv = DummyVecEnv([make_env(ep_len) for ep_len in episode_lengths])
    venv_buffer = BufferingWrapper(venv)

    # To test `pop_transitions`, we will check that every obs, act, and rew
    # returned by `.reset()` and `.step()` is also returned by one of the
    # calls to `pop_transitions()`.
    transitions_list = []  # type: List[rollout.Transitions]

    # Initial observation (only matters for pop_transitions()).
    obs = venv_buffer.reset()
    np.testing.assert_array_equal(obs, [0] * venv.num_envs)

    for t in range(1, n_steps + 1):
        acts = obs * 2.1
        venv_buffer.step_async(acts)
        obs, *_ = venv_buffer.step_wait()

        if t in extra_pop_timesteps:
            transitions_list.append(venv_buffer.pop_transitions())

    transitions_list.append(venv_buffer.pop_transitions())

    # Build expected transitions
    expect_obs = []
    for ep_len in episode_lengths:
        n_complete, remainder = divmod(n_steps, ep_len)
        expect_obs.extend([np.arange(ep_len)] * n_complete)
        expect_obs.append(np.arange(remainder))

    expect_obs = np.concatenate(expect_obs)
    expect_next_obs = expect_obs + 1
    expect_acts = expect_obs * 2.1
    expect_rews = expect_next_obs * 10

    # Check `pop_transitions()`
    trans = _join_transitions(transitions_list)

    _assert_equal_scrambled_vectors(trans.obs, expect_obs)
    _assert_equal_scrambled_vectors(trans.next_obs, expect_next_obs)
    _assert_equal_scrambled_vectors(trans.acts, expect_acts)
    _assert_equal_scrambled_vectors(trans.rews, expect_rews)
예제 #4
0
class AdversarialTrainer:
    """Trainer for GAIL and AIRL."""

    venv: VecEnv
    """The original vectorized environment."""

    venv_train: VecEnv
    """Like `self.venv`, but wrapped with train reward unless in debug mode.

  If `debug_use_ground_truth=True` was passed into the initializer then
  `self.venv_train` is the same as `self.venv`.
  """

    venv_test: VecEnv
    """Like `self.venv`, but wrapped with test reward unless in debug mode.

  If `debug_use_ground_truth=True` was passed into the initializer then
  `self.venv_test` is the same as `self.venv`.
  """
    def __init__(
        self,
        venv: VecEnv,
        gen_policy: BaseRLModel,
        discrim: discrim_net.DiscrimNet,
        expert_demos: rollout.Transitions,
        *,
        log_dir: str = 'output/',
        disc_batch_size: int = 2048,
        disc_minibatch_size: int = 256,
        disc_opt_cls: tf.train.Optimizer = tf.train.AdamOptimizer,
        disc_opt_kwargs: dict = {},
        gen_replay_buffer_capacity: Optional[int] = None,
        init_tensorboard: bool = False,
        init_tensorboard_graph: bool = False,
        debug_use_ground_truth: bool = False,
    ):
        """Builds Trainer.

    Args:
        venv: The vectorized environment to train in.
        gen_policy: The generator policy that is trained to maximize
          discriminator confusion. The generator batch size
          `self.gen_batch_size` is inferred from `gen_policy.n_batch`.
        discrim: The discriminator network.
          For GAIL, use a DiscrimNetGAIL. For AIRL, use a DiscrimNetAIRL.
        expert_demos: Transitions from an expert dataset.
        log_dir: Directory to store TensorBoard logs, plots, etc. in.
        disc_batch_size: The default number of expert and generator transitions
          samples to feed to the discriminator in each call to
          `self.train_disc()`. (Half of the samples are expert and half of the
          samples are generator).
        disc_minibatch_size: The discriminator minibatch size. Each
          discriminator batch is split into minibatches and an Adam update is
          applied on the gradient resulting form each minibatch. Must evenly
          divide `disc_batch_size`. Must be an even number.
        disc_opt_cls: The optimizer for discriminator training.
        disc_opt_kwargs: Parameters for discriminator training.
        gen_replay_buffer_capacity: The capacity of the
          generator replay buffer (the number of obs-action-obs samples from
          the generator that can be stored).

          By default this is equal to `20 * self.gen_batch_size`.
        init_tensorboard: If True, makes various discriminator
          TensorBoard summaries.
        init_tensorboard_graph: If both this and `init_tensorboard` are True,
          then write a Tensorboard graph summary to disk.
        debug_use_ground_truth: If True, use the ground truth reward for
          `self.train_env`.
          This disables the reward wrapping that would normally replace
          the environment reward with the learned reward. This is useful for
          sanity checking that the policy training is functional.
    """
        assert util.logger.is_configured(), ("Requires call to "
                                             "imitation.util.logger.configure")
        self._sess = tf.get_default_session()
        self._global_step = tf.train.create_global_step()

        assert disc_batch_size % disc_minibatch_size == 0
        assert disc_minibatch_size % 2 == 0, (
            "discriminator minibatch size must be even "
            "(equal split between generator and expert samples)")
        self.disc_batch_size = disc_batch_size
        self.disc_minibatch_size = disc_minibatch_size

        self.debug_use_ground_truth = debug_use_ground_truth

        self.venv = venv
        self._expert_demos = expert_demos
        self._gen_policy = gen_policy

        self._log_dir = log_dir

        # Create graph for optimising/recording stats on discriminator
        self._discrim = discrim
        self._disc_opt_cls = disc_opt_cls
        self._disc_opt_kwargs = disc_opt_kwargs
        self._init_tensorboard = init_tensorboard
        self._init_tensorboard_graph = init_tensorboard_graph
        self._build_graph()
        self._sess.run(tf.global_variables_initializer())

        if debug_use_ground_truth:
            # Would use an identity reward fn here, but RewardFns can't see rewards.
            self.reward_train = self.reward_test = None
            self.venv_train = self.venv_test = self.venv
        else:
            self.reward_train = partial(
                self.discrim.reward_train,
                gen_log_prob_fn=self._gen_policy.action_probability)
            self.reward_test = self.discrim.reward_test
            self.venv_train = reward_wrapper.RewardVecEnvWrapper(
                self.venv, self.reward_train)
            self.venv_test = reward_wrapper.RewardVecEnvWrapper(
                self.venv, self.reward_test)

        self.venv_train_norm = VecNormalize(self.venv_train)
        self.venv_train_norm_buffering = BufferingWrapper(self.venv_train_norm)
        self.gen_policy.set_env(self.venv_train_norm_buffering)

        if gen_replay_buffer_capacity is None:
            gen_replay_buffer_capacity = 20 * self.gen_batch_size
        self._gen_replay_buffer = buffer.ReplayBuffer(
            gen_replay_buffer_capacity, self.venv)
        self._exp_replay_buffer = buffer.ReplayBuffer.from_data(expert_demos)
        if self.disc_batch_size // 2 > len(self._exp_replay_buffer):
            warn(
                "The discriminator batch size is more than twice the number of "
                "expert samples. This means that we will be reusing samples every "
                "discrim batch.")

    @property
    def gen_batch_size(self) -> int:
        return self.gen_policy.n_batch

    @property
    def discrim(self) -> discrim_net.DiscrimNet:
        """Discriminator being trained, used to compute reward for policy."""
        return self._discrim

    @property
    def expert_demos(self) -> util.rollout.Transitions:
        """The expert demonstrations that are being imitated."""
        return self._expert_demos

    @property
    def gen_policy(self) -> BaseRLModel:
        """Policy (i.e. the generator) being trained."""
        return self._gen_policy

    def train_disc(self, n_samples: Optional[int] = None) -> None:
        """Trains the discriminator to minimize classification cross-entropy.

    Must call `train_gen` first (otherwise there will be no saved generator
    samples for training, and will error).

    Args:
      n_samples: A number of transitions to sample from the generator
        replay buffer and the expert demonstration dataset. (Half of the
        samples are from each source). By default, `self.disc_batch_size`.
        `n_samples` must be a positive multiple of `self.disc_minibatch_size`.
    """
        if len(self._gen_replay_buffer) == 0:
            raise RuntimeError("No generator samples for training. "
                               "Call `train_gen()` first.")

        if n_samples is None:
            n_samples = self.disc_batch_size
        n_updates = n_samples // self.disc_minibatch_size
        assert n_samples % self.disc_minibatch_size == 0
        assert n_updates >= 1
        for _ in range(n_updates):
            gen_samples = self._gen_replay_buffer.sample(
                self.disc_minibatch_size)
            self.train_disc_step(gen_samples=gen_samples)

    def train_disc_step(
        self,
        *,
        gen_samples: Optional[rollout.Transitions] = None,
        expert_samples: Optional[rollout.Transitions] = None,
    ) -> None:
        """Perform a single discriminator update, optionally using provided samples.

    Args:
      gen_samples: Transition samples from the generator policy. If not
        provided, then take `self.disc_batch_size // 2` samples from the
        generator replay buffer.
      expert_samples: Transition samples from the expert. If not
        provided, then take `n_gen` expert samples from the expert
        dataset, where `n_gen` is the number of samples in `gen_samples`.
    """
        with logger.accumulate_means("disc"):
            fetches = {
                'train_op_out': self._disc_train_op,
                'train_stats': self._discrim.train_stats,
            }
            # optionally write TB summaries for collected ops
            step = self._sess.run(self._global_step)
            write_summaries = self._init_tensorboard and step % 20 == 0
            if write_summaries:
                fetches['events'] = self._summary_op

            # do actual update
            fd = self._build_disc_feed_dict(gen_samples=gen_samples,
                                            expert_samples=expert_samples)
            fetched = self._sess.run(fetches, feed_dict=fd)

            if write_summaries:
                self._summary_writer.add_summary(fetched['events'],
                                                 fetched['step'])

            logger.logkv("step", step)
            for k, v in fetched['train_stats'].items():
                logger.logkv(k, v)
            logger.dumpkvs()

    def eval_disc_loss(self, **kwargs) -> float:
        """Evaluates the discriminator loss.

    Args:
      gen_samples (Optional[rollout.Transitions]): Same as in `train_disc_step`.
      expert_samples (Optional[rollout.Transitions]): Same as in
        `train_disc_step`.

    Returns:
      The total cross-entropy error in the discriminator's classification.
    """
        fd = self._build_disc_feed_dict(**kwargs)
        return np.mean(self._sess.run(self.discrim.disc_loss, feed_dict=fd))

    def train_gen(self,
                  total_timesteps: Optional[int] = None,
                  learn_kwargs: Optional[dict] = None):
        """Trains the generator to maximize the discriminator loss.

    After the end of training populates the generator replay buffer (used in
    discriminator training) with `self.disc_batch_size` transitions.

    Args:
      total_timesteps: The number of transitions to sample from
        `self.venv_train_norm` during training. By default,
        `self.gen_batch_size`.
      learn_kwargs: kwargs for the Stable Baselines `RLModel.learn()`
        method.
    """
        if total_timesteps is None:
            total_timesteps = self.gen_batch_size
        if learn_kwargs is None:
            learn_kwargs = {}

        with logger.accumulate_means("gen"):
            self.gen_policy.learn(total_timesteps=total_timesteps,
                                  reset_num_timesteps=False,
                                  **learn_kwargs)

        with logger.accumulate_means("gen_buffer"):
            # Log stats for finished trajectories stored in the BufferingWrapper. This
            # will bias toward shorter trajectories because trajectories that
            # are partially finished at the time of this log are popped from
            # the buffer a few lines down.
            #
            # This is useful for getting some statistics for unnormalized rewards.
            # (The rewards logged during the call to `.learn()` are the ground truth
            # rewards, retrieved from Monitor.).
            trajs = self.venv_train_norm_buffering._trajectories
            if len(trajs) > 0:
                stats = rollout.rollout_stats(trajs)
                for k, v in stats.items():
                    util.logger.logkv(k, v)

        gen_samples = self.venv_train_norm_buffering.pop_transitions()
        self._gen_replay_buffer.store(gen_samples)

    def train(
        self,
        total_timesteps: int,
        callback: Optional[Callable[[int], None]] = None,
    ) -> None:
        """Alternates between training the generator and discriminator.

    Every epoch consists of a call to `train_gen(self.gen_batch_size)`,
    a call to `train_disc(self.disc_batch_size)`, and
    finally a call to `callback(epoch)`.

    Training ends once an additional epoch would cause the number of transitions
    sampled from the environment to exceed `total_timesteps`.

    Params:
      total_timesteps: An upper bound on the number of transitions to sample
        from the environment during training.
      callback: A function called at the end of every epoch which takes in a
        single argument, the epoch number. Epoch numbers are in
        `range(total_timesteps // self.gen_batch_size)`.
    """
        n_epochs = total_timesteps // self.gen_batch_size
        assert n_epochs >= 1, ("No updates (need at least "
                               f"{self.gen_batch_size} timesteps, have only "
                               f"total_timesteps={total_timesteps})!")
        for epoch in tqdm.tqdm(range(0, n_epochs), desc="epoch"):
            self.train_gen(self.gen_batch_size)
            self.train_disc(self.disc_batch_size)
            if callback:
                callback(epoch)
            util.logger.dumpkvs()

    def _build_graph(self):
        # Build necessary parts of the TF graph. Most of the real action happens in
        # constructors for self.discrim and self.gen_policy.
        with tf.variable_scope("trainer"):
            with tf.variable_scope("discriminator"):
                disc_opt = self._disc_opt_cls(**self._disc_opt_kwargs)
                self._disc_train_op = disc_opt.minimize(
                    tf.reduce_mean(self.discrim.disc_loss),
                    global_step=self._global_step)

        if self._init_tensorboard:
            with tf.name_scope("summaries"):
                tf.logging.info("building summary directory at " +
                                self._log_dir)
                graph = self._sess.graph if self._init_tensorboard_graph else None
                summary_dir = os.path.join(self._log_dir, 'summary')
                os.makedirs(summary_dir, exist_ok=True)
                self._summary_writer = tf.summary.FileWriter(summary_dir,
                                                             graph=graph)
                self._summary_op = tf.summary.merge_all()

    def _build_disc_feed_dict(
        self,
        *,
        gen_samples: Optional[rollout.Transitions] = None,
        expert_samples: Optional[rollout.Transitions] = None,
    ) -> dict:
        """Build and return feed dict for the next discriminator training update.

    Args:
      gen_samples: Same as in `train_disc_step`.
      expert_samples: Same as in `train_disc_step`.
    """
        if gen_samples is None:
            if len(self._gen_replay_buffer) == 0:
                raise RuntimeError("No generator samples for training. "
                                   "Call `train_gen()` first.")
            gen_samples = self._gen_replay_buffer.sample(
                self.disc_batch_size // 2)
        n_gen = len(gen_samples.obs)

        if expert_samples is None:
            expert_samples = self._exp_replay_buffer.sample(n_gen)
        n_expert = len(expert_samples.obs)

        # Check dimensions.
        n_samples = n_expert + n_gen
        assert n_expert == len(expert_samples.acts)
        assert n_expert == len(expert_samples.next_obs)
        assert n_gen == len(gen_samples.acts)
        assert n_gen == len(gen_samples.next_obs)

        # Normalize expert observations to match generator observations.
        expert_obs_norm = self.venv_train_norm.normalize_obs(
            expert_samples.obs)

        # Concatenate rollouts, and label each row as expert or generator.
        obs = np.concatenate([expert_obs_norm, gen_samples.obs])
        acts = np.concatenate([expert_samples.acts, gen_samples.acts])
        next_obs = np.concatenate(
            [expert_samples.next_obs, gen_samples.next_obs])
        labels_gen_is_one = np.concatenate(
            [np.zeros(n_expert, dtype=int),
             np.ones(n_gen, dtype=int)])

        # Calculate generator-policy log probabilities.
        log_act_prob = self._gen_policy.action_probability(obs,
                                                           actions=acts,
                                                           logp=True)
        assert len(log_act_prob) == n_samples
        log_act_prob = log_act_prob.reshape((n_samples, ))

        fd = {
            self.discrim.obs_ph: obs,
            self.discrim.act_ph: acts,
            self.discrim.next_obs_ph: next_obs,
            self.discrim.labels_gen_is_one_ph: labels_gen_is_one,
            self.discrim.log_policy_act_prob_ph: log_act_prob,
        }
        return fd
예제 #5
0
def _make_buffering_venv(error_on_premature_reset: bool, ) -> BufferingWrapper:
    venv = DummyVecEnv([_CountingEnv] * 2)
    venv = BufferingWrapper(venv, error_on_premature_reset)
    venv.reset()
    return venv