Exemplo n.º 1
0
class Trainer():
    def __init__(self, params, experience_replay_buffer,metrics,results_dir,env):
        self.parms = params     
        self.D = experience_replay_buffer  
        self.metrics = metrics
        self.env = env
        self.tested_episodes = 0

        self.statistics_path = results_dir+'/statistics' 
        self.model_path = results_dir+'/model' 
        self.video_path = results_dir+'/video' 
        self.rew_vs_pred_rew_path = results_dir+'/rew_vs_pred_rew'
        self.dump_plan_path = results_dir+'/dump_plan'
        
        #if folder do not exists, create it
        os.makedirs(self.statistics_path, exist_ok=True) 
        os.makedirs(self.model_path, exist_ok=True) 
        os.makedirs(self.video_path, exist_ok=True) 
        os.makedirs(self.rew_vs_pred_rew_path, exist_ok=True) 
        os.makedirs(self.dump_plan_path, exist_ok=True) 
        

        # Create models
        self.transition_model = TransitionModel(self.parms.belief_size, self.parms.state_size, self.env.action_size, self.parms.hidden_size, self.parms.embedding_size, self.parms.activation_function).to(device=self.parms.device)
        self.observation_model = ObservationModel(self.parms.belief_size, self.parms.state_size, self.parms.embedding_size, self.parms.activation_function).to(device=self.parms.device)
        self.reward_model = RewardModel(self.parms.belief_size, self.parms.state_size, self.parms.hidden_size, self.parms.activation_function).to(device=self.parms.device)
        self.encoder = Encoder(self.parms.embedding_size,self.parms.activation_function).to(device=self.parms.device)
        self.param_list = list(self.transition_model.parameters()) + list(self.observation_model.parameters()) + list(self.reward_model.parameters()) + list(self.encoder.parameters()) 
        self.optimiser = optim.Adam(self.param_list, lr=0 if self.parms.learning_rate_schedule != 0 else self.parms.learning_rate, eps=self.parms.adam_epsilon)
        self.planner = MPCPlanner(self.env.action_size, self.parms.planning_horizon, self.parms.optimisation_iters, self.parms.candidates, self.parms.top_candidates, self.transition_model, self.reward_model,self.env.action_range[0], self.env.action_range[1])

        global_prior = Normal(torch.zeros(self.parms.batch_size, self.parms.state_size, device=self.parms.device), torch.ones(self.parms.batch_size, self.parms.state_size, device=self.parms.device))  # Global prior N(0, I)
        self.free_nats = torch.full((1, ), self.parms.free_nats, dtype=torch.float32, device=self.parms.device)  # Allowed deviation in KL divergence

    def load_checkpoints(self):
        self.metrics = torch.load(self.model_path+'/metrics.pth')
        model_path = self.model_path+'/best_model'
        os.makedirs(model_path, exist_ok=True) 
        files = os.listdir(model_path)
        if files:
            checkpoint = [f for f in files if os.path.isfile(os.path.join(model_path, f))]
            model_dicts = torch.load(os.path.join(model_path, checkpoint[0]),map_location=self.parms.device)
            self.transition_model.load_state_dict(model_dicts['transition_model'])
            self.observation_model.load_state_dict(model_dicts['observation_model'])
            self.reward_model.load_state_dict(model_dicts['reward_model'])
            self.encoder.load_state_dict(model_dicts['encoder'])
            self.optimiser.load_state_dict(model_dicts['optimiser'])  
            print("Loading models checkpoints!")
        else:
            print("Checkpoints not found!")


    def update_belief_and_act(self, env, belief, posterior_state, action, observation, reward, min_action=-inf, max_action=inf,explore=False):
        # Infer belief over current state q(s_t|o≤t,a<t) from the history
        encoded_obs = self.encoder(observation).unsqueeze(dim=0).to(device=self.parms.device)       
        belief, _, _, _, posterior_state, _, _ = self.transition_model(posterior_state, action.unsqueeze(dim=0), belief, encoded_obs)  # Action and observation need extra time dimension
        belief, posterior_state = belief.squeeze(dim=0), posterior_state.squeeze(dim=0)  # Remove time dimension from belief/state
        action,pred_next_rew,_,_,_ = self.planner(belief, posterior_state,explore)  # Get action from planner(q(s_t|o≤t,a<t), p)      
        
        if explore:
            action = action + self.parms.action_noise * torch.randn_like(action)  # Add exploration noise ε ~ p(ε) to the action
        action.clamp_(min=min_action, max=max_action)  # Clip action range
        next_observation, reward, done = env.step(action.cpu() if isinstance(env, EnvBatcher) else action[0].cpu())  # If single env is istanceted perform single action (get item from list), else perform all actions
        
        return belief, posterior_state, action, next_observation, reward, done,pred_next_rew 
    
    def fit_buffer(self,episode):
        ####
        # Fit data taken from buffer 
        ######

        # Model fitting
        losses = []
        tqdm.write("Fitting buffer")
        for s in tqdm(range(self.parms.collect_interval)):

            # Draw sequence chunks {(o_t, a_t, r_t+1, terminal_t+1)} ~ D uniformly at random from the dataset (including terminal flags)
            observations, actions, rewards, nonterminals = self.D.sample(self.parms.batch_size, self.parms.chunk_size)  # Transitions start at time t = 0
            # Create initial belief and state for time t = 0
            init_belief, init_state = torch.zeros(self.parms.batch_size, self.parms.belief_size, device=self.parms.device), torch.zeros(self.parms.batch_size, self.parms.state_size, device=self.parms.device)
            encoded_obs = bottle(self.encoder, (observations[1:], ))

            # Update belief/state using posterior from previous belief/state, previous action and current observation (over entire sequence at once)
            beliefs, prior_states, prior_means, prior_std_devs, posterior_states, posterior_means, posterior_std_devs = self.transition_model(init_state, actions[:-1], init_belief, encoded_obs, nonterminals[:-1])
            
            # Calculate observation likelihood, reward likelihood and KL losses (for t = 0 only for latent overshooting); sum over final dims, average over batch and time (original implementation, though paper seems to miss 1/T scaling?)
            # LOSS
            observation_loss = F.mse_loss(bottle(self.observation_model, (beliefs, posterior_states)), observations[1:], reduction='none').sum((2, 3, 4)).mean(dim=(0, 1))
            kl_loss = torch.max(kl_divergence(Normal(posterior_means, posterior_std_devs), Normal(prior_means, prior_std_devs)).sum(dim=2), self.free_nats).mean(dim=(0, 1))  
            reward_loss = F.mse_loss(bottle(self.reward_model, (beliefs, posterior_states)), rewards[:-1], reduction='none').mean(dim=(0, 1))            

            # Update model parameters
            self.optimiser.zero_grad()

            (observation_loss + reward_loss + kl_loss).backward() # BACKPROPAGATION
            nn.utils.clip_grad_norm_(self.param_list, self.parms.grad_clip_norm, norm_type=2)
            self.optimiser.step()
            # Store (0) observation loss (1) reward loss (2) KL loss
            losses.append([observation_loss.item(), reward_loss.item(), kl_loss.item()])#, regularizer_loss.item()])

        #save statistics and plot them
        losses = tuple(zip(*losses))  
        self.metrics['observation_loss'].append(losses[0])
        self.metrics['reward_loss'].append(losses[1])
        self.metrics['kl_loss'].append(losses[2])
      
        lineplot(self.metrics['episodes'][-len(self.metrics['observation_loss']):], self.metrics['observation_loss'], 'observation_loss', self.statistics_path)
        lineplot(self.metrics['episodes'][-len(self.metrics['reward_loss']):], self.metrics['reward_loss'], 'reward_loss', self.statistics_path)
        lineplot(self.metrics['episodes'][-len(self.metrics['kl_loss']):], self.metrics['kl_loss'], 'kl_loss', self.statistics_path)
        
    def explore_and_collect(self,episode):
        tqdm.write("Collect new data:")
        reward = 0
        # Data collection
        with torch.no_grad():
            done = False
            observation, total_reward = self.env.reset(), 0
            belief, posterior_state, action = torch.zeros(1, self.parms.belief_size, device=self.parms.device), torch.zeros(1, self.parms.state_size, device=self.parms.device), torch.zeros(1, self.env.action_size, device=self.parms.device)
            t = 0
            real_rew = []
            predicted_rew = [] 
            total_steps = self.parms.max_episode_length // self.env.action_repeat
            explore = True

            for t in tqdm(range(total_steps)):
                # Here we need to explore
                belief, posterior_state, action, next_observation, reward, done, pred_next_rew = self.update_belief_and_act(self.env, belief, posterior_state, action, observation.to(device=self.parms.device), [reward], self.env.action_range[0], self.env.action_range[1], explore=explore)
                self.D.append(observation, action.cpu(), reward, done)
                real_rew.append(reward)
                predicted_rew.append(pred_next_rew.to(device=self.parms.device).item())
                total_reward += reward
                observation = next_observation
                if self.parms.flag_render:
                    env.render()
                if done:
                    break

        # Update and plot train reward metrics
        self.metrics['steps'].append( (t * self.env.action_repeat) + self.metrics['steps'][-1])
        self.metrics['episodes'].append(episode)
        self.metrics['train_rewards'].append(total_reward)
        self.metrics['predicted_rewards'].append(np.array(predicted_rew).sum())

        lineplot(self.metrics['episodes'][-len(self.metrics['train_rewards']):], self.metrics['train_rewards'], 'train_rewards', self.statistics_path)
        double_lineplot(self.metrics['episodes'], self.metrics['train_rewards'], self.metrics['predicted_rewards'], "train_r_vs_pr", self.statistics_path)

    def train_models(self):
        # from (init_episodes) to (training_episodes + init_episodes)
        tqdm.write("Start training.")

        for episode in tqdm(range(self.parms.num_init_episodes +1, self.parms.training_episodes) ):
            self.fit_buffer(episode)       
            self.explore_and_collect(episode)
            if episode % self.parms.test_interval == 0:
                self.test_model(episode)
                torch.save(self.metrics, os.path.join(self.model_path, 'metrics.pth'))
                torch.save({'transition_model': self.transition_model.state_dict(), 'observation_model': self.observation_model.state_dict(), 'reward_model': self.reward_model.state_dict(), 'encoder': self.encoder.state_dict(), 'optimiser': self.optimiser.state_dict()},  os.path.join(self.model_path, 'models_%d.pth' % episode))
            
            if episode % self.parms.storing_dataset_interval == 0:
                self.D.store_dataset(self.parms.dataset_path+'dump_dataset')

        return self.metrics

    def test_model(self, episode=None): #no explore here
        if episode is None:
            episode = self.tested_episodes


        # Set models to eval mode
        self.transition_model.eval()
        self.observation_model.eval()
        self.reward_model.eval()
        self.encoder.eval()
        
        # Initialise parallelised test environments
        test_envs = EnvBatcher(ControlSuiteEnv, (self.parms.env_name, self.parms.seed, self.parms.max_episode_length, self.parms.bit_depth), {}, self.parms.test_episodes)
        total_steps = self.parms.max_episode_length // test_envs.action_repeat
        rewards = np.zeros(self.parms.test_episodes)
        
        real_rew = torch.zeros([total_steps,self.parms.test_episodes])
        predicted_rew = torch.zeros([total_steps,self.parms.test_episodes])

        with torch.no_grad():
            observation, total_rewards, video_frames = test_envs.reset(), np.zeros((self.parms.test_episodes, )), []            
            belief, posterior_state, action = torch.zeros(self.parms.test_episodes, self.parms.belief_size, device=self.parms.device), torch.zeros(self.parms.test_episodes, self.parms.state_size, device=self.parms.device), torch.zeros(self.parms.test_episodes, self.env.action_size, device=self.parms.device)
            tqdm.write("Testing model.")
            for t in range(total_steps):     
                belief, posterior_state, action, next_observation, rewards, done, pred_next_rew  = self.update_belief_and_act(test_envs,  belief, posterior_state, action, observation.to(device=self.parms.device), list(rewards), self.env.action_range[0], self.env.action_range[1])
                total_rewards += rewards.numpy()
                real_rew[t] = rewards
                predicted_rew[t]  = pred_next_rew

                observation = self.env.get_original_frame().unsqueeze(dim=0)

                video_frames.append(make_grid(torch.cat([observation, self.observation_model(belief, posterior_state).cpu()], dim=3) + 0.5, nrow=5).numpy())  # Decentre
                observation = next_observation
                if done.sum().item() == self.parms.test_episodes:
                    break
            
        real_rew = torch.transpose(real_rew, 0, 1)
        predicted_rew = torch.transpose(predicted_rew, 0, 1)
        
        #save and plot metrics 
        self.tested_episodes += 1
        self.metrics['test_episodes'].append(episode)
        self.metrics['test_rewards'].append(total_rewards.tolist())

        lineplot(self.metrics['test_episodes'], self.metrics['test_rewards'], 'test_rewards', self.statistics_path)
        
        write_video(video_frames, 'test_episode_%s' % str(episode), self.video_path)  # Lossy compression
        # Set models to train mode
        self.transition_model.train()
        self.observation_model.train()
        self.reward_model.train()
        self.encoder.train()
        # Close test environments
        test_envs.close()
        return self.metrics


    def dump_plan_video(self, step_before_plan=120): 
        #number of steps before to start to collect frames to dump
        step_before_plan = min(step_before_plan, (self.parms.max_episode_length // self.env.action_repeat))
        
        # Set models to eval mode
        self.transition_model.eval()
        self.observation_model.eval()
        self.reward_model.eval()
        self.encoder.eval()
        video_frames = []
        reward = 0

        with torch.no_grad():
            observation = self.env.reset()
            belief, posterior_state, action = torch.zeros(1, self.parms.belief_size, device=self.parms.device), torch.zeros(1, self.parms.state_size, device=self.parms.device), torch.zeros(1, self.env.action_size, device=self.parms.device)
            tqdm.write("Executing episode.")
            for t in range(step_before_plan): #floor division
                belief, posterior_state, action, next_observation, reward, done, _ = self.update_belief_and_act(self.env,  belief, posterior_state, action, observation.to(device=self.parms.device), [reward], self.env.action_range[0], self.env.action_range[1])
                observation = next_observation
                video_frames.append(make_grid(torch.cat([observation.cpu(), self.observation_model(belief, posterior_state).to(device=self.parms.device).cpu()], dim=3) + 0.5, nrow=5).numpy())  # Decentre
                if done:
                    break
            self.create_and_dump_plan(self.env,  belief, posterior_state, action, observation.to(device=self.parms.device), [reward], self.env.action_range[0], self.env.action_range[1])
            
            
        # Set models to train mode
        self.transition_model.train()
        self.observation_model.train()
        self.reward_model.train()
        self.encoder.train()
        # Close test environments
        self.env.close()

    def create_and_dump_plan(self, env, belief, posterior_state, action, observation, reward, min_action=-inf, max_action=inf): 

        tqdm.write("Dumping plan")
        video_frames = []

        encoded_obs = self.encoder(observation).unsqueeze(dim=0)
        belief, _, _, _, posterior_state, _, _ = self.transition_model(posterior_state, action.unsqueeze(dim=0), belief, encoded_obs)  
        belief, posterior_state = belief.squeeze(dim=0), posterior_state.squeeze(dim=0)  # Remove time dimension from belief/state
        next_action,_, beliefs, states, plan = self.planner(belief, posterior_state,False)  # Get action from planner(q(s_t|o≤t,a<t), p)      
        predicted_frames = self.observation_model(beliefs, states).to(device=self.parms.device)

        for i in range(self.parms.planning_horizon):
            plan[i].clamp_(min=env.action_range[0], max=self.env.action_range[1])  # Clip action range
            next_observation, reward, done = env.step(plan[i].cpu())  
            next_observation = next_observation.squeeze(dim=0)
            video_frames.append(make_grid(torch.cat([next_observation, predicted_frames[i]], dim=1) + 0.5, nrow=2).numpy())  # Decentre

        write_video(video_frames, 'dump_plan', self.dump_plan_path, dump_frame=True)  
    
            
Exemplo n.º 2
0
                 results_dir,
                 xaxis='step')
        if not args.symbolic_env:
            episode_str = str(episode).zfill(len(str(args.episodes)))
            write_video(video_frames, 'test_episode_%s' % episode_str,
                        results_dir)  # Lossy compression
            save_image(
                torch.as_tensor(video_frames[-1]),
                os.path.join(results_dir, 'test_episode_%s.png' % episode_str))
        torch.save(metrics, os.path.join(results_dir, 'metrics.pth'))

        # Set models to train mode
        transition_model.train()
        observation_model.train()
        reward_model.train()
        encoder.train()
        # Close test environments
        test_envs.close()

    # Checkpoint models
    print("Completed episode {}".format(episode))
    if episode % args.checkpoint_interval == 0:
        print("Saving!")
        torch.save(
            {
                'transition_model': transition_model.state_dict(),
                'observation_model': observation_model.state_dict(),
                'reward_model': reward_model.state_dict(),
                'encoder': encoder.state_dict(),
                'optimiser': optimiser.state_dict()
            }, os.path.join(results_dir, 'models_%d.pth' % episode))
Exemplo n.º 3
0
class Plan(object):
    def __init__(self):

        self.results_dir = os.path.join(
            'results',
            '{}_seed_{}_{}_action_scale_{}_no_explore_{}_pool_len_{}_optimisation_iters_{}_top_planning-horizon'
            .format(args.env, args.seed, args.algo, args.action_scale,
                    args.pool_len, args.optimisation_iters,
                    args.top_planning_horizon))

        args.results_dir = self.results_dir
        args.MultiGPU = True if torch.cuda.device_count(
        ) > 1 and args.MultiGPU else False

        self.__basic_setting()
        self.__init_sample()  # Sampleing The Init Data

        # Initialise model parameters randomly
        self.transition_model = TransitionModel(
            args.belief_size, args.state_size, self.env.action_size,
            args.hidden_size, args.embedding_size,
            args.dense_activation_function).to(device=args.device)
        self.observation_model = ObservationModel(
            args.symbolic_env, self.env.observation_size, args.belief_size,
            args.state_size, args.embedding_size,
            args.cnn_activation_function).to(device=args.device)
        self.reward_model = RewardModel(
            args.belief_size, args.state_size, args.hidden_size,
            args.dense_activation_function).to(device=args.device)
        self.encoder = Encoder(
            args.symbolic_env, self.env.observation_size, args.embedding_size,
            args.cnn_activation_function).to(device=args.device)

        print("We Have {} GPUS".format(torch.cuda.device_count())
              ) if args.MultiGPU else print("We use CPU")
        self.transition_model = nn.DataParallel(
            self.transition_model.to(device=args.device)
        ) if args.MultiGPU else self.transition_model
        self.observation_model = nn.DataParallel(
            self.observation_model.to(device=args.device)
        ) if args.MultiGPU else self.observation_model
        self.reward_model = nn.DataParallel(
            self.reward_model.to(
                device=args.device)) if args.MultiGPU else self.reward_model

        # encoder = nn.DataParallel(encoder.cuda())
        # actor_model = nn.DataParallel(actor_model.cuda())
        # value_model = nn.DataParallel(value_model.cuda())

        # share the global parameters in multiprocessing
        self.encoder.share_memory()
        self.observation_model.share_memory()
        self.reward_model.share_memory()

        # Set all_model/global_actor_optimizer/global_value_optimizer
        self.param_list = list(self.transition_model.parameters()) + list(
            self.observation_model.parameters()) + list(
                self.reward_model.parameters()) + list(
                    self.encoder.parameters())
        self.model_optimizer = optim.Adam(
            self.param_list,
            lr=0
            if args.learning_rate_schedule != 0 else args.model_learning_rate,
            eps=args.adam_epsilon)

    def update_belief_and_act(self,
                              args,
                              env,
                              belief,
                              posterior_state,
                              action,
                              observation,
                              explore=False):
        # Infer belief over current state q(s_t|o≤t,a<t) from the history
        # print("action size: ",action.size()) torch.Size([1, 6])
        belief, _, _, _, posterior_state, _, _ = self.upper_transition_model(
            posterior_state, action.unsqueeze(dim=0), belief,
            self.encoder(observation).unsqueeze(dim=0), None)
        if hasattr(env, "envs"):
            belief, posterior_state = list(
                map(lambda x: x.view(-1, args.test_episodes, x.shape[2]),
                    [x for x in [belief, posterior_state]]))

        belief, posterior_state = belief.squeeze(
            dim=0), posterior_state.squeeze(
                dim=0)  # Remove time dimension from belief/state
        action = self.algorithms.get_action(belief, posterior_state, explore)

        if explore:
            action = torch.clamp(
                Normal(action, args.action_noise).rsample(), -1, 1
            )  # Add gaussian exploration noise on top of the sampled action
            # action = action + args.action_noise * torch.randn_like(action)  # Add exploration noise ε ~ p(ε) to the action
        next_observation, reward, done = env.step(
            action.cpu() if isinstance(env, EnvBatcher) else action[0].cpu(
            ))  # Perform environment step (action repeats handled internally)
        return belief, posterior_state, action, next_observation, reward, done

    def run(self):
        if args.algo == "dreamer":
            print("DREAMER")
            from algorithms.dreamer import Algorithms
            self.algorithms = Algorithms(self.env.action_size,
                                         self.transition_model, self.encoder,
                                         self.reward_model,
                                         self.observation_model)
        elif args.algo == "p2p":
            print("planing to plan")
            from algorithms.plan_to_plan import Algorithms
            self.algorithms = Algorithms(self.env.action_size,
                                         self.transition_model, self.encoder,
                                         self.reward_model,
                                         self.observation_model)
        elif args.algo == "actor_pool_1":
            print("async sub actor")
            from algorithms.actor_pool_1 import Algorithms_actor
            self.algorithms = Algorithms_actor(self.env.action_size,
                                               self.transition_model,
                                               self.encoder, self.reward_model,
                                               self.observation_model)
        elif args.algo == "aap":
            from algorithms.asynchronous_actor_planet import Algorithms
            self.algorithms = Algorithms(self.env.action_size,
                                         self.transition_model, self.encoder,
                                         self.reward_model,
                                         self.observation_model)
        else:
            print("planet")
            from algorithms.planet import Algorithms
            # args.MultiGPU = False
            self.algorithms = Algorithms(self.env.action_size,
                                         self.transition_model,
                                         self.reward_model)

        if args.test: self.test_only()

        self.global_prior = Normal(
            torch.zeros(args.batch_size, args.state_size, device=args.device),
            torch.ones(args.batch_size, args.state_size,
                       device=args.device))  # Global prior N(0, I)
        self.free_nats = torch.full(
            (1, ), args.free_nats,
            device=args.device)  # Allowed deviation in KL divergence

        # Training (and testing)
        # args.episodes = 1
        for episode in tqdm(range(self.metrics['episodes'][-1] + 1,
                                  args.episodes + 1),
                            total=args.episodes,
                            initial=self.metrics['episodes'][-1] + 1):
            losses = self.train()
            # self.algorithms.save_loss_data(self.metrics['episodes']) # Update and plot loss metrics
            self.save_loss_data(tuple(
                zip(*losses)))  # Update and plot loss metrics
            self.data_collection(episode=episode)  # Data collection
            # args.test_interval = 1
            if episode % args.test_interval == 0:
                self.test(episode=episode)  # Test model
            self.save_model_data(episode=episode)  # save model

        self.env.close()  # Close training environment

    def train_env_model(self, beliefs, prior_states, prior_means,
                        prior_std_devs, posterior_states, posterior_means,
                        posterior_std_devs, observations, actions, rewards,
                        nonterminals):
        # Calculate observation likelihood, reward likelihood and KL losses (for t = 0 only for latent overshooting); sum over final dims, average over batch and time (original implementation, though paper seems to miss 1/T scaling?)
        if args.worldmodel_LogProbLoss:
            observation_dist = Normal(
                bottle(self.observation_model, (beliefs, posterior_states)), 1)
            observation_loss = -observation_dist.log_prob(
                observations[1:]).sum(
                    dim=2 if args.symbolic_env else (2, 3, 4)).mean(dim=(0, 1))
        else:
            observation_loss = F.mse_loss(
                bottle(self.observation_model, (beliefs, posterior_states)),
                observations[1:],
                reduction='none').sum(
                    dim=2 if args.symbolic_env else (2, 3, 4)).mean(dim=(0, 1))
        if args.worldmodel_LogProbLoss:
            reward_dist = Normal(
                bottle(self.reward_model, (beliefs, posterior_states)), 1)
            reward_loss = -reward_dist.log_prob(rewards[:-1]).mean(dim=(0, 1))
        else:
            reward_loss = F.mse_loss(bottle(self.reward_model,
                                            (beliefs, posterior_states)),
                                     rewards[:-1],
                                     reduction='none').mean(dim=(0, 1))

        # transition loss
        div = kl_divergence(Normal(posterior_means, posterior_std_devs),
                            Normal(prior_means, prior_std_devs)).sum(dim=2)
        kl_loss = torch.max(div, self.free_nats).mean(
            dim=(0, 1)
        )  # Note that normalisation by overshooting distance and weighting by overshooting distance cancel out
        if args.global_kl_beta != 0:
            kl_loss += args.global_kl_beta * kl_divergence(
                Normal(posterior_means, posterior_std_devs),
                self.global_prior).sum(dim=2).mean(dim=(0, 1))
        # Calculate latent overshooting objective for t > 0
        if args.overshooting_kl_beta != 0:
            overshooting_vars = [
            ]  # Collect variables for overshooting to process in batch
            for t in range(1, args.chunk_size - 1):
                d = min(t + args.overshooting_distance,
                        args.chunk_size - 1)  # Overshooting distance
                t_, d_ = t - 1, d - 1  # Use t_ and d_ to deal with different time indexing for latent states
                seq_pad = (
                    0, 0, 0, 0, 0, t - d + args.overshooting_distance
                )  # Calculate sequence padding so overshooting terms can be calculated in one batch
                # Store (0) actions, (1) nonterminals, (2) rewards, (3) beliefs, (4) prior states, (5) posterior means, (6) posterior standard deviations and (7) sequence masks
                overshooting_vars.append(
                    (F.pad(actions[t:d],
                           seq_pad), F.pad(nonterminals[t:d], seq_pad),
                     F.pad(rewards[t:d],
                           seq_pad[2:]), beliefs[t_], prior_states[t_],
                     F.pad(posterior_means[t_ + 1:d_ + 1].detach(), seq_pad),
                     F.pad(posterior_std_devs[t_ + 1:d_ + 1].detach(),
                           seq_pad,
                           value=1),
                     F.pad(
                         torch.ones(d - t,
                                    args.batch_size,
                                    args.state_size,
                                    device=args.device), seq_pad))
                )  # Posterior standard deviations must be padded with > 0 to prevent infinite KL divergences
            overshooting_vars = tuple(zip(*overshooting_vars))
            # Update belief/state using prior from previous belief/state and previous action (over entire sequence at once)
            beliefs, prior_states, prior_means, prior_std_devs = self.upper_transition_model(
                torch.cat(overshooting_vars[4], dim=0),
                torch.cat(overshooting_vars[0], dim=1),
                torch.cat(overshooting_vars[3], dim=0), None,
                torch.cat(overshooting_vars[1], dim=1))
            seq_mask = torch.cat(overshooting_vars[7], dim=1)
            # Calculate overshooting KL loss with sequence mask
            kl_loss += (
                1 / args.overshooting_distance
            ) * args.overshooting_kl_beta * torch.max((kl_divergence(
                Normal(torch.cat(overshooting_vars[5], dim=1),
                       torch.cat(overshooting_vars[6], dim=1)),
                Normal(prior_means, prior_std_devs)
            ) * seq_mask).sum(dim=2), self.free_nats).mean(dim=(0, 1)) * (
                args.chunk_size
                - 1
            )  # Update KL loss (compensating for extra average over each overshooting/open loop sequence)
            # Calculate overshooting reward prediction loss with sequence mask
            if args.overshooting_reward_scale != 0:
                reward_loss += (
                    1 / args.overshooting_distance
                ) * args.overshooting_reward_scale * F.mse_loss(
                    bottle(self.reward_model,
                           (beliefs, prior_states)) * seq_mask[:, :, 0],
                    torch.cat(overshooting_vars[2], dim=1),
                    reduction='none'
                ).mean(dim=(0, 1)) * (
                    args.chunk_size - 1
                )  # Update reward loss (compensating for extra average over each overshooting/open loop sequence)
        # Apply linearly ramping learning rate schedule
        if args.learning_rate_schedule != 0:
            for group in self.model_optimizer.param_groups:
                group['lr'] = min(
                    group['lr'] + args.model_learning_rate /
                    args.model_learning_rate_schedule,
                    args.model_learning_rate)
        model_loss = observation_loss + reward_loss + kl_loss
        # Update model parameters
        self.model_optimizer.zero_grad()
        model_loss.backward()
        nn.utils.clip_grad_norm_(self.param_list,
                                 args.grad_clip_norm,
                                 norm_type=2)
        self.model_optimizer.step()
        return observation_loss, reward_loss, kl_loss

    def train(self):
        # Model fitting
        losses = []
        print("training loop")
        # args.collect_interval = 1
        for s in tqdm(range(args.collect_interval)):

            # Draw sequence chunks {(o_t, a_t, r_t+1, terminal_t+1)} ~ D uniformly at random from the dataset (including terminal flags)
            observations, actions, rewards, nonterminals = self.D.sample(
                args.batch_size,
                args.chunk_size)  # Transitions start at time t = 0
            # Create initial belief and state for time t = 0
            init_belief, init_state = torch.zeros(
                args.batch_size, args.belief_size,
                device=args.device), torch.zeros(args.batch_size,
                                                 args.state_size,
                                                 device=args.device)
            # Update belief/state using posterior from previous belief/state, previous action and current observation (over entire sequence at once)
            obs = bottle(self.encoder, (observations[1:], ))
            beliefs, prior_states, prior_means, prior_std_devs, posterior_states, posterior_means, posterior_std_devs = self.upper_transition_model(
                prev_state=init_state,
                actions=actions[:-1],
                prev_belief=init_belief,
                obs=obs,
                nonterminals=nonterminals[:-1])

            # Calculate observation likelihood, reward likelihood and KL losses (for t = 0 only for latent overshooting); sum over final dims, average over batch and time (original implementation, though paper seems to miss 1/T scaling?)
            observation_loss, reward_loss, kl_loss = self.train_env_model(
                beliefs, prior_states, prior_means, prior_std_devs,
                posterior_states, posterior_means, posterior_std_devs,
                observations, actions, rewards, nonterminals)

            # Dreamer implementation: actor loss calculation and optimization
            with torch.no_grad():
                actor_states = posterior_states.detach().to(
                    device=args.device).share_memory_()
                actor_beliefs = beliefs.detach().to(
                    device=args.device).share_memory_()

            # if not os.path.exists(os.path.join(os.getcwd(), 'tensor_data/' + args.results_dir)): os.mkdir(os.path.join(os.getcwd(), 'tensor_data/' + args.results_dir))
            torch.save(
                actor_states,
                os.path.join(os.getcwd(),
                             args.results_dir + '/actor_states.pt'))
            torch.save(
                actor_beliefs,
                os.path.join(os.getcwd(),
                             args.results_dir + '/actor_beliefs.pt'))

            # [self.actor_pipes[i][0].send(1) for i, w in enumerate(self.workers_actor)]  # Parent_pipe send data using i'th pipes
            # [self.actor_pipes[i][0].recv() for i, _ in enumerate(self.actor_pool)]  # waitting the children finish

            self.algorithms.train_algorithm(actor_states, actor_beliefs)
            losses.append(
                [observation_loss.item(),
                 reward_loss.item(),
                 kl_loss.item()])

            # if self.algorithms.train_algorithm(actor_states, actor_beliefs) is not None:
            #   merge_actor_loss, merge_value_loss = self.algorithms.train_algorithm(actor_states, actor_beliefs)
            #   losses.append([observation_loss.item(), reward_loss.item(), kl_loss.item(), merge_actor_loss.item(), merge_value_loss.item()])
            # else:
            #   losses.append([observation_loss.item(), reward_loss.item(), kl_loss.item()])

        return losses

    def data_collection(self, episode):
        print("Data collection")
        with torch.no_grad():
            observation, total_reward = self.env.reset(), 0
            belief, posterior_state, action = torch.zeros(
                1, args.belief_size, device=args.device), torch.zeros(
                    1, args.state_size,
                    device=args.device), torch.zeros(1,
                                                     self.env.action_size,
                                                     device=args.device)
            pbar = tqdm(range(args.max_episode_length // args.action_repeat))
            for t in pbar:
                # print("step",t)
                belief, posterior_state, action, next_observation, reward, done = self.update_belief_and_act(
                    args, self.env, belief, posterior_state, action,
                    observation.to(device=args.device))
                self.D.append(observation, action.cpu(), reward, done)
                total_reward += reward
                observation = next_observation
                if args.render: self.env.render()
                if done:
                    pbar.close()
                    break

            # Update and plot train reward metrics
            self.metrics['steps'].append(t + self.metrics['steps'][-1])
            self.metrics['episodes'].append(episode)
            self.metrics['train_rewards'].append(total_reward)

            Save_Txt(self.metrics['episodes'][-1],
                     self.metrics['train_rewards'][-1], 'train_rewards',
                     args.results_dir)
            # lineplot(metrics['episodes'][-len(metrics['train_rewards']):], metrics['train_rewards'], 'train_rewards', results_dir)

    def test(self, episode):
        print("Test model")
        # Set models to eval mode
        self.transition_model.eval()
        self.observation_model.eval()
        self.reward_model.eval()
        self.encoder.eval()
        self.algorithms.train_to_eval()
        # self.actor_model_g.eval()
        # self.value_model_g.eval()
        # Initialise parallelised test environments
        test_envs = EnvBatcher(
            Env, (args.env, args.symbolic_env, args.seed,
                  args.max_episode_length, args.action_repeat, args.bit_depth),
            {}, args.test_episodes)

        with torch.no_grad():
            observation, total_rewards, video_frames = test_envs.reset(
            ), np.zeros((args.test_episodes, )), []
            belief, posterior_state, action = torch.zeros(
                args.test_episodes, args.belief_size,
                device=args.device), torch.zeros(
                    args.test_episodes, args.state_size,
                    device=args.device), torch.zeros(args.test_episodes,
                                                     self.env.action_size,
                                                     device=args.device)
            pbar = tqdm(range(args.max_episode_length // args.action_repeat))
            for t in pbar:
                belief, posterior_state, action, next_observation, reward, done = self.update_belief_and_act(
                    args, test_envs, belief, posterior_state, action,
                    observation.to(device=args.device))
                total_rewards += reward.numpy()
                if not args.symbolic_env:  # Collect real vs. predicted frames for video
                    video_frames.append(
                        make_grid(torch.cat([
                            observation,
                            self.observation_model(belief,
                                                   posterior_state).cpu()
                        ],
                                            dim=3) + 0.5,
                                  nrow=5).numpy())  # Decentre
                observation = next_observation
                if done.sum().item() == args.test_episodes:
                    pbar.close()
                    break

        # Update and plot reward metrics (and write video if applicable) and save metrics
        self.metrics['test_episodes'].append(episode)
        self.metrics['test_rewards'].append(total_rewards.tolist())

        Save_Txt(self.metrics['test_episodes'][-1],
                 self.metrics['test_rewards'][-1], 'test_rewards',
                 args.results_dir)
        # Save_Txt(np.asarray(metrics['steps'])[np.asarray(metrics['test_episodes']) - 1], metrics['test_rewards'],'test_rewards_steps', results_dir, xaxis='step')

        # lineplot(metrics['test_episodes'], metrics['test_rewards'], 'test_rewards', results_dir)
        # lineplot(np.asarray(metrics['steps'])[np.asarray(metrics['test_episodes']) - 1], metrics['test_rewards'], 'test_rewards_steps', results_dir, xaxis='step')
        if not args.symbolic_env:
            episode_str = str(episode).zfill(len(str(args.episodes)))
            write_video(video_frames, 'test_episode_%s' % episode_str,
                        args.results_dir)  # Lossy compression
            save_image(
                torch.as_tensor(video_frames[-1]),
                os.path.join(args.results_dir,
                             'test_episode_%s.png' % episode_str))

        torch.save(self.metrics, os.path.join(args.results_dir, 'metrics.pth'))

        # Set models to train mode
        self.transition_model.train()
        self.observation_model.train()
        self.reward_model.train()
        self.encoder.train()
        # self.actor_model_g.train()
        # self.value_model_g.train()
        self.algorithms.eval_to_train()
        # Close test environments
        test_envs.close()

    def test_only(self):
        # Set models to eval mode
        self.transition_model.eval()
        self.reward_model.eval()
        self.encoder.eval()
        with torch.no_grad():
            total_reward = 0
            for _ in tqdm(range(args.test_episodes)):
                observation = self.env.reset()
                belief, posterior_state, action = torch.zeros(
                    1, args.belief_size, device=args.device), torch.zeros(
                        1, args.state_size,
                        device=args.device), torch.zeros(1,
                                                         self.env.action_size,
                                                         device=args.device)
                pbar = tqdm(
                    range(args.max_episode_length // args.action_repeat))
                for t in pbar:
                    belief, posterior_state, action, observation, reward, done = self.update_belief_and_act(
                        args, self.env, belief, posterior_state, action,
                        observation.to(evice=args.device))
                    total_reward += reward
                    if args.render: self.env.render()
                    if done:
                        pbar.close()
                        break
        print('Average Reward:', total_reward / args.test_episodes)
        self.env.close()
        quit()

    def __basic_setting(self):
        args.overshooting_distance = min(
            args.chunk_size, args.overshooting_distance
        )  # Overshooting distance cannot be greater than chunk size
        print(' ' * 26 + 'Options')
        for k, v in vars(args).items():
            print(' ' * 26 + k + ': ' + str(v))

        print("torch.cuda.device_count() {}".format(torch.cuda.device_count()))
        os.makedirs(args.results_dir, exist_ok=True)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        # Set Cuda
        if torch.cuda.is_available() and not args.disable_cuda:
            print("using CUDA")
            args.device = torch.device('cuda')
            torch.cuda.manual_seed(args.seed)
        else:
            print("using CPU")
            args.device = torch.device('cpu')

        self.summary_name = args.results_dir + "/{}_{}_log"
        self.writer = SummaryWriter(self.summary_name.format(
            args.env, args.id))
        self.env = Env(args.env, args.symbolic_env, args.seed,
                       args.max_episode_length, args.action_repeat,
                       args.bit_depth)
        self.metrics = {
            'steps': [],
            'episodes': [],
            'train_rewards': [],
            'test_episodes': [],
            'test_rewards': [],
            'observation_loss': [],
            'reward_loss': [],
            'kl_loss': [],
            'merge_actor_loss': [],
            'merge_value_loss': []
        }

    def __init_sample(self):
        if args.experience_replay is not '' and os.path.exists(
                args.experience_replay):
            self.D = torch.load(args.experience_replay)
            self.metrics['steps'], self.metrics['episodes'] = [
                self.D.steps
            ] * self.D.episodes, list(range(1, self.D.episodes + 1))
        elif not args.test:
            self.D = ExperienceReplay(args.experience_size, args.symbolic_env,
                                      self.env.observation_size,
                                      self.env.action_size, args.bit_depth,
                                      args.device)

            # Initialise dataset D with S random seed episodes
            print(
                "Start Multi Sample Processing -------------------------------"
            )
            start_time = time.time()
            data_lists = [
                Manager().list() for i in range(1, args.seed_episodes + 1)
            ]  # Set Global Lists
            pipes = [Pipe() for i in range(1, args.seed_episodes + 1)
                     ]  # Set Multi Pipe
            workers_init_sample = [
                Worker_init_Sample(child_conn=child, id=i + 1)
                for i, [parent, child] in enumerate(pipes)
            ]

            for i, w in enumerate(workers_init_sample):
                w.start()  # Start Single Process
                pipes[i][0].send(
                    data_lists[i])  # Parent_pipe send data using i'th pipes
            [w.join() for w in workers_init_sample]  # wait sub_process done

            for i, [parent, child] in enumerate(pipes):
                # datas = parent.recv()
                for data in list(parent.recv()):
                    if isinstance(data, tuple):
                        assert len(data) == 4
                        self.D.append(data[0], data[1], data[2], data[3])
                    elif isinstance(data, int):
                        t = data
                        self.metrics['steps'].append(t * args.action_repeat + (
                            0 if len(self.metrics['steps']) ==
                            0 else self.metrics['steps'][-1]))
                        self.metrics['episodes'].append(i + 1)
                    else:
                        print(
                            "The Recvive Data Have Some Problems, Need To Fix")
            end_time = time.time()
            print("the process times {} s".format(end_time - start_time))
            print(
                "End Multi Sample Processing -------------------------------")

    def upper_transition_model(self, prev_state, actions, prev_belief, obs,
                               nonterminals):
        actions = torch.transpose(actions, 0, 1) if args.MultiGPU else actions
        nonterminals = torch.transpose(nonterminals, 0, 1).to(
            device=args.device
        ) if args.MultiGPU and nonterminals is not None else nonterminals
        obs = torch.transpose(obs, 0, 1).to(
            device=args.device) if args.MultiGPU and obs is not None else obs
        temp_val = self.transition_model(prev_state.to(device=args.device),
                                         actions.to(device=args.device),
                                         prev_belief.to(device=args.device),
                                         obs, nonterminals)

        return list(
            map(
                lambda x: torch.cat(x.chunk(torch.cuda.device_count(), 0), 1)
                if x.shape[1] != prev_state.shape[0] else x,
                [x for x in temp_val]))

    def save_loss_data(self, losses):
        self.metrics['observation_loss'].append(losses[0])
        self.metrics['reward_loss'].append(losses[1])
        self.metrics['kl_loss'].append(losses[2])
        self.metrics['merge_actor_loss'].append(
            losses[3]) if losses.__len__() > 3 else None
        self.metrics['merge_value_loss'].append(
            losses[4]) if losses.__len__() > 3 else None

        Save_Txt(self.metrics['episodes'][-1],
                 self.metrics['observation_loss'][-1], 'observation_loss',
                 args.results_dir)
        Save_Txt(self.metrics['episodes'][-1], self.metrics['reward_loss'][-1],
                 'reward_loss', args.results_dir)
        Save_Txt(self.metrics['episodes'][-1], self.metrics['kl_loss'][-1],
                 'kl_loss', args.results_dir)
        Save_Txt(self.metrics['episodes'][-1],
                 self.metrics['merge_actor_loss'][-1], 'merge_actor_loss',
                 args.results_dir) if losses.__len__() > 3 else None
        Save_Txt(self.metrics['episodes'][-1],
                 self.metrics['merge_value_loss'][-1], 'merge_value_loss',
                 args.results_dir) if losses.__len__() > 3 else None

        # lineplot(metrics['episodes'][-len(metrics['observation_loss']):], metrics['observation_loss'], 'observation_loss', results_dir)
        # lineplot(metrics['episodes'][-len(metrics['reward_loss']):], metrics['reward_loss'], 'reward_loss', results_dir)
        # lineplot(metrics['episodes'][-len(metrics['kl_loss']):], metrics['kl_loss'], 'kl_loss', results_dir)
        # lineplot(metrics['episodes'][-len(metrics['actor_loss']):], metrics['actor_loss'], 'actor_loss', results_dir)
        # lineplot(metrics['episodes'][-len(metrics['value_loss']):], metrics['value_loss'], 'value_loss', results_dir)

    def save_model_data(self, episode):
        # writer.add_scalar("train_reward", metrics['train_rewards'][-1], metrics['steps'][-1])
        # writer.add_scalar("train/episode_reward", metrics['train_rewards'][-1], metrics['steps'][-1]*args.action_repeat)
        # writer.add_scalar("observation_loss", metrics['observation_loss'][0][-1], metrics['steps'][-1])
        # writer.add_scalar("reward_loss", metrics['reward_loss'][0][-1], metrics['steps'][-1])
        # writer.add_scalar("kl_loss", metrics['kl_loss'][0][-1], metrics['steps'][-1])
        # writer.add_scalar("actor_loss", metrics['actor_loss'][0][-1], metrics['steps'][-1])
        # writer.add_scalar("value_loss", metrics['value_loss'][0][-1], metrics['steps'][-1])
        # print("episodes: {}, total_steps: {}, train_reward: {} ".format(metrics['episodes'][-1], metrics['steps'][-1], metrics['train_rewards'][-1]))

        # Checkpoint models
        if episode % args.checkpoint_interval == 0:
            # torch.save({'transition_model': transition_model.state_dict(),
            #             'observation_model': observation_model.state_dict(),
            #             'reward_model': reward_model.state_dict(),
            #             'encoder': encoder.state_dict(),
            #             'actor_model': actor_model_g.state_dict(),
            #             'value_model': value_model_g.state_dict(),
            #             'model_optimizer': model_optimizer.state_dict(),
            #             'actor_optimizer': actor_optimizer_g.state_dict(),
            #             'value_optimizer': value_optimizer_g.state_dict()
            #             }, os.path.join(results_dir, 'models_%d.pth' % episode))
            if args.checkpoint_experience:
                torch.save(
                    self.D, os.path.join(args.results_dir, 'experience.pth')
                )  # Warning: will fail with MemoryError with large memory sizes
Exemplo n.º 4
0
class SACAgent():
    def __init__(self, action_size, state_size, config):
        self.seed = config["seed"]
        torch.manual_seed(self.seed)
        np.random.seed(seed=self.seed)
        self.env = gym.make(config["env_name"])
        self.env = FrameStack(self.env, config)
        self.env.seed(self.seed)
        self.action_size = action_size
        self.state_size = state_size
        self.tau = config["tau"]
        self.gamma = config["gamma"]
        self.batch_size = config["batch_size"]
        self.lr = config["lr"]
        self.history_length = config["history_length"]
        self.size = config["size"]
        if not torch.cuda.is_available():
            config["device"] == "cpu"
        self.device = config["device"]
        self.eval = config["eval"]
        self.vid_path = config["vid_path"]
        print("actions size ", action_size)
        self.critic = QNetwork(state_size, action_size, config["fc1_units"],
                               config["fc2_units"]).to(self.device)
        self.q_optim = torch.optim.Adam(self.critic.parameters(),
                                        config["lr_critic"])
        self.target_critic = QNetwork(state_size, action_size,
                                      config["fc1_units"],
                                      config["fc2_units"]).to(self.device)
        self.target_critic.load_state_dict(self.critic.state_dict())
        self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device)
        self.alpha = self.log_alpha.exp()
        self.alpha_optim = Adam([self.log_alpha], lr=config["lr_alpha"])
        self.policy = SACActor(state_size, action_size).to(self.device)
        self.policy_optim = Adam(self.policy.parameters(),
                                 lr=config["lr_policy"])
        self.encoder = Encoder(config).to(self.device)
        self.encoder_optimizer = torch.optim.Adam(self.encoder.parameters(),
                                                  self.lr)
        self.episodes = config["episodes"]
        self.memory = ReplayBuffer((self.history_length, self.size, self.size),
                                   (1, ), config["buffer_size"],
                                   config["image_pad"], self.seed, self.device)
        pathname = config["seed"]
        tensorboard_name = str(config["res_path"]) + '/runs/' + str(pathname)
        self.writer = SummaryWriter(tensorboard_name)
        self.steps = 0
        self.target_entropy = -torch.prod(
            torch.Tensor(action_size).to(self.device)).item()

    def act(self, state, evaluate=False):
        with torch.no_grad():
            state = torch.FloatTensor(state).to(self.device).unsqueeze(0)
            state = state.type(torch.float32).div_(255)
            self.encoder.eval()
            state = self.encoder.create_vector(state)
            self.encoder.train()
            if evaluate is False:
                action = self.policy.sample(state)
            else:
                action_prob, _ = self.policy(state)
                action = torch.argmax(action_prob)
                action = action.cpu().numpy()
                return action
            # action = np.clip(action, self.min_action, self.max_action)
            action = action.cpu().numpy()[0]
        return action

    def train_agent(self):
        average_reward = 0
        scores_window = deque(maxlen=100)
        t0 = time.time()
        for i_epiosde in range(1, self.episodes):
            episode_reward = 0
            state = self.env.reset()
            t = 0
            while True:
                t += 1
                action = self.act(state)
                next_state, reward, done, _ = self.env.step(action)
                episode_reward += reward
                if i_epiosde > 10:
                    self.learn()
                self.memory.add(state, reward, action, next_state, done)
                state = next_state
                if done:
                    scores_window.append(episode_reward)
                    break
            if i_epiosde % self.eval == 0:
                self.eval_policy()
            ave_reward = np.mean(scores_window)
            print("Epiosde {} Steps {} Reward {} Reward averge{:.2f} Time {}".
                  format(i_epiosde, t, episode_reward, np.mean(scores_window),
                         time_format(time.time() - t0)))
            self.writer.add_scalar('Aver_reward', ave_reward, self.steps)

    def learn(self):
        self.steps += 1
        states, rewards, actions, next_states, dones = self.memory.sample(
            self.batch_size)
        states = states.type(torch.float32).div_(255)
        states = self.encoder.create_vector(states)
        states_detached = states.detach()
        qf1, qf2 = self.critic(states)
        q_value1 = qf1.gather(1, actions)
        q_value2 = qf2.gather(1, actions)

        with torch.no_grad():
            next_states = next_states.type(torch.float32).div_(255)
            next_states = self.encoder.create_vector(next_states)
            q1_target, q2_target = self.target_critic(next_states)
            min_q_target = torch.min(q1_target, q2_target)
            next_action_prob, next_action_log_prob = self.policy(next_states)
            next_q_target = (
                next_action_prob *
                (min_q_target - self.alpha * next_action_log_prob)).sum(
                    dim=1, keepdim=True)
            next_q_value = rewards + (1 - dones) * self.gamma * next_q_target

        # --------------------------update-q--------------------------------------------------------
        loss = F.mse_loss(q_value1, next_q_value) + F.mse_loss(
            q_value2, next_q_value)
        self.q_optim.zero_grad()
        self.encoder_optimizer.zero_grad()
        loss.backward()
        self.q_optim.step()
        self.encoder_optimizer.zero_grad()
        self.writer.add_scalar('loss/q', loss, self.steps)

        # --------------------------update-policy--------------------------------------------------------
        action_prob, log_action_prob = self.policy(states_detached)
        with torch.no_grad():
            q_pi1, q_pi2 = self.critic(states_detached)
            min_q_values = torch.min(q_pi1, q_pi2)
        #policy_loss = (action_prob *  ((self.alpha * log_action_prob) - min_q_values).detach()).sum(dim=1).mean()
        policy_loss = (action_prob *
                       ((self.alpha * log_action_prob) - min_q_values)).sum(
                           dim=1).mean()
        self.policy_optim.zero_grad()
        policy_loss.backward()
        self.policy_optim.step()
        self.writer.add_scalar('loss/policy', policy_loss, self.steps)

        # --------------------------update-alpha--------------------------------------------------------
        alpha_loss = (action_prob.detach() *
                      (-self.log_alpha *
                       (log_action_prob + self.target_entropy).detach())).sum(
                           dim=1).mean()
        self.alpha_optim.zero_grad()
        alpha_loss.backward()
        self.alpha_optim.step()
        self.writer.add_scalar('loss/alpha', alpha_loss, self.steps)
        self.soft_udapte(self.critic, self.target_critic)
        self.alpha = self.log_alpha.exp()

    def soft_udapte(self, online, target):
        for param, target_parm in zip(online.parameters(),
                                      target.parameters()):
            target_parm.data.copy_(self.tau * param.data +
                                   (1 - self.tau) * target_parm.data)

    def eval_policy(self, eval_episodes=4):
        env = gym.make(self.env_name)
        env = wrappers.Monitor(env,
                               str(self.vid_path) + "/{}".format(self.steps),
                               video_callable=lambda episode_id: True,
                               force=True)
        average_reward = 0
        scores_window = deque(maxlen=100)
        for i_epiosde in range(eval_episodes):
            print("Eval Episode {} of {} ".format(i_epiosde, eval_episodes))
            episode_reward = 0
            state = env.reset()
            while True:
                action = self.act(state, evaluate=True)
                state, reward, done, _ = env.step(action)
                episode_reward += reward
                if done:
                    break
            scores_window.append(episode_reward)
        average_reward = np.mean(scores_window)
        self.writer.add_scalar('Eval_reward', average_reward, self.steps)
Exemplo n.º 5
0
class AAETrainer(AbstractTrainer):
    def __init__(self, opt):
        super().__init__(opt)

        print('[info] Dataset:', self.opt.dataset)
        print('[info] Alhpa = ', self.opt.alpha)
        print('[info] Latent dimension = ', self.opt.latent_dim)

        self.opt = opt
        self.start_visdom()

    def start_visdom(self):
        self.vis = utils.Visualizer(env='Adversarial AutoEncoder Training',
                                    port=8888)

    def build_network(self):
        print('[info] Build the network architecture')
        self.encoder = Encoder(z_dim=self.opt.latent_dim)
        if self.opt.dataset == 'SMPL':
            num_verts = 6890
        elif self.opt.dataset == 'all_animals':
            num_verts = 3889
        self.decoder = Decoder(num_verts=num_verts, z_dim=self.opt.latent_dim)
        self.discriminator = Discriminator(input_dim=self.opt.latent_dim)

        self.encoder.cuda()
        self.decoder.cuda()
        self.discriminator.cuda()

    def build_optimizer(self):
        print('[info] Build the optimizer')
        self.optim_dis = optim.SGD(self.discriminator.parameters(),
                                   lr=self.opt.learning_rate)
        self.optim_AE = optim.Adam(itertools.chain(self.encoder.parameters(),
                                                   self.decoder.parameters()),
                                   lr=self.opt.learning_rate)

    def build_dataset_train(self):
        train_data = ACAPData(mode='train', name=self.opt.dataset)
        self.num_train_data = len(train_data)
        print('[info] Number of training samples = ', self.num_train_data)
        self.train_loader = torch.utils.data.DataLoader(
            train_data, batch_size=self.opt.batch_size, shuffle=True)

    def build_dataset_valid(self):
        valid_data = ACAPData(mode='valid', name=self.opt.dataset)
        self.num_valid_data = len(valid_data)
        print('[info] Number of validation samples = ', self.num_valid_data)
        self.valid_loader = torch.utils.data.DataLoader(valid_data,
                                                        batch_size=128,
                                                        shuffle=True)

    def build_losses(self):
        print('[info] Build the loss functions')
        self.mseLoss = torch.nn.MSELoss()
        self.ganLoss = torch.nn.BCELoss()

    def print_iteration_stats(self):
        """
        print stats at each iteration
        """
        print(
            '\r[Epoch %d] [Iteration %d/%d] enc = %f dis = %f rec = %f' %
            (self.epoch, self.iteration,
             int(self.num_train_data / self.opt.batch_size),
             self.enc_loss.item(), self.dis_loss.item(), self.rec_loss.item()),
            end='')

    def train_iteration(self):

        self.encoder.train()
        self.decoder.train()
        self.discriminator.train()

        x = self.data.cuda()

        z = self.encoder(x)
        ''' Discriminator '''
        # sample from N(0, I)
        z_real = Variable(torch.randn(z.size(0), z.size(1))).cuda()

        y_real = Variable(torch.ones(z.size(0))).cuda()
        dis_real_loss = self.ganLoss(
            self.discriminator(z_real).view(-1), y_real)

        y_fake = Variable(torch.zeros(z.size(0))).cuda()
        dis_fake_loss = self.ganLoss(self.discriminator(z).view(-1), y_fake)

        self.optim_dis.zero_grad()
        self.dis_loss = 0.5 * (dis_fake_loss + dis_real_loss)
        self.dis_loss.backward(retain_graph=True)
        self.optim_dis.step()
        self.dis_losses.append(self.dis_loss.item())
        ''' Autoencoder '''
        # Encoder hopes to generate latent vectors that are closed to prior.
        y_real = Variable(torch.ones(z.size(0))).cuda()
        self.enc_loss = self.ganLoss(self.discriminator(z).view(-1), y_real)

        # Decoder hopes to make the reconstruction as similar to input as possible.
        rec = self.decoder(z)
        self.rec_loss = self.mseLoss(rec, x)

        # There is a trade-off here:
        # Latent regularization V.S. Reconstruction quality
        self.EG_loss = self.opt.alpha * self.enc_loss + (
            1 - self.opt.alpha) * self.rec_loss

        self.optim_AE.zero_grad()
        self.EG_loss.backward()
        self.optim_AE.step()

        self.enc_losses.append(self.enc_loss.item())
        self.rec_losses.append(self.rec_loss.item())

        self.print_iteration_stats()
        self.increment_iteration()

    def train_epoch(self):

        self.reset_iteration()
        self.dis_losses = []
        self.enc_losses = []
        self.rec_losses = []
        for step, data in enumerate(self.train_loader):
            self.data = data
            self.train_iteration()

        self.dis_losses = torch.Tensor(self.dis_losses)
        self.dis_losses = torch.mean(self.dis_losses)

        self.enc_losses = torch.Tensor(self.enc_losses)
        self.enc_losses = torch.mean(self.enc_losses)

        self.rec_losses = torch.Tensor(self.rec_losses)
        self.rec_losses = torch.mean(self.rec_losses)

        self.vis.draw_line(win='Encoder Loss', x=self.epoch, y=self.enc_losses)
        self.vis.draw_line(win='Discriminator Loss',
                           x=self.epoch,
                           y=self.dis_losses)
        self.vis.draw_line(win='Reconstruction Loss',
                           x=self.epoch,
                           y=self.rec_losses)

    def valid_iteration(self):

        self.encoder.eval()
        self.decoder.eval()
        self.discriminator.eval()

        x = self.data.cuda()
        z = self.encoder(x)
        recon = self.decoder(z)

        # loss
        rec_loss = self.mseLoss(recon, x)
        self.rec_loss.append(rec_loss.item())
        self.increment_iteration()

    def valid_epoch(self):
        self.reset_iteration()
        self.rec_loss = []
        for step, data in enumerate(self.valid_loader):
            self.data = data
            self.valid_iteration()

        self.rec_loss = torch.Tensor(self.rec_loss)
        self.rec_loss = torch.mean(self.rec_loss)
        self.vis.draw_line(win='Valid reconstruction loss',
                           x=self.epoch,
                           y=self.rec_loss)

    def save_network(self):
        print("\n[info] saving net...")
        torch.save(self.encoder.state_dict(),
                   f"{self.opt.save_path}/Encoder.pth")
        torch.save(self.decoder.state_dict(),
                   f"{self.opt.save_path}/Decoder.pth")
        torch.save(self.discriminator.state_dict(),
                   f"{self.opt.save_path}/Discriminator.pth")
Exemplo n.º 6
0
class Solver:
    def __init__(self):
        self.train_lr = 1e-4
        self.num_classes = 9
        self.clf_target = Classifier().cuda()
        self.clf2 = Classifier().cuda()
        self.clf1 = Classifier().cuda()
        self.encoder = Encoder().cuda()
        self.pretrain_lr = 1e-4
        self.weights_coef = 1e-3

    def to_var(self, x):
        """Converts numpy to variable."""
        if torch.cuda.is_available():
            x = x.cuda()
        return Variable(x, requires_grad=False).float()

    def loss(self, predictions, y_true, weights_coef=None):
        """
        :param predictions: list of prediction tensors
        """
        assert len(predictions[0].shape) == 2 and len(
            y_true.shape) == 1, (predictions.shape, y_true.shape)
        losses = [F.cross_entropy(y_hat, y_true) for y_hat in predictions]
        loss = sum(losses)

        # """
        # We add the term |W1^T W2| to the cost function, where W1, W2 denote fully connected layers’
        # weights of F1 and F2 which are first applied to the feature F(xi)
        # """
        if weights_coef:
            lw = torch.matmul(solver.clf1.fc1.weight,
                              solver.clf2.fc1.weight.T).abs().sum().mean()
            loss += weights_coef * lw

        return loss

    def pretrain(self, source_loader, target_val_loader, pretrain_epochs=1):
        source_iter = iter(source_loader)
        source_per_epoch = len(source_iter)
        print("source_per_epoch:", source_per_epoch)

        # pretrain
        log_pre = 250
        lr = self.pretrain_lr
        pretrain_iters = source_per_epoch * pretrain_epochs
        params = reduce(
            lambda a, b: a + b,
            map(lambda i: list(i.parameters()),
                [self.encoder, self.clf1, self.clf2, self.clf_target]))
        pretrain_optimizer = optim.Adam(params, lr)
        accuracies = []

        for step in range(pretrain_iters + 1):
            # ============ Initialization ============#
            # refresh
            if (step + 1) % source_per_epoch == 0:
                source_iter = iter(source_loader)
            # load the data
            source, s_labels = next(source_iter)
            source, s_labels = self.to_var(source), self.to_var(
                s_labels).long().squeeze()

            # ============ Training ============ #
            pretrain_optimizer.zero_grad()
            # forward
            features = self.encoder(source)
            y1_hat = self.clf1(features)
            y2_hat = self.clf2(features)
            y_target_hat = self.clf_target(features)

            # loss
            loss_source_class = self.loss([y1_hat, y2_hat, y_target_hat],
                                          s_labels,
                                          weights_coef=self.weights_coef)

            # one step
            loss_source_class.backward()
            pretrain_optimizer.step()
            pretrain_optimizer.zero_grad()
            # TODO: make this each step and on log_pre step just average and print previous results
            # ============ Validation ============ #
            if (step + 1) % log_pre == 0:
                with torch.no_grad():
                    source_val_features = self.encoder(source)
                    c_source1 = self.clf1(source_val_features)
                    c_source2 = self.clf2(source_val_features)
                    c_target = self.clf_target(source_val_features)
                    print("Train data (source) scores:")
                    print("Step %d | Source clf1=%.2f, clf2=%.2f | Source data clf_t=%.2f" \
                          % (step,
                             accuracy(c_source1, s_labels),
                             accuracy(c_source2, s_labels),
                             accuracy(c_target, s_labels))
                          )
                    acc = self.eval(target_val_loader, self.clf_target)
                    print("Val target data acc=%.2f" % acc)
                    print()

    def pseudo_labeling(self, loader, pool_size=4000, threshold=0.9):
        """
        When C1, C2 denote the class which has the maximum predicted probability for
        y1, y2, we assign a pseudo-label to xk if the following two
        conditions are satisfied. First, we require C1 = C2 to give
        pseudo-labels, which means two different classifiers agree
        with the prediction. The second requirement is that the
        maximizing probability of y1 or y2 exceeds the threshold
        parameter, which we set as 0.9 or 0.95 in the experiment.

        :return:
        """
        pool = []  # x, y_pseudo
        for x, _ in loader:
            batch_size = x.shape[0]
            x = self.to_var(x)
            ys1 = F.softmax(self.clf1(self.encoder(x)))
            ys2 = F.softmax(self.clf2(self.encoder(x)))
            # _, pseudo_labels = torch.max(pseudo_labels, 1)
            for i in range(batch_size):
                y1 = ys1[i]
                y2 = ys2[i]
                val1, idx1 = torch.max(y1, 0)
                val2, idx2 = torch.max(y2, 0)
                if idx1 == idx2 and max(val1, val2) >= threshold:
                    pool.append((x[i].cpu(), idx1.cpu().item()))
                if len(pool) >= pool_size:
                    return pool
        return pool

    def train(self, source_loader, source_val_loader, target_loader,
              target_val_loader, epochs):
        """
        :param epochs: target epochs the training will be done
        """

        # pretrain
        log_pre = 30
        lr = self.train_lr

        params1 = reduce(
            lambda a, b: a + b,
            map(lambda i: list(i.parameters()),
                [self.encoder, self.clf1, self.clf2]))
        params2 = list(self.encoder.parameters()) + list(
            self.clf_target.parameters())
        optimizer1 = optim.Adam(params1, lr)
        optimizer2 = optim.Adam(params2, lr)

        # ad-hoc
        acs1 = []
        acs2 = []
        acs3 = []

        for epoch in range(epochs):
            source_iter = iter(source_loader)
            target_iter = iter(target_loader)

            source_per_epoch = len(source_iter)
            target_per_epoch = len(target_iter)
            if epoch == 0:
                print("source_per_epoch, target_per_epoch:", source_per_epoch,
                      target_per_epoch)
            if epoch == 3:
                for param_group in optimizer1.param_groups:
                    param_group['lr'] = lr * 0.1

                for param_group in optimizer2.param_groups:
                    param_group['lr'] = lr * 0.1
            if epoch == 6:
                for param_group in optimizer1.param_groups:
                    param_group['lr'] = lr * 0.01

                for param_group in optimizer2.param_groups:
                    param_group['lr'] = lr * 0.01

            # ============ Pseudo-labeling  ============ #
            # Fill candidates
            target_candidates = self.pseudo_labeling(target_loader,
                                                     pool_size=4000 * epoch)
            print("Target candidates len:", len(target_candidates))
            if len(target_candidates) <= 1:
                target_candidates = self.pseudo_labeling(target_loader,
                                                         threshold=0.0)
                print("Target candidates len:", len(target_candidates))
            target_candidates_loader = self.wrap_to_loader(
                target_candidates, batch_size=target_loader.batch_size)
            for step, (target,
                       t_labels) in enumerate(target_candidates_loader):
                if (step + 1) % source_per_epoch == 0:
                    source_iter = iter(source_loader)

                source, s_labels = next(source_iter)
                target, t_labels = self.to_var(target), self.to_var(
                    t_labels).long().squeeze()
                source, s_labels = self.to_var(source), self.to_var(
                    s_labels).long().squeeze()

                # ============ Train F, F1, F2  ============ #
                optimizer1.zero_grad()
                # Source data
                # forward
                features = self.encoder(source)
                y1s_hat = self.clf1(features)
                y2s_hat = self.clf2(features)
                # loss
                loss_source_class = self.loss([y1s_hat, y2s_hat],
                                              s_labels,
                                              weights_coef=self.weights_coef)

                # Target data
                # forward
                features = self.encoder(target)
                y1t_hat = self.clf1(features)
                y2t_hat = self.clf2(features)
                # loss
                loss_target_class = self.loss([y1t_hat, y2t_hat],
                                              t_labels,
                                              weights_coef=self.weights_coef)
                # one step
                (loss_source_class + loss_target_class).backward()
                optimizer1.step()
                optimizer1.zero_grad()

                # ============ Train F, Ft  ============ #
                optimizer2.zero_grad()
                # Target data
                # forward
                y_target_hat = self.clf_target(self.encoder(target))
                # loss
                loss_target_class = self.loss([y_target_hat], t_labels)
                # one step
                loss_target_class.backward()
                optimizer2.step()
                optimizer2.zero_grad()

                # ============ Validation ============ #
                acs1.append(accuracy(y1s_hat, s_labels).item())
                acs2.append(accuracy(y2s_hat, s_labels).item())
                acs3.append(accuracy(y_target_hat, t_labels).item())

                if (step + 1) % log_pre == 0:
                    acc = self.eval(target_val_loader, self.clf_target)
                    print("Step %d | Val data target classifier acc=%.2f" %
                          (step, acc))
                    print(
                        "          Train accuracy clf1=%.2f, clf2=%.2f, clf_t=%.2f"
                        % (np.mean(acs1), np.mean(acs2), np.mean(acs3)))
                    acs1 = []
                    acs2 = []
                    acs3 = []

                    # acc1 = self.eval(source_val_loader, self.clf1)
                    # print("        | Val data source classifier1 acc=%.2f" % acc1)
                    # acc2 = self.eval(source_val_loader, self.clf2)
                    # print("        | Val data source classifier2 acc=%.2f" % acc2)
                    print()

    def save_models(self):
        torch.save(self.encoder, 'encoder.pth')
        torch.save(self.clf1, 'clf1.pth')
        torch.save(self.clf2, 'clf2.pth')
        torch.save(self.clf_target, 'clf_target.pth')

    def load_models(self):
        self.encoder = torch.load('encoder.pth')
        self.clf1 = torch.load('clf1.pth')
        self.clf2 = torch.load('clf2.pth')
        self.clf_target = torch.load('clf_target.pth')

    def eval(self, loader, classifier):
        """
        Evaluate encoder + passed classifier
        """
        # for x, y_true in loader:
        #     y_hat = classifier(self.encoder)
        #     acc = accuracy(y_hat, y_true)

        class_correct = [0] * self.num_classes
        class_total = [0.] * self.num_classes
        classes = shl_processing.coarse_label_mapping
        self.encoder.eval()
        classifier.eval()

        for x, y_true in loader:
            # forward pass: compute predicted outputs by passing inputs to the model
            x, y_true = self.to_var(x), self.to_var(y_true).long().squeeze()

            y_hat = classifier(self.encoder(x))
            _, pred = torch.max(y_hat, 1)
            correct = np.squeeze(pred.eq(y_true.data.view_as(pred)))
            # calculate test accuracy for each object class
            for i in range(len(y_true.data)):
                label = y_true.data[i]
                class_correct[label] += correct[i].item()
                class_total[label] += 1

        for i in range(self.num_classes):
            if class_total[i] > 0:
                print('\tTest Accuracy of %10s: %2d%% (%2d/%2d)' %
                      (classes[i], 100 * class_correct[i] / class_total[i],
                       np.sum(class_correct[i]), np.sum(class_total[i])))
            else:
                print('\tTest Accuracy of %10s: N/A (no training examples)' %
                      (classes[i]))

        self.encoder.train()
        classifier.train()

        return 100. * np.sum(class_correct) / np.sum(class_total)

    def wrap_to_loader(self, target_candidates, batch_size):
        """
        :param target_candidates: [(x,y_pseudo)]
        :return:
        """
        assert len(target_candidates) > 0
        tmp = target_candidates  # CondomDataset(target_candidates)
        return torch.utils.data.DataLoader(dataset=tmp,
                                           batch_size=batch_size,
                                           shuffle=True,
                                           num_workers=0)

    def confusion_matrix(self, loader, classifier):

        labels = []
        preds = []
        for x, y_true in loader:
            labels += list(y_true.cpu().detach().numpy().flatten())
            x, y_true = self.to_var(x), self.to_var(y_true).long().squeeze()

            y_hat = classifier(self.encoder(x))
            _, pred = torch.max(y_hat, 1)

            preds += list(pred.cpu().detach().numpy().flatten())

        cm = confusion_matrix(labels, preds)

        df_cm = pd.DataFrame(cm,
                             index=coarse_label_mapping,
                             columns=coarse_label_mapping,
                             dtype=np.int)
        plt.figure(figsize=(10, 7))
        sn.heatmap(df_cm, annot=True)

        plt.show()
Exemplo n.º 7
0
def main():
    """
    Describe main process including train and validation.
    """

    global start_epoch, checkpoint, fine_tune_encoder, best_bleu4, epochs_since_improvement, word_map

    # Read word map
    word_map_path = os.path.join(data_folder,
                                 'WORDMAP_' + dataset_name + ".json")
    with open(word_map_path, 'r') as j:
        word_map = json.load(j)

    # Set checkpoint or read from checkpoint
    if checkpoint is None:  # No pretrained model, set model from beginning
        decoder = Decoder(embed_dim=embed_dim,
                          decoder_dim=decoder_dim,
                          vocab_size=len(word_map),
                          dropout=dropout_rate)
        decoder_param = filter(lambda p: p.requires_grad, decoder.parameters())
        for param in decoder_param:
            tensor0 = param.data
            dist.all_reduce(tensor0, op=dist.reduce_op.SUM)
            param.data = tensor0 / np.sqrt(np.float(num_nodes))
        decoder_optimizer = optim.Adam(params=decoder_param, lr=decoder_lr)
        encoder = Encoder()
        encoder.fine_tune(fine_tune_encoder)
        encoder_param = filter(lambda p: p.requires_grad, encoder.parameters())
        if fine_tune_encoder:
            for param in encoder_param:
                tensor0 = param.data
                dist.all_reduce(tensor0, op=dist.reduce_op.SUM)
                param.data = tensor0 / np.sqrt(np.float(num_nodes))
        encoder_optimizer = optim.Adam(
            params=encoder_param, lr=encoder_lr) if fine_tune_encoder else None
    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint["epoch"] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        best_bleu4 = checkpoint['bleu-4']
        decoder = checkpoint['decoder']
        #decoder_optimizer = checkpoint['decoder_optimizer']
        encoder = checkpoint['encoder']
        #encoder_optimizer = checkpoint['encoder_optimizer']
        if fine_tune_encoder and encoder_optimizer is None:
            encoder.fine_tune(fine_tune_encoder)
            encoder_optimizer = torch.optim.Adam(params=filter(
                lambda p: p.requires_grad, encoder.parameters()),
                                                 lr=encoder_lr)

    decoder = decoder.to(device)
    encoder = encoder.to(device)
    criterion = nn.CrossEntropyLoss()

    # Data loader
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_set = CaptionDataset(data_folder=h5data_folder,
                               data_name=dataset_name,
                               split="TRAIN",
                               transform=transforms.Compose([normalize]))
    val_set = CaptionDataset(data_folder=h5data_folder,
                             data_name=dataset_name,
                             split="VAL",
                             transform=transforms.Compose([normalize]))
    train_loader = DataLoader(train_set,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=workers,
                              pin_memory=True)
    val_loader = DataLoader(val_set,
                            batch_size=batch_size,
                            shuffle=True,
                            num_workers=workers,
                            pin_memory=True)

    total_start_time = datetime.datetime.now()
    print("Start the 1st epoch at: ", total_start_time)

    # Epoch
    for epoch in range(start_epoch, num_epochs):
        # Pre-check by epochs_since_improvement
        if epochs_since_improvement == 20:  # If there are 20 epochs that no improvements are achieved
            break
        if epochs_since_improvement % 8 == 0 and epochs_since_improvement > 0:
            adjust_learning_rate(decoder_optimizer)
            if fine_tune_encoder:
                adjust_learning_rate(encoder_optimizer)

        # For every batch
        batch_time = AverageMeter()  # forward prop. + back prop. time
        data_time = AverageMeter()  # data loading time
        losses = AverageMeter()  # loss (per word decoded)
        top5accs = AverageMeter()  # top5 accuracy
        decoder.train()
        encoder.train()

        start = time.time()
        start_time = datetime.datetime.now(
        )  # Initialize start time for this epoch

        # TRAIN
        for j, (images, captions, caplens) in enumerate(train_loader):
            if fine_tune_encoder and (epoch - start_epoch > 0 or j > 10):
                for group in encoder_optimizer.param_groups:
                    for p in group['params']:
                        state = encoder_optimizer.state[p]
                        if (state['step'] >= 1024):
                            state['step'] = 1000

            if (epoch - start_epoch > 0 or j > 10):
                for group in decoder_optimizer.param_groups:
                    for p in group['params']:
                        state = decoder_optimizer.state[p]
                        if (state['step'] >= 1024):
                            state['step'] = 1000

            data_time.update(time.time() - start)

            images = images.to(device)
            captions = captions.to(device)
            caplens = caplens.to(device)
            # Forward
            enc_images = encoder(images)
            predictions, enc_captions, dec_lengths, sort_ind = decoder(
                enc_images, captions, caplens)

            # Define target as original captions excluding <start>
            target = enc_captions[:, 1:]  # (batch_size, max_caption_length-1)
            target, _ = pack_padded_sequence(
                target, dec_lengths, batch_first=True
            )  # Delete all paddings and concat all other parts
            predictions, _ = pack_padded_sequence(
                predictions, dec_lengths,
                batch_first=True)  # (batch_size, sum(dec_lengths))

            loss = criterion(predictions, target)

            # Backward
            decoder_optimizer.zero_grad()
            if encoder_optimizer is not None:
                encoder_optimizer.zero_grad()
            loss.backward()
            ## Clip gradients
            if grad_clip is not None:
                clip_gradient(decoder_optimizer, grad_clip)
                if encoder_optimizer is not None:
                    clip_gradient(encoder_optimizer, grad_clip)
            ## Update
            decoder_optimizer.step()
            if encoder_optimizer is not None:
                encoder_optimizer.step()

            # Update metrics (AverageMeter)
            acc_top5 = compute_accuracy(predictions, target, k=5)
            top5accs.update(acc_top5, sum(dec_lengths))
            losses.update(loss.item(), sum(dec_lengths))
            batch_time.update(time.time() - start)

            # Print current status
            if (j + 1) % print_freq == 0:
                print(
                    'Epoch: [{0}][{1}/{2}]\t'
                    'Current Batch Time: {batch_time.val:.3f} (Average: {batch_time.avg:.3f})\t'
                    'Current Data Load Time: {data_time.val:.3f} (Average: {data_time.avg:.3f})\t'
                    'Current Loss: {loss.val:.4f} (Average: {loss.avg:.4f})\t'
                    'Current Top-5 Accuracy: {top5.val:.3f} (Average: {top5.avg:.3f})'
                    .format(epoch + 1,
                            j + 1,
                            len(train_loader),
                            batch_time=batch_time,
                            data_time=data_time,
                            loss=losses,
                            top5=top5accs))
                now_time = datetime.datetime.now()
                print("Epoch Training Time: ", now_time - start_time)
                print("Total Time: ", now_time - total_start_time)

            start = time.time()

        # VALIDATION
        decoder.eval()
        encoder.eval()

        batch_time = AverageMeter()  # forward prop. + back prop. time
        losses = AverageMeter()  # loss (per word decoded)
        top5accs = AverageMeter()  # top5 accuracy
        references = list(
        )  # references (true captions) for calculating BLEU-4 score
        hypotheses = list()  # hypotheses (predictions)

        start_time = datetime.datetime.now()

        for j, (images, captions, caplens, all_caps) in enumerate(val_loader):
            start = time.time()

            images = images.to(device)
            captions = captions.to(device)
            caplens = caplens.to(device)

            # Forward
            enc_images = encoder(images)
            predictions, enc_captions, dec_lengths, sort_ind = decoder(
                enc_images, captions, caplens)

            # Define target as original captions excluding <start>
            predictions_copy = predictions.clone()
            target = enc_captions[:, 1:]  # (batch_size, max_caption_length-1)
            target, _ = pack_padded_sequence(
                target, dec_lengths, batch_first=True
            )  # Delete all paddings and concat all other parts
            predictions, _ = pack_padded_sequence(
                predictions, dec_lengths,
                batch_first=True)  # (batch_size, sum(dec_lengths))

            loss = criterion(predictions, target)

            # Update metrics (AverageMeter)
            acc_top5 = compute_accuracy(predictions, target, k=5)
            top5accs.update(acc_top5, sum(dec_lengths))
            losses.update(loss.item(), sum(dec_lengths))
            batch_time.update(time.time() - start)

            # Print current status
            if (j + 1) % print_freq == 0:
                print(
                    'Epoch: [{0}][{1}/{2}]\t'
                    'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                    'Data Load Time {data_time.val:.3f} ({data_time.avg:.3f})\t'
                    'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                    'Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})'.format(
                        epoch + 1,
                        j,
                        len(val_loader),
                        batch_time=batch_time,
                        data_time=data_time,
                        loss=losses,
                        top5=top5accs))
                now_time = datetime.datetime.now()
                print("Epoch Validation Time: ", now_time - start_time)
                print("Total Time: ", now_time - total_start_time)

            ## Store references (true captions), and hypothesis (prediction) for each image
            ## If for n images, we have n hypotheses, and references a, b, c... for each image, we need -
            ## references = [[ref1a, ref1b, ref1c], [ref2a, ref2b], ...], hypotheses = [hyp1, hyp2, ...]

            # references
            all_caps = all_caps[sort_ind]
            for k in range(all_caps.shape[0]):
                img_caps = all_caps[k].tolist()
                img_captions = list(
                    map(
                        lambda c: [
                            w for w in c if w not in
                            {word_map["<start>"], word_map["<pad>"]}
                        ], img_caps))
                references.append(img_captions)

            # hypotheses
            _, preds = torch.max(predictions_copy, dim=2)
            preds = preds.tolist()
            temp_preds = list()
            for i, p in enumerate(preds):
                temp_preds.append(preds[i][:dec_lengths[i]])  # remove pads
            preds = temp_preds
            hypotheses.extend(preds)

            assert len(references) == len(hypotheses)

        ## Compute BLEU-4 Scores
        #recent_bleu4 = corpus_bleu(references, hypotheses, emulate_multibleu=True)
        recent_bleu4 = corpus_bleu(references, hypotheses)

        print(
            '\n * LOSS - {loss.avg:.3f}, TOP-5 ACCURACY - {top5.avg:.3f}, BLEU-4 - {bleu}\n'
            .format(loss=losses, top5=top5accs, bleu=recent_bleu4))

        # CHECK IMPROVEMENT
        is_best = recent_bleu4 > best_bleu4
        best_bleu4 = max(recent_bleu4, best_bleu4)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" %
                  (epochs_since_improvement))
        else:
            epochs_since_improvement = 0

        # SAVE CHECKPOINT
        save_checkpoint(dataset_name, epoch, epochs_since_improvement, encoder,
                        decoder, encoder_optimizer, decoder_optimizer,
                        recent_bleu4, is_best)
        print("Epoch {}, cost time: {}\n".format(epoch + 1,
                                                 now_time - total_start_time))
class DQNAgent():
    """Interacts with and learns from the environment."""
    def __init__(self, state_size, action_size, config):
        """Initialize an Agent object.
        
        Params
        ======
            state_size (int): dimension of each state
            action_size (int): dimension of each action
            seed (int): random seed
        """
        self.state_size = state_size
        self.action_size = action_size
        self.seed = random.seed(config["seed"])
        self.seed = config["seed"]
        self.gamma = 0.99
        self.batch_size = config["batch_size"]
        self.lr = config["lr"]
        self.tau = config["tau"]
        self.fc1 = config["fc1_units"]
        self.fc2 = config["fc2_units"]
        self.device = config["device"]
        # Q-Network
        self.qnetwork_local = QNetwork(state_size, action_size, self.fc1,
                                       self.fc2, self.seed).to(self.device)
        self.qnetwork_target = QNetwork(state_size, action_size, self.fc1,
                                        self.fc2, self.seed).to(self.device)

        self.optimizer = optim.Adam(self.qnetwork_local.parameters(),
                                    lr=self.lr)
        self.encoder = Encoder(config).to(self.device)
        self.encoder_optimizer = torch.optim.Adam(self.encoder.parameters(),
                                                  self.lr)

        # Replay memory

        # Initialize time step (for updating every UPDATE_EVERY steps)
        self.t_step = 0

    def step(self, memory, writer):
        self.t_step += 1
        if len(memory) > self.batch_size:
            if self.t_step % 4 == 0:
                experiences = memory.sample(self.batch_size)
                self.learn(experiences, writer)

    def act(self, state, eps=0.):
        """Returns actions for given state as per current policy.
        
        Params
        ======
            state (array_like): current state
            eps (float): epsilon, for epsilon-greedy action selection
        """
        state = torch.from_numpy(state).float().unsqueeze(0).to(self.device)
        state = state.type(torch.float32).div_(255)
        self.qnetwork_local.eval()
        self.encoder.eval()
        with torch.no_grad():
            state = self.encoder.create_vector(state)
            action_values = self.qnetwork_local(state)
        self.qnetwork_local.train()
        self.encoder.train()

        # Epsilon-greedy action selection
        if random.random() > eps:
            return np.argmax(action_values.cpu().data.numpy())
        else:
            return random.choice(np.arange(self.action_size))

    def learn(self, experiences, writer):
        """Update value parameters using given batch of experience tuples.
        Params
        ======
            experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples 
            gamma (float): discount factor
        """
        states, actions, rewards, next_states, dones = experiences
        states = states.type(torch.float32).div_(255)
        states = self.encoder.create_vector(states)
        next_states = next_states.type(torch.float32).div_(255)
        next_states = self.encoder.create_vector(next_states)
        actions = actions.type(torch.int64)
        # Get max predicted Q values (for next states) from target model
        Q_targets_next = self.qnetwork_target(next_states).detach().max(
            1)[0].unsqueeze(1)
        # Compute Q targets for current states
        Q_targets = rewards + (self.gamma * Q_targets_next * dones)

        # Get expected Q values from local model
        Q_expected = self.qnetwork_local(states).gather(1, actions)

        # Compute loss
        loss = F.mse_loss(Q_expected, Q_targets)
        writer.add_scalar('Q_loss', loss, self.t_step)
        # Minimize the loss
        self.optimizer.zero_grad()
        self.encoder_optimizer.zero_grad()

        loss.backward()
        self.optimizer.step()
        self.encoder_optimizer.step()

        # ------------------- update target network ------------------- #
        self.soft_update(self.qnetwork_local, self.qnetwork_target)

    def soft_update(self, local_model, target_model):
        """Soft update model parameters.
        θ_target = τ*θ_local + (1 - τ)*θ_target
        Params
        ======
            local_model (PyTorch model): weights will be copied from
            target_model (PyTorch model): weights will be copied to
            tau (float): interpolation parameter 
        """
        for target_param, local_param in zip(target_model.parameters(),
                                             local_model.parameters()):
            target_param.data.copy_(self.tau * local_param.data +
                                    (1.0 - self.tau) * target_param.data)

    def save(self, filename):
        """
        """
        mkdir("", filename)
        torch.save(self.qnetwork_local.state_dict(), filename + "_q_net.pth")
        torch.save(self.optimizer.state_dict(),
                   filename + "_q_net_optimizer.pth")
        torch.save(self.encoder.state_dict(), filename + "_encoder.pth")
        torch.save(self.encoder_optimizer.state_dict(),
                   filename + "_encoder_optimizer.pth")
        print("Save models to {}".format(filename))
# load the new state dict
src_encoder.load_state_dict(src_encoder_dict)
optimizer = optim.SGD(
    list(src_encoder.parameters()) + list(src_classifier.parameters()),
    lr=lr,
    momentum=momentum,
    weight_decay=weight_decay)

criterion = nn.CrossEntropyLoss()

if cuda: 
    src_encoder = src_encoder.cuda()
    src_classifier = src_classifier.cuda() 
    criterion = criterion.cuda() 

src_encoder.train()
src_classifier.train()
# begin train
for epoch in range(1, epochs+1):
    correct = 0
    for batch_idx, (src_data, label) in enumerate(src_train_loader):
        if cuda:
            src_data, label = src_data.cuda(), label.cuda()
        src_data, label = Variable(src_data), Variable(label)
        optimizer.zero_grad()
        src_feature = src_encoder(src_data)
        output = src_classifier(src_feature)
        loss = criterion(output, label)
        output = F.softmax(output, dim=1)
        pred = output.data.max(1, keepdim=True)[1]
        correct += pred.eq(label.data.view_as(pred)).cpu().sum()