def train(self): set_random_seeds(self.random_seed) all_langs = self.data_loader.all_langs num_langs = len(all_langs) idx = list(range(num_langs)) random.shuffle(idx) num_langs_per_fold = (num_langs + self.k_fold - 1) // self.k_fold accum_metrics = Metrics() for fold in range(self.k_fold): # Get train-dev split. start_idx = fold * num_langs_per_fold end_idx = start_idx + num_langs_per_fold if fold < self.k_fold - 1 else num_langs dev_langs = [all_langs[idx[i]] for i in range(start_idx, end_idx)] logging.imp(f'dev_langs: {sorted(dev_langs)}') train_langs = [all_langs[idx[i]] for i in range(num_langs) if i < start_idx or i >= end_idx] assert len(set(dev_langs + train_langs)) == num_langs self.trainer.reset() best_mse = self.trainer.train( self.evaluator, train_langs, dev_langs, fold) # Aggregate every fold. accum_metrics += best_mse logging.info(accum_metrics.get_table())
def train(self, evaluator: evaluator.Evaluator, train_langs: List[str], dev_langs: List[str], fold_idx: int) -> Metrics: # Reset parameters. self._init_params(init_matrix=True, init_vector=True, init_higher_tensor=True) self.optimizer = optim.Adam(self.model.parameters(), self.learning_rate) # Main boy. accum_metrics = Metrics() best_mse = None while not self.tracker.is_finished: # Get data first. metrics = self.train_loop(train_langs) accum_metrics += metrics self.tracker.update() self.check_metrics(accum_metrics) if self.track % self.save_interval == 0: self.save(dev_langs, f'{fold_idx}.latest') dev_metrics = evaluator.evaluate(dev_langs) logging.info(dev_metrics.get_table(title='dev')) if best_mse is None or dev_metrics.mse.mean < best_mse: best_mse = dev_metrics.mse.mean logging.imp(f'Updated best mse score: {best_mse:.3f}') self.save(dev_langs, f'{fold_idx}.best') return Metric('best_mse', best_mse, 1)
def evaluate(self) -> Metrics: with torch.no_grad(): self.model.eval() all_metrics = Metrics() for batch in self.data_loader: scores = self.model.score(batch) metrics = self.analyze_scores(scores) all_metrics += metrics return all_metrics
def evaluate(self, dev_langs: List[str]) -> Metrics: metrics = Metrics() fold_data_loader = self.data_loader.select(dev_langs, self.data_loader.all_langs) with torch.no_grad(): self.model.eval() for batch in fold_data_loader: output = self.model(batch) mse = (output - batch.dist) ** 2 mse = Metric('mse', mse.sum(), len(batch)) metrics += mse return metrics
def analyze_scores(self, scores) -> Metrics: metrics = Metrics() total_loss = 0.0 total_weight = 0.0 for name, (losses, weights) in scores.items(): if should_include(self.feat_groups, name): loss = (losses * weights).sum() weight = weights.sum() total_loss += loss total_weight += weight loss = Metric(f'loss_{name.snake}', loss, weight) metrics += loss metrics += Metric('loss', total_loss, total_weight) return metrics
def train_loop(self, train_langs: List[str]) -> Metrics: fold_data_loader = self.data_loader.select(train_langs, train_langs) metrics = Metrics() for batch_i, batch in enumerate(fold_data_loader): self.model.train() self.optimizer.zero_grad() output = self.model(batch) mse = (output - batch.dist) ** 2 mse = Metric('mse', mse.sum(), len(batch)) metrics += mse mse.mean.backward() self.optimizer.step() return metrics
def train_loop(self) -> Metrics: self.model.train() self.optimizer.zero_grad() batch = next(self.iterator) ret = self.model(batch) bs = batch.feat_matrix.size('batch') breakpoint() # DEBUG(j_luo) modified_log_probs = ret['sample_log_probs'] * self.concentration + (~ret['is_unique']).float() * (-999.9) sample_probs = modified_log_probs.log_softmax(dim='sample').exp() final_ret = ret['lm_score'] + ret['word_score'] * self.score_per_word score = (sample_probs * final_ret).sum() lm_score = Metric('lm_score', ret['lm_score'].sum(), bs) word_score = Metric('word_score', ret['word_score'].sum(), bs) score = Metric('score', score, bs) metrics = Metrics(score, lm_score, word_score) loss = -score.mean loss.backward() self.optimizer.step() return metrics
def check_metrics(self, accum_metrics: Metrics): if self.track % self.check_interval == 0: logging.info(accum_metrics.get_table(f'Step: {self.track}')) accum_metrics.clear()