Exemplo n.º 1
0
    def train_disc_step(
        self,
        *,
        gen_samples: Optional[types.Transitions] = None,
        expert_samples: Optional[types.Transitions] = None,
    ) -> Dict[str, float]:
        """Perform a single discriminator update, optionally using provided samples.

        Args:
          gen_samples: Transition samples from the generator policy. If not
            provided, then take `self.disc_batch_size // 2` samples from the
            generator replay buffer. Observations should not be normalized.
          expert_samples: Transition samples from the expert. If not
            provided, then take `n_gen` expert samples from the expert
            dataset, where `n_gen` is the number of samples in `gen_samples`.
            Observations should not be normalized.

        Returns:
           dict: Statistics for discriminator (e.g. loss, accuracy).
        """
        with logger.accumulate_means("disc"):
            # optionally write TB summaries for collected ops
            write_summaries = self._init_tensorboard and self._global_step % 20 == 0

            # compute loss
            batch = self._make_disc_train_batch(gen_samples=gen_samples,
                                                expert_samples=expert_samples)
            disc_logits = self.discrim.logits_gen_is_high(
                batch["state"],
                batch["action"],
                batch["next_state"],
                batch["done"],
                batch["log_policy_act_prob"],
            )
            loss = self.discrim.disc_loss(disc_logits,
                                          batch["labels_gen_is_one"])

            # do gradient step
            self._disc_opt.zero_grad()
            loss.backward()
            self._disc_opt.step()
            self._disc_step += 1

            # compute/write stats and TensorBoard data
            with th.no_grad():
                train_stats = rew_common.compute_train_stats(
                    disc_logits, batch["labels_gen_is_one"], loss)
            logger.record("global_step", self._global_step)
            for k, v in train_stats.items():
                logger.record(k, v)
            logger.dump(self._disc_step)
            if write_summaries:
                self._summary_writer.add_histogram("disc_logits",
                                                   disc_logits.detach())

        return train_stats
Exemplo n.º 2
0
def test_compute_train_stats(n_samples):
    disc_logits_gen_is_high = th.from_numpy(
        np.random.standard_normal([n_samples]) * 10)
    labels_gen_is_one = th.from_numpy(np.random.randint(2, size=[n_samples]))
    disc_loss = th.tensor(np.random.random() * 10)
    stats = common.compute_train_stats(disc_logits_gen_is_high,
                                       labels_gen_is_one, disc_loss)
    for k, v in stats.items():
        assert isinstance(k, str)
        assert isinstance(v, float)
Exemplo n.º 3
0
    def train_disc(
        self,
        *,
        expert_samples: Optional[Mapping] = None,
        gen_samples: Optional[Mapping] = None,
    ) -> Dict[str, float]:
        """Perform a single discriminator update, optionally using provided samples.

        Args:
            expert_samples: Transition samples from the expert in dictionary form.
                If provided, must contain keys corresponding to every field of the
                `Transitions` dataclass except "infos". All corresponding values can be
                either NumPy arrays or Tensors. Extra keys are ignored. Must contain
                `self.expert_batch_size` samples.

                If this argument is not provided, then `self.expert_batch_size` expert
                samples from `self.expert_data_loader` are used by default.
            gen_samples: Transition samples from the generator policy in same dictionary
                form as `expert_samples`. If provided, must contain exactly
                `self.expert_batch_size` samples. If not provided, then take
                `len(expert_samples)` samples from the generator replay buffer.

        Returns:
           dict: Statistics for discriminator (e.g. loss, accuracy).
        """
        with logger.accumulate_means("disc"):
            # optionally write TB summaries for collected ops
            write_summaries = self._init_tensorboard and self._global_step % 20 == 0

            # compute loss
            batch = self._make_disc_train_batch(
                gen_samples=gen_samples, expert_samples=expert_samples
            )
            disc_logits = self.discrim.logits_gen_is_high(
                batch["state"],
                batch["action"],
                batch["next_state"],
                batch["done"],
                batch["log_policy_act_prob"],
            )
            loss = self.discrim.disc_loss(disc_logits, batch["labels_gen_is_one"])

            # do gradient step
            self._disc_opt.zero_grad()
            loss.backward()
            self._disc_opt.step()
            self._disc_step += 1

            # compute/write stats and TensorBoard data
            with th.no_grad():
                train_stats = rew_common.compute_train_stats(
                    disc_logits, batch["labels_gen_is_one"], loss
                )
            logger.record("global_step", self._global_step)
            for k, v in train_stats.items():
                logger.record(k, v)
            logger.dump(self._disc_step)
            if write_summaries:
                self._summary_writer.add_histogram("disc_logits", disc_logits.detach())

        return train_stats