def adversarial_learning( venv, expert=None, expert_venv=None, expert_trajectories=None, state_only=False, policy_fn=get_ppo, total_timesteps=20000, gen_batch_size=200, disc_batch_size=100, updates_per_batch=2, policy_lr=1e-3, reward_lr=1e-3, is_airl=True, **kwargs, ): # Set up generator gen_policy = policy_fn(venv, learning_rate=policy_lr) policy = gen_policy # Set up discriminator if is_airl: rn = BasicShapedRewardNet( venv.observation_space, venv.action_space, theta_units=[32, 32], phi_units=[32, 32], scale=True, state_only=state_only, ) discrim = DiscrimNetAIRL(rn, entropy_weight=1.0) else: rn = None discrim = DiscrimNetGAIL(venv.observation_space, venv.action_space) # Set up optimizer train_op = tf.train.AdamOptimizer(learning_rate=reward_lr).minimize( tf.reduce_mean(discrim.disc_loss)) # Set up environment reward reward_train = functools.partial( discrim.reward_train, gen_log_prob_fn=gen_policy.action_probability) venv_train = reward_wrapper.RewardVecEnvWrapper(venv, reward_train) venv_train_buffering = BufferingWrapper(venv_train) gen_policy.set_env(venv_train_buffering) # possibly redundant # Set up replay buffers gen_replay_buffer_capacity = 20 * gen_batch_size gen_replay_buffer = buffer.ReplayBuffer(gen_replay_buffer_capacity, venv) if expert_trajectories is not None: expert_transitions = flatten_trajectories(expert_trajectories) exp_replay_buffer = buffer.ReplayBuffer.from_data(expert_transitions) else: exp_replay_buffer = buffer.ReplayBuffer(gen_replay_buffer_capacity, venv) # Start training sess = tf.get_default_session() sess.run(tf.global_variables_initializer()) num_epochs = int(np.ceil(total_timesteps / gen_batch_size)) for epoch in range(num_epochs): # Train gen gen_policy.learn(total_timesteps=gen_batch_size, reset_num_timesteps=True) gen_replay_buffer.store(venv_train_buffering.pop_transitions()) if expert_trajectories is None: exp_replay_buffer.store( flatten_trajectories( sample_trajectories(expert_venv, expert, n_timesteps=gen_batch_size))) # Train disc for _ in range(updates_per_batch): disc_minibatch_size = disc_batch_size // updates_per_batch half_minibatch = disc_minibatch_size // 2 gen_samples = gen_replay_buffer.sample(half_minibatch) expert_samples = exp_replay_buffer.sample(half_minibatch) obs = np.concatenate([gen_samples.obs, expert_samples.obs]) acts = np.concatenate([gen_samples.acts, expert_samples.acts]) next_obs = np.concatenate( [gen_samples.next_obs, expert_samples.next_obs]) labels = np.concatenate( [np.ones(half_minibatch), np.zeros(half_minibatch)]) log_act_prob = gen_policy.action_probability(obs, actions=acts, logp=True) log_act_prob = log_act_prob.reshape((disc_minibatch_size, )) _, logits_v, loss_v = sess.run( [ train_op, discrim._disc_logits_gen_is_high, discrim._disc_loss, ], feed_dict={ discrim.obs_ph: obs, discrim.act_ph: acts, discrim.next_obs_ph: next_obs, discrim.labels_gen_is_one_ph: labels, discrim.log_policy_act_prob_ph: log_act_prob, }, ) results = {} results["reward_model"] = rn results["discrim"] = discrim results["policy"] = gen_policy return results
def __init__( self, venv: VecEnv, gen_policy: BaseRLModel, discrim: discrim_net.DiscrimNet, expert_demos: rollout.Transitions, *, log_dir: str = 'output/', disc_batch_size: int = 2048, disc_minibatch_size: int = 256, disc_opt_cls: tf.train.Optimizer = tf.train.AdamOptimizer, disc_opt_kwargs: dict = {}, gen_replay_buffer_capacity: Optional[int] = None, init_tensorboard: bool = False, init_tensorboard_graph: bool = False, debug_use_ground_truth: bool = False, ): """Builds Trainer. Args: venv: The vectorized environment to train in. gen_policy: The generator policy that is trained to maximize discriminator confusion. The generator batch size `self.gen_batch_size` is inferred from `gen_policy.n_batch`. discrim: The discriminator network. For GAIL, use a DiscrimNetGAIL. For AIRL, use a DiscrimNetAIRL. expert_demos: Transitions from an expert dataset. log_dir: Directory to store TensorBoard logs, plots, etc. in. disc_batch_size: The default number of expert and generator transitions samples to feed to the discriminator in each call to `self.train_disc()`. (Half of the samples are expert and half of the samples are generator). disc_minibatch_size: The discriminator minibatch size. Each discriminator batch is split into minibatches and an Adam update is applied on the gradient resulting form each minibatch. Must evenly divide `disc_batch_size`. Must be an even number. disc_opt_cls: The optimizer for discriminator training. disc_opt_kwargs: Parameters for discriminator training. gen_replay_buffer_capacity: The capacity of the generator replay buffer (the number of obs-action-obs samples from the generator that can be stored). By default this is equal to `20 * self.gen_batch_size`. init_tensorboard: If True, makes various discriminator TensorBoard summaries. init_tensorboard_graph: If both this and `init_tensorboard` are True, then write a Tensorboard graph summary to disk. debug_use_ground_truth: If True, use the ground truth reward for `self.train_env`. This disables the reward wrapping that would normally replace the environment reward with the learned reward. This is useful for sanity checking that the policy training is functional. """ assert util.logger.is_configured(), ("Requires call to " "imitation.util.logger.configure") self._sess = tf.get_default_session() self._global_step = tf.train.create_global_step() assert disc_batch_size % disc_minibatch_size == 0 assert disc_minibatch_size % 2 == 0, ( "discriminator minibatch size must be even " "(equal split between generator and expert samples)") self.disc_batch_size = disc_batch_size self.disc_minibatch_size = disc_minibatch_size self.debug_use_ground_truth = debug_use_ground_truth self.venv = venv self._expert_demos = expert_demos self._gen_policy = gen_policy self._log_dir = log_dir # Create graph for optimising/recording stats on discriminator self._discrim = discrim self._disc_opt_cls = disc_opt_cls self._disc_opt_kwargs = disc_opt_kwargs self._init_tensorboard = init_tensorboard self._init_tensorboard_graph = init_tensorboard_graph self._build_graph() self._sess.run(tf.global_variables_initializer()) if debug_use_ground_truth: # Would use an identity reward fn here, but RewardFns can't see rewards. self.reward_train = self.reward_test = None self.venv_train = self.venv_test = self.venv else: self.reward_train = partial( self.discrim.reward_train, gen_log_prob_fn=self._gen_policy.action_probability) self.reward_test = self.discrim.reward_test self.venv_train = reward_wrapper.RewardVecEnvWrapper( self.venv, self.reward_train) self.venv_test = reward_wrapper.RewardVecEnvWrapper( self.venv, self.reward_test) self.venv_train_norm = VecNormalize(self.venv_train) self.venv_train_norm_buffering = BufferingWrapper(self.venv_train_norm) self.gen_policy.set_env(self.venv_train_norm_buffering) if gen_replay_buffer_capacity is None: gen_replay_buffer_capacity = 20 * self.gen_batch_size self._gen_replay_buffer = buffer.ReplayBuffer( gen_replay_buffer_capacity, self.venv) self._exp_replay_buffer = buffer.ReplayBuffer.from_data(expert_demos) if self.disc_batch_size // 2 > len(self._exp_replay_buffer): warn( "The discriminator batch size is more than twice the number of " "expert samples. This means that we will be reusing samples every " "discrim batch.")
def test_pop(episode_lengths: Sequence[int], n_steps: int, extra_pop_timesteps: Sequence[int]): """Check pop_transitions() results for BufferWrapper. To make things easier to test, we use _CountingEnv where the observation is simply the episode timestep. The reward is 10x the timestep. Our action is 2.1x the timestep. There is an confusing offset for the observation because it has timestep 0 (due to reset()) and the other quantities don't, so here is an example of environment outputs and associated actions: ``` episode_length = 5 obs = [0, 1, 2, 3, 4, 5] (len=6) acts = [0, 2.1, 4.2, ..., 8.4] (len=5) rews = [10, ..., 50] (len=5) ``` Converted to `Transition`-format, this looks like: ``` episode_length = 5 obs = [0, 1, 2, 3, 4, 5] (len=5) next_obs = [1, 2, 3, 4, 5] (len=5) acts = [0, 2.1, 4.2, ..., 8.4] (len=5) rews = [10, ..., 50] (len=5) ``` Args: episode_lengths: The number of timesteps before episode end in each dummy environment. n_steps: Number of times to call `step()` on the dummy environment. extra_pop_timesteps: By default, we only call `pop_*()` after `n_steps` calls to `step()`. For every unique positive `x` in `extra_pop_timesteps`, we also call `pop_*()` after the `x`th call to `step()`. All popped samples are concatenated before validating results at the end of this test case. All `x` in `extra_pop_timesteps` must be in range(1, n_steps). (`x == 0` is not valid because there are no transitions to pop at timestep 0). """ if not n_steps >= 1: # pragma: no cover raise ValueError(n_steps) for t in extra_pop_timesteps: # pragma: no cover if t < 1: raise ValueError(t) if not 1 <= t < n_steps: pytest.skip("pop timesteps out of bounds for this test case") def make_env(ep_len): return lambda: _CountingEnv(episode_length=ep_len) venv = DummyVecEnv([make_env(ep_len) for ep_len in episode_lengths]) venv_buffer = BufferingWrapper(venv) # To test `pop_transitions`, we will check that every obs, act, and rew # returned by `.reset()` and `.step()` is also returned by one of the # calls to `pop_transitions()`. transitions_list = [] # type: List[rollout.Transitions] # Initial observation (only matters for pop_transitions()). obs = venv_buffer.reset() np.testing.assert_array_equal(obs, [0] * venv.num_envs) for t in range(1, n_steps + 1): acts = obs * 2.1 venv_buffer.step_async(acts) obs, *_ = venv_buffer.step_wait() if t in extra_pop_timesteps: transitions_list.append(venv_buffer.pop_transitions()) transitions_list.append(venv_buffer.pop_transitions()) # Build expected transitions expect_obs = [] for ep_len in episode_lengths: n_complete, remainder = divmod(n_steps, ep_len) expect_obs.extend([np.arange(ep_len)] * n_complete) expect_obs.append(np.arange(remainder)) expect_obs = np.concatenate(expect_obs) expect_next_obs = expect_obs + 1 expect_acts = expect_obs * 2.1 expect_rews = expect_next_obs * 10 # Check `pop_transitions()` trans = _join_transitions(transitions_list) _assert_equal_scrambled_vectors(trans.obs, expect_obs) _assert_equal_scrambled_vectors(trans.next_obs, expect_next_obs) _assert_equal_scrambled_vectors(trans.acts, expect_acts) _assert_equal_scrambled_vectors(trans.rews, expect_rews)
class AdversarialTrainer: """Trainer for GAIL and AIRL.""" venv: VecEnv """The original vectorized environment.""" venv_train: VecEnv """Like `self.venv`, but wrapped with train reward unless in debug mode. If `debug_use_ground_truth=True` was passed into the initializer then `self.venv_train` is the same as `self.venv`. """ venv_test: VecEnv """Like `self.venv`, but wrapped with test reward unless in debug mode. If `debug_use_ground_truth=True` was passed into the initializer then `self.venv_test` is the same as `self.venv`. """ def __init__( self, venv: VecEnv, gen_policy: BaseRLModel, discrim: discrim_net.DiscrimNet, expert_demos: rollout.Transitions, *, log_dir: str = 'output/', disc_batch_size: int = 2048, disc_minibatch_size: int = 256, disc_opt_cls: tf.train.Optimizer = tf.train.AdamOptimizer, disc_opt_kwargs: dict = {}, gen_replay_buffer_capacity: Optional[int] = None, init_tensorboard: bool = False, init_tensorboard_graph: bool = False, debug_use_ground_truth: bool = False, ): """Builds Trainer. Args: venv: The vectorized environment to train in. gen_policy: The generator policy that is trained to maximize discriminator confusion. The generator batch size `self.gen_batch_size` is inferred from `gen_policy.n_batch`. discrim: The discriminator network. For GAIL, use a DiscrimNetGAIL. For AIRL, use a DiscrimNetAIRL. expert_demos: Transitions from an expert dataset. log_dir: Directory to store TensorBoard logs, plots, etc. in. disc_batch_size: The default number of expert and generator transitions samples to feed to the discriminator in each call to `self.train_disc()`. (Half of the samples are expert and half of the samples are generator). disc_minibatch_size: The discriminator minibatch size. Each discriminator batch is split into minibatches and an Adam update is applied on the gradient resulting form each minibatch. Must evenly divide `disc_batch_size`. Must be an even number. disc_opt_cls: The optimizer for discriminator training. disc_opt_kwargs: Parameters for discriminator training. gen_replay_buffer_capacity: The capacity of the generator replay buffer (the number of obs-action-obs samples from the generator that can be stored). By default this is equal to `20 * self.gen_batch_size`. init_tensorboard: If True, makes various discriminator TensorBoard summaries. init_tensorboard_graph: If both this and `init_tensorboard` are True, then write a Tensorboard graph summary to disk. debug_use_ground_truth: If True, use the ground truth reward for `self.train_env`. This disables the reward wrapping that would normally replace the environment reward with the learned reward. This is useful for sanity checking that the policy training is functional. """ assert util.logger.is_configured(), ("Requires call to " "imitation.util.logger.configure") self._sess = tf.get_default_session() self._global_step = tf.train.create_global_step() assert disc_batch_size % disc_minibatch_size == 0 assert disc_minibatch_size % 2 == 0, ( "discriminator minibatch size must be even " "(equal split between generator and expert samples)") self.disc_batch_size = disc_batch_size self.disc_minibatch_size = disc_minibatch_size self.debug_use_ground_truth = debug_use_ground_truth self.venv = venv self._expert_demos = expert_demos self._gen_policy = gen_policy self._log_dir = log_dir # Create graph for optimising/recording stats on discriminator self._discrim = discrim self._disc_opt_cls = disc_opt_cls self._disc_opt_kwargs = disc_opt_kwargs self._init_tensorboard = init_tensorboard self._init_tensorboard_graph = init_tensorboard_graph self._build_graph() self._sess.run(tf.global_variables_initializer()) if debug_use_ground_truth: # Would use an identity reward fn here, but RewardFns can't see rewards. self.reward_train = self.reward_test = None self.venv_train = self.venv_test = self.venv else: self.reward_train = partial( self.discrim.reward_train, gen_log_prob_fn=self._gen_policy.action_probability) self.reward_test = self.discrim.reward_test self.venv_train = reward_wrapper.RewardVecEnvWrapper( self.venv, self.reward_train) self.venv_test = reward_wrapper.RewardVecEnvWrapper( self.venv, self.reward_test) self.venv_train_norm = VecNormalize(self.venv_train) self.venv_train_norm_buffering = BufferingWrapper(self.venv_train_norm) self.gen_policy.set_env(self.venv_train_norm_buffering) if gen_replay_buffer_capacity is None: gen_replay_buffer_capacity = 20 * self.gen_batch_size self._gen_replay_buffer = buffer.ReplayBuffer( gen_replay_buffer_capacity, self.venv) self._exp_replay_buffer = buffer.ReplayBuffer.from_data(expert_demos) if self.disc_batch_size // 2 > len(self._exp_replay_buffer): warn( "The discriminator batch size is more than twice the number of " "expert samples. This means that we will be reusing samples every " "discrim batch.") @property def gen_batch_size(self) -> int: return self.gen_policy.n_batch @property def discrim(self) -> discrim_net.DiscrimNet: """Discriminator being trained, used to compute reward for policy.""" return self._discrim @property def expert_demos(self) -> util.rollout.Transitions: """The expert demonstrations that are being imitated.""" return self._expert_demos @property def gen_policy(self) -> BaseRLModel: """Policy (i.e. the generator) being trained.""" return self._gen_policy def train_disc(self, n_samples: Optional[int] = None) -> None: """Trains the discriminator to minimize classification cross-entropy. Must call `train_gen` first (otherwise there will be no saved generator samples for training, and will error). Args: n_samples: A number of transitions to sample from the generator replay buffer and the expert demonstration dataset. (Half of the samples are from each source). By default, `self.disc_batch_size`. `n_samples` must be a positive multiple of `self.disc_minibatch_size`. """ if len(self._gen_replay_buffer) == 0: raise RuntimeError("No generator samples for training. " "Call `train_gen()` first.") if n_samples is None: n_samples = self.disc_batch_size n_updates = n_samples // self.disc_minibatch_size assert n_samples % self.disc_minibatch_size == 0 assert n_updates >= 1 for _ in range(n_updates): gen_samples = self._gen_replay_buffer.sample( self.disc_minibatch_size) self.train_disc_step(gen_samples=gen_samples) def train_disc_step( self, *, gen_samples: Optional[rollout.Transitions] = None, expert_samples: Optional[rollout.Transitions] = None, ) -> None: """Perform a single discriminator update, optionally using provided samples. Args: gen_samples: Transition samples from the generator policy. If not provided, then take `self.disc_batch_size // 2` samples from the generator replay buffer. expert_samples: Transition samples from the expert. If not provided, then take `n_gen` expert samples from the expert dataset, where `n_gen` is the number of samples in `gen_samples`. """ with logger.accumulate_means("disc"): fetches = { 'train_op_out': self._disc_train_op, 'train_stats': self._discrim.train_stats, } # optionally write TB summaries for collected ops step = self._sess.run(self._global_step) write_summaries = self._init_tensorboard and step % 20 == 0 if write_summaries: fetches['events'] = self._summary_op # do actual update fd = self._build_disc_feed_dict(gen_samples=gen_samples, expert_samples=expert_samples) fetched = self._sess.run(fetches, feed_dict=fd) if write_summaries: self._summary_writer.add_summary(fetched['events'], fetched['step']) logger.logkv("step", step) for k, v in fetched['train_stats'].items(): logger.logkv(k, v) logger.dumpkvs() def eval_disc_loss(self, **kwargs) -> float: """Evaluates the discriminator loss. Args: gen_samples (Optional[rollout.Transitions]): Same as in `train_disc_step`. expert_samples (Optional[rollout.Transitions]): Same as in `train_disc_step`. Returns: The total cross-entropy error in the discriminator's classification. """ fd = self._build_disc_feed_dict(**kwargs) return np.mean(self._sess.run(self.discrim.disc_loss, feed_dict=fd)) def train_gen(self, total_timesteps: Optional[int] = None, learn_kwargs: Optional[dict] = None): """Trains the generator to maximize the discriminator loss. After the end of training populates the generator replay buffer (used in discriminator training) with `self.disc_batch_size` transitions. Args: total_timesteps: The number of transitions to sample from `self.venv_train_norm` during training. By default, `self.gen_batch_size`. learn_kwargs: kwargs for the Stable Baselines `RLModel.learn()` method. """ if total_timesteps is None: total_timesteps = self.gen_batch_size if learn_kwargs is None: learn_kwargs = {} with logger.accumulate_means("gen"): self.gen_policy.learn(total_timesteps=total_timesteps, reset_num_timesteps=False, **learn_kwargs) with logger.accumulate_means("gen_buffer"): # Log stats for finished trajectories stored in the BufferingWrapper. This # will bias toward shorter trajectories because trajectories that # are partially finished at the time of this log are popped from # the buffer a few lines down. # # This is useful for getting some statistics for unnormalized rewards. # (The rewards logged during the call to `.learn()` are the ground truth # rewards, retrieved from Monitor.). trajs = self.venv_train_norm_buffering._trajectories if len(trajs) > 0: stats = rollout.rollout_stats(trajs) for k, v in stats.items(): util.logger.logkv(k, v) gen_samples = self.venv_train_norm_buffering.pop_transitions() self._gen_replay_buffer.store(gen_samples) def train( self, total_timesteps: int, callback: Optional[Callable[[int], None]] = None, ) -> None: """Alternates between training the generator and discriminator. Every epoch consists of a call to `train_gen(self.gen_batch_size)`, a call to `train_disc(self.disc_batch_size)`, and finally a call to `callback(epoch)`. Training ends once an additional epoch would cause the number of transitions sampled from the environment to exceed `total_timesteps`. Params: total_timesteps: An upper bound on the number of transitions to sample from the environment during training. callback: A function called at the end of every epoch which takes in a single argument, the epoch number. Epoch numbers are in `range(total_timesteps // self.gen_batch_size)`. """ n_epochs = total_timesteps // self.gen_batch_size assert n_epochs >= 1, ("No updates (need at least " f"{self.gen_batch_size} timesteps, have only " f"total_timesteps={total_timesteps})!") for epoch in tqdm.tqdm(range(0, n_epochs), desc="epoch"): self.train_gen(self.gen_batch_size) self.train_disc(self.disc_batch_size) if callback: callback(epoch) util.logger.dumpkvs() def _build_graph(self): # Build necessary parts of the TF graph. Most of the real action happens in # constructors for self.discrim and self.gen_policy. with tf.variable_scope("trainer"): with tf.variable_scope("discriminator"): disc_opt = self._disc_opt_cls(**self._disc_opt_kwargs) self._disc_train_op = disc_opt.minimize( tf.reduce_mean(self.discrim.disc_loss), global_step=self._global_step) if self._init_tensorboard: with tf.name_scope("summaries"): tf.logging.info("building summary directory at " + self._log_dir) graph = self._sess.graph if self._init_tensorboard_graph else None summary_dir = os.path.join(self._log_dir, 'summary') os.makedirs(summary_dir, exist_ok=True) self._summary_writer = tf.summary.FileWriter(summary_dir, graph=graph) self._summary_op = tf.summary.merge_all() def _build_disc_feed_dict( self, *, gen_samples: Optional[rollout.Transitions] = None, expert_samples: Optional[rollout.Transitions] = None, ) -> dict: """Build and return feed dict for the next discriminator training update. Args: gen_samples: Same as in `train_disc_step`. expert_samples: Same as in `train_disc_step`. """ if gen_samples is None: if len(self._gen_replay_buffer) == 0: raise RuntimeError("No generator samples for training. " "Call `train_gen()` first.") gen_samples = self._gen_replay_buffer.sample( self.disc_batch_size // 2) n_gen = len(gen_samples.obs) if expert_samples is None: expert_samples = self._exp_replay_buffer.sample(n_gen) n_expert = len(expert_samples.obs) # Check dimensions. n_samples = n_expert + n_gen assert n_expert == len(expert_samples.acts) assert n_expert == len(expert_samples.next_obs) assert n_gen == len(gen_samples.acts) assert n_gen == len(gen_samples.next_obs) # Normalize expert observations to match generator observations. expert_obs_norm = self.venv_train_norm.normalize_obs( expert_samples.obs) # Concatenate rollouts, and label each row as expert or generator. obs = np.concatenate([expert_obs_norm, gen_samples.obs]) acts = np.concatenate([expert_samples.acts, gen_samples.acts]) next_obs = np.concatenate( [expert_samples.next_obs, gen_samples.next_obs]) labels_gen_is_one = np.concatenate( [np.zeros(n_expert, dtype=int), np.ones(n_gen, dtype=int)]) # Calculate generator-policy log probabilities. log_act_prob = self._gen_policy.action_probability(obs, actions=acts, logp=True) assert len(log_act_prob) == n_samples log_act_prob = log_act_prob.reshape((n_samples, )) fd = { self.discrim.obs_ph: obs, self.discrim.act_ph: acts, self.discrim.next_obs_ph: next_obs, self.discrim.labels_gen_is_one_ph: labels_gen_is_one, self.discrim.log_policy_act_prob_ph: log_act_prob, } return fd
def _make_buffering_venv(error_on_premature_reset: bool, ) -> BufferingWrapper: venv = DummyVecEnv([_CountingEnv] * 2) venv = BufferingWrapper(venv, error_on_premature_reset) venv.reset() return venv