Beispiel #1
0
    )  # Perform environment step (action repeats handled internally)
    return belief, posterior_state, action[0], next_observation, reward, done


print("Starting training!")
for episode in tqdm(range(metrics['episodes'][-1] + 1, args.episodes + 1),
                    total=args.episodes,
                    initial=metrics['episodes'][-1] + 1):
    print("Starting episode {}".format(episode))
    # Model fitting
    losses = []
    print("Drawing sequence chunks")
    for s in tqdm(range(args.collect_interval)):
        # Draw sequence chunks {(o_t, a_t, r_t+1, terminal_t+1)} ~ D uniformly at random from the dataset (including terminal flags)
        observations, actions, rewards, nonterminals = D.sample(
            args.batch_size,
            args.chunk_size)  # Transitions start at time t = 0
        # Create initial belief and state for time t = 0
        init_belief, init_state = torch.zeros(args.batch_size,
                                              args.belief_size,
                                              device=args.device), torch.zeros(
                                                  args.batch_size,
                                                  args.state_size,
                                                  device=args.device)
        # Update belief/state using posterior from previous belief/state, previous action and current observation (over entire sequence at once)
        beliefs, prior_states, prior_means, prior_std_devs, posterior_states, posterior_means, posterior_std_devs = transition_model(
            init_state, actions[:-1], init_belief,
            bottle(encoder, (observations[1:], )), nonterminals[:-1])
        #print("******************")
        #print(beliefs.shape)
        #print(prior_states.shape)
Beispiel #2
0
                if args.render:
                    env.render()
                if done:
                    pbar.close()
                    break

    print('Average Reward:', total_reward / args.test_episodes)
    env.close()
    quit()

# Training (and testing)
for episode in tqdm(range(metrics['episodes'][-1] + 1, args.episodes + 1),
                    total=args.episodes,
                    initial=metrics['episodes'][-1] + 1):
    data = D.sample(args.batch_size, args.chunk_size)
    # Model fitting
    loss_info = agent.update_parameters(data, args.collect_interval)

    # Update and plot loss metrics
    losses = tuple(zip(*loss_info))
    metrics['observation_loss'].append(losses[0])
    metrics['reward_loss'].append(losses[1])
    metrics['kl_loss'].append(losses[2])
    metrics['pcont_loss'].append(losses[3])
    metrics['actor_loss'].append(losses[4])
    metrics['value_loss'].append(losses[5])
    lineplot(metrics['episodes'][-len(metrics['observation_loss']):],
             metrics['observation_loss'], 'observation_loss', results_dir)
    lineplot(metrics['episodes'][-len(metrics['reward_loss']):],
             metrics['reward_loss'], 'reward_loss', results_dir)
Beispiel #3
0
class Dreamer(Agent):
    # The agent has its own replay buffer, update, act
    def __init__(self, args):
        """
    All paras are passed by args
    :param args: a dict that includes parameters
    """
        super().__init__()
        self.args = args
        # Initialise model parameters randomly
        self.transition_model = TransitionModel(
            args.belief_size, args.state_size, args.action_size,
            args.hidden_size, args.embedding_size,
            args.dense_act).to(device=args.device)

        self.observation_model = ObservationModel(
            args.symbolic,
            args.observation_size,
            args.belief_size,
            args.state_size,
            args.embedding_size,
            activation_function=(args.dense_act if args.symbolic else
                                 args.cnn_act)).to(device=args.device)

        self.reward_model = RewardModel(args.belief_size, args.state_size,
                                        args.hidden_size,
                                        args.dense_act).to(device=args.device)

        self.encoder = Encoder(args.symbolic, args.observation_size,
                               args.embedding_size,
                               args.cnn_act).to(device=args.device)

        self.actor_model = ActorModel(
            args.action_size,
            args.belief_size,
            args.state_size,
            args.hidden_size,
            activation_function=args.dense_act,
            fix_speed=args.fix_speed,
            throttle_base=args.throttle_base).to(device=args.device)

        self.value_model = ValueModel(args.belief_size, args.state_size,
                                      args.hidden_size,
                                      args.dense_act).to(device=args.device)

        self.value_model2 = ValueModel(args.belief_size, args.state_size,
                                       args.hidden_size,
                                       args.dense_act).to(device=args.device)

        self.pcont_model = PCONTModel(args.belief_size, args.state_size,
                                      args.hidden_size,
                                      args.dense_act).to(device=args.device)

        self.target_value_model = deepcopy(self.value_model)
        self.target_value_model2 = deepcopy(self.value_model2)

        for p in self.target_value_model.parameters():
            p.requires_grad = False
        for p in self.target_value_model2.parameters():
            p.requires_grad = False

        # setup the paras to update
        self.world_param = list(self.transition_model.parameters())\
                          + list(self.observation_model.parameters())\
                          + list(self.reward_model.parameters())\
                          + list(self.encoder.parameters())
        if args.pcont:
            self.world_param += list(self.pcont_model.parameters())

        # setup optimizer
        self.world_optimizer = optim.Adam(self.world_param, lr=args.world_lr)
        self.actor_optimizer = optim.Adam(self.actor_model.parameters(),
                                          lr=args.actor_lr)
        self.value_optimizer = optim.Adam(list(self.value_model.parameters()) +
                                          list(self.value_model2.parameters()),
                                          lr=args.value_lr)

        # setup the free_nat to
        self.free_nats = torch.full(
            (1, ), args.free_nats, dtype=torch.float32,
            device=args.device)  # Allowed deviation in KL divergence

        # TODO: change it to the new replay buffer, in buffer.py
        self.D = ExperienceReplay(args.experience_size, args.symbolic,
                                  args.observation_size, args.action_size,
                                  args.bit_depth, args.device)

        if self.args.auto_temp:
            # setup for learning of alpha term (temp of the entropy term)
            self.log_temp = torch.zeros(1,
                                        requires_grad=True,
                                        device=args.device)
            self.target_entropy = -np.prod(
                args.action_size if not args.fix_speed else self.args.
                action_size - 1).item()  # heuristic value from SAC paper
            self.temp_optimizer = optim.Adam(
                [self.log_temp], lr=args.value_lr)  # use the same value_lr

        # TODO: print out the param used in Dreamer
        # var_counts = tuple(count_vars(module) for module in [self., self.ac.q1, self.ac.q2])
        # print('\nNumber of parameters: \t pi: %d, \t q1: %d, \t q2: %d\n' % var_counts)

    # def process_im(self, image, image_size=None, rgb=None):
    #   # Resize, put channel first, convert it to a tensor, centre it to [-0.5, 0.5] and add batch dimenstion.
    #
    #   def preprocess_observation_(observation, bit_depth):
    #     # Preprocesses an observation inplace (from float32 Tensor [0, 255] to [-0.5, 0.5])
    #     observation.div_(2 ** (8 - bit_depth)).floor_().div_(2 ** bit_depth).sub_(
    #       0.5)  # Quantise to given bit depth and centre
    #     observation.add_(torch.rand_like(observation).div_(
    #       2 ** bit_depth))  # Dequantise (to approx. match likelihood of PDF of continuous images vs. PMF of discrete images)
    #
    #   image = image[40:, :, :]  # clip the above 40 rows
    #   image = torch.tensor(cv2.resize(image, (40, 40), interpolation=cv2.INTER_LINEAR).transpose(2, 0, 1),
    #                         dtype=torch.float32)  # Resize and put channel first
    #
    #   preprocess_observation_(image, self.args.bit_depth)
    #   return image.unsqueeze(dim=0)
    def process_im(self, images, image_size=None, rgb=None):
        images = cv2.resize(images, (40, 40))
        images = np.dot(images, [0.299, 0.587, 0.114])
        obs = torch.tensor(images,
                           dtype=torch.float32).div_(255.).sub_(0.5).unsqueeze(
                               dim=0)  # shape [1, 40, 40], range:[-0.5,0.5]
        return obs.unsqueeze(dim=0)  # add batch dimension

    def append_buffer(self, new_traj):
        # append new collected trajectory, not implement the data augmentation
        # shape of new_traj: [(o, a, r, d) * steps]
        for state in new_traj:
            observation, action, reward, done = state
            self.D.append(observation, action.cpu(), reward, done)

    def _compute_loss_world(self, state, data):
        # unpackage data
        beliefs, prior_states, prior_means, prior_std_devs, posterior_states, posterior_means, posterior_std_devs = state
        observations, rewards, nonterminals = data

        # observation_loss = F.mse_loss(
        #   bottle(self.observation_model, (beliefs, posterior_states)),
        #   observations[1:],
        #   reduction='none').sum(dim=2 if self.args.symbolic else (2, 3, 4)).mean(dim=(0, 1))
        #
        # reward_loss = F.mse_loss(
        #   bottle(self.reward_model, (beliefs, posterior_states)),
        #   rewards[1:],
        #   reduction='none').mean(dim=(0,1))

        observation_loss = F.mse_loss(
            bottle(self.observation_model, (beliefs, posterior_states)),
            observations,
            reduction='none').sum(
                dim=2 if self.args.symbolic else (2, 3, 4)).mean(dim=(0, 1))

        reward_loss = F.mse_loss(bottle(self.reward_model,
                                        (beliefs, posterior_states)),
                                 rewards,
                                 reduction='none').mean(dim=(0, 1))  # TODO: 5

        # transition loss
        kl_loss = torch.max(
            kl_divergence(
                Independent(Normal(posterior_means, posterior_std_devs), 1),
                Independent(Normal(prior_means, prior_std_devs), 1)),
            self.free_nats).mean(dim=(0, 1))

        # print("check the reward", bottle(pcont_model, (beliefs, posterior_states)).shape, nonterminals[:-1].shape)
        if self.args.pcont:
            pcont_loss = F.binary_cross_entropy(
                bottle(self.pcont_model, (beliefs, posterior_states)),
                nonterminals)
            # pcont_pred = torch.distributions.Bernoulli(logits=bottle(self.pcont_model, (beliefs, posterior_states)))
            # pcont_loss = -pcont_pred.log_prob(nonterminals[1:]).mean(dim=(0, 1))

        return observation_loss, self.args.reward_scale * reward_loss, kl_loss, (
            self.args.pcont_scale * pcont_loss if self.args.pcont else 0)

    def _compute_loss_actor(self,
                            imag_beliefs,
                            imag_states,
                            imag_ac_logps=None):
        # reward and value prediction of imagined trajectories
        imag_rewards = bottle(self.reward_model, (imag_beliefs, imag_states))
        imag_values = bottle(self.value_model, (imag_beliefs, imag_states))
        imag_values2 = bottle(self.value_model2, (imag_beliefs, imag_states))
        imag_values = torch.min(imag_values, imag_values2)

        with torch.no_grad():
            if self.args.pcont:
                pcont = bottle(self.pcont_model, (imag_beliefs, imag_states))
            else:
                pcont = self.args.discount * torch.ones_like(imag_rewards)
        pcont = pcont.detach()

        if imag_ac_logps is not None:
            imag_values[
                1:] -= self.args.temp * imag_ac_logps  # add entropy here

        returns = cal_returns(imag_rewards[:-1],
                              imag_values[:-1],
                              imag_values[-1],
                              pcont[:-1],
                              lambda_=self.args.disclam)

        discount = torch.cumprod(
            torch.cat([torch.ones_like(pcont[:1]), pcont[:-2]], 0), 0)
        discount = discount.detach()

        assert list(discount.size()) == list(returns.size())
        actor_loss = -torch.mean(discount * returns)
        return actor_loss

    def _compute_loss_critic(self,
                             imag_beliefs,
                             imag_states,
                             imag_ac_logps=None):

        with torch.no_grad():
            # calculate the target with the target nn
            target_imag_values = bottle(self.target_value_model,
                                        (imag_beliefs, imag_states))
            target_imag_values2 = bottle(self.target_value_model2,
                                         (imag_beliefs, imag_states))
            target_imag_values = torch.min(target_imag_values,
                                           target_imag_values2)
            imag_rewards = bottle(self.reward_model,
                                  (imag_beliefs, imag_states))

            if self.args.pcont:
                pcont = bottle(self.pcont_model, (imag_beliefs, imag_states))
            else:
                pcont = self.args.discount * torch.ones_like(imag_rewards)

        # print("check pcont", pcont)
            if imag_ac_logps is not None:
                target_imag_values[1:] -= self.args.temp * imag_ac_logps

        returns = cal_returns(imag_rewards[:-1],
                              target_imag_values[:-1],
                              target_imag_values[-1],
                              pcont[:-1],
                              lambda_=self.args.disclam)
        target_return = returns.detach()

        value_pred = bottle(self.value_model, (imag_beliefs, imag_states))[:-1]
        value_pred2 = bottle(self.value_model2,
                             (imag_beliefs, imag_states))[:-1]

        value_loss = F.mse_loss(value_pred, target_return,
                                reduction="none").mean(dim=(0, 1))
        value_loss2 = F.mse_loss(value_pred2, target_return,
                                 reduction="none").mean(dim=(0, 1))
        value_loss += value_loss2

        return value_loss

    def _latent_imagination(self,
                            beliefs,
                            posterior_states,
                            with_logprob=False):
        # Rollout to generate imagined trajectories

        chunk_size, batch_size, _ = list(
            posterior_states.size())  # flatten the tensor
        flatten_size = chunk_size * batch_size

        posterior_states = posterior_states.detach().reshape(flatten_size, -1)
        beliefs = beliefs.detach().reshape(flatten_size, -1)

        imag_beliefs, imag_states, imag_ac_logps = [beliefs
                                                    ], [posterior_states], []

        for i in range(self.args.planning_horizon):
            imag_action, imag_ac_logp = self.actor_model(
                imag_beliefs[-1].detach(),
                imag_states[-1].detach(),
                deterministic=False,
                with_logprob=with_logprob,
            )
            imag_action = imag_action.unsqueeze(dim=0)  # add time dim

            # print(imag_states[-1].shape, imag_action.shape, imag_beliefs[-1].shape)
            imag_belief, imag_state, _, _ = self.transition_model(
                imag_states[-1], imag_action, imag_beliefs[-1])
            imag_beliefs.append(imag_belief.squeeze(dim=0))
            imag_states.append(imag_state.squeeze(dim=0))
            if with_logprob:
                imag_ac_logps.append(imag_ac_logp.squeeze(dim=0))

        imag_beliefs = torch.stack(imag_beliefs, dim=0).to(
            self.args.device
        )  # shape [horizon+1, (chuck-1)*batch, belief_size]
        imag_states = torch.stack(imag_states, dim=0).to(self.args.device)
        if with_logprob:
            imag_ac_logps = torch.stack(imag_ac_logps, dim=0).to(
                self.args.device)  # shape [horizon, (chuck-1)*batch]

        return imag_beliefs, imag_states, imag_ac_logps if with_logprob else None

    def update_parameters(self, gradient_steps):
        loss_info = []  # used to record loss
        for s in tqdm(range(gradient_steps)):
            # get state and belief of samples
            observations, actions, rewards, nonterminals = self.D.sample(
                self.args.batch_size, self.args.chunk_size)
            # print("check sampled rewrads", rewards)
            init_belief = torch.zeros(self.args.batch_size,
                                      self.args.belief_size,
                                      device=self.args.device)
            init_state = torch.zeros(self.args.batch_size,
                                     self.args.state_size,
                                     device=self.args.device)

            # Update belief/state using posterior from previous belief/state, previous action and current observation (over entire sequence at once)
            # beliefs, prior_states, prior_means, prior_std_devs, posterior_states, posterior_means, posterior_std_devs = self.transition_model(
            #   init_state,
            #   actions[:-1],
            #   init_belief,
            #   bottle(self.encoder, (observations[1:], )),
            #   nonterminals[:-1])

            beliefs, prior_states, prior_means, prior_std_devs, posterior_states, posterior_means, posterior_std_devs = self.transition_model(
                init_state, actions, init_belief,
                bottle(self.encoder, (observations, )),
                nonterminals)  # TODO: 4

            # update paras of world model
            world_model_loss = self._compute_loss_world(
                state=(beliefs, prior_states, prior_means, prior_std_devs,
                       posterior_states, posterior_means, posterior_std_devs),
                data=(observations, rewards, nonterminals))
            observation_loss, reward_loss, kl_loss, pcont_loss = world_model_loss
            self.world_optimizer.zero_grad()
            (observation_loss + reward_loss + kl_loss + pcont_loss).backward()
            nn.utils.clip_grad_norm_(self.world_param,
                                     self.args.grad_clip_norm,
                                     norm_type=2)
            self.world_optimizer.step()

            # freeze params to save memory
            for p in self.world_param:
                p.requires_grad = False
            for p in self.value_model.parameters():
                p.requires_grad = False
            for p in self.value_model2.parameters():
                p.requires_gard = False

            # latent imagination
            imag_beliefs, imag_states, imag_ac_logps = self._latent_imagination(
                beliefs, posterior_states, with_logprob=self.args.with_logprob)

            # update temp
            if self.args.auto_temp:
                temp_loss = -(
                    self.log_temp *
                    (imag_ac_logps[0] + self.target_entropy).detach()).mean()
                self.temp_optimizer.zero_grad()
                temp_loss.backward()
                self.temp_optimizer.step()
                self.args.temp = self.log_temp.exp()

            # update actor
            actor_loss = self._compute_loss_actor(imag_beliefs,
                                                  imag_states,
                                                  imag_ac_logps=imag_ac_logps)

            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            nn.utils.clip_grad_norm_(self.actor_model.parameters(),
                                     self.args.grad_clip_norm,
                                     norm_type=2)
            self.actor_optimizer.step()

            for p in self.world_param:
                p.requires_grad = True
            for p in self.value_model.parameters():
                p.requires_grad = True
            for p in self.value_model2.parameters():
                p.requires_grad = True

            # update critic
            imag_beliefs = imag_beliefs.detach()
            imag_states = imag_states.detach()

            critic_loss = self._compute_loss_critic(
                imag_beliefs, imag_states, imag_ac_logps=imag_ac_logps)

            self.value_optimizer.zero_grad()
            critic_loss.backward()
            nn.utils.clip_grad_norm_(self.value_model.parameters(),
                                     self.args.grad_clip_norm,
                                     norm_type=2)
            nn.utils.clip_grad_norm_(self.value_model2.parameters(),
                                     self.args.grad_clip_norm,
                                     norm_type=2)
            self.value_optimizer.step()

            loss_info.append([
                observation_loss.item(),
                reward_loss.item(),
                kl_loss.item(),
                pcont_loss.item() if self.args.pcont else 0,
                actor_loss.item(),
                critic_loss.item()
            ])

        # finally, update target value function every #gradient_steps
        with torch.no_grad():
            self.target_value_model.load_state_dict(
                self.value_model.state_dict())
        with torch.no_grad():
            self.target_value_model2.load_state_dict(
                self.value_model2.state_dict())

        return loss_info

    def infer_state(self, observation, action, belief=None, state=None):
        """ Infer belief over current state q(s_t|o≤t,a<t) from the history,
        return updated belief and posterior_state at time t
        returned shape: belief/state [belief/state_dim] (remove the time_dim)
    """
        # observation is obs.to(device), action.shape=[act_dim] (will add time dim inside this fn), belief.shape
        belief, _, _, _, posterior_state, _, _ = self.transition_model(
            state, action.unsqueeze(dim=0), belief,
            self.encoder(observation).unsqueeze(
                dim=0))  # Action and observation need extra time dimension

        belief, posterior_state = belief.squeeze(
            dim=0), posterior_state.squeeze(
                dim=0)  # Remove time dimension from belief/state

        return belief, posterior_state

    def select_action(self, state, deterministic=False):
        # get action with the inputs get from fn: infer_state; return a numpy with shape [batch, act_size]
        belief, posterior_state = state
        action, _ = self.actor_model(belief,
                                     posterior_state,
                                     deterministic=deterministic,
                                     with_logprob=False)
        if not deterministic and not self.args.with_logprob:
            print("e")
            action = Normal(action, self.args.expl_amount).rsample()

            # clip the angle
            action[:, 0].clamp_(min=self.args.angle_min,
                                max=self.args.angle_max)
            # clip the throttle
            if self.args.fix_speed:
                action[:, 1] = self.args.throttle_base
            else:
                action[:, 1].clamp_(min=self.args.throttle_min,
                                    max=self.args.throttle_max)
        print("action", action)
        # return action.cup().numpy()
        return action  # this is a Tonsor.cuda

    def import_parameters(self, params):
        # only import or export the parameters used when local rollout
        self.encoder.load_state_dict(params["encoder"])
        self.actor_model.load_state_dict(params["policy"])
        self.transition_model.load_state_dict(params["transition"])

    def export_parameters(self):
        """ return the model paras used for local rollout """
        params = {
            "encoder": self.encoder.cpu().state_dict(),
            "policy": self.actor_model.cpu().state_dict(),
            "transition": self.transition_model.cpu().state_dict()
        }

        self.encoder.to(self.args.device)
        self.actor_model.to(self.args.device)
        self.transition_model.to(self.args.device)

        return params
Beispiel #4
0
def train(args: argparse.Namespace,
          env: Env,
          D: ExperienceReplay,
          models: Tuple[nn.Module, nn.Module, nn.Module, nn.Module],
          optimizer: Tuple[optim.Optimizer, optim.Optimizer],
          param_list: List[nn.parameter.Parameter],
          planner: nn.Module):
    # auxilliary tensors
    global_prior = Normal(
        torch.zeros(args.batch_size, args.state_size, device=args.device),
        torch.ones(args.batch_size, args.state_size, device=args.device)
    )  # Global prior N(0, I)
    # Allowed deviation in KL divergence
    free_nats = torch.full((1, ), args.free_nats, dtype=torch.float32, device=args.device)
    summary_writter = SummaryWriter(args.tensorboard_dir)

    # unpack models
    transition_model, observation_model, reward_model, encoder = models
    transition_optimizer, reward_optimizer = optimizer

    for idx_episode in trange(args.episodes, leave=False, desc="Episode"):
        for idx_train in trange(args.collect_interval, leave=False, desc="Training"):
            # Draw sequence chunks {(o[t], a[t], r[t+1], z[t+1])} ~ D uniformly at random from the dataset
            # The first two dimensions of the tensors are L (chunk size) and n (batch size)
            # We want to use o[t+1] to correct the error of the transition model,
            # so we need to convert the sequence to {(o[t+1], a[t], r[t+1], z[t+1])}
            observations, actions, rewards_dist, rewards_coll, nonterminals = D.sample(args.batch_size, args.chunk_size)
            # Create initial belief and state for time t = 0
            init_belief = torch.zeros(args.batch_size, args.belief_size, device=args.device)
            init_state = torch.zeros(args.batch_size, args.state_size, device=args.device)
            # Transition model forward
            # deterministic: h[t+1] = f(h[t], a[t])
            # prior:         s[t+1] ~ Prob(s|h[t+1])
            # posterior:     s[t+1] ~ Prob(s|h[t+1], o[t+1])
            beliefs, prior_states, prior_means, prior_std_devs, posterior_states, posterior_means, posterior_std_devs = transition_model(
                init_state,
                actions[:-1],
                init_belief,
                bottle(encoder, (observations[1:], )),
                nonterminals[:-1]
            )

            # observation loss
            predictions = bottle(observation_model, (beliefs, posterior_states))
            visual_loss = F.mse_loss(
                predictions[:, :, :3*64*64],
                observations[1:, :, :3*64*64]
            ).mean()
            symbol_loss = F.mse_loss(
                predictions[:, :, 3*64*64:],
                observations[1:, :, 3*64*64:]
            ).mean()
            observation_loss = visual_loss + symbol_loss

            # KL divergence loss. Minimize the difference between posterior and prior
            kl_loss = torch.max(
                kl_divergence(
                    Normal(posterior_means, posterior_std_devs),
                    Normal(prior_means, prior_std_devs)
                ).sum(dim=2),
                free_nats
            ).mean(dim=(0, 1))  # Note that normalisation by overshooting distance and weighting by overshooting distance cancel out
            if args.global_kl_beta != 0:
                kl_loss += args.global_kl_beta * kl_divergence(
                    Normal(posterior_means, posterior_std_devs),
                    global_prior
                ).sum(dim=2).mean(dim=(0, 1))

            # overshooting loss
            if args.overshooting_kl_beta != 0:
                overshooting_vars = []  # Collect variables for overshooting to process in batch
                for t in range(1, args.chunk_size - 1):
                    d = min(t + args.overshooting_distance, args.chunk_size - 1)  # Overshooting distance
                    # Use t_ and d_ to deal with different time indexing for latent states
                    t_, d_ = t - 1, d - 1
                    # Calculate sequence padding so overshooting terms can be calculated in one batch
                    seq_pad = (0, 0, 0, 0, 0, t - d + args.overshooting_distance)
                    # Store
                    # * a[t:d],
                    # * z[t+1:d+1]
                    # * r[t+1:d+1]
                    # * h[t]
                    # * s[t] prior
                    # * E[s[t:d]] posterior
                    # * Var[s[t:d]] posterior
                    # * mask:
                    #       the last few sequences do not have enough length,
                    #       so we pad it with 0 to the same length as previous sequence for batch operation,
                    #       and use mask to indicate invalid variables.
                    overshooting_vars.append(
                        (F.pad(actions[t:d], seq_pad),
                         F.pad(nonterminals[t:d], seq_pad),
                         F.pad(rewards_dist[t:d], seq_pad[2:]),
                         beliefs[t_],
                         prior_states[t_],
                         F.pad(posterior_means[t_ + 1:d_ + 1].detach(), seq_pad),
                         F.pad(posterior_std_devs[t_ + 1:d_ + 1].detach(), seq_pad, value=1),
                         F.pad(torch.ones(d - t, args.batch_size, args.state_size, device=args.device), seq_pad)
                         )
                    )  # Posterior standard deviations must be padded with > 0 to prevent infinite KL divergences

                overshooting_vars = tuple(zip(*overshooting_vars))
                # Update belief/state using prior from previous belief/state and previous action (over entire sequence at once)
                beliefs, prior_states, prior_means, prior_std_devs = transition_model(
                    torch.cat(overshooting_vars[4], dim=0),
                    torch.cat(overshooting_vars[0], dim=1),
                    torch.cat(overshooting_vars[3], dim=0),
                    None,
                    torch.cat(overshooting_vars[1], dim=1)
                )
                seq_mask = torch.cat(overshooting_vars[7], dim=1)
                # Calculate overshooting KL loss with sequence mask
                kl_loss += (1 / args.overshooting_distance) * args.overshooting_kl_beta * torch.max(
                    (kl_divergence(
                        Normal(torch.cat(overshooting_vars[5], dim=1), torch.cat(overshooting_vars[6], dim=1)),
                        Normal(prior_means, prior_std_devs)
                    ) * seq_mask).sum(dim=2),
                    free_nats
                ).mean(dim=(0, 1)) * (args.chunk_size - 1)  # Update KL loss (compensating for extra average over each overshooting/open loop sequence)

            # TODO: add learning rate schedule
            # Update model parameters
            transition_optimizer.zero_grad()
            loss = observation_loss * 200 + kl_loss
            loss.backward()
            nn.utils.clip_grad_norm_(param_list, args.grad_clip_norm, norm_type=2)
            transition_optimizer.step()

            # reward loss
            rewards_dist_predict, rewards_coll_predict = bottle(reward_model.raw, (beliefs.detach(), posterior_states.detach()))
            reward_loss = F.mse_loss(
                rewards_dist_predict,
                rewards_dist[:-1],
                reduction='mean'
            ) + F.binary_cross_entropy(
                rewards_coll_predict,
                rewards_coll[:-1],
                reduction='mean'
            )
            reward_optimizer.zero_grad()
            reward_loss.backward()
            reward_optimizer.step()

            # add tensorboard log
            global_step = idx_train + idx_episode * args.collect_interval
            summary_writter.add_scalar("observation_loss", observation_loss, global_step)
            summary_writter.add_scalar("reward_loss", reward_loss, global_step)
            summary_writter.add_scalar("kl_loss", kl_loss, global_step)

        for idx_collect in trange(1, leave=False, desc="Collecting"):
            experience = collect_experience(args, env, models, planner, True, desc="Collecting experience {}".format(idx_collect))
            T = len(experience["observation"])
            for idx_step in range(T):
                D.append(experience["observation"][idx_step],
                         experience["action"][idx_step],
                         experience["reward_dist"][idx_step],
                         experience["reward_coll"][idx_step],
                         experience["done"][idx_step])

        # Checkpoint models
        if (idx_episode + 1) % args.checkpoint_interval == 0:
            record_path = os.path.join(args.checkpoint_dir, "checkpoint")
            checkpoint_path = os.path.join(args.checkpoint_dir, 'models_%d.pth' % (idx_episode+1))
            torch.save(
                {
                    'transition_model': transition_model.state_dict(),
                    'observation_model': observation_model.state_dict(),
                    'reward_model': reward_model.state_dict(),
                    'encoder': encoder.state_dict(),
                    'transition_optimizer': transition_optimizer.state_dict(),
                    'reward_optimizer': reward_optimizer.state_dict()
                },
                checkpoint_path)
            with open(record_path, "w") as f:
                f.write('models_%d.pth' % (idx_episode+1))
            planner.save(os.path.join(args.torchscript_dir, "mpc_planner.pth"))
            transition_model.save(os.path.join(args.torchscript_dir, "transition_model.pth"))
            reward_model.save(os.path.join(args.torchscript_dir, "reward_model.pth"))
            observation_model.save(os.path.join(args.torchscript_dir, "observation_decoder.pth"))
            encoder.save(os.path.join(args.torchscript_dir, "observation_encoder.pth"))

    summary_writter.close()