def train( self, total_timesteps: int, callback: Optional[Callable[[int], None]] = None, ) -> None: """Alternates between training the generator and discriminator. Every "round" consists of a call to `train_gen(self.gen_batch_size)`, a call to `train_disc`, and finally a call to `callback(round)`. Training ends once an additional "round" would cause the number of transitions sampled from the environment to exceed `total_timesteps`. Args: total_timesteps: An upper bound on the number of transitions to sample from the environment during training. callback: A function called at the end of every round which takes in a single argument, the round number. Round numbers are in `range(total_timesteps // self.gen_batch_size)`. """ n_rounds = total_timesteps // self.gen_batch_size assert n_rounds >= 1, ( "No updates (need at least " f"{self.gen_batch_size} timesteps, have only " f"total_timesteps={total_timesteps})!" ) for r in tqdm.tqdm(range(0, n_rounds), desc="round"): self.train_gen(self.gen_batch_size) for _ in range(self.n_disc_updates_per_round): self.train_disc() if callback: callback(r) logger.dump(self._global_step)
def train_disc_step( self, *, gen_samples: Optional[types.Transitions] = None, expert_samples: Optional[types.Transitions] = None, ) -> Dict[str, float]: """Perform a single discriminator update, optionally using provided samples. Args: gen_samples: Transition samples from the generator policy. If not provided, then take `self.disc_batch_size // 2` samples from the generator replay buffer. Observations should not be normalized. expert_samples: Transition samples from the expert. If not provided, then take `n_gen` expert samples from the expert dataset, where `n_gen` is the number of samples in `gen_samples`. Observations should not be normalized. Returns: dict: Statistics for discriminator (e.g. loss, accuracy). """ with logger.accumulate_means("disc"): # optionally write TB summaries for collected ops write_summaries = self._init_tensorboard and self._global_step % 20 == 0 # compute loss batch = self._make_disc_train_batch(gen_samples=gen_samples, expert_samples=expert_samples) disc_logits = self.discrim.logits_gen_is_high( batch["state"], batch["action"], batch["next_state"], batch["done"], batch["log_policy_act_prob"], ) loss = self.discrim.disc_loss(disc_logits, batch["labels_gen_is_one"]) # do gradient step self._disc_opt.zero_grad() loss.backward() self._disc_opt.step() self._disc_step += 1 # compute/write stats and TensorBoard data with th.no_grad(): train_stats = rew_common.compute_train_stats( disc_logits, batch["labels_gen_is_one"], loss) logger.record("global_step", self._global_step) for k, v in train_stats.items(): logger.record(k, v) logger.dump(self._disc_step) if write_summaries: self._summary_writer.add_histogram("disc_logits", disc_logits.detach()) return train_stats
def train_disc( self, *, expert_samples: Optional[Mapping] = None, gen_samples: Optional[Mapping] = None, ) -> Dict[str, float]: """Perform a single discriminator update, optionally using provided samples. Args: expert_samples: Transition samples from the expert in dictionary form. If provided, must contain keys corresponding to every field of the `Transitions` dataclass except "infos". All corresponding values can be either NumPy arrays or Tensors. Extra keys are ignored. Must contain `self.expert_batch_size` samples. If this argument is not provided, then `self.expert_batch_size` expert samples from `self.expert_data_loader` are used by default. gen_samples: Transition samples from the generator policy in same dictionary form as `expert_samples`. If provided, must contain exactly `self.expert_batch_size` samples. If not provided, then take `len(expert_samples)` samples from the generator replay buffer. Returns: dict: Statistics for discriminator (e.g. loss, accuracy). """ with logger.accumulate_means("disc"): # optionally write TB summaries for collected ops write_summaries = self._init_tensorboard and self._global_step % 20 == 0 # compute loss batch = self._make_disc_train_batch( gen_samples=gen_samples, expert_samples=expert_samples ) disc_logits = self.discrim.logits_gen_is_high( batch["state"], batch["action"], batch["next_state"], batch["done"], batch["log_policy_act_prob"], ) loss = self.discrim.disc_loss(disc_logits, batch["labels_gen_is_one"]) # do gradient step self._disc_opt.zero_grad() loss.backward() self._disc_opt.step() self._disc_step += 1 # compute/write stats and TensorBoard data with th.no_grad(): train_stats = rew_common.compute_train_stats( disc_logits, batch["labels_gen_is_one"], loss ) logger.record("global_step", self._global_step) for k, v in train_stats.items(): logger.record(k, v) logger.dump(self._disc_step) if write_summaries: self._summary_writer.add_histogram("disc_logits", disc_logits.detach()) return train_stats