Exemplo n.º 1
0
 def _compute_target(self, imag_feat, imag_state, imag_action, reward,
                     actor_ent, state_ent, slow):
     reward = tf.cast(reward, tf.float32)
     if 'discount' in self._world_model.heads:
         inp = self._world_model.dynamics.get_feat(imag_state)
         discount = self._world_model.heads['discount'](inp,
                                                        tf.float32).mean()
     else:
         discount = self._config.discount * tf.ones_like(reward)
     if self._config.future_entropy and tf.greater(
             self._config.actor_entropy(), 0):
         reward += self._config.actor_entropy() * actor_ent
     if self._config.future_entropy and tf.greater(
             self._config.actor_state_entropy(), 0):
         reward += self._config.actor_state_entropy() * state_ent
     if slow:
         value = self._slow_value(imag_feat, tf.float32).mode()
     else:
         value = self.value(imag_feat, tf.float32).mode()
     target = tools.lambda_return(reward[:-1],
                                  value[:-1],
                                  discount[:-1],
                                  bootstrap=value[-1],
                                  lambda_=self._config.discount_lambda,
                                  axis=0)
     weights = tf.stop_gradient(
         tf.math.cumprod(
             tf.concat([tf.ones_like(discount[:1]), discount[:-1]], 0), 0))
     return target, weights
Exemplo n.º 2
0
    def _train(self, data, log_images):
        with tf.GradientTape() as model_tape:
            embed = self._encode(data)
            post, prior = self._dynamics.observe(embed, data['action'])
            feat = self._dynamics.get_feat(post)
            image_pred = self._decode(feat)
            reward_pred = self._reward(feat)
            likes = tools.AttrDict()
            likes.image = tf.reduce_mean(image_pred.log_prob(data[self._c.obs_type]))
            likes.reward = tf.reduce_mean(reward_pred.log_prob(data['reward']))
            if self._c.pcont:
                pcont_pred = self._pcont(feat)
                pcont_target = self._c.discount * data['discount']
                likes.pcont = tf.reduce_mean(pcont_pred.log_prob(pcont_target))
                likes.pcont *= self._c.pcont_scale
            prior_dist = self._dynamics.get_dist(prior)
            post_dist = self._dynamics.get_dist(post)
            div = tf.reduce_mean(tfd.kl_divergence(post_dist, prior_dist))
            div = tf.maximum(div, self._c.free_nats)
            model_loss = self._c.kl_scale * div - sum(likes.values())
            model_loss /= float(self._strategy.num_replicas_in_sync)

        with tf.GradientTape() as actor_tape:
            imag_feat = self._imagine_ahead(post)
            reward = tf.cast(self._reward(imag_feat).mode(), 'float')  # cast: to address the output of bernoulli
            if self._c.pcont:
                pcont = self._pcont(imag_feat).mean()
            else:
                pcont = self._c.discount * tf.ones_like(reward)
            value = self._value(imag_feat).mode()
            returns = tools.lambda_return(
                reward[:-1], value[:-1], pcont[:-1],
                bootstrap=value[-1], lambda_=self._c.disclam, axis=0)
            discount = tf.stop_gradient(tf.math.cumprod(tf.concat(
                [tf.ones_like(pcont[:1]), pcont[:-2]], 0), 0))
            actor_loss = -tf.reduce_mean(discount * returns)
            actor_loss /= float(self._strategy.num_replicas_in_sync)

        with tf.GradientTape() as value_tape:
            value_pred = self._value(imag_feat)[:-1]
            target = tf.stop_gradient(returns)
            value_loss = -tf.reduce_mean(discount * value_pred.log_prob(target))
            value_loss /= float(self._strategy.num_replicas_in_sync)

        model_norm = self._model_opt(model_tape, model_loss)
        actor_norm = self._actor_opt(actor_tape, actor_loss)
        value_norm = self._value_opt(value_tape, value_loss)

        if tf.distribute.get_replica_context().replica_id_in_sync_group == 0:
            if self._c.log_scalars:
                self._scalar_summaries(
                    data, feat, prior_dist, post_dist, likes, div,
                    model_loss, value_loss, actor_loss, model_norm, value_norm,
                    actor_norm)
            if tf.equal(log_images, True):
                self._image_summaries(data, embed, image_pred)
                self._reward_summaries(data, reward_pred)
Exemplo n.º 3
0
    def _train(self, data, log_images):
        with tf.GradientTape() as model_tape:
            embed = self._encode(data)
            post, prior = self._dynamics.observe(embed, data['action'])
            feat = self._dynamics.get_feat(post)
            image_pred = self._decode(feat)
            reward_pred = self._reward(feat)
            likes = tools.AttrDict()
            likes.image = tf.reduce_mean(image_pred.log_prob(data['laser']))
            likes.reward = tf.reduce_mean(reward_pred.log_prob(data['reward']))
            if self._c.pcont:
                pcont_pred = self._pcont(feat)
                pcont_target = self._c.discount * data['discount']
                likes.pcont = tf.reduce_mean(pcont_pred.log_prob(pcont_target))
                likes.pcont *= self._c.pcont_scale
            prior_dist = self._dynamics.get_dist(prior)
            post_dist = self._dynamics.get_dist(post)
            div = tf.reduce_mean(tfd.kl_divergence(post_dist, prior_dist))
            div = tf.maximum(div, self._c.free_nats)
            model_loss = self._c.kl_scale * div - sum(likes.values())

        with tf.GradientTape() as actor_tape:
            imag_feat = self._imagine_ahead(post)
            reward = self._reward(imag_feat).mode()
            if self._c.pcont:
                pcont = self._pcont(imag_feat).mean()
            else:
                pcont = self._c.discount * tf.ones_like(reward)
            value = self._value(imag_feat).mode()
            returns = tools.lambda_return(reward[:-1],
                                          value[:-1],
                                          pcont[:-1],
                                          bootstrap=value[-1],
                                          lambda_=self._c.disclam,
                                          axis=0)
            discount = tf.stop_gradient(
                tf.math.cumprod(
                    tf.concat([tf.ones_like(pcont[:1]), pcont[:-2]], 0), 0))
            actor_loss = -tf.reduce_mean(discount * returns)

        with tf.GradientTape() as value_tape:
            value_pred = self._value(imag_feat)[:-1]
            target = tf.stop_gradient(returns)
            value_loss = -tf.reduce_mean(
                discount * value_pred.log_prob(target))

        model_norm = self._model_opt(model_tape, model_loss)
        actor_norm = self._actor_opt(actor_tape, actor_loss)
        value_norm = self._value_opt(value_tape, value_loss)

        if self._c.log_scalars:
            self._scalar_summaries(data, feat, prior_dist, post_dist, likes,
                                   div, model_loss, value_loss, actor_loss,
                                   model_norm, value_norm, actor_norm)
Exemplo n.º 4
0
    def _trajectory_optimization(self, post):
        def policy(state):
            return self._actor(tf.stop_gradient(
                self._dynamics.get_feat(state))).sample()

        def repeat(x):
            return tf.repeat(x, self._c.num_samples, axis=0)

        states, actions = tools.static_scan_action(
            lambda prev, action, _: self._dynamics.img_step(prev, action),
            lambda prev: policy(prev), tf.range(self._c.horizon), post)

        feat = self._dynamics.get_feat(states)
        reward = self._reward(feat).mode()

        if self._c.pcont:
            pcont = self._pcont(feat).mean()
        else:
            pcont = self._c.discount * tf.ones_like(reward)
        value = self._value(feat).mode()

        # compute the accumulated reward
        returns = tools.lambda_return(reward[:-1],
                                      value[:-1],
                                      pcont[:-1],
                                      bootstrap=value[-1],
                                      lambda_=self._c.disclam,
                                      axis=0)

        accumulated_reward = returns[0, 0]

        # since the reward and latent dynamics are fully differentiable, we can backprop the gradients to update the actions
        grad = tf.gradients(accumulated_reward, actions)[0]
        act = actions + grad * self._c.traj_opt_lr

        return act
Exemplo n.º 5
0
    def _train(self, data, log_images):
        with tf.GradientTape() as model_tape:
            embed = self._encode(data)
            post, prior = self._dynamics.observe(embed, data['action'])
            feat = self._dynamics.get_feat(post)
            image_pred = self._decode(feat)
            reward_pred = self._reward(feat)
            likes = tools.AttrDict()
            likes.image = tf.reduce_mean(image_pred.log_prob(data['image']))
            likes.reward = tf.reduce_mean(reward_pred.log_prob(data['reward']))
            if self._c.pcont:
                pcont_pred = self._pcont(feat)
                pcont_target = self._c.discount * data['discount']
                likes.pcont = tf.reduce_mean(pcont_pred.log_prob(pcont_target))
                likes.pcont *= self._c.pcont_scale
            prior_dist = self._dynamics.get_dist(prior)
            post_dist = self._dynamics.get_dist(post)
            div = tf.reduce_mean(tfd.kl_divergence(post_dist, prior_dist))
            div = tf.maximum(div, self._c.free_nats)
            model_loss = self._c.kl_scale * div - sum(likes.values())
            model_loss /= float(self._strategy.num_replicas_in_sync)


        with tf.GradientTape() as actor_tape:
            alset = []
            alset_re = []
            imag_feat, prob_traj = self._imagine_ahead(post)
            prob_traj=tf.reduce_sum(prob_traj,0,keepdims=True)
            prob_traj = (prob_traj - tf.reduce_mean(prob_traj)) / (tf.math.reduce_std(prob_traj) + 1e-9) +1
            prob_traj = tf.clip_by_value(prob_traj, 1 - self._c.reweight_clip, 1 + self._c.reweight_clip)
            prob_traj = prob_traj

            entropy = tf.reduce_mean(self._actor(tf.stop_gradient(feat)).entropy())

            tlikes = tools.AttrDict()
            tlikes.reward = reward_pred.log_prob(data['reward'])
            reward = self._reward(imag_feat).mode()

            if self._c.pcont:
                pcont = self._pcont(imag_feat).mean()
            else:
                pcont = self._c.discount * tf.ones_like(reward)

            value = self._value(imag_feat).mode()

            returns = tools.lambda_return(
                reward[:-1], value[:-1], pcont[:-1],
                bootstrap=value[-1], lambda_=self._c.disclam, axis=0)
            discount = tf.stop_gradient(tf.math.cumprod(tf.concat(
                [tf.ones_like(pcont[:1]), pcont[:-2]], 0), 0))
            alset.append(-tf.reduce_mean(discount * returns))
            alset_re.append(-tf.reduce_mean(tf.stop_gradient(prob_traj)* discount * returns))
            actor_loss = tf.reduce_mean(tf.stack(alset_re)) - self._c.ent_alpha * entropy

            actor_loss /= float(self._strategy.num_replicas_in_sync)

        with tf.GradientTape() as value_tape:
            value_pred = self._value(imag_feat)[:-1]
            target = tf.stop_gradient(returns)
            value_loss = -tf.reduce_mean(prob_traj * discount * value_pred.log_prob(target))
            value_loss /= float(self._strategy.num_replicas_in_sync)

        model_norm = self._model_opt(model_tape, model_loss)
        actor_norm = self._actor_opt(actor_tape, actor_loss)
        value_norm = self._value_opt(value_tape, value_loss)
        if tf.distribute.get_replica_context().replica_id_in_sync_group == 0:
            if self._c.log_scalars:
                self._scalar_summaries(
                    data, feat, prior_dist, post_dist, likes, div,
                    model_loss, value_loss, actor_loss, entropy, model_norm, value_norm,
                    actor_norm)
            if tf.equal(log_images, True):
                self._image_summaries(data, embed, image_pred)
Exemplo n.º 6
0
    def _train(self, data, log_images):
        with tf.GradientTape() as model_tape:
            if 'success' in data:
                success_rate = tf.reduce_sum(
                    data['success']) / data['success'].shape[1]
            else:
                success_rate = tf.convert_to_tensor(-1)
            embed = self._encode(data)
            if 'state' in data:
                embed = tf.concat([data['state'], embed], axis=-1)
            post, prior = self._dynamics.observe(embed, data['action'])
            feat = self._dynamics.get_feat(post)
            image_pred = self._decode(feat)
            reward_pred = self._reward(feat)
            likes = tools.AttrDict()
            likes.image = tf.reduce_mean(image_pred.log_prob(data['image']))
            reward_obj = reward_pred.log_prob(data['reward'])

            # Mask out the elements which came from the real world env
            reward_obj = reward_obj * (1 - data['real_world'])

            likes.reward = tf.reduce_mean(reward_obj)
            if self._c.pcont:
                pcont_pred = self._pcont(feat)
                pcont_target = self._c.discount * data['discount']
                likes.pcont = tf.reduce_mean(pcont_pred.log_prob(pcont_target))
                likes.pcont *= self._c.pcont_scale
            prior_dist = self._dynamics.get_dist(prior)
            post_dist = self._dynamics.get_dist(post)
            div = tf.reduce_mean(tfd.kl_divergence(post_dist, prior_dist))
            div = tf.maximum(div, self._c.free_nats)
            model_loss = self._c.kl_scale * div - sum(likes.values())
            model_loss /= float(self._strategy.num_replicas_in_sync)

        with tf.GradientTape() as actor_tape:
            imag_feat = self._imagine_ahead(post)
            reward = self._reward(imag_feat).mode()
            if self._c.pcont:
                pcont = self._pcont(imag_feat).mean()
            else:
                pcont = self._c.discount * tf.ones_like(reward)
            value = self._value(imag_feat).mode()
            returns = tools.lambda_return(reward[:-1],
                                          value[:-1],
                                          pcont[:-1],
                                          bootstrap=value[-1],
                                          lambda_=self._c.disclam,
                                          axis=0)
            discount = tf.stop_gradient(
                tf.math.cumprod(
                    tf.concat([tf.ones_like(pcont[:1]), pcont[:-2]], 0), 0))
            actor_loss = -tf.reduce_mean(discount * returns)
            actor_loss /= float(self._strategy.num_replicas_in_sync)

        with tf.GradientTape() as value_tape:
            value_pred = self._value(imag_feat)[:-1]
            target = tf.stop_gradient(returns)
            value_loss = -tf.reduce_mean(
                discount * value_pred.log_prob(target))
            value_loss /= float(self._strategy.num_replicas_in_sync)

        model_norm = self._model_opt(model_tape, model_loss)
        actor_norm = self._actor_opt(actor_tape, actor_loss)
        value_norm = self._value_opt(value_tape, value_loss)

        if tf.distribute.get_replica_context().replica_id_in_sync_group == 0:
            if self._c.log_scalars:
                self._scalar_summaries(data, feat, prior_dist, post_dist,
                                       likes, div, model_loss, value_loss,
                                       actor_loss, model_norm, value_norm,
                                       actor_norm, success_rate)
            if tf.equal(log_images, True):
                self._image_summaries(data, embed, image_pred)
Exemplo n.º 7
0
    def _train(self, data, log_images, init_horizon, imagine_depth):
        with tf.GradientTape() as model_tape:
            embed = self._encode(data)
            post, prior = self._dynamics.observe(embed, data['action'])
            feat = self._dynamics.get_feat(post)
            image_pred = self._decode(feat)
            reward_pred = self._reward(feat)
            likes = tools.AttrDict()
            likes.image = tf.reduce_mean(image_pred.log_prob(data['image']))
            likes.reward = tf.reduce_mean(reward_pred.log_prob(data['reward']))
            if self._c.pcont:
                pcont_pred = self._pcont(feat)
                pcont_target = self._c.discount * data['discount']
                likes.pcont = tf.reduce_mean(pcont_pred.log_prob(pcont_target))
                likes.pcont *= self._c.pcont_scale
            prior_dist = self._dynamics.get_dist(prior)
            post_dist = self._dynamics.get_dist(post)
            div = tf.reduce_mean(tfd.kl_divergence(post_dist, prior_dist))
            div = tf.maximum(div, self._c.free_nats)
            model_loss = self._c.kl_scale * div - sum(likes.values())
            model_loss /= float(self._strategy.num_replicas_in_sync)

        with tf.GradientTape() as actor_tape:
            flatten = lambda x: tf.reshape(x, [-1] + list(x.shape[2:]))
            discount = None
            imag_feats = []
            returns_lst = []
            discounts = []
            actor_loss = 0.0
            horizon = init_horizon
            for depth in range(imagine_depth):
                if self._c.pcont:  # Last step could be terminal.
                    post = {k: v[:, :-1] for k, v in post.items()}
                post = {k: flatten(v) for k, v in post.items()}
                if depth != 0:
                    post = {
                        k: tf.stop_gradient(
                            tf.gather(v, indices=max_indexes, axis=0))
                        for k, v in post.items()
                    }
                imag_feat, post = self._imagine_ahead(post, horizon)
                tf.print("Imagination Features:", tf.shape(imag_feat))
                reward = self._reward(imag_feat).mode()
                if self._c.pcont:
                    pcont = self._pcont(imag_feat).mean()
                else:
                    pcont = self._c.discount * tf.ones_like(reward)
                value = self._value(imag_feat).mode()
                returns = tools.lambda_return(reward[:-1],
                                              value[:-1],
                                              pcont[:-1],
                                              bootstrap=value[-1],
                                              lambda_=self._c.disclam,
                                              axis=0)

                discount = tf.stop_gradient(
                    tf.math.cumprod(
                        tf.concat([tf.ones_like(pcont[:1]), pcont[:-2]], 0),
                        0))

                if depth != imagine_depth - 1:
                    if self._c.branch_type == "reward":
                        flat_reward = flatten(reward)
                        max_indexes = tf.math.top_k(
                            flat_reward,
                            k=int(2500 / self._strategy.num_replicas_in_sync),
                            sorted=False)[1]
                    elif self._c.branch_type == "uniform":
                        flat_reward = flatten(reward)
                        max_indexes = tf.random.uniform(
                            [int(2500 / self._strategy.num_replicas_in_sync)],
                            minval=0,
                            maxval=flat_reward.shape[0],
                            dtype=tf.int32)
                    elif self._c.branch_type == "value":
                        flat_value = flatten(value)
                        max_indexes = tf.math.top_k(
                            flat_value,
                            k=int(2500 / self._strategy.num_replicas_in_sync),
                            sorted=False)[1]

                horizon = int(horizon * self._c.imagine_decay)

                imag_feats.append(imag_feat)
                returns_lst.append(returns)
                discounts.append(discount)
                actor_loss += -tf.reduce_mean(discount * returns)

            actor_loss /= float(self._strategy.num_replicas_in_sync *
                                imagine_depth)

        with tf.GradientTape() as value_tape:
            value_loss = 0.0
            for imag_feat, returns, discount in zip(imag_feats, returns_lst,
                                                    discounts):
                value_pred = self._value(imag_feat)[:-1]
                target = tf.stop_gradient(returns)
                value_loss += -tf.reduce_mean(
                    discount * value_pred.log_prob(target))
            value_loss /= float(self._strategy.num_replicas_in_sync *
                                imagine_depth)

        model_norm = self._model_opt(model_tape, model_loss)
        actor_norm = self._actor_opt(actor_tape, actor_loss)
        value_norm = self._value_opt(value_tape, value_loss)

        if tf.distribute.get_replica_context().replica_id_in_sync_group == 0:
            if self._c.log_scalars:
                self._scalar_summaries(data, feat, prior_dist, post_dist,
                                       likes, div, model_loss, value_loss,
                                       actor_loss, model_norm, value_norm,
                                       actor_norm)
            if tf.equal(log_images, True):
                self._image_summaries(data, embed, image_pred)
Exemplo n.º 8
0
    def _train(self, data, log_images):
        with tf.GradientTape() as model_tape:
            embed = self._encode(data)
            post, prior = self._dynamics.observe(embed, data["action"])
            feat = self._dynamics.get_feat(post)
            image_pred = self._decode(feat)
            reward_pred = self._reward(feat)

            likes = tools.AttrDict()
            likes.image = tf.reduce_mean(image_pred.log_prob(data["image"]))

            ######################################################################
            # RE3: + intrinsic rewards
            rand_embed_ = tf.stop_gradient(self._rand_encode(data))
            rand_embed = tf.reshape(rand_embed_, [-1, 50])
            dist = tf.norm(rand_embed[:, None, :] - rand_embed[None, :, :], axis=-1)
            int_reward = -1.0 * tf.math.top_k(-dist, k=self._c.k).values[:, -1]
            norm_int_reward = self._rms(int_reward)
            norm_int_reward = tf.reshape(norm_int_reward, tf.shape(rand_embed_)[:-1])
            likes.reward = tf.reduce_mean(
                reward_pred.log_prob(data["reward"] + self._c.beta * norm_int_reward)
            )
            ######################################################################

            if self._c.pcont:
                pcont_pred = self._pcont(feat)
                pcont_target = self._c.discount * data["discount"]
                likes.pcont = tf.reduce_mean(pcont_pred.log_prob(pcont_target))
                likes.pcont *= self._c.pcont_scale
            prior_dist = self._dynamics.get_dist(prior)
            post_dist = self._dynamics.get_dist(post)
            div = tf.reduce_mean(tfd.kl_divergence(post_dist, prior_dist))
            div = tf.maximum(div, self._c.free_nats)
            model_loss = self._c.kl_scale * div - sum(likes.values())
            model_loss /= float(self._strategy.num_replicas_in_sync)

        with tf.GradientTape() as actor_tape:
            imag_feat = self._imagine_ahead(post)
            reward = self._reward(imag_feat).mode()
            if self._c.pcont:
                pcont = self._pcont(imag_feat).mean()
            else:
                pcont = self._c.discount * tf.ones_like(reward)
            value = self._value(imag_feat).mode()
            returns = tools.lambda_return(
                reward[:-1],
                value[:-1],
                pcont[:-1],
                bootstrap=value[-1],
                lambda_=self._c.disclam,
                axis=0,
            )
            discount = tf.stop_gradient(
                tf.math.cumprod(tf.concat([tf.ones_like(pcont[:1]), pcont[:-2]], 0), 0)
            )
            actor_loss = -tf.reduce_mean(discount * returns)
            actor_loss /= float(self._strategy.num_replicas_in_sync)

        with tf.GradientTape() as value_tape:
            value_pred = self._value(imag_feat)[:-1]
            target = tf.stop_gradient(returns)
            value_loss = -tf.reduce_mean(discount * value_pred.log_prob(target))
            value_loss /= float(self._strategy.num_replicas_in_sync)

        model_norm = self._model_opt(model_tape, model_loss)
        actor_norm = self._actor_opt(actor_tape, actor_loss)
        value_norm = self._value_opt(value_tape, value_loss)

        if tf.distribute.get_replica_context().replica_id_in_sync_group == 0:
            if self._c.log_scalars:
                self._scalar_summaries(
                    data,
                    feat,
                    prior_dist,
                    post_dist,
                    likes,
                    div,
                    model_loss,
                    value_loss,
                    actor_loss,
                    model_norm,
                    value_norm,
                    actor_norm,
                    int_reward,
                    norm_int_reward,
                )
            if tf.equal(log_images, True):
                self._image_summaries(data, embed, image_pred)
Exemplo n.º 9
0
    def _train(self, data, log_images):
        with tf.GradientTape() as model_tape:
            embed = self._encode(data)
            post, prior = self._dynamics.observe(embed, data['action'])
            feat = self._dynamics.get_feat(post)
            reward_pred = self._reward(feat)
            likes = tools.AttrDict()
            likes.reward = tf.reduce_mean(reward_pred.log_prob(data['reward']))

            # if we use the generative observation model, we need to perform observation reconstruction
            image_pred = self._decode(feat)
            # compute the contrative loss directly in CVRL
            cont_loss = self._contrastive(feat, embed)

            # the contrastive / generative implementation of the observation model p(o|s)
            if self._c.obs_model == 'generative':
                likes.image = tf.reduce_mean(image_pred.log_prob(
                    data['image']))
            elif self._c.obs_model == 'contrastive':
                likes.image = tf.reduce_mean(cont_loss)

            if self._c.pcont:
                pcont_pred = self._pcont(feat)
                pcont_target = self._c.discount * data['discount']
                likes.pcont = tf.reduce_mean(pcont_pred.log_prob(pcont_target))
                likes.pcont *= self._c.pcont_scale

            prior_dist = self._dynamics.get_dist(prior)
            post_dist = self._dynamics.get_dist(post)
            div = tf.reduce_mean(tfd.kl_divergence(post_dist, prior_dist))
            div = tf.maximum(div, self._c.free_nats)
            model_loss = self._c.kl_scale * div - sum(likes.values())

        assert self._c.use_dreamer or self._c.use_sac

        if self._c.use_dreamer:
            with tf.GradientTape() as actor_tape:
                imag_feat = self._imagine_ahead(post)
                reward = self._reward(imag_feat).mode()
                if self._c.pcont:
                    pcont = self._pcont(imag_feat).mean()
                else:
                    pcont = self._c.discount * tf.ones_like(reward)
                value = self._value(imag_feat).mode()
                returns = tools.lambda_return(reward[:-1],
                                              value[:-1],
                                              pcont[:-1],
                                              bootstrap=value[-1],
                                              lambda_=self._c.disclam,
                                              axis=0)
                discount = tf.stop_gradient(
                    tf.math.cumprod(
                        tf.concat([tf.ones_like(pcont[:1]), pcont[:-2]], 0),
                        0))
                actor_loss = -tf.reduce_mean(discount * returns)

            with tf.GradientTape() as value_tape:
                value_pred = self._value(imag_feat)[:-1]
                target = tf.stop_gradient(returns)
                value_loss = - \
                    tf.reduce_mean(discount * value_pred.log_prob(target))

            actor_norm = self._actor_opt(actor_tape, actor_loss)
            value_norm = self._value_opt(value_tape, value_loss)
        else:
            actor_norm = actor_loss = 0
            value_norm = value_loss = 0

        model_norm = self._model_opt(model_tape, model_loss)
        states = tf.concat([post['stoch'], post['deter']], axis=-1)
        rewards = data['reward']
        dones = tf.zeros_like(rewards)
        actions = data['action']

        # if we use SAC, add the SAC training
        if self._c.use_sac:
            self._sac._do_training(self._step, states, actions, rewards, dones)

        if tf.distribute.get_replica_context().replica_id_in_sync_group == 0:
            if self._c.log_scalars:
                self._scalar_summaries(data, feat, prior_dist, post_dist,
                                       likes, div, model_loss, value_loss,
                                       actor_loss, model_norm, value_norm,
                                       actor_norm)
            if tf.equal(log_images, True) and self._c.log_imgs:
                self._image_summaries(data, embed, image_pred)