) # Perform environment step (action repeats handled internally) return belief, posterior_state, action[0], next_observation, reward, done print("Starting training!") for episode in tqdm(range(metrics['episodes'][-1] + 1, args.episodes + 1), total=args.episodes, initial=metrics['episodes'][-1] + 1): print("Starting episode {}".format(episode)) # Model fitting losses = [] print("Drawing sequence chunks") for s in tqdm(range(args.collect_interval)): # Draw sequence chunks {(o_t, a_t, r_t+1, terminal_t+1)} ~ D uniformly at random from the dataset (including terminal flags) observations, actions, rewards, nonterminals = D.sample( args.batch_size, args.chunk_size) # Transitions start at time t = 0 # Create initial belief and state for time t = 0 init_belief, init_state = torch.zeros(args.batch_size, args.belief_size, device=args.device), torch.zeros( args.batch_size, args.state_size, device=args.device) # Update belief/state using posterior from previous belief/state, previous action and current observation (over entire sequence at once) beliefs, prior_states, prior_means, prior_std_devs, posterior_states, posterior_means, posterior_std_devs = transition_model( init_state, actions[:-1], init_belief, bottle(encoder, (observations[1:], )), nonterminals[:-1]) #print("******************") #print(beliefs.shape) #print(prior_states.shape)
if args.render: env.render() if done: pbar.close() break print('Average Reward:', total_reward / args.test_episodes) env.close() quit() # Training (and testing) for episode in tqdm(range(metrics['episodes'][-1] + 1, args.episodes + 1), total=args.episodes, initial=metrics['episodes'][-1] + 1): data = D.sample(args.batch_size, args.chunk_size) # Model fitting loss_info = agent.update_parameters(data, args.collect_interval) # Update and plot loss metrics losses = tuple(zip(*loss_info)) metrics['observation_loss'].append(losses[0]) metrics['reward_loss'].append(losses[1]) metrics['kl_loss'].append(losses[2]) metrics['pcont_loss'].append(losses[3]) metrics['actor_loss'].append(losses[4]) metrics['value_loss'].append(losses[5]) lineplot(metrics['episodes'][-len(metrics['observation_loss']):], metrics['observation_loss'], 'observation_loss', results_dir) lineplot(metrics['episodes'][-len(metrics['reward_loss']):], metrics['reward_loss'], 'reward_loss', results_dir)
class Dreamer(Agent): # The agent has its own replay buffer, update, act def __init__(self, args): """ All paras are passed by args :param args: a dict that includes parameters """ super().__init__() self.args = args # Initialise model parameters randomly self.transition_model = TransitionModel( args.belief_size, args.state_size, args.action_size, args.hidden_size, args.embedding_size, args.dense_act).to(device=args.device) self.observation_model = ObservationModel( args.symbolic, args.observation_size, args.belief_size, args.state_size, args.embedding_size, activation_function=(args.dense_act if args.symbolic else args.cnn_act)).to(device=args.device) self.reward_model = RewardModel(args.belief_size, args.state_size, args.hidden_size, args.dense_act).to(device=args.device) self.encoder = Encoder(args.symbolic, args.observation_size, args.embedding_size, args.cnn_act).to(device=args.device) self.actor_model = ActorModel( args.action_size, args.belief_size, args.state_size, args.hidden_size, activation_function=args.dense_act, fix_speed=args.fix_speed, throttle_base=args.throttle_base).to(device=args.device) self.value_model = ValueModel(args.belief_size, args.state_size, args.hidden_size, args.dense_act).to(device=args.device) self.value_model2 = ValueModel(args.belief_size, args.state_size, args.hidden_size, args.dense_act).to(device=args.device) self.pcont_model = PCONTModel(args.belief_size, args.state_size, args.hidden_size, args.dense_act).to(device=args.device) self.target_value_model = deepcopy(self.value_model) self.target_value_model2 = deepcopy(self.value_model2) for p in self.target_value_model.parameters(): p.requires_grad = False for p in self.target_value_model2.parameters(): p.requires_grad = False # setup the paras to update self.world_param = list(self.transition_model.parameters())\ + list(self.observation_model.parameters())\ + list(self.reward_model.parameters())\ + list(self.encoder.parameters()) if args.pcont: self.world_param += list(self.pcont_model.parameters()) # setup optimizer self.world_optimizer = optim.Adam(self.world_param, lr=args.world_lr) self.actor_optimizer = optim.Adam(self.actor_model.parameters(), lr=args.actor_lr) self.value_optimizer = optim.Adam(list(self.value_model.parameters()) + list(self.value_model2.parameters()), lr=args.value_lr) # setup the free_nat to self.free_nats = torch.full( (1, ), args.free_nats, dtype=torch.float32, device=args.device) # Allowed deviation in KL divergence # TODO: change it to the new replay buffer, in buffer.py self.D = ExperienceReplay(args.experience_size, args.symbolic, args.observation_size, args.action_size, args.bit_depth, args.device) if self.args.auto_temp: # setup for learning of alpha term (temp of the entropy term) self.log_temp = torch.zeros(1, requires_grad=True, device=args.device) self.target_entropy = -np.prod( args.action_size if not args.fix_speed else self.args. action_size - 1).item() # heuristic value from SAC paper self.temp_optimizer = optim.Adam( [self.log_temp], lr=args.value_lr) # use the same value_lr # TODO: print out the param used in Dreamer # var_counts = tuple(count_vars(module) for module in [self., self.ac.q1, self.ac.q2]) # print('\nNumber of parameters: \t pi: %d, \t q1: %d, \t q2: %d\n' % var_counts) # def process_im(self, image, image_size=None, rgb=None): # # Resize, put channel first, convert it to a tensor, centre it to [-0.5, 0.5] and add batch dimenstion. # # def preprocess_observation_(observation, bit_depth): # # Preprocesses an observation inplace (from float32 Tensor [0, 255] to [-0.5, 0.5]) # observation.div_(2 ** (8 - bit_depth)).floor_().div_(2 ** bit_depth).sub_( # 0.5) # Quantise to given bit depth and centre # observation.add_(torch.rand_like(observation).div_( # 2 ** bit_depth)) # Dequantise (to approx. match likelihood of PDF of continuous images vs. PMF of discrete images) # # image = image[40:, :, :] # clip the above 40 rows # image = torch.tensor(cv2.resize(image, (40, 40), interpolation=cv2.INTER_LINEAR).transpose(2, 0, 1), # dtype=torch.float32) # Resize and put channel first # # preprocess_observation_(image, self.args.bit_depth) # return image.unsqueeze(dim=0) def process_im(self, images, image_size=None, rgb=None): images = cv2.resize(images, (40, 40)) images = np.dot(images, [0.299, 0.587, 0.114]) obs = torch.tensor(images, dtype=torch.float32).div_(255.).sub_(0.5).unsqueeze( dim=0) # shape [1, 40, 40], range:[-0.5,0.5] return obs.unsqueeze(dim=0) # add batch dimension def append_buffer(self, new_traj): # append new collected trajectory, not implement the data augmentation # shape of new_traj: [(o, a, r, d) * steps] for state in new_traj: observation, action, reward, done = state self.D.append(observation, action.cpu(), reward, done) def _compute_loss_world(self, state, data): # unpackage data beliefs, prior_states, prior_means, prior_std_devs, posterior_states, posterior_means, posterior_std_devs = state observations, rewards, nonterminals = data # observation_loss = F.mse_loss( # bottle(self.observation_model, (beliefs, posterior_states)), # observations[1:], # reduction='none').sum(dim=2 if self.args.symbolic else (2, 3, 4)).mean(dim=(0, 1)) # # reward_loss = F.mse_loss( # bottle(self.reward_model, (beliefs, posterior_states)), # rewards[1:], # reduction='none').mean(dim=(0,1)) observation_loss = F.mse_loss( bottle(self.observation_model, (beliefs, posterior_states)), observations, reduction='none').sum( dim=2 if self.args.symbolic else (2, 3, 4)).mean(dim=(0, 1)) reward_loss = F.mse_loss(bottle(self.reward_model, (beliefs, posterior_states)), rewards, reduction='none').mean(dim=(0, 1)) # TODO: 5 # transition loss kl_loss = torch.max( kl_divergence( Independent(Normal(posterior_means, posterior_std_devs), 1), Independent(Normal(prior_means, prior_std_devs), 1)), self.free_nats).mean(dim=(0, 1)) # print("check the reward", bottle(pcont_model, (beliefs, posterior_states)).shape, nonterminals[:-1].shape) if self.args.pcont: pcont_loss = F.binary_cross_entropy( bottle(self.pcont_model, (beliefs, posterior_states)), nonterminals) # pcont_pred = torch.distributions.Bernoulli(logits=bottle(self.pcont_model, (beliefs, posterior_states))) # pcont_loss = -pcont_pred.log_prob(nonterminals[1:]).mean(dim=(0, 1)) return observation_loss, self.args.reward_scale * reward_loss, kl_loss, ( self.args.pcont_scale * pcont_loss if self.args.pcont else 0) def _compute_loss_actor(self, imag_beliefs, imag_states, imag_ac_logps=None): # reward and value prediction of imagined trajectories imag_rewards = bottle(self.reward_model, (imag_beliefs, imag_states)) imag_values = bottle(self.value_model, (imag_beliefs, imag_states)) imag_values2 = bottle(self.value_model2, (imag_beliefs, imag_states)) imag_values = torch.min(imag_values, imag_values2) with torch.no_grad(): if self.args.pcont: pcont = bottle(self.pcont_model, (imag_beliefs, imag_states)) else: pcont = self.args.discount * torch.ones_like(imag_rewards) pcont = pcont.detach() if imag_ac_logps is not None: imag_values[ 1:] -= self.args.temp * imag_ac_logps # add entropy here returns = cal_returns(imag_rewards[:-1], imag_values[:-1], imag_values[-1], pcont[:-1], lambda_=self.args.disclam) discount = torch.cumprod( torch.cat([torch.ones_like(pcont[:1]), pcont[:-2]], 0), 0) discount = discount.detach() assert list(discount.size()) == list(returns.size()) actor_loss = -torch.mean(discount * returns) return actor_loss def _compute_loss_critic(self, imag_beliefs, imag_states, imag_ac_logps=None): with torch.no_grad(): # calculate the target with the target nn target_imag_values = bottle(self.target_value_model, (imag_beliefs, imag_states)) target_imag_values2 = bottle(self.target_value_model2, (imag_beliefs, imag_states)) target_imag_values = torch.min(target_imag_values, target_imag_values2) imag_rewards = bottle(self.reward_model, (imag_beliefs, imag_states)) if self.args.pcont: pcont = bottle(self.pcont_model, (imag_beliefs, imag_states)) else: pcont = self.args.discount * torch.ones_like(imag_rewards) # print("check pcont", pcont) if imag_ac_logps is not None: target_imag_values[1:] -= self.args.temp * imag_ac_logps returns = cal_returns(imag_rewards[:-1], target_imag_values[:-1], target_imag_values[-1], pcont[:-1], lambda_=self.args.disclam) target_return = returns.detach() value_pred = bottle(self.value_model, (imag_beliefs, imag_states))[:-1] value_pred2 = bottle(self.value_model2, (imag_beliefs, imag_states))[:-1] value_loss = F.mse_loss(value_pred, target_return, reduction="none").mean(dim=(0, 1)) value_loss2 = F.mse_loss(value_pred2, target_return, reduction="none").mean(dim=(0, 1)) value_loss += value_loss2 return value_loss def _latent_imagination(self, beliefs, posterior_states, with_logprob=False): # Rollout to generate imagined trajectories chunk_size, batch_size, _ = list( posterior_states.size()) # flatten the tensor flatten_size = chunk_size * batch_size posterior_states = posterior_states.detach().reshape(flatten_size, -1) beliefs = beliefs.detach().reshape(flatten_size, -1) imag_beliefs, imag_states, imag_ac_logps = [beliefs ], [posterior_states], [] for i in range(self.args.planning_horizon): imag_action, imag_ac_logp = self.actor_model( imag_beliefs[-1].detach(), imag_states[-1].detach(), deterministic=False, with_logprob=with_logprob, ) imag_action = imag_action.unsqueeze(dim=0) # add time dim # print(imag_states[-1].shape, imag_action.shape, imag_beliefs[-1].shape) imag_belief, imag_state, _, _ = self.transition_model( imag_states[-1], imag_action, imag_beliefs[-1]) imag_beliefs.append(imag_belief.squeeze(dim=0)) imag_states.append(imag_state.squeeze(dim=0)) if with_logprob: imag_ac_logps.append(imag_ac_logp.squeeze(dim=0)) imag_beliefs = torch.stack(imag_beliefs, dim=0).to( self.args.device ) # shape [horizon+1, (chuck-1)*batch, belief_size] imag_states = torch.stack(imag_states, dim=0).to(self.args.device) if with_logprob: imag_ac_logps = torch.stack(imag_ac_logps, dim=0).to( self.args.device) # shape [horizon, (chuck-1)*batch] return imag_beliefs, imag_states, imag_ac_logps if with_logprob else None def update_parameters(self, gradient_steps): loss_info = [] # used to record loss for s in tqdm(range(gradient_steps)): # get state and belief of samples observations, actions, rewards, nonterminals = self.D.sample( self.args.batch_size, self.args.chunk_size) # print("check sampled rewrads", rewards) init_belief = torch.zeros(self.args.batch_size, self.args.belief_size, device=self.args.device) init_state = torch.zeros(self.args.batch_size, self.args.state_size, device=self.args.device) # Update belief/state using posterior from previous belief/state, previous action and current observation (over entire sequence at once) # beliefs, prior_states, prior_means, prior_std_devs, posterior_states, posterior_means, posterior_std_devs = self.transition_model( # init_state, # actions[:-1], # init_belief, # bottle(self.encoder, (observations[1:], )), # nonterminals[:-1]) beliefs, prior_states, prior_means, prior_std_devs, posterior_states, posterior_means, posterior_std_devs = self.transition_model( init_state, actions, init_belief, bottle(self.encoder, (observations, )), nonterminals) # TODO: 4 # update paras of world model world_model_loss = self._compute_loss_world( state=(beliefs, prior_states, prior_means, prior_std_devs, posterior_states, posterior_means, posterior_std_devs), data=(observations, rewards, nonterminals)) observation_loss, reward_loss, kl_loss, pcont_loss = world_model_loss self.world_optimizer.zero_grad() (observation_loss + reward_loss + kl_loss + pcont_loss).backward() nn.utils.clip_grad_norm_(self.world_param, self.args.grad_clip_norm, norm_type=2) self.world_optimizer.step() # freeze params to save memory for p in self.world_param: p.requires_grad = False for p in self.value_model.parameters(): p.requires_grad = False for p in self.value_model2.parameters(): p.requires_gard = False # latent imagination imag_beliefs, imag_states, imag_ac_logps = self._latent_imagination( beliefs, posterior_states, with_logprob=self.args.with_logprob) # update temp if self.args.auto_temp: temp_loss = -( self.log_temp * (imag_ac_logps[0] + self.target_entropy).detach()).mean() self.temp_optimizer.zero_grad() temp_loss.backward() self.temp_optimizer.step() self.args.temp = self.log_temp.exp() # update actor actor_loss = self._compute_loss_actor(imag_beliefs, imag_states, imag_ac_logps=imag_ac_logps) self.actor_optimizer.zero_grad() actor_loss.backward() nn.utils.clip_grad_norm_(self.actor_model.parameters(), self.args.grad_clip_norm, norm_type=2) self.actor_optimizer.step() for p in self.world_param: p.requires_grad = True for p in self.value_model.parameters(): p.requires_grad = True for p in self.value_model2.parameters(): p.requires_grad = True # update critic imag_beliefs = imag_beliefs.detach() imag_states = imag_states.detach() critic_loss = self._compute_loss_critic( imag_beliefs, imag_states, imag_ac_logps=imag_ac_logps) self.value_optimizer.zero_grad() critic_loss.backward() nn.utils.clip_grad_norm_(self.value_model.parameters(), self.args.grad_clip_norm, norm_type=2) nn.utils.clip_grad_norm_(self.value_model2.parameters(), self.args.grad_clip_norm, norm_type=2) self.value_optimizer.step() loss_info.append([ observation_loss.item(), reward_loss.item(), kl_loss.item(), pcont_loss.item() if self.args.pcont else 0, actor_loss.item(), critic_loss.item() ]) # finally, update target value function every #gradient_steps with torch.no_grad(): self.target_value_model.load_state_dict( self.value_model.state_dict()) with torch.no_grad(): self.target_value_model2.load_state_dict( self.value_model2.state_dict()) return loss_info def infer_state(self, observation, action, belief=None, state=None): """ Infer belief over current state q(s_t|o≤t,a<t) from the history, return updated belief and posterior_state at time t returned shape: belief/state [belief/state_dim] (remove the time_dim) """ # observation is obs.to(device), action.shape=[act_dim] (will add time dim inside this fn), belief.shape belief, _, _, _, posterior_state, _, _ = self.transition_model( state, action.unsqueeze(dim=0), belief, self.encoder(observation).unsqueeze( dim=0)) # Action and observation need extra time dimension belief, posterior_state = belief.squeeze( dim=0), posterior_state.squeeze( dim=0) # Remove time dimension from belief/state return belief, posterior_state def select_action(self, state, deterministic=False): # get action with the inputs get from fn: infer_state; return a numpy with shape [batch, act_size] belief, posterior_state = state action, _ = self.actor_model(belief, posterior_state, deterministic=deterministic, with_logprob=False) if not deterministic and not self.args.with_logprob: print("e") action = Normal(action, self.args.expl_amount).rsample() # clip the angle action[:, 0].clamp_(min=self.args.angle_min, max=self.args.angle_max) # clip the throttle if self.args.fix_speed: action[:, 1] = self.args.throttle_base else: action[:, 1].clamp_(min=self.args.throttle_min, max=self.args.throttle_max) print("action", action) # return action.cup().numpy() return action # this is a Tonsor.cuda def import_parameters(self, params): # only import or export the parameters used when local rollout self.encoder.load_state_dict(params["encoder"]) self.actor_model.load_state_dict(params["policy"]) self.transition_model.load_state_dict(params["transition"]) def export_parameters(self): """ return the model paras used for local rollout """ params = { "encoder": self.encoder.cpu().state_dict(), "policy": self.actor_model.cpu().state_dict(), "transition": self.transition_model.cpu().state_dict() } self.encoder.to(self.args.device) self.actor_model.to(self.args.device) self.transition_model.to(self.args.device) return params
def train(args: argparse.Namespace, env: Env, D: ExperienceReplay, models: Tuple[nn.Module, nn.Module, nn.Module, nn.Module], optimizer: Tuple[optim.Optimizer, optim.Optimizer], param_list: List[nn.parameter.Parameter], planner: nn.Module): # auxilliary tensors global_prior = Normal( torch.zeros(args.batch_size, args.state_size, device=args.device), torch.ones(args.batch_size, args.state_size, device=args.device) ) # Global prior N(0, I) # Allowed deviation in KL divergence free_nats = torch.full((1, ), args.free_nats, dtype=torch.float32, device=args.device) summary_writter = SummaryWriter(args.tensorboard_dir) # unpack models transition_model, observation_model, reward_model, encoder = models transition_optimizer, reward_optimizer = optimizer for idx_episode in trange(args.episodes, leave=False, desc="Episode"): for idx_train in trange(args.collect_interval, leave=False, desc="Training"): # Draw sequence chunks {(o[t], a[t], r[t+1], z[t+1])} ~ D uniformly at random from the dataset # The first two dimensions of the tensors are L (chunk size) and n (batch size) # We want to use o[t+1] to correct the error of the transition model, # so we need to convert the sequence to {(o[t+1], a[t], r[t+1], z[t+1])} observations, actions, rewards_dist, rewards_coll, nonterminals = D.sample(args.batch_size, args.chunk_size) # Create initial belief and state for time t = 0 init_belief = torch.zeros(args.batch_size, args.belief_size, device=args.device) init_state = torch.zeros(args.batch_size, args.state_size, device=args.device) # Transition model forward # deterministic: h[t+1] = f(h[t], a[t]) # prior: s[t+1] ~ Prob(s|h[t+1]) # posterior: s[t+1] ~ Prob(s|h[t+1], o[t+1]) beliefs, prior_states, prior_means, prior_std_devs, posterior_states, posterior_means, posterior_std_devs = transition_model( init_state, actions[:-1], init_belief, bottle(encoder, (observations[1:], )), nonterminals[:-1] ) # observation loss predictions = bottle(observation_model, (beliefs, posterior_states)) visual_loss = F.mse_loss( predictions[:, :, :3*64*64], observations[1:, :, :3*64*64] ).mean() symbol_loss = F.mse_loss( predictions[:, :, 3*64*64:], observations[1:, :, 3*64*64:] ).mean() observation_loss = visual_loss + symbol_loss # KL divergence loss. Minimize the difference between posterior and prior kl_loss = torch.max( kl_divergence( Normal(posterior_means, posterior_std_devs), Normal(prior_means, prior_std_devs) ).sum(dim=2), free_nats ).mean(dim=(0, 1)) # Note that normalisation by overshooting distance and weighting by overshooting distance cancel out if args.global_kl_beta != 0: kl_loss += args.global_kl_beta * kl_divergence( Normal(posterior_means, posterior_std_devs), global_prior ).sum(dim=2).mean(dim=(0, 1)) # overshooting loss if args.overshooting_kl_beta != 0: overshooting_vars = [] # Collect variables for overshooting to process in batch for t in range(1, args.chunk_size - 1): d = min(t + args.overshooting_distance, args.chunk_size - 1) # Overshooting distance # Use t_ and d_ to deal with different time indexing for latent states t_, d_ = t - 1, d - 1 # Calculate sequence padding so overshooting terms can be calculated in one batch seq_pad = (0, 0, 0, 0, 0, t - d + args.overshooting_distance) # Store # * a[t:d], # * z[t+1:d+1] # * r[t+1:d+1] # * h[t] # * s[t] prior # * E[s[t:d]] posterior # * Var[s[t:d]] posterior # * mask: # the last few sequences do not have enough length, # so we pad it with 0 to the same length as previous sequence for batch operation, # and use mask to indicate invalid variables. overshooting_vars.append( (F.pad(actions[t:d], seq_pad), F.pad(nonterminals[t:d], seq_pad), F.pad(rewards_dist[t:d], seq_pad[2:]), beliefs[t_], prior_states[t_], F.pad(posterior_means[t_ + 1:d_ + 1].detach(), seq_pad), F.pad(posterior_std_devs[t_ + 1:d_ + 1].detach(), seq_pad, value=1), F.pad(torch.ones(d - t, args.batch_size, args.state_size, device=args.device), seq_pad) ) ) # Posterior standard deviations must be padded with > 0 to prevent infinite KL divergences overshooting_vars = tuple(zip(*overshooting_vars)) # Update belief/state using prior from previous belief/state and previous action (over entire sequence at once) beliefs, prior_states, prior_means, prior_std_devs = transition_model( torch.cat(overshooting_vars[4], dim=0), torch.cat(overshooting_vars[0], dim=1), torch.cat(overshooting_vars[3], dim=0), None, torch.cat(overshooting_vars[1], dim=1) ) seq_mask = torch.cat(overshooting_vars[7], dim=1) # Calculate overshooting KL loss with sequence mask kl_loss += (1 / args.overshooting_distance) * args.overshooting_kl_beta * torch.max( (kl_divergence( Normal(torch.cat(overshooting_vars[5], dim=1), torch.cat(overshooting_vars[6], dim=1)), Normal(prior_means, prior_std_devs) ) * seq_mask).sum(dim=2), free_nats ).mean(dim=(0, 1)) * (args.chunk_size - 1) # Update KL loss (compensating for extra average over each overshooting/open loop sequence) # TODO: add learning rate schedule # Update model parameters transition_optimizer.zero_grad() loss = observation_loss * 200 + kl_loss loss.backward() nn.utils.clip_grad_norm_(param_list, args.grad_clip_norm, norm_type=2) transition_optimizer.step() # reward loss rewards_dist_predict, rewards_coll_predict = bottle(reward_model.raw, (beliefs.detach(), posterior_states.detach())) reward_loss = F.mse_loss( rewards_dist_predict, rewards_dist[:-1], reduction='mean' ) + F.binary_cross_entropy( rewards_coll_predict, rewards_coll[:-1], reduction='mean' ) reward_optimizer.zero_grad() reward_loss.backward() reward_optimizer.step() # add tensorboard log global_step = idx_train + idx_episode * args.collect_interval summary_writter.add_scalar("observation_loss", observation_loss, global_step) summary_writter.add_scalar("reward_loss", reward_loss, global_step) summary_writter.add_scalar("kl_loss", kl_loss, global_step) for idx_collect in trange(1, leave=False, desc="Collecting"): experience = collect_experience(args, env, models, planner, True, desc="Collecting experience {}".format(idx_collect)) T = len(experience["observation"]) for idx_step in range(T): D.append(experience["observation"][idx_step], experience["action"][idx_step], experience["reward_dist"][idx_step], experience["reward_coll"][idx_step], experience["done"][idx_step]) # Checkpoint models if (idx_episode + 1) % args.checkpoint_interval == 0: record_path = os.path.join(args.checkpoint_dir, "checkpoint") checkpoint_path = os.path.join(args.checkpoint_dir, 'models_%d.pth' % (idx_episode+1)) torch.save( { 'transition_model': transition_model.state_dict(), 'observation_model': observation_model.state_dict(), 'reward_model': reward_model.state_dict(), 'encoder': encoder.state_dict(), 'transition_optimizer': transition_optimizer.state_dict(), 'reward_optimizer': reward_optimizer.state_dict() }, checkpoint_path) with open(record_path, "w") as f: f.write('models_%d.pth' % (idx_episode+1)) planner.save(os.path.join(args.torchscript_dir, "mpc_planner.pth")) transition_model.save(os.path.join(args.torchscript_dir, "transition_model.pth")) reward_model.save(os.path.join(args.torchscript_dir, "reward_model.pth")) observation_model.save(os.path.join(args.torchscript_dir, "observation_decoder.pth")) encoder.save(os.path.join(args.torchscript_dir, "observation_encoder.pth")) summary_writter.close()