def train_gen(self, total_timesteps: Optional[int] = None, callback=None): """Trains the generator to maximize the discriminator loss. After the end of training populates the generator replay buffer (used in discriminator training) with `self.disc_batch_size` transitions. Args: total_timesteps: The number of transitions to sample from `self.venv_train_norm` during training. By default, `self.gen_batch_size`. callback: Callback argument to the Stable Baselines `RLModel.learn()` method. """ if total_timesteps is None: total_timesteps = self.gen_batch_size with logger.accumulate_means("gen"): self.gen_policy.set_env(self.venv_train_norm_buffering) # TODO(adam): learn was not intended to be called for each training batch # It should work, but might incur unnecessary overhead: e.g. in PPO2 # a new Runner instance is created each time. Also a hotspot for errors: # algorithms not tested for this use case, may reset state accidentally. self.gen_policy.learn(total_timesteps=total_timesteps, reset_num_timesteps=False, callback=callback) gen_samples = self.venv_train_norm_buffering.pop_transitions() self._gen_replay_buffer.store(gen_samples)
def train_gen( self, total_timesteps: Optional[int] = None, learn_kwargs: Optional[Mapping] = None, ): """Trains the generator to maximize the discriminator loss. After the end of training populates the generator replay buffer (used in discriminator training) with `self.disc_batch_size` transitions. Args: total_timesteps: The number of transitions to sample from `self.venv_train` during training. By default, `self.gen_batch_size`. learn_kwargs: kwargs for the Stable Baselines `RLModel.learn()` method. """ if total_timesteps is None: total_timesteps = self.gen_batch_size if learn_kwargs is None: learn_kwargs = {} with logger.accumulate_means("gen"): self.gen_algo.learn( total_timesteps=total_timesteps, reset_num_timesteps=False, callback=self.gen_callback, **learn_kwargs, ) self._global_step += 1 gen_samples = self.venv_buffering.pop_transitions() self._gen_replay_buffer.store(gen_samples)
def train_disc_step( self, *, gen_samples: Optional[types.Transitions] = None, expert_samples: Optional[types.Transitions] = None, ) -> Dict[str, float]: """Perform a single discriminator update, optionally using provided samples. Args: gen_samples: Transition samples from the generator policy. If not provided, then take `self.disc_batch_size // 2` samples from the generator replay buffer. Observations should not be normalized. expert_samples: Transition samples from the expert. If not provided, then take `n_gen` expert samples from the expert dataset, where `n_gen` is the number of samples in `gen_samples`. Observations should not be normalized. Returns: dict: Statistics for discriminator (e.g. loss, accuracy). """ with logger.accumulate_means("disc"): # optionally write TB summaries for collected ops write_summaries = self._init_tensorboard and self._global_step % 20 == 0 # compute loss batch = self._make_disc_train_batch(gen_samples=gen_samples, expert_samples=expert_samples) disc_logits = self.discrim.logits_gen_is_high( batch["state"], batch["action"], batch["next_state"], batch["done"], batch["log_policy_act_prob"], ) loss = self.discrim.disc_loss(disc_logits, batch["labels_gen_is_one"]) # do gradient step self._disc_opt.zero_grad() loss.backward() self._disc_opt.step() self._disc_step += 1 # compute/write stats and TensorBoard data with th.no_grad(): train_stats = rew_common.compute_train_stats( disc_logits, batch["labels_gen_is_one"], loss) logger.record("global_step", self._global_step) for k, v in train_stats.items(): logger.record(k, v) logger.dump(self._disc_step) if write_summaries: self._summary_writer.add_histogram("disc_logits", disc_logits.detach()) return train_stats
def train_gen(self, total_timesteps: Optional[int] = None, learn_kwargs: Optional[dict] = None): """Trains the generator to maximize the discriminator loss. After the end of training populates the generator replay buffer (used in discriminator training) with `self.disc_batch_size` transitions. Args: total_timesteps: The number of transitions to sample from `self.venv_train_norm` during training. By default, `self.gen_batch_size`. learn_kwargs: kwargs for the Stable Baselines `RLModel.learn()` method. """ if total_timesteps is None: total_timesteps = self.gen_batch_size if learn_kwargs is None: learn_kwargs = {} with logger.accumulate_means("gen"): self.gen_policy.learn(total_timesteps=total_timesteps, reset_num_timesteps=False, **learn_kwargs) with logger.accumulate_means("gen_buffer"): # Log stats for finished trajectories stored in the BufferingWrapper. This # will bias toward shorter trajectories because trajectories that # are partially finished at the time of this log are popped from # the buffer a few lines down. # # This is useful for getting some statistics for unnormalized rewards. # (The rewards logged during the call to `.learn()` are the ground truth # rewards, retrieved from Monitor.). trajs = self.venv_train_norm_buffering._trajectories if len(trajs) > 0: stats = rollout.rollout_stats(trajs) for k, v in stats.items(): util.logger.logkv(k, v) gen_samples = self.venv_train_norm_buffering.pop_transitions() self._gen_replay_buffer.store(gen_samples)
def train_disc_step(self, *, gen_samples: Optional[rollout.Transitions] = None, expert_samples: Optional[rollout.Transitions] = None, ) -> None: """Perform a single discriminator update, optionally using provided samples. Args: gen_samples: Transition samples from the generator policy. If not provided, then take `self.disc_batch_size // 2` samples from the generator replay buffer. Observations should not be normalized. expert_samples: Transition samples from the expert. If not provided, then take `n_gen` expert samples from the expert dataset, where `n_gen` is the number of samples in `gen_samples`. Observations should not be normalized. """ with logger.accumulate_means("disc"): fetches = { 'train_op_out': self._disc_train_op, 'train_stats': self._discrim.train_stats, } #if not self.update_discr: #del fetches['train_op_out'] # optionally write TB summaries for collected ops step = self._sess.run(self._global_step) write_summaries = self._init_tensorboard and step % 20 == 0 if write_summaries: fetches['events'] = self._summary_op # do actual update fd = self._build_disc_feed_dict(gen_samples=gen_samples, expert_samples=expert_samples) fetched = self._sess.run(fetches, feed_dict=fd) self._discrim.clip_params() if write_summaries: self._summary_writer.add_summary(fetched['events'], fetched['step']) logger.logkv("step", step) for k, v in fetched['train_stats'].items(): logger.logkv(k, v) logger.dumpkvs()
def train_disc( self, *, expert_samples: Optional[Mapping] = None, gen_samples: Optional[Mapping] = None, ) -> Dict[str, float]: """Perform a single discriminator update, optionally using provided samples. Args: expert_samples: Transition samples from the expert in dictionary form. If provided, must contain keys corresponding to every field of the `Transitions` dataclass except "infos". All corresponding values can be either NumPy arrays or Tensors. Extra keys are ignored. Must contain `self.expert_batch_size` samples. If this argument is not provided, then `self.expert_batch_size` expert samples from `self.expert_data_loader` are used by default. gen_samples: Transition samples from the generator policy in same dictionary form as `expert_samples`. If provided, must contain exactly `self.expert_batch_size` samples. If not provided, then take `len(expert_samples)` samples from the generator replay buffer. Returns: dict: Statistics for discriminator (e.g. loss, accuracy). """ with logger.accumulate_means("disc"): # optionally write TB summaries for collected ops write_summaries = self._init_tensorboard and self._global_step % 20 == 0 # compute loss batch = self._make_disc_train_batch( gen_samples=gen_samples, expert_samples=expert_samples ) disc_logits = self.discrim.logits_gen_is_high( batch["state"], batch["action"], batch["next_state"], batch["done"], batch["log_policy_act_prob"], ) loss = self.discrim.disc_loss(disc_logits, batch["labels_gen_is_one"]) # do gradient step self._disc_opt.zero_grad() loss.backward() self._disc_opt.step() self._disc_step += 1 # compute/write stats and TensorBoard data with th.no_grad(): train_stats = rew_common.compute_train_stats( disc_logits, batch["labels_gen_is_one"], loss ) logger.record("global_step", self._global_step) for k, v in train_stats.items(): logger.record(k, v) logger.dump(self._disc_step) if write_summaries: self._summary_writer.add_histogram("disc_logits", disc_logits.detach()) return train_stats
def test_hard(tmpdir): logger.configure(tmpdir) # Part One: Test logging outside of the accumulating scope, and within scopes # with two different different logging keys (including a repeat). sb_logger.record("no_context", 1) with logger.accumulate_means("disc"): sb_logger.record("C", 2) sb_logger.record("D", 2) sb_logger.dump() sb_logger.record("C", 4) sb_logger.dump() with logger.accumulate_means("gen"): sb_logger.record("E", 2) sb_logger.dump() sb_logger.record("E", 0) sb_logger.dump() with logger.accumulate_means("disc"): sb_logger.record("C", 3) sb_logger.dump() sb_logger.dump() # Writes 1 mean each from "gen" and "disc". expect_raw_gen = {"raw/gen/E": [2, 0]} expect_raw_disc = { "raw/disc/C": [2, 4, 3], "raw/disc/D": [2, "", ""], } expect_default = { "mean/gen/E": [1], "mean/disc/C": [3], "mean/disc/D": [2], "no_context": [1], } _compare_csv_lines(osp.join(tmpdir, "progress.csv"), expect_default) _compare_csv_lines(osp.join(tmpdir, "raw", "gen", "progress.csv"), expect_raw_gen) _compare_csv_lines(osp.join(tmpdir, "raw", "disc", "progress.csv"), expect_raw_disc) # Part Two: # Check that we append to the same logs after the first dump to "means/*". with logger.accumulate_means("disc"): sb_logger.record("D", 100) sb_logger.dump() sb_logger.record("no_context", 2) sb_logger.dump() # Writes 1 mean from "disc". "gen" is blank. expect_raw_gen = {"raw/gen/E": [2, 0]} expect_raw_disc = { "raw/disc/C": [2, 4, 3, ""], "raw/disc/D": [2, "", "", 100], } expect_default = { "mean/gen/E": [1, ""], "mean/disc/C": [3, ""], "mean/disc/D": [2, 100], "no_context": [1, 2], } _compare_csv_lines(osp.join(tmpdir, "progress.csv"), expect_default) _compare_csv_lines(osp.join(tmpdir, "raw", "gen", "progress.csv"), expect_raw_gen) _compare_csv_lines(osp.join(tmpdir, "raw", "disc", "progress.csv"), expect_raw_disc)