def train_disc_step( self, *, gen_samples: Optional[types.Transitions] = None, expert_samples: Optional[types.Transitions] = None, ) -> Dict[str, float]: """Perform a single discriminator update, optionally using provided samples. Args: gen_samples: Transition samples from the generator policy. If not provided, then take `self.disc_batch_size // 2` samples from the generator replay buffer. Observations should not be normalized. expert_samples: Transition samples from the expert. If not provided, then take `n_gen` expert samples from the expert dataset, where `n_gen` is the number of samples in `gen_samples`. Observations should not be normalized. Returns: dict: Statistics for discriminator (e.g. loss, accuracy). """ with logger.accumulate_means("disc"): # optionally write TB summaries for collected ops write_summaries = self._init_tensorboard and self._global_step % 20 == 0 # compute loss batch = self._make_disc_train_batch(gen_samples=gen_samples, expert_samples=expert_samples) disc_logits = self.discrim.logits_gen_is_high( batch["state"], batch["action"], batch["next_state"], batch["done"], batch["log_policy_act_prob"], ) loss = self.discrim.disc_loss(disc_logits, batch["labels_gen_is_one"]) # do gradient step self._disc_opt.zero_grad() loss.backward() self._disc_opt.step() self._disc_step += 1 # compute/write stats and TensorBoard data with th.no_grad(): train_stats = rew_common.compute_train_stats( disc_logits, batch["labels_gen_is_one"], loss) logger.record("global_step", self._global_step) for k, v in train_stats.items(): logger.record(k, v) logger.dump(self._disc_step) if write_summaries: self._summary_writer.add_histogram("disc_logits", disc_logits.detach()) return train_stats
def test_compute_train_stats(n_samples): disc_logits_gen_is_high = th.from_numpy( np.random.standard_normal([n_samples]) * 10) labels_gen_is_one = th.from_numpy(np.random.randint(2, size=[n_samples])) disc_loss = th.tensor(np.random.random() * 10) stats = common.compute_train_stats(disc_logits_gen_is_high, labels_gen_is_one, disc_loss) for k, v in stats.items(): assert isinstance(k, str) assert isinstance(v, float)
def train_disc( self, *, expert_samples: Optional[Mapping] = None, gen_samples: Optional[Mapping] = None, ) -> Dict[str, float]: """Perform a single discriminator update, optionally using provided samples. Args: expert_samples: Transition samples from the expert in dictionary form. If provided, must contain keys corresponding to every field of the `Transitions` dataclass except "infos". All corresponding values can be either NumPy arrays or Tensors. Extra keys are ignored. Must contain `self.expert_batch_size` samples. If this argument is not provided, then `self.expert_batch_size` expert samples from `self.expert_data_loader` are used by default. gen_samples: Transition samples from the generator policy in same dictionary form as `expert_samples`. If provided, must contain exactly `self.expert_batch_size` samples. If not provided, then take `len(expert_samples)` samples from the generator replay buffer. Returns: dict: Statistics for discriminator (e.g. loss, accuracy). """ with logger.accumulate_means("disc"): # optionally write TB summaries for collected ops write_summaries = self._init_tensorboard and self._global_step % 20 == 0 # compute loss batch = self._make_disc_train_batch( gen_samples=gen_samples, expert_samples=expert_samples ) disc_logits = self.discrim.logits_gen_is_high( batch["state"], batch["action"], batch["next_state"], batch["done"], batch["log_policy_act_prob"], ) loss = self.discrim.disc_loss(disc_logits, batch["labels_gen_is_one"]) # do gradient step self._disc_opt.zero_grad() loss.backward() self._disc_opt.step() self._disc_step += 1 # compute/write stats and TensorBoard data with th.no_grad(): train_stats = rew_common.compute_train_stats( disc_logits, batch["labels_gen_is_one"], loss ) logger.record("global_step", self._global_step) for k, v in train_stats.items(): logger.record(k, v) logger.dump(self._disc_step) if write_summaries: self._summary_writer.add_histogram("disc_logits", disc_logits.detach()) return train_stats