class Trainer(): def __init__(self, model, modelParam, config, dataLoader, saveRestorer): self.model = model self.modelParam = modelParam self.config = config self.dataLoader = dataLoader self.saveRestorer = saveRestorer self.plotter = Plotter(self.modelParam, self.config) return def train(self): for cur_epoch in range(self.model.start_epoch, self.modelParam['numbOfEpochs']): for modeSetup in self.modelParam['modeSetups']: mode = modeSetup[0] is_train = modeSetup[1] if mode == 'train': self.model.net.train() loss = self.run_epoch(mode, self.model, is_train, cur_epoch) else: with torch.no_grad(): self.model.net.eval() loss = self.run_epoch(mode, self.model, is_train, cur_epoch) self.plotter.update(cur_epoch, loss, mode) self.saveRestorer.save(cur_epoch, loss, self.model) return def run_epoch(self, mode, model, is_train, cur_epoch): cur_it = -1 epochTotalLoss = 0 numbOfWordsInEpoch = 0 if self.modelParam['inNotebook']: tt = tqdm_notebook(self.dataLoader.myDataDicts[mode], desc='', leave=True, mininterval=0.01, file=sys.stdout) else: # tt = tqdm_notebook(self.dataLoader.myDataDicts[mode], desc='', leave=True, mininterval=0.01,file=sys.stdout) tt = tqdm(self.dataLoader.myDataDicts[mode], desc='', leave=True, mininterval=0.01, file=sys.stdout) for dataDict in tt: for key in ['xTokens', 'yTokens', 'yWeights', 'vgg_fc7_features']: dataDict[key] = dataDict[key].to(model.device) cur_it += 1 batchTotalLoss = 0 numbOfWordsInBatch = 0 for iter in range(dataDict['numbOfTruncatedSequences']): xTokens = dataDict['xTokens'][:, :, iter] yTokens = dataDict['yTokens'][:, :, iter] yWeights = dataDict['yWeights'][:, :, iter] vgg_fc7_features = dataDict['vgg_fc7_features'] if iter==0: logits, current_hidden_state = model.net(vgg_fc7_features, xTokens, is_train) else: logits, current_hidden_state_Ref = model.net(vgg_fc7_features, xTokens, is_train, current_hidden_state) sumLoss, meanLoss = model.loss_fn(logits, yTokens, yWeights) if mode == 'train': model.optimizer.zero_grad() meanLoss.backward(retain_graph=True) model.optimizer.step() batchTotalLoss += sumLoss.item() numbOfWordsInBatch += yWeights.sum().item() epochTotalLoss += batchTotalLoss numbOfWordsInEpoch +=numbOfWordsInBatch desc = f'{mode} | Epcohs={cur_epoch} | loss={batchTotalLoss/numbOfWordsInBatch:.4f}' tt.set_description(desc) tt.update() epochLoss = epochTotalLoss/numbOfWordsInEpoch return epochLoss
class Trainer(): def __init__(self, model, modelParam, config, saveRestorer, env): self.model = model self.modelParam = modelParam self.config = config self.saveRestorer = saveRestorer self.env = env self.plotter = Plotter(self.modelParam, self.config) return def train(self): running_reward = 10 given_range = range(self.model.update_counter, self.modelParam['numb_of_updates']) if self.modelParam['inNotebook']: tt = tqdm_notebook(given_range, desc='', leave=True, mininterval=0.01, file=sys.stdout) else: tt = tqdm(given_range, desc='', leave=True, mininterval=0.01, file=sys.stdout) for update_counter in tt: if self.modelParam['is_train']: self.model.policyNet.train() else: self.model.policyNet.eval() loss, reward = self.run_update() reward = reward / self.modelParam["episode_batch"] running_reward = 0.2 * reward + (1 - 0.2) * running_reward desc = f'Update_counter={update_counter} | reward={reward:.4f} | | running_reward={running_reward:.4f}' tt.set_description(desc) tt.update() if update_counter % self.modelParam['storeModelFreq'] == 0: self.plotter.update(update_counter, running_reward) self.saveRestorer.save(update_counter, running_reward, self.model) return def run_update(self): episodes_summary = { 'episodes_log_probs': [], 'episodes_rewards': [], 'episodes_total_reward': [], 'episodes_return': [], } for episode_ind in range(self.modelParam['episode_batch']): #play episode ep_log_probs, ep_rewards, ep_total_reward, ep_returns = self.play_episode( ) episodes_summary['episodes_log_probs'].append(ep_log_probs) episodes_summary['episodes_rewards'].append(ep_rewards) episodes_summary['episodes_total_reward'].append(ep_total_reward) episodes_summary['episodes_return'].append(ep_returns) loss = self.gradient_update(episodes_summary) reward = sum(episodes_summary['episodes_total_reward']) return loss, reward def play_episode(self): ep_log_probs = [] ep_rewards = [] state, ep_total_reward = self.env.reset(), 0 for t in range(1, self.modelParam['max_episode_len'] ): # Don't infinite loop while learning action, log_prob = self.model.select_action(state) state, reward, done, _ = self.env.step(action) if self.modelParam['render']: self.env.render() ep_rewards.append(reward) ep_log_probs.append(log_prob) ep_total_reward += reward if done: break #calculate return ep_returns = [] G_t = 0 for r in ep_rewards[::-1]: G_t = r + self.config['gamma'] * G_t ep_returns.insert(0, G_t) return ep_log_probs, ep_rewards, ep_total_reward, ep_returns def gradient_update(self, episodes_summary): policy_loss = [] #flatten the list of lists into a single list episodes_return = [ item for sublist in episodes_summary['episodes_return'] for item in sublist ] episodes_log_probs = [ item for sublist in episodes_summary['episodes_log_probs'] for item in sublist ] #normalize the return to zero mean ans unit variance eps = np.finfo(np.float32).eps.item() episodes_return = torch.tensor(episodes_return, device=self.model.device) episodes_return = (episodes_return - episodes_return.mean()) / ( episodes_return.std() + eps) for log_prob, R in zip(episodes_log_probs, episodes_return): policy_loss.append(-log_prob * R) #we multiply with -1 to get gradient ascent self.model.optimizer.zero_grad() policy_loss = torch.cat(policy_loss).sum() policy_loss.backward() self.model.optimizer.step() return policy_loss.detach().cpu().item()