예제 #1
0
class Trainer():
    def __init__(self, model, modelParam, config, dataLoader, saveRestorer):
        self.model        = model
        self.modelParam   = modelParam
        self.config       = config
        self.dataLoader   = dataLoader
        self.saveRestorer = saveRestorer
        self.plotter      = Plotter(self.modelParam, self.config)
        return

    def train(self):
        for cur_epoch in range(self.model.start_epoch, self.modelParam['numbOfEpochs']):
            for modeSetup in self.modelParam['modeSetups']:
                mode     = modeSetup[0]
                is_train = modeSetup[1]
                if mode == 'train':
                    self.model.net.train()
                    loss = self.run_epoch(mode, self.model, is_train, cur_epoch)
                else:
                    with torch.no_grad():
                        self.model.net.eval()
                        loss = self.run_epoch(mode, self.model, is_train, cur_epoch)
                self.plotter.update(cur_epoch, loss, mode)
            self.saveRestorer.save(cur_epoch, loss, self.model)
        return


    def run_epoch(self, mode, model, is_train, cur_epoch):
        cur_it = -1
        epochTotalLoss = 0
        numbOfWordsInEpoch = 0

        if self.modelParam['inNotebook']:
            tt = tqdm_notebook(self.dataLoader.myDataDicts[mode], desc='', leave=True, mininterval=0.01, file=sys.stdout)
        else:
            # tt = tqdm_notebook(self.dataLoader.myDataDicts[mode], desc='', leave=True, mininterval=0.01,file=sys.stdout)
            tt = tqdm(self.dataLoader.myDataDicts[mode], desc='', leave=True, mininterval=0.01, file=sys.stdout)
        for dataDict in tt:
            for key in ['xTokens', 'yTokens', 'yWeights', 'vgg_fc7_features']:
                dataDict[key] = dataDict[key].to(model.device)
            cur_it += 1
            batchTotalLoss = 0
            numbOfWordsInBatch = 0
            for iter in range(dataDict['numbOfTruncatedSequences']):
                xTokens  = dataDict['xTokens'][:, :, iter]
                yTokens  = dataDict['yTokens'][:, :, iter]
                yWeights = dataDict['yWeights'][:, :, iter]
                vgg_fc7_features = dataDict['vgg_fc7_features']
                if iter==0:
                    logits, current_hidden_state = model.net(vgg_fc7_features, xTokens,  is_train)
                else:
                    logits, current_hidden_state_Ref = model.net(vgg_fc7_features, xTokens,  is_train, current_hidden_state)
                sumLoss, meanLoss = model.loss_fn(logits, yTokens, yWeights)

                if mode == 'train':
                    model.optimizer.zero_grad()
                    meanLoss.backward(retain_graph=True)
                    model.optimizer.step()

                batchTotalLoss += sumLoss.item()
                numbOfWordsInBatch += yWeights.sum().item()

            epochTotalLoss += batchTotalLoss
            numbOfWordsInEpoch +=numbOfWordsInBatch

            desc = f'{mode} | Epcohs={cur_epoch} | loss={batchTotalLoss/numbOfWordsInBatch:.4f}'
            tt.set_description(desc)
            tt.update()

            epochLoss = epochTotalLoss/numbOfWordsInEpoch
        return epochLoss
예제 #2
0
class Trainer():
    def __init__(self, model, modelParam, config, saveRestorer, env):
        self.model = model
        self.modelParam = modelParam
        self.config = config
        self.saveRestorer = saveRestorer
        self.env = env
        self.plotter = Plotter(self.modelParam, self.config)
        return

    def train(self):
        running_reward = 10
        given_range = range(self.model.update_counter,
                            self.modelParam['numb_of_updates'])
        if self.modelParam['inNotebook']:
            tt = tqdm_notebook(given_range,
                               desc='',
                               leave=True,
                               mininterval=0.01,
                               file=sys.stdout)
        else:
            tt = tqdm(given_range,
                      desc='',
                      leave=True,
                      mininterval=0.01,
                      file=sys.stdout)
        for update_counter in tt:
            if self.modelParam['is_train']:
                self.model.policyNet.train()
            else:
                self.model.policyNet.eval()
            loss, reward = self.run_update()

            reward = reward / self.modelParam["episode_batch"]
            running_reward = 0.2 * reward + (1 - 0.2) * running_reward

            desc = f'Update_counter={update_counter} | reward={reward:.4f} | | running_reward={running_reward:.4f}'
            tt.set_description(desc)
            tt.update()

            if update_counter % self.modelParam['storeModelFreq'] == 0:
                self.plotter.update(update_counter, running_reward)
                self.saveRestorer.save(update_counter, running_reward,
                                       self.model)
        return

    def run_update(self):
        episodes_summary = {
            'episodes_log_probs': [],
            'episodes_rewards': [],
            'episodes_total_reward': [],
            'episodes_return': [],
        }

        for episode_ind in range(self.modelParam['episode_batch']):
            #play episode
            ep_log_probs, ep_rewards, ep_total_reward, ep_returns = self.play_episode(
            )
            episodes_summary['episodes_log_probs'].append(ep_log_probs)
            episodes_summary['episodes_rewards'].append(ep_rewards)
            episodes_summary['episodes_total_reward'].append(ep_total_reward)
            episodes_summary['episodes_return'].append(ep_returns)

        loss = self.gradient_update(episodes_summary)
        reward = sum(episodes_summary['episodes_total_reward'])
        return loss, reward

    def play_episode(self):
        ep_log_probs = []
        ep_rewards = []
        state, ep_total_reward = self.env.reset(), 0
        for t in range(1, self.modelParam['max_episode_len']
                       ):  # Don't infinite loop while learning
            action, log_prob = self.model.select_action(state)
            state, reward, done, _ = self.env.step(action)
            if self.modelParam['render']:
                self.env.render()
            ep_rewards.append(reward)
            ep_log_probs.append(log_prob)
            ep_total_reward += reward
            if done:
                break
        #calculate return
        ep_returns = []
        G_t = 0
        for r in ep_rewards[::-1]:
            G_t = r + self.config['gamma'] * G_t
            ep_returns.insert(0, G_t)
        return ep_log_probs, ep_rewards, ep_total_reward, ep_returns

    def gradient_update(self, episodes_summary):
        policy_loss = []
        #flatten the list of lists into a single list
        episodes_return = [
            item for sublist in episodes_summary['episodes_return']
            for item in sublist
        ]
        episodes_log_probs = [
            item for sublist in episodes_summary['episodes_log_probs']
            for item in sublist
        ]

        #normalize the return to zero mean ans unit variance
        eps = np.finfo(np.float32).eps.item()
        episodes_return = torch.tensor(episodes_return,
                                       device=self.model.device)
        episodes_return = (episodes_return - episodes_return.mean()) / (
            episodes_return.std() + eps)
        for log_prob, R in zip(episodes_log_probs, episodes_return):
            policy_loss.append(-log_prob *
                               R)  #we multiply with -1 to get gradient ascent
        self.model.optimizer.zero_grad()
        policy_loss = torch.cat(policy_loss).sum()
        policy_loss.backward()
        self.model.optimizer.step()

        return policy_loss.detach().cpu().item()