class Trainer(): def __init__(self, params, experience_replay_buffer,metrics,results_dir,env): self.parms = params self.D = experience_replay_buffer self.metrics = metrics self.env = env self.tested_episodes = 0 self.statistics_path = results_dir+'/statistics' self.model_path = results_dir+'/model' self.video_path = results_dir+'/video' self.rew_vs_pred_rew_path = results_dir+'/rew_vs_pred_rew' self.dump_plan_path = results_dir+'/dump_plan' #if folder do not exists, create it os.makedirs(self.statistics_path, exist_ok=True) os.makedirs(self.model_path, exist_ok=True) os.makedirs(self.video_path, exist_ok=True) os.makedirs(self.rew_vs_pred_rew_path, exist_ok=True) os.makedirs(self.dump_plan_path, exist_ok=True) # Create models self.transition_model = TransitionModel(self.parms.belief_size, self.parms.state_size, self.env.action_size, self.parms.hidden_size, self.parms.embedding_size, self.parms.activation_function).to(device=self.parms.device) self.observation_model = ObservationModel(self.parms.belief_size, self.parms.state_size, self.parms.embedding_size, self.parms.activation_function).to(device=self.parms.device) self.reward_model = RewardModel(self.parms.belief_size, self.parms.state_size, self.parms.hidden_size, self.parms.activation_function).to(device=self.parms.device) self.encoder = Encoder(self.parms.embedding_size,self.parms.activation_function).to(device=self.parms.device) self.param_list = list(self.transition_model.parameters()) + list(self.observation_model.parameters()) + list(self.reward_model.parameters()) + list(self.encoder.parameters()) self.optimiser = optim.Adam(self.param_list, lr=0 if self.parms.learning_rate_schedule != 0 else self.parms.learning_rate, eps=self.parms.adam_epsilon) self.planner = MPCPlanner(self.env.action_size, self.parms.planning_horizon, self.parms.optimisation_iters, self.parms.candidates, self.parms.top_candidates, self.transition_model, self.reward_model,self.env.action_range[0], self.env.action_range[1]) global_prior = Normal(torch.zeros(self.parms.batch_size, self.parms.state_size, device=self.parms.device), torch.ones(self.parms.batch_size, self.parms.state_size, device=self.parms.device)) # Global prior N(0, I) self.free_nats = torch.full((1, ), self.parms.free_nats, dtype=torch.float32, device=self.parms.device) # Allowed deviation in KL divergence def load_checkpoints(self): self.metrics = torch.load(self.model_path+'/metrics.pth') model_path = self.model_path+'/best_model' os.makedirs(model_path, exist_ok=True) files = os.listdir(model_path) if files: checkpoint = [f for f in files if os.path.isfile(os.path.join(model_path, f))] model_dicts = torch.load(os.path.join(model_path, checkpoint[0]),map_location=self.parms.device) self.transition_model.load_state_dict(model_dicts['transition_model']) self.observation_model.load_state_dict(model_dicts['observation_model']) self.reward_model.load_state_dict(model_dicts['reward_model']) self.encoder.load_state_dict(model_dicts['encoder']) self.optimiser.load_state_dict(model_dicts['optimiser']) print("Loading models checkpoints!") else: print("Checkpoints not found!") def update_belief_and_act(self, env, belief, posterior_state, action, observation, reward, min_action=-inf, max_action=inf,explore=False): # Infer belief over current state q(s_t|o≤t,a<t) from the history encoded_obs = self.encoder(observation).unsqueeze(dim=0).to(device=self.parms.device) belief, _, _, _, posterior_state, _, _ = self.transition_model(posterior_state, action.unsqueeze(dim=0), belief, encoded_obs) # Action and observation need extra time dimension belief, posterior_state = belief.squeeze(dim=0), posterior_state.squeeze(dim=0) # Remove time dimension from belief/state action,pred_next_rew,_,_,_ = self.planner(belief, posterior_state,explore) # Get action from planner(q(s_t|o≤t,a<t), p) if explore: action = action + self.parms.action_noise * torch.randn_like(action) # Add exploration noise ε ~ p(ε) to the action action.clamp_(min=min_action, max=max_action) # Clip action range next_observation, reward, done = env.step(action.cpu() if isinstance(env, EnvBatcher) else action[0].cpu()) # If single env is istanceted perform single action (get item from list), else perform all actions return belief, posterior_state, action, next_observation, reward, done,pred_next_rew def fit_buffer(self,episode): #### # Fit data taken from buffer ###### # Model fitting losses = [] tqdm.write("Fitting buffer") for s in tqdm(range(self.parms.collect_interval)): # Draw sequence chunks {(o_t, a_t, r_t+1, terminal_t+1)} ~ D uniformly at random from the dataset (including terminal flags) observations, actions, rewards, nonterminals = self.D.sample(self.parms.batch_size, self.parms.chunk_size) # Transitions start at time t = 0 # Create initial belief and state for time t = 0 init_belief, init_state = torch.zeros(self.parms.batch_size, self.parms.belief_size, device=self.parms.device), torch.zeros(self.parms.batch_size, self.parms.state_size, device=self.parms.device) encoded_obs = bottle(self.encoder, (observations[1:], )) # Update belief/state using posterior from previous belief/state, previous action and current observation (over entire sequence at once) beliefs, prior_states, prior_means, prior_std_devs, posterior_states, posterior_means, posterior_std_devs = self.transition_model(init_state, actions[:-1], init_belief, encoded_obs, nonterminals[:-1]) # Calculate observation likelihood, reward likelihood and KL losses (for t = 0 only for latent overshooting); sum over final dims, average over batch and time (original implementation, though paper seems to miss 1/T scaling?) # LOSS observation_loss = F.mse_loss(bottle(self.observation_model, (beliefs, posterior_states)), observations[1:], reduction='none').sum((2, 3, 4)).mean(dim=(0, 1)) kl_loss = torch.max(kl_divergence(Normal(posterior_means, posterior_std_devs), Normal(prior_means, prior_std_devs)).sum(dim=2), self.free_nats).mean(dim=(0, 1)) reward_loss = F.mse_loss(bottle(self.reward_model, (beliefs, posterior_states)), rewards[:-1], reduction='none').mean(dim=(0, 1)) # Update model parameters self.optimiser.zero_grad() (observation_loss + reward_loss + kl_loss).backward() # BACKPROPAGATION nn.utils.clip_grad_norm_(self.param_list, self.parms.grad_clip_norm, norm_type=2) self.optimiser.step() # Store (0) observation loss (1) reward loss (2) KL loss losses.append([observation_loss.item(), reward_loss.item(), kl_loss.item()])#, regularizer_loss.item()]) #save statistics and plot them losses = tuple(zip(*losses)) self.metrics['observation_loss'].append(losses[0]) self.metrics['reward_loss'].append(losses[1]) self.metrics['kl_loss'].append(losses[2]) lineplot(self.metrics['episodes'][-len(self.metrics['observation_loss']):], self.metrics['observation_loss'], 'observation_loss', self.statistics_path) lineplot(self.metrics['episodes'][-len(self.metrics['reward_loss']):], self.metrics['reward_loss'], 'reward_loss', self.statistics_path) lineplot(self.metrics['episodes'][-len(self.metrics['kl_loss']):], self.metrics['kl_loss'], 'kl_loss', self.statistics_path) def explore_and_collect(self,episode): tqdm.write("Collect new data:") reward = 0 # Data collection with torch.no_grad(): done = False observation, total_reward = self.env.reset(), 0 belief, posterior_state, action = torch.zeros(1, self.parms.belief_size, device=self.parms.device), torch.zeros(1, self.parms.state_size, device=self.parms.device), torch.zeros(1, self.env.action_size, device=self.parms.device) t = 0 real_rew = [] predicted_rew = [] total_steps = self.parms.max_episode_length // self.env.action_repeat explore = True for t in tqdm(range(total_steps)): # Here we need to explore belief, posterior_state, action, next_observation, reward, done, pred_next_rew = self.update_belief_and_act(self.env, belief, posterior_state, action, observation.to(device=self.parms.device), [reward], self.env.action_range[0], self.env.action_range[1], explore=explore) self.D.append(observation, action.cpu(), reward, done) real_rew.append(reward) predicted_rew.append(pred_next_rew.to(device=self.parms.device).item()) total_reward += reward observation = next_observation if self.parms.flag_render: env.render() if done: break # Update and plot train reward metrics self.metrics['steps'].append( (t * self.env.action_repeat) + self.metrics['steps'][-1]) self.metrics['episodes'].append(episode) self.metrics['train_rewards'].append(total_reward) self.metrics['predicted_rewards'].append(np.array(predicted_rew).sum()) lineplot(self.metrics['episodes'][-len(self.metrics['train_rewards']):], self.metrics['train_rewards'], 'train_rewards', self.statistics_path) double_lineplot(self.metrics['episodes'], self.metrics['train_rewards'], self.metrics['predicted_rewards'], "train_r_vs_pr", self.statistics_path) def train_models(self): # from (init_episodes) to (training_episodes + init_episodes) tqdm.write("Start training.") for episode in tqdm(range(self.parms.num_init_episodes +1, self.parms.training_episodes) ): self.fit_buffer(episode) self.explore_and_collect(episode) if episode % self.parms.test_interval == 0: self.test_model(episode) torch.save(self.metrics, os.path.join(self.model_path, 'metrics.pth')) torch.save({'transition_model': self.transition_model.state_dict(), 'observation_model': self.observation_model.state_dict(), 'reward_model': self.reward_model.state_dict(), 'encoder': self.encoder.state_dict(), 'optimiser': self.optimiser.state_dict()}, os.path.join(self.model_path, 'models_%d.pth' % episode)) if episode % self.parms.storing_dataset_interval == 0: self.D.store_dataset(self.parms.dataset_path+'dump_dataset') return self.metrics def test_model(self, episode=None): #no explore here if episode is None: episode = self.tested_episodes # Set models to eval mode self.transition_model.eval() self.observation_model.eval() self.reward_model.eval() self.encoder.eval() # Initialise parallelised test environments test_envs = EnvBatcher(ControlSuiteEnv, (self.parms.env_name, self.parms.seed, self.parms.max_episode_length, self.parms.bit_depth), {}, self.parms.test_episodes) total_steps = self.parms.max_episode_length // test_envs.action_repeat rewards = np.zeros(self.parms.test_episodes) real_rew = torch.zeros([total_steps,self.parms.test_episodes]) predicted_rew = torch.zeros([total_steps,self.parms.test_episodes]) with torch.no_grad(): observation, total_rewards, video_frames = test_envs.reset(), np.zeros((self.parms.test_episodes, )), [] belief, posterior_state, action = torch.zeros(self.parms.test_episodes, self.parms.belief_size, device=self.parms.device), torch.zeros(self.parms.test_episodes, self.parms.state_size, device=self.parms.device), torch.zeros(self.parms.test_episodes, self.env.action_size, device=self.parms.device) tqdm.write("Testing model.") for t in range(total_steps): belief, posterior_state, action, next_observation, rewards, done, pred_next_rew = self.update_belief_and_act(test_envs, belief, posterior_state, action, observation.to(device=self.parms.device), list(rewards), self.env.action_range[0], self.env.action_range[1]) total_rewards += rewards.numpy() real_rew[t] = rewards predicted_rew[t] = pred_next_rew observation = self.env.get_original_frame().unsqueeze(dim=0) video_frames.append(make_grid(torch.cat([observation, self.observation_model(belief, posterior_state).cpu()], dim=3) + 0.5, nrow=5).numpy()) # Decentre observation = next_observation if done.sum().item() == self.parms.test_episodes: break real_rew = torch.transpose(real_rew, 0, 1) predicted_rew = torch.transpose(predicted_rew, 0, 1) #save and plot metrics self.tested_episodes += 1 self.metrics['test_episodes'].append(episode) self.metrics['test_rewards'].append(total_rewards.tolist()) lineplot(self.metrics['test_episodes'], self.metrics['test_rewards'], 'test_rewards', self.statistics_path) write_video(video_frames, 'test_episode_%s' % str(episode), self.video_path) # Lossy compression # Set models to train mode self.transition_model.train() self.observation_model.train() self.reward_model.train() self.encoder.train() # Close test environments test_envs.close() return self.metrics def dump_plan_video(self, step_before_plan=120): #number of steps before to start to collect frames to dump step_before_plan = min(step_before_plan, (self.parms.max_episode_length // self.env.action_repeat)) # Set models to eval mode self.transition_model.eval() self.observation_model.eval() self.reward_model.eval() self.encoder.eval() video_frames = [] reward = 0 with torch.no_grad(): observation = self.env.reset() belief, posterior_state, action = torch.zeros(1, self.parms.belief_size, device=self.parms.device), torch.zeros(1, self.parms.state_size, device=self.parms.device), torch.zeros(1, self.env.action_size, device=self.parms.device) tqdm.write("Executing episode.") for t in range(step_before_plan): #floor division belief, posterior_state, action, next_observation, reward, done, _ = self.update_belief_and_act(self.env, belief, posterior_state, action, observation.to(device=self.parms.device), [reward], self.env.action_range[0], self.env.action_range[1]) observation = next_observation video_frames.append(make_grid(torch.cat([observation.cpu(), self.observation_model(belief, posterior_state).to(device=self.parms.device).cpu()], dim=3) + 0.5, nrow=5).numpy()) # Decentre if done: break self.create_and_dump_plan(self.env, belief, posterior_state, action, observation.to(device=self.parms.device), [reward], self.env.action_range[0], self.env.action_range[1]) # Set models to train mode self.transition_model.train() self.observation_model.train() self.reward_model.train() self.encoder.train() # Close test environments self.env.close() def create_and_dump_plan(self, env, belief, posterior_state, action, observation, reward, min_action=-inf, max_action=inf): tqdm.write("Dumping plan") video_frames = [] encoded_obs = self.encoder(observation).unsqueeze(dim=0) belief, _, _, _, posterior_state, _, _ = self.transition_model(posterior_state, action.unsqueeze(dim=0), belief, encoded_obs) belief, posterior_state = belief.squeeze(dim=0), posterior_state.squeeze(dim=0) # Remove time dimension from belief/state next_action,_, beliefs, states, plan = self.planner(belief, posterior_state,False) # Get action from planner(q(s_t|o≤t,a<t), p) predicted_frames = self.observation_model(beliefs, states).to(device=self.parms.device) for i in range(self.parms.planning_horizon): plan[i].clamp_(min=env.action_range[0], max=self.env.action_range[1]) # Clip action range next_observation, reward, done = env.step(plan[i].cpu()) next_observation = next_observation.squeeze(dim=0) video_frames.append(make_grid(torch.cat([next_observation, predicted_frames[i]], dim=1) + 0.5, nrow=2).numpy()) # Decentre write_video(video_frames, 'dump_plan', self.dump_plan_path, dump_frame=True)
results_dir, xaxis='step') if not args.symbolic_env: episode_str = str(episode).zfill(len(str(args.episodes))) write_video(video_frames, 'test_episode_%s' % episode_str, results_dir) # Lossy compression save_image( torch.as_tensor(video_frames[-1]), os.path.join(results_dir, 'test_episode_%s.png' % episode_str)) torch.save(metrics, os.path.join(results_dir, 'metrics.pth')) # Set models to train mode transition_model.train() observation_model.train() reward_model.train() encoder.train() # Close test environments test_envs.close() # Checkpoint models print("Completed episode {}".format(episode)) if episode % args.checkpoint_interval == 0: print("Saving!") torch.save( { 'transition_model': transition_model.state_dict(), 'observation_model': observation_model.state_dict(), 'reward_model': reward_model.state_dict(), 'encoder': encoder.state_dict(), 'optimiser': optimiser.state_dict() }, os.path.join(results_dir, 'models_%d.pth' % episode))
class Plan(object): def __init__(self): self.results_dir = os.path.join( 'results', '{}_seed_{}_{}_action_scale_{}_no_explore_{}_pool_len_{}_optimisation_iters_{}_top_planning-horizon' .format(args.env, args.seed, args.algo, args.action_scale, args.pool_len, args.optimisation_iters, args.top_planning_horizon)) args.results_dir = self.results_dir args.MultiGPU = True if torch.cuda.device_count( ) > 1 and args.MultiGPU else False self.__basic_setting() self.__init_sample() # Sampleing The Init Data # Initialise model parameters randomly self.transition_model = TransitionModel( args.belief_size, args.state_size, self.env.action_size, args.hidden_size, args.embedding_size, args.dense_activation_function).to(device=args.device) self.observation_model = ObservationModel( args.symbolic_env, self.env.observation_size, args.belief_size, args.state_size, args.embedding_size, args.cnn_activation_function).to(device=args.device) self.reward_model = RewardModel( args.belief_size, args.state_size, args.hidden_size, args.dense_activation_function).to(device=args.device) self.encoder = Encoder( args.symbolic_env, self.env.observation_size, args.embedding_size, args.cnn_activation_function).to(device=args.device) print("We Have {} GPUS".format(torch.cuda.device_count()) ) if args.MultiGPU else print("We use CPU") self.transition_model = nn.DataParallel( self.transition_model.to(device=args.device) ) if args.MultiGPU else self.transition_model self.observation_model = nn.DataParallel( self.observation_model.to(device=args.device) ) if args.MultiGPU else self.observation_model self.reward_model = nn.DataParallel( self.reward_model.to( device=args.device)) if args.MultiGPU else self.reward_model # encoder = nn.DataParallel(encoder.cuda()) # actor_model = nn.DataParallel(actor_model.cuda()) # value_model = nn.DataParallel(value_model.cuda()) # share the global parameters in multiprocessing self.encoder.share_memory() self.observation_model.share_memory() self.reward_model.share_memory() # Set all_model/global_actor_optimizer/global_value_optimizer self.param_list = list(self.transition_model.parameters()) + list( self.observation_model.parameters()) + list( self.reward_model.parameters()) + list( self.encoder.parameters()) self.model_optimizer = optim.Adam( self.param_list, lr=0 if args.learning_rate_schedule != 0 else args.model_learning_rate, eps=args.adam_epsilon) def update_belief_and_act(self, args, env, belief, posterior_state, action, observation, explore=False): # Infer belief over current state q(s_t|o≤t,a<t) from the history # print("action size: ",action.size()) torch.Size([1, 6]) belief, _, _, _, posterior_state, _, _ = self.upper_transition_model( posterior_state, action.unsqueeze(dim=0), belief, self.encoder(observation).unsqueeze(dim=0), None) if hasattr(env, "envs"): belief, posterior_state = list( map(lambda x: x.view(-1, args.test_episodes, x.shape[2]), [x for x in [belief, posterior_state]])) belief, posterior_state = belief.squeeze( dim=0), posterior_state.squeeze( dim=0) # Remove time dimension from belief/state action = self.algorithms.get_action(belief, posterior_state, explore) if explore: action = torch.clamp( Normal(action, args.action_noise).rsample(), -1, 1 ) # Add gaussian exploration noise on top of the sampled action # action = action + args.action_noise * torch.randn_like(action) # Add exploration noise ε ~ p(ε) to the action next_observation, reward, done = env.step( action.cpu() if isinstance(env, EnvBatcher) else action[0].cpu( )) # Perform environment step (action repeats handled internally) return belief, posterior_state, action, next_observation, reward, done def run(self): if args.algo == "dreamer": print("DREAMER") from algorithms.dreamer import Algorithms self.algorithms = Algorithms(self.env.action_size, self.transition_model, self.encoder, self.reward_model, self.observation_model) elif args.algo == "p2p": print("planing to plan") from algorithms.plan_to_plan import Algorithms self.algorithms = Algorithms(self.env.action_size, self.transition_model, self.encoder, self.reward_model, self.observation_model) elif args.algo == "actor_pool_1": print("async sub actor") from algorithms.actor_pool_1 import Algorithms_actor self.algorithms = Algorithms_actor(self.env.action_size, self.transition_model, self.encoder, self.reward_model, self.observation_model) elif args.algo == "aap": from algorithms.asynchronous_actor_planet import Algorithms self.algorithms = Algorithms(self.env.action_size, self.transition_model, self.encoder, self.reward_model, self.observation_model) else: print("planet") from algorithms.planet import Algorithms # args.MultiGPU = False self.algorithms = Algorithms(self.env.action_size, self.transition_model, self.reward_model) if args.test: self.test_only() self.global_prior = Normal( torch.zeros(args.batch_size, args.state_size, device=args.device), torch.ones(args.batch_size, args.state_size, device=args.device)) # Global prior N(0, I) self.free_nats = torch.full( (1, ), args.free_nats, device=args.device) # Allowed deviation in KL divergence # Training (and testing) # args.episodes = 1 for episode in tqdm(range(self.metrics['episodes'][-1] + 1, args.episodes + 1), total=args.episodes, initial=self.metrics['episodes'][-1] + 1): losses = self.train() # self.algorithms.save_loss_data(self.metrics['episodes']) # Update and plot loss metrics self.save_loss_data(tuple( zip(*losses))) # Update and plot loss metrics self.data_collection(episode=episode) # Data collection # args.test_interval = 1 if episode % args.test_interval == 0: self.test(episode=episode) # Test model self.save_model_data(episode=episode) # save model self.env.close() # Close training environment def train_env_model(self, beliefs, prior_states, prior_means, prior_std_devs, posterior_states, posterior_means, posterior_std_devs, observations, actions, rewards, nonterminals): # Calculate observation likelihood, reward likelihood and KL losses (for t = 0 only for latent overshooting); sum over final dims, average over batch and time (original implementation, though paper seems to miss 1/T scaling?) if args.worldmodel_LogProbLoss: observation_dist = Normal( bottle(self.observation_model, (beliefs, posterior_states)), 1) observation_loss = -observation_dist.log_prob( observations[1:]).sum( dim=2 if args.symbolic_env else (2, 3, 4)).mean(dim=(0, 1)) else: observation_loss = F.mse_loss( bottle(self.observation_model, (beliefs, posterior_states)), observations[1:], reduction='none').sum( dim=2 if args.symbolic_env else (2, 3, 4)).mean(dim=(0, 1)) if args.worldmodel_LogProbLoss: reward_dist = Normal( bottle(self.reward_model, (beliefs, posterior_states)), 1) reward_loss = -reward_dist.log_prob(rewards[:-1]).mean(dim=(0, 1)) else: reward_loss = F.mse_loss(bottle(self.reward_model, (beliefs, posterior_states)), rewards[:-1], reduction='none').mean(dim=(0, 1)) # transition loss div = kl_divergence(Normal(posterior_means, posterior_std_devs), Normal(prior_means, prior_std_devs)).sum(dim=2) kl_loss = torch.max(div, self.free_nats).mean( dim=(0, 1) ) # Note that normalisation by overshooting distance and weighting by overshooting distance cancel out if args.global_kl_beta != 0: kl_loss += args.global_kl_beta * kl_divergence( Normal(posterior_means, posterior_std_devs), self.global_prior).sum(dim=2).mean(dim=(0, 1)) # Calculate latent overshooting objective for t > 0 if args.overshooting_kl_beta != 0: overshooting_vars = [ ] # Collect variables for overshooting to process in batch for t in range(1, args.chunk_size - 1): d = min(t + args.overshooting_distance, args.chunk_size - 1) # Overshooting distance t_, d_ = t - 1, d - 1 # Use t_ and d_ to deal with different time indexing for latent states seq_pad = ( 0, 0, 0, 0, 0, t - d + args.overshooting_distance ) # Calculate sequence padding so overshooting terms can be calculated in one batch # Store (0) actions, (1) nonterminals, (2) rewards, (3) beliefs, (4) prior states, (5) posterior means, (6) posterior standard deviations and (7) sequence masks overshooting_vars.append( (F.pad(actions[t:d], seq_pad), F.pad(nonterminals[t:d], seq_pad), F.pad(rewards[t:d], seq_pad[2:]), beliefs[t_], prior_states[t_], F.pad(posterior_means[t_ + 1:d_ + 1].detach(), seq_pad), F.pad(posterior_std_devs[t_ + 1:d_ + 1].detach(), seq_pad, value=1), F.pad( torch.ones(d - t, args.batch_size, args.state_size, device=args.device), seq_pad)) ) # Posterior standard deviations must be padded with > 0 to prevent infinite KL divergences overshooting_vars = tuple(zip(*overshooting_vars)) # Update belief/state using prior from previous belief/state and previous action (over entire sequence at once) beliefs, prior_states, prior_means, prior_std_devs = self.upper_transition_model( torch.cat(overshooting_vars[4], dim=0), torch.cat(overshooting_vars[0], dim=1), torch.cat(overshooting_vars[3], dim=0), None, torch.cat(overshooting_vars[1], dim=1)) seq_mask = torch.cat(overshooting_vars[7], dim=1) # Calculate overshooting KL loss with sequence mask kl_loss += ( 1 / args.overshooting_distance ) * args.overshooting_kl_beta * torch.max((kl_divergence( Normal(torch.cat(overshooting_vars[5], dim=1), torch.cat(overshooting_vars[6], dim=1)), Normal(prior_means, prior_std_devs) ) * seq_mask).sum(dim=2), self.free_nats).mean(dim=(0, 1)) * ( args.chunk_size - 1 ) # Update KL loss (compensating for extra average over each overshooting/open loop sequence) # Calculate overshooting reward prediction loss with sequence mask if args.overshooting_reward_scale != 0: reward_loss += ( 1 / args.overshooting_distance ) * args.overshooting_reward_scale * F.mse_loss( bottle(self.reward_model, (beliefs, prior_states)) * seq_mask[:, :, 0], torch.cat(overshooting_vars[2], dim=1), reduction='none' ).mean(dim=(0, 1)) * ( args.chunk_size - 1 ) # Update reward loss (compensating for extra average over each overshooting/open loop sequence) # Apply linearly ramping learning rate schedule if args.learning_rate_schedule != 0: for group in self.model_optimizer.param_groups: group['lr'] = min( group['lr'] + args.model_learning_rate / args.model_learning_rate_schedule, args.model_learning_rate) model_loss = observation_loss + reward_loss + kl_loss # Update model parameters self.model_optimizer.zero_grad() model_loss.backward() nn.utils.clip_grad_norm_(self.param_list, args.grad_clip_norm, norm_type=2) self.model_optimizer.step() return observation_loss, reward_loss, kl_loss def train(self): # Model fitting losses = [] print("training loop") # args.collect_interval = 1 for s in tqdm(range(args.collect_interval)): # Draw sequence chunks {(o_t, a_t, r_t+1, terminal_t+1)} ~ D uniformly at random from the dataset (including terminal flags) observations, actions, rewards, nonterminals = self.D.sample( args.batch_size, args.chunk_size) # Transitions start at time t = 0 # Create initial belief and state for time t = 0 init_belief, init_state = torch.zeros( args.batch_size, args.belief_size, device=args.device), torch.zeros(args.batch_size, args.state_size, device=args.device) # Update belief/state using posterior from previous belief/state, previous action and current observation (over entire sequence at once) obs = bottle(self.encoder, (observations[1:], )) beliefs, prior_states, prior_means, prior_std_devs, posterior_states, posterior_means, posterior_std_devs = self.upper_transition_model( prev_state=init_state, actions=actions[:-1], prev_belief=init_belief, obs=obs, nonterminals=nonterminals[:-1]) # Calculate observation likelihood, reward likelihood and KL losses (for t = 0 only for latent overshooting); sum over final dims, average over batch and time (original implementation, though paper seems to miss 1/T scaling?) observation_loss, reward_loss, kl_loss = self.train_env_model( beliefs, prior_states, prior_means, prior_std_devs, posterior_states, posterior_means, posterior_std_devs, observations, actions, rewards, nonterminals) # Dreamer implementation: actor loss calculation and optimization with torch.no_grad(): actor_states = posterior_states.detach().to( device=args.device).share_memory_() actor_beliefs = beliefs.detach().to( device=args.device).share_memory_() # if not os.path.exists(os.path.join(os.getcwd(), 'tensor_data/' + args.results_dir)): os.mkdir(os.path.join(os.getcwd(), 'tensor_data/' + args.results_dir)) torch.save( actor_states, os.path.join(os.getcwd(), args.results_dir + '/actor_states.pt')) torch.save( actor_beliefs, os.path.join(os.getcwd(), args.results_dir + '/actor_beliefs.pt')) # [self.actor_pipes[i][0].send(1) for i, w in enumerate(self.workers_actor)] # Parent_pipe send data using i'th pipes # [self.actor_pipes[i][0].recv() for i, _ in enumerate(self.actor_pool)] # waitting the children finish self.algorithms.train_algorithm(actor_states, actor_beliefs) losses.append( [observation_loss.item(), reward_loss.item(), kl_loss.item()]) # if self.algorithms.train_algorithm(actor_states, actor_beliefs) is not None: # merge_actor_loss, merge_value_loss = self.algorithms.train_algorithm(actor_states, actor_beliefs) # losses.append([observation_loss.item(), reward_loss.item(), kl_loss.item(), merge_actor_loss.item(), merge_value_loss.item()]) # else: # losses.append([observation_loss.item(), reward_loss.item(), kl_loss.item()]) return losses def data_collection(self, episode): print("Data collection") with torch.no_grad(): observation, total_reward = self.env.reset(), 0 belief, posterior_state, action = torch.zeros( 1, args.belief_size, device=args.device), torch.zeros( 1, args.state_size, device=args.device), torch.zeros(1, self.env.action_size, device=args.device) pbar = tqdm(range(args.max_episode_length // args.action_repeat)) for t in pbar: # print("step",t) belief, posterior_state, action, next_observation, reward, done = self.update_belief_and_act( args, self.env, belief, posterior_state, action, observation.to(device=args.device)) self.D.append(observation, action.cpu(), reward, done) total_reward += reward observation = next_observation if args.render: self.env.render() if done: pbar.close() break # Update and plot train reward metrics self.metrics['steps'].append(t + self.metrics['steps'][-1]) self.metrics['episodes'].append(episode) self.metrics['train_rewards'].append(total_reward) Save_Txt(self.metrics['episodes'][-1], self.metrics['train_rewards'][-1], 'train_rewards', args.results_dir) # lineplot(metrics['episodes'][-len(metrics['train_rewards']):], metrics['train_rewards'], 'train_rewards', results_dir) def test(self, episode): print("Test model") # Set models to eval mode self.transition_model.eval() self.observation_model.eval() self.reward_model.eval() self.encoder.eval() self.algorithms.train_to_eval() # self.actor_model_g.eval() # self.value_model_g.eval() # Initialise parallelised test environments test_envs = EnvBatcher( Env, (args.env, args.symbolic_env, args.seed, args.max_episode_length, args.action_repeat, args.bit_depth), {}, args.test_episodes) with torch.no_grad(): observation, total_rewards, video_frames = test_envs.reset( ), np.zeros((args.test_episodes, )), [] belief, posterior_state, action = torch.zeros( args.test_episodes, args.belief_size, device=args.device), torch.zeros( args.test_episodes, args.state_size, device=args.device), torch.zeros(args.test_episodes, self.env.action_size, device=args.device) pbar = tqdm(range(args.max_episode_length // args.action_repeat)) for t in pbar: belief, posterior_state, action, next_observation, reward, done = self.update_belief_and_act( args, test_envs, belief, posterior_state, action, observation.to(device=args.device)) total_rewards += reward.numpy() if not args.symbolic_env: # Collect real vs. predicted frames for video video_frames.append( make_grid(torch.cat([ observation, self.observation_model(belief, posterior_state).cpu() ], dim=3) + 0.5, nrow=5).numpy()) # Decentre observation = next_observation if done.sum().item() == args.test_episodes: pbar.close() break # Update and plot reward metrics (and write video if applicable) and save metrics self.metrics['test_episodes'].append(episode) self.metrics['test_rewards'].append(total_rewards.tolist()) Save_Txt(self.metrics['test_episodes'][-1], self.metrics['test_rewards'][-1], 'test_rewards', args.results_dir) # Save_Txt(np.asarray(metrics['steps'])[np.asarray(metrics['test_episodes']) - 1], metrics['test_rewards'],'test_rewards_steps', results_dir, xaxis='step') # lineplot(metrics['test_episodes'], metrics['test_rewards'], 'test_rewards', results_dir) # lineplot(np.asarray(metrics['steps'])[np.asarray(metrics['test_episodes']) - 1], metrics['test_rewards'], 'test_rewards_steps', results_dir, xaxis='step') if not args.symbolic_env: episode_str = str(episode).zfill(len(str(args.episodes))) write_video(video_frames, 'test_episode_%s' % episode_str, args.results_dir) # Lossy compression save_image( torch.as_tensor(video_frames[-1]), os.path.join(args.results_dir, 'test_episode_%s.png' % episode_str)) torch.save(self.metrics, os.path.join(args.results_dir, 'metrics.pth')) # Set models to train mode self.transition_model.train() self.observation_model.train() self.reward_model.train() self.encoder.train() # self.actor_model_g.train() # self.value_model_g.train() self.algorithms.eval_to_train() # Close test environments test_envs.close() def test_only(self): # Set models to eval mode self.transition_model.eval() self.reward_model.eval() self.encoder.eval() with torch.no_grad(): total_reward = 0 for _ in tqdm(range(args.test_episodes)): observation = self.env.reset() belief, posterior_state, action = torch.zeros( 1, args.belief_size, device=args.device), torch.zeros( 1, args.state_size, device=args.device), torch.zeros(1, self.env.action_size, device=args.device) pbar = tqdm( range(args.max_episode_length // args.action_repeat)) for t in pbar: belief, posterior_state, action, observation, reward, done = self.update_belief_and_act( args, self.env, belief, posterior_state, action, observation.to(evice=args.device)) total_reward += reward if args.render: self.env.render() if done: pbar.close() break print('Average Reward:', total_reward / args.test_episodes) self.env.close() quit() def __basic_setting(self): args.overshooting_distance = min( args.chunk_size, args.overshooting_distance ) # Overshooting distance cannot be greater than chunk size print(' ' * 26 + 'Options') for k, v in vars(args).items(): print(' ' * 26 + k + ': ' + str(v)) print("torch.cuda.device_count() {}".format(torch.cuda.device_count())) os.makedirs(args.results_dir, exist_ok=True) np.random.seed(args.seed) torch.manual_seed(args.seed) # Set Cuda if torch.cuda.is_available() and not args.disable_cuda: print("using CUDA") args.device = torch.device('cuda') torch.cuda.manual_seed(args.seed) else: print("using CPU") args.device = torch.device('cpu') self.summary_name = args.results_dir + "/{}_{}_log" self.writer = SummaryWriter(self.summary_name.format( args.env, args.id)) self.env = Env(args.env, args.symbolic_env, args.seed, args.max_episode_length, args.action_repeat, args.bit_depth) self.metrics = { 'steps': [], 'episodes': [], 'train_rewards': [], 'test_episodes': [], 'test_rewards': [], 'observation_loss': [], 'reward_loss': [], 'kl_loss': [], 'merge_actor_loss': [], 'merge_value_loss': [] } def __init_sample(self): if args.experience_replay is not '' and os.path.exists( args.experience_replay): self.D = torch.load(args.experience_replay) self.metrics['steps'], self.metrics['episodes'] = [ self.D.steps ] * self.D.episodes, list(range(1, self.D.episodes + 1)) elif not args.test: self.D = ExperienceReplay(args.experience_size, args.symbolic_env, self.env.observation_size, self.env.action_size, args.bit_depth, args.device) # Initialise dataset D with S random seed episodes print( "Start Multi Sample Processing -------------------------------" ) start_time = time.time() data_lists = [ Manager().list() for i in range(1, args.seed_episodes + 1) ] # Set Global Lists pipes = [Pipe() for i in range(1, args.seed_episodes + 1) ] # Set Multi Pipe workers_init_sample = [ Worker_init_Sample(child_conn=child, id=i + 1) for i, [parent, child] in enumerate(pipes) ] for i, w in enumerate(workers_init_sample): w.start() # Start Single Process pipes[i][0].send( data_lists[i]) # Parent_pipe send data using i'th pipes [w.join() for w in workers_init_sample] # wait sub_process done for i, [parent, child] in enumerate(pipes): # datas = parent.recv() for data in list(parent.recv()): if isinstance(data, tuple): assert len(data) == 4 self.D.append(data[0], data[1], data[2], data[3]) elif isinstance(data, int): t = data self.metrics['steps'].append(t * args.action_repeat + ( 0 if len(self.metrics['steps']) == 0 else self.metrics['steps'][-1])) self.metrics['episodes'].append(i + 1) else: print( "The Recvive Data Have Some Problems, Need To Fix") end_time = time.time() print("the process times {} s".format(end_time - start_time)) print( "End Multi Sample Processing -------------------------------") def upper_transition_model(self, prev_state, actions, prev_belief, obs, nonterminals): actions = torch.transpose(actions, 0, 1) if args.MultiGPU else actions nonterminals = torch.transpose(nonterminals, 0, 1).to( device=args.device ) if args.MultiGPU and nonterminals is not None else nonterminals obs = torch.transpose(obs, 0, 1).to( device=args.device) if args.MultiGPU and obs is not None else obs temp_val = self.transition_model(prev_state.to(device=args.device), actions.to(device=args.device), prev_belief.to(device=args.device), obs, nonterminals) return list( map( lambda x: torch.cat(x.chunk(torch.cuda.device_count(), 0), 1) if x.shape[1] != prev_state.shape[0] else x, [x for x in temp_val])) def save_loss_data(self, losses): self.metrics['observation_loss'].append(losses[0]) self.metrics['reward_loss'].append(losses[1]) self.metrics['kl_loss'].append(losses[2]) self.metrics['merge_actor_loss'].append( losses[3]) if losses.__len__() > 3 else None self.metrics['merge_value_loss'].append( losses[4]) if losses.__len__() > 3 else None Save_Txt(self.metrics['episodes'][-1], self.metrics['observation_loss'][-1], 'observation_loss', args.results_dir) Save_Txt(self.metrics['episodes'][-1], self.metrics['reward_loss'][-1], 'reward_loss', args.results_dir) Save_Txt(self.metrics['episodes'][-1], self.metrics['kl_loss'][-1], 'kl_loss', args.results_dir) Save_Txt(self.metrics['episodes'][-1], self.metrics['merge_actor_loss'][-1], 'merge_actor_loss', args.results_dir) if losses.__len__() > 3 else None Save_Txt(self.metrics['episodes'][-1], self.metrics['merge_value_loss'][-1], 'merge_value_loss', args.results_dir) if losses.__len__() > 3 else None # lineplot(metrics['episodes'][-len(metrics['observation_loss']):], metrics['observation_loss'], 'observation_loss', results_dir) # lineplot(metrics['episodes'][-len(metrics['reward_loss']):], metrics['reward_loss'], 'reward_loss', results_dir) # lineplot(metrics['episodes'][-len(metrics['kl_loss']):], metrics['kl_loss'], 'kl_loss', results_dir) # lineplot(metrics['episodes'][-len(metrics['actor_loss']):], metrics['actor_loss'], 'actor_loss', results_dir) # lineplot(metrics['episodes'][-len(metrics['value_loss']):], metrics['value_loss'], 'value_loss', results_dir) def save_model_data(self, episode): # writer.add_scalar("train_reward", metrics['train_rewards'][-1], metrics['steps'][-1]) # writer.add_scalar("train/episode_reward", metrics['train_rewards'][-1], metrics['steps'][-1]*args.action_repeat) # writer.add_scalar("observation_loss", metrics['observation_loss'][0][-1], metrics['steps'][-1]) # writer.add_scalar("reward_loss", metrics['reward_loss'][0][-1], metrics['steps'][-1]) # writer.add_scalar("kl_loss", metrics['kl_loss'][0][-1], metrics['steps'][-1]) # writer.add_scalar("actor_loss", metrics['actor_loss'][0][-1], metrics['steps'][-1]) # writer.add_scalar("value_loss", metrics['value_loss'][0][-1], metrics['steps'][-1]) # print("episodes: {}, total_steps: {}, train_reward: {} ".format(metrics['episodes'][-1], metrics['steps'][-1], metrics['train_rewards'][-1])) # Checkpoint models if episode % args.checkpoint_interval == 0: # torch.save({'transition_model': transition_model.state_dict(), # 'observation_model': observation_model.state_dict(), # 'reward_model': reward_model.state_dict(), # 'encoder': encoder.state_dict(), # 'actor_model': actor_model_g.state_dict(), # 'value_model': value_model_g.state_dict(), # 'model_optimizer': model_optimizer.state_dict(), # 'actor_optimizer': actor_optimizer_g.state_dict(), # 'value_optimizer': value_optimizer_g.state_dict() # }, os.path.join(results_dir, 'models_%d.pth' % episode)) if args.checkpoint_experience: torch.save( self.D, os.path.join(args.results_dir, 'experience.pth') ) # Warning: will fail with MemoryError with large memory sizes
class SACAgent(): def __init__(self, action_size, state_size, config): self.seed = config["seed"] torch.manual_seed(self.seed) np.random.seed(seed=self.seed) self.env = gym.make(config["env_name"]) self.env = FrameStack(self.env, config) self.env.seed(self.seed) self.action_size = action_size self.state_size = state_size self.tau = config["tau"] self.gamma = config["gamma"] self.batch_size = config["batch_size"] self.lr = config["lr"] self.history_length = config["history_length"] self.size = config["size"] if not torch.cuda.is_available(): config["device"] == "cpu" self.device = config["device"] self.eval = config["eval"] self.vid_path = config["vid_path"] print("actions size ", action_size) self.critic = QNetwork(state_size, action_size, config["fc1_units"], config["fc2_units"]).to(self.device) self.q_optim = torch.optim.Adam(self.critic.parameters(), config["lr_critic"]) self.target_critic = QNetwork(state_size, action_size, config["fc1_units"], config["fc2_units"]).to(self.device) self.target_critic.load_state_dict(self.critic.state_dict()) self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device) self.alpha = self.log_alpha.exp() self.alpha_optim = Adam([self.log_alpha], lr=config["lr_alpha"]) self.policy = SACActor(state_size, action_size).to(self.device) self.policy_optim = Adam(self.policy.parameters(), lr=config["lr_policy"]) self.encoder = Encoder(config).to(self.device) self.encoder_optimizer = torch.optim.Adam(self.encoder.parameters(), self.lr) self.episodes = config["episodes"] self.memory = ReplayBuffer((self.history_length, self.size, self.size), (1, ), config["buffer_size"], config["image_pad"], self.seed, self.device) pathname = config["seed"] tensorboard_name = str(config["res_path"]) + '/runs/' + str(pathname) self.writer = SummaryWriter(tensorboard_name) self.steps = 0 self.target_entropy = -torch.prod( torch.Tensor(action_size).to(self.device)).item() def act(self, state, evaluate=False): with torch.no_grad(): state = torch.FloatTensor(state).to(self.device).unsqueeze(0) state = state.type(torch.float32).div_(255) self.encoder.eval() state = self.encoder.create_vector(state) self.encoder.train() if evaluate is False: action = self.policy.sample(state) else: action_prob, _ = self.policy(state) action = torch.argmax(action_prob) action = action.cpu().numpy() return action # action = np.clip(action, self.min_action, self.max_action) action = action.cpu().numpy()[0] return action def train_agent(self): average_reward = 0 scores_window = deque(maxlen=100) t0 = time.time() for i_epiosde in range(1, self.episodes): episode_reward = 0 state = self.env.reset() t = 0 while True: t += 1 action = self.act(state) next_state, reward, done, _ = self.env.step(action) episode_reward += reward if i_epiosde > 10: self.learn() self.memory.add(state, reward, action, next_state, done) state = next_state if done: scores_window.append(episode_reward) break if i_epiosde % self.eval == 0: self.eval_policy() ave_reward = np.mean(scores_window) print("Epiosde {} Steps {} Reward {} Reward averge{:.2f} Time {}". format(i_epiosde, t, episode_reward, np.mean(scores_window), time_format(time.time() - t0))) self.writer.add_scalar('Aver_reward', ave_reward, self.steps) def learn(self): self.steps += 1 states, rewards, actions, next_states, dones = self.memory.sample( self.batch_size) states = states.type(torch.float32).div_(255) states = self.encoder.create_vector(states) states_detached = states.detach() qf1, qf2 = self.critic(states) q_value1 = qf1.gather(1, actions) q_value2 = qf2.gather(1, actions) with torch.no_grad(): next_states = next_states.type(torch.float32).div_(255) next_states = self.encoder.create_vector(next_states) q1_target, q2_target = self.target_critic(next_states) min_q_target = torch.min(q1_target, q2_target) next_action_prob, next_action_log_prob = self.policy(next_states) next_q_target = ( next_action_prob * (min_q_target - self.alpha * next_action_log_prob)).sum( dim=1, keepdim=True) next_q_value = rewards + (1 - dones) * self.gamma * next_q_target # --------------------------update-q-------------------------------------------------------- loss = F.mse_loss(q_value1, next_q_value) + F.mse_loss( q_value2, next_q_value) self.q_optim.zero_grad() self.encoder_optimizer.zero_grad() loss.backward() self.q_optim.step() self.encoder_optimizer.zero_grad() self.writer.add_scalar('loss/q', loss, self.steps) # --------------------------update-policy-------------------------------------------------------- action_prob, log_action_prob = self.policy(states_detached) with torch.no_grad(): q_pi1, q_pi2 = self.critic(states_detached) min_q_values = torch.min(q_pi1, q_pi2) #policy_loss = (action_prob * ((self.alpha * log_action_prob) - min_q_values).detach()).sum(dim=1).mean() policy_loss = (action_prob * ((self.alpha * log_action_prob) - min_q_values)).sum( dim=1).mean() self.policy_optim.zero_grad() policy_loss.backward() self.policy_optim.step() self.writer.add_scalar('loss/policy', policy_loss, self.steps) # --------------------------update-alpha-------------------------------------------------------- alpha_loss = (action_prob.detach() * (-self.log_alpha * (log_action_prob + self.target_entropy).detach())).sum( dim=1).mean() self.alpha_optim.zero_grad() alpha_loss.backward() self.alpha_optim.step() self.writer.add_scalar('loss/alpha', alpha_loss, self.steps) self.soft_udapte(self.critic, self.target_critic) self.alpha = self.log_alpha.exp() def soft_udapte(self, online, target): for param, target_parm in zip(online.parameters(), target.parameters()): target_parm.data.copy_(self.tau * param.data + (1 - self.tau) * target_parm.data) def eval_policy(self, eval_episodes=4): env = gym.make(self.env_name) env = wrappers.Monitor(env, str(self.vid_path) + "/{}".format(self.steps), video_callable=lambda episode_id: True, force=True) average_reward = 0 scores_window = deque(maxlen=100) for i_epiosde in range(eval_episodes): print("Eval Episode {} of {} ".format(i_epiosde, eval_episodes)) episode_reward = 0 state = env.reset() while True: action = self.act(state, evaluate=True) state, reward, done, _ = env.step(action) episode_reward += reward if done: break scores_window.append(episode_reward) average_reward = np.mean(scores_window) self.writer.add_scalar('Eval_reward', average_reward, self.steps)
class AAETrainer(AbstractTrainer): def __init__(self, opt): super().__init__(opt) print('[info] Dataset:', self.opt.dataset) print('[info] Alhpa = ', self.opt.alpha) print('[info] Latent dimension = ', self.opt.latent_dim) self.opt = opt self.start_visdom() def start_visdom(self): self.vis = utils.Visualizer(env='Adversarial AutoEncoder Training', port=8888) def build_network(self): print('[info] Build the network architecture') self.encoder = Encoder(z_dim=self.opt.latent_dim) if self.opt.dataset == 'SMPL': num_verts = 6890 elif self.opt.dataset == 'all_animals': num_verts = 3889 self.decoder = Decoder(num_verts=num_verts, z_dim=self.opt.latent_dim) self.discriminator = Discriminator(input_dim=self.opt.latent_dim) self.encoder.cuda() self.decoder.cuda() self.discriminator.cuda() def build_optimizer(self): print('[info] Build the optimizer') self.optim_dis = optim.SGD(self.discriminator.parameters(), lr=self.opt.learning_rate) self.optim_AE = optim.Adam(itertools.chain(self.encoder.parameters(), self.decoder.parameters()), lr=self.opt.learning_rate) def build_dataset_train(self): train_data = ACAPData(mode='train', name=self.opt.dataset) self.num_train_data = len(train_data) print('[info] Number of training samples = ', self.num_train_data) self.train_loader = torch.utils.data.DataLoader( train_data, batch_size=self.opt.batch_size, shuffle=True) def build_dataset_valid(self): valid_data = ACAPData(mode='valid', name=self.opt.dataset) self.num_valid_data = len(valid_data) print('[info] Number of validation samples = ', self.num_valid_data) self.valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=128, shuffle=True) def build_losses(self): print('[info] Build the loss functions') self.mseLoss = torch.nn.MSELoss() self.ganLoss = torch.nn.BCELoss() def print_iteration_stats(self): """ print stats at each iteration """ print( '\r[Epoch %d] [Iteration %d/%d] enc = %f dis = %f rec = %f' % (self.epoch, self.iteration, int(self.num_train_data / self.opt.batch_size), self.enc_loss.item(), self.dis_loss.item(), self.rec_loss.item()), end='') def train_iteration(self): self.encoder.train() self.decoder.train() self.discriminator.train() x = self.data.cuda() z = self.encoder(x) ''' Discriminator ''' # sample from N(0, I) z_real = Variable(torch.randn(z.size(0), z.size(1))).cuda() y_real = Variable(torch.ones(z.size(0))).cuda() dis_real_loss = self.ganLoss( self.discriminator(z_real).view(-1), y_real) y_fake = Variable(torch.zeros(z.size(0))).cuda() dis_fake_loss = self.ganLoss(self.discriminator(z).view(-1), y_fake) self.optim_dis.zero_grad() self.dis_loss = 0.5 * (dis_fake_loss + dis_real_loss) self.dis_loss.backward(retain_graph=True) self.optim_dis.step() self.dis_losses.append(self.dis_loss.item()) ''' Autoencoder ''' # Encoder hopes to generate latent vectors that are closed to prior. y_real = Variable(torch.ones(z.size(0))).cuda() self.enc_loss = self.ganLoss(self.discriminator(z).view(-1), y_real) # Decoder hopes to make the reconstruction as similar to input as possible. rec = self.decoder(z) self.rec_loss = self.mseLoss(rec, x) # There is a trade-off here: # Latent regularization V.S. Reconstruction quality self.EG_loss = self.opt.alpha * self.enc_loss + ( 1 - self.opt.alpha) * self.rec_loss self.optim_AE.zero_grad() self.EG_loss.backward() self.optim_AE.step() self.enc_losses.append(self.enc_loss.item()) self.rec_losses.append(self.rec_loss.item()) self.print_iteration_stats() self.increment_iteration() def train_epoch(self): self.reset_iteration() self.dis_losses = [] self.enc_losses = [] self.rec_losses = [] for step, data in enumerate(self.train_loader): self.data = data self.train_iteration() self.dis_losses = torch.Tensor(self.dis_losses) self.dis_losses = torch.mean(self.dis_losses) self.enc_losses = torch.Tensor(self.enc_losses) self.enc_losses = torch.mean(self.enc_losses) self.rec_losses = torch.Tensor(self.rec_losses) self.rec_losses = torch.mean(self.rec_losses) self.vis.draw_line(win='Encoder Loss', x=self.epoch, y=self.enc_losses) self.vis.draw_line(win='Discriminator Loss', x=self.epoch, y=self.dis_losses) self.vis.draw_line(win='Reconstruction Loss', x=self.epoch, y=self.rec_losses) def valid_iteration(self): self.encoder.eval() self.decoder.eval() self.discriminator.eval() x = self.data.cuda() z = self.encoder(x) recon = self.decoder(z) # loss rec_loss = self.mseLoss(recon, x) self.rec_loss.append(rec_loss.item()) self.increment_iteration() def valid_epoch(self): self.reset_iteration() self.rec_loss = [] for step, data in enumerate(self.valid_loader): self.data = data self.valid_iteration() self.rec_loss = torch.Tensor(self.rec_loss) self.rec_loss = torch.mean(self.rec_loss) self.vis.draw_line(win='Valid reconstruction loss', x=self.epoch, y=self.rec_loss) def save_network(self): print("\n[info] saving net...") torch.save(self.encoder.state_dict(), f"{self.opt.save_path}/Encoder.pth") torch.save(self.decoder.state_dict(), f"{self.opt.save_path}/Decoder.pth") torch.save(self.discriminator.state_dict(), f"{self.opt.save_path}/Discriminator.pth")
class Solver: def __init__(self): self.train_lr = 1e-4 self.num_classes = 9 self.clf_target = Classifier().cuda() self.clf2 = Classifier().cuda() self.clf1 = Classifier().cuda() self.encoder = Encoder().cuda() self.pretrain_lr = 1e-4 self.weights_coef = 1e-3 def to_var(self, x): """Converts numpy to variable.""" if torch.cuda.is_available(): x = x.cuda() return Variable(x, requires_grad=False).float() def loss(self, predictions, y_true, weights_coef=None): """ :param predictions: list of prediction tensors """ assert len(predictions[0].shape) == 2 and len( y_true.shape) == 1, (predictions.shape, y_true.shape) losses = [F.cross_entropy(y_hat, y_true) for y_hat in predictions] loss = sum(losses) # """ # We add the term |W1^T W2| to the cost function, where W1, W2 denote fully connected layers’ # weights of F1 and F2 which are first applied to the feature F(xi) # """ if weights_coef: lw = torch.matmul(solver.clf1.fc1.weight, solver.clf2.fc1.weight.T).abs().sum().mean() loss += weights_coef * lw return loss def pretrain(self, source_loader, target_val_loader, pretrain_epochs=1): source_iter = iter(source_loader) source_per_epoch = len(source_iter) print("source_per_epoch:", source_per_epoch) # pretrain log_pre = 250 lr = self.pretrain_lr pretrain_iters = source_per_epoch * pretrain_epochs params = reduce( lambda a, b: a + b, map(lambda i: list(i.parameters()), [self.encoder, self.clf1, self.clf2, self.clf_target])) pretrain_optimizer = optim.Adam(params, lr) accuracies = [] for step in range(pretrain_iters + 1): # ============ Initialization ============# # refresh if (step + 1) % source_per_epoch == 0: source_iter = iter(source_loader) # load the data source, s_labels = next(source_iter) source, s_labels = self.to_var(source), self.to_var( s_labels).long().squeeze() # ============ Training ============ # pretrain_optimizer.zero_grad() # forward features = self.encoder(source) y1_hat = self.clf1(features) y2_hat = self.clf2(features) y_target_hat = self.clf_target(features) # loss loss_source_class = self.loss([y1_hat, y2_hat, y_target_hat], s_labels, weights_coef=self.weights_coef) # one step loss_source_class.backward() pretrain_optimizer.step() pretrain_optimizer.zero_grad() # TODO: make this each step and on log_pre step just average and print previous results # ============ Validation ============ # if (step + 1) % log_pre == 0: with torch.no_grad(): source_val_features = self.encoder(source) c_source1 = self.clf1(source_val_features) c_source2 = self.clf2(source_val_features) c_target = self.clf_target(source_val_features) print("Train data (source) scores:") print("Step %d | Source clf1=%.2f, clf2=%.2f | Source data clf_t=%.2f" \ % (step, accuracy(c_source1, s_labels), accuracy(c_source2, s_labels), accuracy(c_target, s_labels)) ) acc = self.eval(target_val_loader, self.clf_target) print("Val target data acc=%.2f" % acc) print() def pseudo_labeling(self, loader, pool_size=4000, threshold=0.9): """ When C1, C2 denote the class which has the maximum predicted probability for y1, y2, we assign a pseudo-label to xk if the following two conditions are satisfied. First, we require C1 = C2 to give pseudo-labels, which means two different classifiers agree with the prediction. The second requirement is that the maximizing probability of y1 or y2 exceeds the threshold parameter, which we set as 0.9 or 0.95 in the experiment. :return: """ pool = [] # x, y_pseudo for x, _ in loader: batch_size = x.shape[0] x = self.to_var(x) ys1 = F.softmax(self.clf1(self.encoder(x))) ys2 = F.softmax(self.clf2(self.encoder(x))) # _, pseudo_labels = torch.max(pseudo_labels, 1) for i in range(batch_size): y1 = ys1[i] y2 = ys2[i] val1, idx1 = torch.max(y1, 0) val2, idx2 = torch.max(y2, 0) if idx1 == idx2 and max(val1, val2) >= threshold: pool.append((x[i].cpu(), idx1.cpu().item())) if len(pool) >= pool_size: return pool return pool def train(self, source_loader, source_val_loader, target_loader, target_val_loader, epochs): """ :param epochs: target epochs the training will be done """ # pretrain log_pre = 30 lr = self.train_lr params1 = reduce( lambda a, b: a + b, map(lambda i: list(i.parameters()), [self.encoder, self.clf1, self.clf2])) params2 = list(self.encoder.parameters()) + list( self.clf_target.parameters()) optimizer1 = optim.Adam(params1, lr) optimizer2 = optim.Adam(params2, lr) # ad-hoc acs1 = [] acs2 = [] acs3 = [] for epoch in range(epochs): source_iter = iter(source_loader) target_iter = iter(target_loader) source_per_epoch = len(source_iter) target_per_epoch = len(target_iter) if epoch == 0: print("source_per_epoch, target_per_epoch:", source_per_epoch, target_per_epoch) if epoch == 3: for param_group in optimizer1.param_groups: param_group['lr'] = lr * 0.1 for param_group in optimizer2.param_groups: param_group['lr'] = lr * 0.1 if epoch == 6: for param_group in optimizer1.param_groups: param_group['lr'] = lr * 0.01 for param_group in optimizer2.param_groups: param_group['lr'] = lr * 0.01 # ============ Pseudo-labeling ============ # # Fill candidates target_candidates = self.pseudo_labeling(target_loader, pool_size=4000 * epoch) print("Target candidates len:", len(target_candidates)) if len(target_candidates) <= 1: target_candidates = self.pseudo_labeling(target_loader, threshold=0.0) print("Target candidates len:", len(target_candidates)) target_candidates_loader = self.wrap_to_loader( target_candidates, batch_size=target_loader.batch_size) for step, (target, t_labels) in enumerate(target_candidates_loader): if (step + 1) % source_per_epoch == 0: source_iter = iter(source_loader) source, s_labels = next(source_iter) target, t_labels = self.to_var(target), self.to_var( t_labels).long().squeeze() source, s_labels = self.to_var(source), self.to_var( s_labels).long().squeeze() # ============ Train F, F1, F2 ============ # optimizer1.zero_grad() # Source data # forward features = self.encoder(source) y1s_hat = self.clf1(features) y2s_hat = self.clf2(features) # loss loss_source_class = self.loss([y1s_hat, y2s_hat], s_labels, weights_coef=self.weights_coef) # Target data # forward features = self.encoder(target) y1t_hat = self.clf1(features) y2t_hat = self.clf2(features) # loss loss_target_class = self.loss([y1t_hat, y2t_hat], t_labels, weights_coef=self.weights_coef) # one step (loss_source_class + loss_target_class).backward() optimizer1.step() optimizer1.zero_grad() # ============ Train F, Ft ============ # optimizer2.zero_grad() # Target data # forward y_target_hat = self.clf_target(self.encoder(target)) # loss loss_target_class = self.loss([y_target_hat], t_labels) # one step loss_target_class.backward() optimizer2.step() optimizer2.zero_grad() # ============ Validation ============ # acs1.append(accuracy(y1s_hat, s_labels).item()) acs2.append(accuracy(y2s_hat, s_labels).item()) acs3.append(accuracy(y_target_hat, t_labels).item()) if (step + 1) % log_pre == 0: acc = self.eval(target_val_loader, self.clf_target) print("Step %d | Val data target classifier acc=%.2f" % (step, acc)) print( " Train accuracy clf1=%.2f, clf2=%.2f, clf_t=%.2f" % (np.mean(acs1), np.mean(acs2), np.mean(acs3))) acs1 = [] acs2 = [] acs3 = [] # acc1 = self.eval(source_val_loader, self.clf1) # print(" | Val data source classifier1 acc=%.2f" % acc1) # acc2 = self.eval(source_val_loader, self.clf2) # print(" | Val data source classifier2 acc=%.2f" % acc2) print() def save_models(self): torch.save(self.encoder, 'encoder.pth') torch.save(self.clf1, 'clf1.pth') torch.save(self.clf2, 'clf2.pth') torch.save(self.clf_target, 'clf_target.pth') def load_models(self): self.encoder = torch.load('encoder.pth') self.clf1 = torch.load('clf1.pth') self.clf2 = torch.load('clf2.pth') self.clf_target = torch.load('clf_target.pth') def eval(self, loader, classifier): """ Evaluate encoder + passed classifier """ # for x, y_true in loader: # y_hat = classifier(self.encoder) # acc = accuracy(y_hat, y_true) class_correct = [0] * self.num_classes class_total = [0.] * self.num_classes classes = shl_processing.coarse_label_mapping self.encoder.eval() classifier.eval() for x, y_true in loader: # forward pass: compute predicted outputs by passing inputs to the model x, y_true = self.to_var(x), self.to_var(y_true).long().squeeze() y_hat = classifier(self.encoder(x)) _, pred = torch.max(y_hat, 1) correct = np.squeeze(pred.eq(y_true.data.view_as(pred))) # calculate test accuracy for each object class for i in range(len(y_true.data)): label = y_true.data[i] class_correct[label] += correct[i].item() class_total[label] += 1 for i in range(self.num_classes): if class_total[i] > 0: print('\tTest Accuracy of %10s: %2d%% (%2d/%2d)' % (classes[i], 100 * class_correct[i] / class_total[i], np.sum(class_correct[i]), np.sum(class_total[i]))) else: print('\tTest Accuracy of %10s: N/A (no training examples)' % (classes[i])) self.encoder.train() classifier.train() return 100. * np.sum(class_correct) / np.sum(class_total) def wrap_to_loader(self, target_candidates, batch_size): """ :param target_candidates: [(x,y_pseudo)] :return: """ assert len(target_candidates) > 0 tmp = target_candidates # CondomDataset(target_candidates) return torch.utils.data.DataLoader(dataset=tmp, batch_size=batch_size, shuffle=True, num_workers=0) def confusion_matrix(self, loader, classifier): labels = [] preds = [] for x, y_true in loader: labels += list(y_true.cpu().detach().numpy().flatten()) x, y_true = self.to_var(x), self.to_var(y_true).long().squeeze() y_hat = classifier(self.encoder(x)) _, pred = torch.max(y_hat, 1) preds += list(pred.cpu().detach().numpy().flatten()) cm = confusion_matrix(labels, preds) df_cm = pd.DataFrame(cm, index=coarse_label_mapping, columns=coarse_label_mapping, dtype=np.int) plt.figure(figsize=(10, 7)) sn.heatmap(df_cm, annot=True) plt.show()
def main(): """ Describe main process including train and validation. """ global start_epoch, checkpoint, fine_tune_encoder, best_bleu4, epochs_since_improvement, word_map # Read word map word_map_path = os.path.join(data_folder, 'WORDMAP_' + dataset_name + ".json") with open(word_map_path, 'r') as j: word_map = json.load(j) # Set checkpoint or read from checkpoint if checkpoint is None: # No pretrained model, set model from beginning decoder = Decoder(embed_dim=embed_dim, decoder_dim=decoder_dim, vocab_size=len(word_map), dropout=dropout_rate) decoder_param = filter(lambda p: p.requires_grad, decoder.parameters()) for param in decoder_param: tensor0 = param.data dist.all_reduce(tensor0, op=dist.reduce_op.SUM) param.data = tensor0 / np.sqrt(np.float(num_nodes)) decoder_optimizer = optim.Adam(params=decoder_param, lr=decoder_lr) encoder = Encoder() encoder.fine_tune(fine_tune_encoder) encoder_param = filter(lambda p: p.requires_grad, encoder.parameters()) if fine_tune_encoder: for param in encoder_param: tensor0 = param.data dist.all_reduce(tensor0, op=dist.reduce_op.SUM) param.data = tensor0 / np.sqrt(np.float(num_nodes)) encoder_optimizer = optim.Adam( params=encoder_param, lr=encoder_lr) if fine_tune_encoder else None else: checkpoint = torch.load(checkpoint) start_epoch = checkpoint["epoch"] + 1 epochs_since_improvement = checkpoint['epochs_since_improvement'] best_bleu4 = checkpoint['bleu-4'] decoder = checkpoint['decoder'] #decoder_optimizer = checkpoint['decoder_optimizer'] encoder = checkpoint['encoder'] #encoder_optimizer = checkpoint['encoder_optimizer'] if fine_tune_encoder and encoder_optimizer is None: encoder.fine_tune(fine_tune_encoder) encoder_optimizer = torch.optim.Adam(params=filter( lambda p: p.requires_grad, encoder.parameters()), lr=encoder_lr) decoder = decoder.to(device) encoder = encoder.to(device) criterion = nn.CrossEntropyLoss() # Data loader normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_set = CaptionDataset(data_folder=h5data_folder, data_name=dataset_name, split="TRAIN", transform=transforms.Compose([normalize])) val_set = CaptionDataset(data_folder=h5data_folder, data_name=dataset_name, split="VAL", transform=transforms.Compose([normalize])) train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True) val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True) total_start_time = datetime.datetime.now() print("Start the 1st epoch at: ", total_start_time) # Epoch for epoch in range(start_epoch, num_epochs): # Pre-check by epochs_since_improvement if epochs_since_improvement == 20: # If there are 20 epochs that no improvements are achieved break if epochs_since_improvement % 8 == 0 and epochs_since_improvement > 0: adjust_learning_rate(decoder_optimizer) if fine_tune_encoder: adjust_learning_rate(encoder_optimizer) # For every batch batch_time = AverageMeter() # forward prop. + back prop. time data_time = AverageMeter() # data loading time losses = AverageMeter() # loss (per word decoded) top5accs = AverageMeter() # top5 accuracy decoder.train() encoder.train() start = time.time() start_time = datetime.datetime.now( ) # Initialize start time for this epoch # TRAIN for j, (images, captions, caplens) in enumerate(train_loader): if fine_tune_encoder and (epoch - start_epoch > 0 or j > 10): for group in encoder_optimizer.param_groups: for p in group['params']: state = encoder_optimizer.state[p] if (state['step'] >= 1024): state['step'] = 1000 if (epoch - start_epoch > 0 or j > 10): for group in decoder_optimizer.param_groups: for p in group['params']: state = decoder_optimizer.state[p] if (state['step'] >= 1024): state['step'] = 1000 data_time.update(time.time() - start) images = images.to(device) captions = captions.to(device) caplens = caplens.to(device) # Forward enc_images = encoder(images) predictions, enc_captions, dec_lengths, sort_ind = decoder( enc_images, captions, caplens) # Define target as original captions excluding <start> target = enc_captions[:, 1:] # (batch_size, max_caption_length-1) target, _ = pack_padded_sequence( target, dec_lengths, batch_first=True ) # Delete all paddings and concat all other parts predictions, _ = pack_padded_sequence( predictions, dec_lengths, batch_first=True) # (batch_size, sum(dec_lengths)) loss = criterion(predictions, target) # Backward decoder_optimizer.zero_grad() if encoder_optimizer is not None: encoder_optimizer.zero_grad() loss.backward() ## Clip gradients if grad_clip is not None: clip_gradient(decoder_optimizer, grad_clip) if encoder_optimizer is not None: clip_gradient(encoder_optimizer, grad_clip) ## Update decoder_optimizer.step() if encoder_optimizer is not None: encoder_optimizer.step() # Update metrics (AverageMeter) acc_top5 = compute_accuracy(predictions, target, k=5) top5accs.update(acc_top5, sum(dec_lengths)) losses.update(loss.item(), sum(dec_lengths)) batch_time.update(time.time() - start) # Print current status if (j + 1) % print_freq == 0: print( 'Epoch: [{0}][{1}/{2}]\t' 'Current Batch Time: {batch_time.val:.3f} (Average: {batch_time.avg:.3f})\t' 'Current Data Load Time: {data_time.val:.3f} (Average: {data_time.avg:.3f})\t' 'Current Loss: {loss.val:.4f} (Average: {loss.avg:.4f})\t' 'Current Top-5 Accuracy: {top5.val:.3f} (Average: {top5.avg:.3f})' .format(epoch + 1, j + 1, len(train_loader), batch_time=batch_time, data_time=data_time, loss=losses, top5=top5accs)) now_time = datetime.datetime.now() print("Epoch Training Time: ", now_time - start_time) print("Total Time: ", now_time - total_start_time) start = time.time() # VALIDATION decoder.eval() encoder.eval() batch_time = AverageMeter() # forward prop. + back prop. time losses = AverageMeter() # loss (per word decoded) top5accs = AverageMeter() # top5 accuracy references = list( ) # references (true captions) for calculating BLEU-4 score hypotheses = list() # hypotheses (predictions) start_time = datetime.datetime.now() for j, (images, captions, caplens, all_caps) in enumerate(val_loader): start = time.time() images = images.to(device) captions = captions.to(device) caplens = caplens.to(device) # Forward enc_images = encoder(images) predictions, enc_captions, dec_lengths, sort_ind = decoder( enc_images, captions, caplens) # Define target as original captions excluding <start> predictions_copy = predictions.clone() target = enc_captions[:, 1:] # (batch_size, max_caption_length-1) target, _ = pack_padded_sequence( target, dec_lengths, batch_first=True ) # Delete all paddings and concat all other parts predictions, _ = pack_padded_sequence( predictions, dec_lengths, batch_first=True) # (batch_size, sum(dec_lengths)) loss = criterion(predictions, target) # Update metrics (AverageMeter) acc_top5 = compute_accuracy(predictions, target, k=5) top5accs.update(acc_top5, sum(dec_lengths)) losses.update(loss.item(), sum(dec_lengths)) batch_time.update(time.time() - start) # Print current status if (j + 1) % print_freq == 0: print( 'Epoch: [{0}][{1}/{2}]\t' 'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data Load Time {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})'.format( epoch + 1, j, len(val_loader), batch_time=batch_time, data_time=data_time, loss=losses, top5=top5accs)) now_time = datetime.datetime.now() print("Epoch Validation Time: ", now_time - start_time) print("Total Time: ", now_time - total_start_time) ## Store references (true captions), and hypothesis (prediction) for each image ## If for n images, we have n hypotheses, and references a, b, c... for each image, we need - ## references = [[ref1a, ref1b, ref1c], [ref2a, ref2b], ...], hypotheses = [hyp1, hyp2, ...] # references all_caps = all_caps[sort_ind] for k in range(all_caps.shape[0]): img_caps = all_caps[k].tolist() img_captions = list( map( lambda c: [ w for w in c if w not in {word_map["<start>"], word_map["<pad>"]} ], img_caps)) references.append(img_captions) # hypotheses _, preds = torch.max(predictions_copy, dim=2) preds = preds.tolist() temp_preds = list() for i, p in enumerate(preds): temp_preds.append(preds[i][:dec_lengths[i]]) # remove pads preds = temp_preds hypotheses.extend(preds) assert len(references) == len(hypotheses) ## Compute BLEU-4 Scores #recent_bleu4 = corpus_bleu(references, hypotheses, emulate_multibleu=True) recent_bleu4 = corpus_bleu(references, hypotheses) print( '\n * LOSS - {loss.avg:.3f}, TOP-5 ACCURACY - {top5.avg:.3f}, BLEU-4 - {bleu}\n' .format(loss=losses, top5=top5accs, bleu=recent_bleu4)) # CHECK IMPROVEMENT is_best = recent_bleu4 > best_bleu4 best_bleu4 = max(recent_bleu4, best_bleu4) if not is_best: epochs_since_improvement += 1 print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement)) else: epochs_since_improvement = 0 # SAVE CHECKPOINT save_checkpoint(dataset_name, epoch, epochs_since_improvement, encoder, decoder, encoder_optimizer, decoder_optimizer, recent_bleu4, is_best) print("Epoch {}, cost time: {}\n".format(epoch + 1, now_time - total_start_time))
class DQNAgent(): """Interacts with and learns from the environment.""" def __init__(self, state_size, action_size, config): """Initialize an Agent object. Params ====== state_size (int): dimension of each state action_size (int): dimension of each action seed (int): random seed """ self.state_size = state_size self.action_size = action_size self.seed = random.seed(config["seed"]) self.seed = config["seed"] self.gamma = 0.99 self.batch_size = config["batch_size"] self.lr = config["lr"] self.tau = config["tau"] self.fc1 = config["fc1_units"] self.fc2 = config["fc2_units"] self.device = config["device"] # Q-Network self.qnetwork_local = QNetwork(state_size, action_size, self.fc1, self.fc2, self.seed).to(self.device) self.qnetwork_target = QNetwork(state_size, action_size, self.fc1, self.fc2, self.seed).to(self.device) self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=self.lr) self.encoder = Encoder(config).to(self.device) self.encoder_optimizer = torch.optim.Adam(self.encoder.parameters(), self.lr) # Replay memory # Initialize time step (for updating every UPDATE_EVERY steps) self.t_step = 0 def step(self, memory, writer): self.t_step += 1 if len(memory) > self.batch_size: if self.t_step % 4 == 0: experiences = memory.sample(self.batch_size) self.learn(experiences, writer) def act(self, state, eps=0.): """Returns actions for given state as per current policy. Params ====== state (array_like): current state eps (float): epsilon, for epsilon-greedy action selection """ state = torch.from_numpy(state).float().unsqueeze(0).to(self.device) state = state.type(torch.float32).div_(255) self.qnetwork_local.eval() self.encoder.eval() with torch.no_grad(): state = self.encoder.create_vector(state) action_values = self.qnetwork_local(state) self.qnetwork_local.train() self.encoder.train() # Epsilon-greedy action selection if random.random() > eps: return np.argmax(action_values.cpu().data.numpy()) else: return random.choice(np.arange(self.action_size)) def learn(self, experiences, writer): """Update value parameters using given batch of experience tuples. Params ====== experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples gamma (float): discount factor """ states, actions, rewards, next_states, dones = experiences states = states.type(torch.float32).div_(255) states = self.encoder.create_vector(states) next_states = next_states.type(torch.float32).div_(255) next_states = self.encoder.create_vector(next_states) actions = actions.type(torch.int64) # Get max predicted Q values (for next states) from target model Q_targets_next = self.qnetwork_target(next_states).detach().max( 1)[0].unsqueeze(1) # Compute Q targets for current states Q_targets = rewards + (self.gamma * Q_targets_next * dones) # Get expected Q values from local model Q_expected = self.qnetwork_local(states).gather(1, actions) # Compute loss loss = F.mse_loss(Q_expected, Q_targets) writer.add_scalar('Q_loss', loss, self.t_step) # Minimize the loss self.optimizer.zero_grad() self.encoder_optimizer.zero_grad() loss.backward() self.optimizer.step() self.encoder_optimizer.step() # ------------------- update target network ------------------- # self.soft_update(self.qnetwork_local, self.qnetwork_target) def soft_update(self, local_model, target_model): """Soft update model parameters. θ_target = τ*θ_local + (1 - τ)*θ_target Params ====== local_model (PyTorch model): weights will be copied from target_model (PyTorch model): weights will be copied to tau (float): interpolation parameter """ for target_param, local_param in zip(target_model.parameters(), local_model.parameters()): target_param.data.copy_(self.tau * local_param.data + (1.0 - self.tau) * target_param.data) def save(self, filename): """ """ mkdir("", filename) torch.save(self.qnetwork_local.state_dict(), filename + "_q_net.pth") torch.save(self.optimizer.state_dict(), filename + "_q_net_optimizer.pth") torch.save(self.encoder.state_dict(), filename + "_encoder.pth") torch.save(self.encoder_optimizer.state_dict(), filename + "_encoder_optimizer.pth") print("Save models to {}".format(filename))
# load the new state dict src_encoder.load_state_dict(src_encoder_dict) optimizer = optim.SGD( list(src_encoder.parameters()) + list(src_classifier.parameters()), lr=lr, momentum=momentum, weight_decay=weight_decay) criterion = nn.CrossEntropyLoss() if cuda: src_encoder = src_encoder.cuda() src_classifier = src_classifier.cuda() criterion = criterion.cuda() src_encoder.train() src_classifier.train() # begin train for epoch in range(1, epochs+1): correct = 0 for batch_idx, (src_data, label) in enumerate(src_train_loader): if cuda: src_data, label = src_data.cuda(), label.cuda() src_data, label = Variable(src_data), Variable(label) optimizer.zero_grad() src_feature = src_encoder(src_data) output = src_classifier(src_feature) loss = criterion(output, label) output = F.softmax(output, dim=1) pred = output.data.max(1, keepdim=True)[1] correct += pred.eq(label.data.view_as(pred)).cpu().sum()