def sample_posterior(self, images, actions, step_types, features=None):
    sequence_length = step_types.shape[1].value - 1
    actions = actions[:, :sequence_length]

    if features is None:
      features = self.compressor(images)

    # swap batch and time axes
    features = tf.transpose(features, [1, 0, 2])
    actions = tf.transpose(actions, [1, 0, 2])
    step_types = tf.transpose(step_types, [1, 0])

    latent_dists = []
    latent_samples = []
    for t in range(sequence_length + 1):
      if t == 0:
        latent_dist = self.latent_first_posterior(features[t])
        latent_sample = latent_dist.sample()
      else:
        reset_mask = tf.equal(step_types[t], ts.StepType.FIRST)
        latent_first_dist = self.latent_first_posterior(features[t])
        latent_dist = self.latent_posterior(features[t], latent_samples[t-1], actions[t-1])
        latent_dist = nest_utils.map_distribution_structure(
            functools.partial(tf.where, reset_mask), latent_first_dist, latent_dist)
        latent_sample = latent_dist.sample()

      latent_dists.append(latent_dist)
      latent_samples.append(latent_sample)

    latent_dists = nest_utils.map_distribution_structure(lambda *x: tf.stack(x, axis=1), *latent_dists)
    latent_samples = tf.stack(latent_samples, axis=1)
    return latent_samples, latent_dists
  def sample_prior_or_posterior(self, actions, step_types=None, images=None):
    """Samples from the prior, except for the first time steps in which conditioning images are given."""
    if step_types is None:
      batch_size = tf.shape(actions)[0]
      sequence_length = actions.shape[1].value  # should be statically defined
      step_types = tf.fill(
          [batch_size, sequence_length + 1], ts.StepType.MID)
    else:
      sequence_length = step_types.shape[1].value - 1
      actions = actions[:, :sequence_length]
    if images is not None:
      features = self.compressor(images)

    # swap batch and time axes
    actions = tf.transpose(actions, [1, 0, 2])
    step_types = tf.transpose(step_types, [1, 0])
    if images is not None:
      features = tf.transpose(features, [1, 0, 2])

    latent_dists = []
    latent_samples = []
    for t in range(sequence_length + 1):
      is_conditional = images is not None and (t < images.shape[1].value)
      if t == 0:
        if is_conditional:
          latent_dist = self.latent_first_posterior(features[t])
        else:
          latent_dist = self.latent_first_prior(step_types[t])  # step_types is only used to infer batch_size
        latent_sample = latent_dist.sample()
      else:
        reset_mask = tf.equal(step_types[t], ts.StepType.FIRST)
        if is_conditional:
          latent_first_dist = self.latent_first_posterior(features[t])
          latent_dist = self.latent_posterior(features[t], latent_samples[t-1], actions[t-1])
        else:
          latent_first_dist = self.latent_first_prior(step_types[t])
          latent_dist = self.latent_prior(latent_samples[t-1], actions[t-1])
        latent_dist = nest_utils.map_distribution_structure(
            functools.partial(tf.where, reset_mask), latent_first_dist, latent_dist)
        latent_sample = latent_dist.sample()

      latent_dists.append(latent_dist)
      latent_samples.append(latent_sample)

    latent_dists = nest_utils.map_distribution_structure(lambda *x: tf.stack(x, axis=1), *latent_dists)
    latent_samples = tf.stack(latent_samples, axis=1)
    return latent_samples, latent_dists
Exemple #3
0
    def compute_loss(self,
                     images,
                     actions,
                     step_types,
                     rewards=None,
                     discounts=None,
                     latent_posterior_samples_and_dists=None):
        sequence_length = step_types.shape[1].value - 1

        if latent_posterior_samples_and_dists is None:
            latent_posterior_samples_and_dists = self.sample_posterior(
                images, actions, step_types)
        (latent1_posterior_samples, latent2_posterior_samples), (
            latent1_posterior_dists,
            latent2_posterior_dists) = (latent_posterior_samples_and_dists)
        (latent1_prior_samples,
         latent2_prior_samples), _ = self.sample_prior_or_posterior(
             actions, step_types)  # for visualization
        (latent1_conditional_prior_samples, latent2_conditional_prior_samples
         ), _ = self.sample_prior_or_posterior(
             actions, step_types, images=images[:, :1]
         )  # for visualization. condition on first image only

        def where_and_concat(reset_masks, first_prior_tensors,
                             after_first_prior_tensors):
            after_first_prior_tensors = tf.where(reset_masks[:, 1:],
                                                 first_prior_tensors[:, 1:],
                                                 after_first_prior_tensors)
            prior_tensors = tf.concat(
                [first_prior_tensors[:, 0:1], after_first_prior_tensors],
                axis=1)
            return prior_tensors

        reset_masks = tf.concat([
            tf.ones_like(step_types[:, 0:1], dtype=tf.bool),
            tf.equal(step_types[:, 1:], ts.StepType.FIRST)
        ],
                                axis=1)

        latent1_reset_masks = tf.tile(reset_masks[:, :, None],
                                      [1, 1, self.latent1_size])
        latent1_first_prior_dists = self.latent1_first_prior(step_types)
        # these distributions start at t=1 and the inputs are from t-1
        latent1_after_first_prior_dists = self.latent1_prior(
            latent2_posterior_samples[:, :sequence_length],
            actions[:, :sequence_length])
        latent1_prior_dists = nest_utils.map_distribution_structure(
            functools.partial(where_and_concat, latent1_reset_masks),
            latent1_first_prior_dists, latent1_after_first_prior_dists)

        latent2_reset_masks = tf.tile(reset_masks[:, :, None],
                                      [1, 1, self.latent2_size])
        latent2_first_prior_dists = self.latent2_first_prior(
            latent1_posterior_samples)
        # these distributions start at t=1 and the last 2 inputs are from t-1
        latent2_after_first_prior_dists = self.latent2_prior(
            latent1_posterior_samples[:, 1:sequence_length + 1],
            latent2_posterior_samples[:, :sequence_length],
            actions[:, :sequence_length])
        latent2_prior_dists = nest_utils.map_distribution_structure(
            functools.partial(where_and_concat, latent2_reset_masks),
            latent2_first_prior_dists, latent2_after_first_prior_dists)

        outputs = {}

        if self.kl_analytic:
            latent1_kl_divergences = tfd.kl_divergence(latent1_posterior_dists,
                                                       latent1_prior_dists)
        else:
            latent1_kl_divergences = (
                latent1_posterior_dists.log_prob(latent1_posterior_samples) -
                latent1_prior_dists.log_prob(latent1_posterior_samples))
        latent1_kl_divergences = tf.reduce_sum(latent1_kl_divergences, axis=1)
        outputs.update({
            'latent1_kl_divergence':
            tf.reduce_mean(latent1_kl_divergences),
        })
        if self.latent2_posterior == self.latent2_prior:
            latent2_kl_divergences = 0.0
        else:
            if self.kl_analytic:
                latent2_kl_divergences = tfd.kl_divergence(
                    latent2_posterior_dists, latent2_prior_dists)
            else:
                latent2_kl_divergences = (
                    latent2_posterior_dists.log_prob(latent2_posterior_samples)
                    - latent2_prior_dists.log_prob(latent2_posterior_samples))
            latent2_kl_divergences = tf.reduce_sum(latent2_kl_divergences,
                                                   axis=1)
        outputs.update({
            'latent2_kl_divergence':
            tf.reduce_mean(latent2_kl_divergences),
        })
        outputs.update({
            'kl_divergence':
            tf.reduce_mean(latent1_kl_divergences + latent2_kl_divergences),
        })

        likelihood_dists = self.decoder(latent1_posterior_samples,
                                        latent2_posterior_samples)
        likelihood_log_probs = likelihood_dists.log_prob(images)
        likelihood_log_probs = tf.reduce_sum(likelihood_log_probs, axis=1)
        reconstruction_error = tf.reduce_sum(
            tf.square(images - likelihood_dists.distribution.loc),
            axis=list(range(-len(likelihood_dists.event_shape), 0)))
        reconstruction_error = tf.reduce_sum(reconstruction_error, axis=1)
        outputs.update({
            'log_likelihood':
            tf.reduce_mean(likelihood_log_probs),
            'reconstruction_error':
            tf.reduce_mean(reconstruction_error),
        })

        # summed over the time dimension
        elbo = likelihood_log_probs - latent1_kl_divergences - latent2_kl_divergences

        if self.model_reward:
            reward_dists = self.reward_predictor(
                latent1_posterior_samples[:, :sequence_length],
                latent2_posterior_samples[:, :sequence_length],
                actions[:, :sequence_length],
                latent1_posterior_samples[:, 1:sequence_length + 1],
                latent2_posterior_samples[:, 1:sequence_length + 1])
            reward_valid_mask = tf.cast(
                tf.not_equal(step_types[:, :sequence_length],
                             ts.StepType.LAST), tf.float32)
            reward_log_probs = reward_dists.log_prob(
                rewards[:, :sequence_length])
            reward_log_probs = tf.reduce_sum(reward_log_probs *
                                             reward_valid_mask,
                                             axis=1)
            reward_reconstruction_error = tf.square(
                rewards[:, :sequence_length] - reward_dists.loc)
            reward_reconstruction_error = tf.reduce_sum(
                reward_reconstruction_error * reward_valid_mask, axis=1)
            outputs.update({
                'reward_log_likelihood':
                tf.reduce_mean(reward_log_probs),
                'reward_reconstruction_error':
                tf.reduce_mean(reward_reconstruction_error),
            })
            elbo += reward_log_probs

        if self.model_discount:
            discount_dists = self.discount_predictor(
                latent1_posterior_samples[:, 1:sequence_length + 1],
                latent2_posterior_samples[:, 1:sequence_length + 1])
            discount_log_probs = discount_dists.log_prob(
                discounts[:, :sequence_length])
            discount_log_probs = tf.reduce_sum(discount_log_probs, axis=1)
            discount_accuracy = tf.cast(
                tf.equal(tf.cast(discount_dists.mode(), tf.float32),
                         discounts[:, :sequence_length]), tf.float32)
            discount_accuracy = tf.reduce_sum(discount_accuracy, axis=1)
            outputs.update({
                'discount_log_likelihood':
                tf.reduce_mean(discount_log_probs),
                'discount_accuracy':
                tf.reduce_mean(discount_accuracy),
            })
            elbo += discount_log_probs

        # average over the batch dimension
        loss = -tf.reduce_mean(elbo)

        posterior_images = likelihood_dists.mean()
        prior_images = self.decoder(latent1_prior_samples,
                                    latent2_prior_samples).mean()
        conditional_prior_images = self.decoder(
            latent1_conditional_prior_samples,
            latent2_conditional_prior_samples).mean()

        outputs.update({
            'elbo': tf.reduce_mean(elbo),
            'images': images,
            'posterior_images': posterior_images,
            'prior_images': prior_images,
            'conditional_prior_images': conditional_prior_images,
        })
        return loss, outputs
Exemple #4
0
    def model_loss(self,
                   images,
                   actions,
                   step_types,
                   rewards,
                   discounts,
                   latent_posterior_samples_and_dists=None,
                   weights=None):
        with tf.name_scope('model_loss'):
            if self._model_batch_size is not None:
                # Allow model batch size to be smaller than the batch size of the
                # other losses. This is because the model loss already gets a lot of
                # supervision from having a loss over all time steps.
                images, actions, step_types, rewards, discounts = tf.nest.map_structure(
                    lambda x: x[:self._model_batch_size],
                    (images, actions, step_types, rewards, discounts))
                if latent_posterior_samples_and_dists is not None:
                    latent_posterior_samples, latent_posterior_dists = latent_posterior_samples_and_dists
                    latent_posterior_samples = tf.nest.map_structure(
                        lambda x: x[:self._model_batch_size],
                        latent_posterior_samples)
                    latent_posterior_dists = slac_nest_utils.map_distribution_structure(
                        lambda x: x[:self._model_batch_size],
                        latent_posterior_dists)
                    latent_posterior_samples_and_dists = (
                        latent_posterior_samples, latent_posterior_dists)

            model_loss, outputs = self._model_network.compute_loss(
                images,
                actions,
                step_types,
                rewards=rewards,
                discounts=discounts,
                latent_posterior_samples_and_dists=
                latent_posterior_samples_and_dists)
            for name, output in outputs.items():
                if output.shape.ndims == 0:
                    tf.contrib.summary.scalar(name, output)
                elif output.shape.ndims == 5:
                    fps = 10 if self._control_timestep is None else int(
                        np.round(1.0 / self._control_timestep))
                    if self._debug_summaries:
                        _gif_summary(name + '/original',
                                     output[:self._num_images_per_summary],
                                     fps,
                                     step=self.train_step_counter)
                    _gif_summary(name,
                                 output[:self._num_images_per_summary],
                                 fps,
                                 saturate=True,
                                 step=self.train_step_counter)
                else:
                    raise NotImplementedError

            if weights is not None:
                model_loss *= weights

            model_loss = tf.reduce_mean(input_tensor=model_loss)

            if self._debug_summaries:
                common.generate_tensor_summaries('model_loss', model_loss,
                                                 self.train_step_counter)

            return model_loss