Ejemplo n.º 1
0
    def train(self):
        set_random_seeds(self.random_seed)
        all_langs = self.data_loader.all_langs
        num_langs = len(all_langs)
        idx = list(range(num_langs))
        random.shuffle(idx)

        num_langs_per_fold = (num_langs + self.k_fold - 1) // self.k_fold

        accum_metrics = Metrics()
        for fold in range(self.k_fold):
            # Get train-dev split.
            start_idx = fold * num_langs_per_fold
            end_idx = start_idx + num_langs_per_fold if fold < self.k_fold - 1 else num_langs
            dev_langs = [all_langs[idx[i]] for i in range(start_idx, end_idx)]
            logging.imp(f'dev_langs: {sorted(dev_langs)}')
            train_langs = [all_langs[idx[i]] for i in range(num_langs) if i < start_idx or i >= end_idx]
            assert len(set(dev_langs + train_langs)) == num_langs

            self.trainer.reset()
            best_mse = self.trainer.train(
                self.evaluator,
                train_langs,
                dev_langs,
                fold)

            # Aggregate every fold.
            accum_metrics += best_mse

        logging.info(accum_metrics.get_table())
Ejemplo n.º 2
0
    def train(self,
              evaluator: evaluator.Evaluator,
              train_langs: List[str],
              dev_langs: List[str],
              fold_idx: int) -> Metrics:
        # Reset parameters.
        self._init_params(init_matrix=True, init_vector=True, init_higher_tensor=True)
        self.optimizer = optim.Adam(self.model.parameters(), self.learning_rate)
        # Main boy.
        accum_metrics = Metrics()
        best_mse = None
        while not self.tracker.is_finished:
            # Get data first.
            metrics = self.train_loop(train_langs)
            accum_metrics += metrics
            self.tracker.update()

            self.check_metrics(accum_metrics)

            if self.track % self.save_interval == 0:
                self.save(dev_langs, f'{fold_idx}.latest')
                dev_metrics = evaluator.evaluate(dev_langs)
                logging.info(dev_metrics.get_table(title='dev'))
                if best_mse is None or dev_metrics.mse.mean < best_mse:
                    best_mse = dev_metrics.mse.mean
                    logging.imp(f'Updated best mse score: {best_mse:.3f}')
                    self.save(dev_langs, f'{fold_idx}.best')
        return Metric('best_mse', best_mse, 1)
Ejemplo n.º 3
0
 def evaluate(self) -> Metrics:
     with torch.no_grad():
         self.model.eval()
         all_metrics = Metrics()
         for batch in self.data_loader:
             scores = self.model.score(batch)
             metrics = self.analyze_scores(scores)
             all_metrics += metrics
     return all_metrics
Ejemplo n.º 4
0
 def evaluate(self, dev_langs: List[str]) -> Metrics:
     metrics = Metrics()
     fold_data_loader = self.data_loader.select(dev_langs, self.data_loader.all_langs)
     with torch.no_grad():
         self.model.eval()
         for batch in fold_data_loader:
             output = self.model(batch)
             mse = (output - batch.dist) ** 2
             mse = Metric('mse', mse.sum(), len(batch))
             metrics += mse
     return metrics
Ejemplo n.º 5
0
 def analyze_scores(self, scores) -> Metrics:
     metrics = Metrics()
     total_loss = 0.0
     total_weight = 0.0
     for name, (losses, weights) in scores.items():
         if should_include(self.feat_groups, name):
             loss = (losses * weights).sum()
             weight = weights.sum()
             total_loss += loss
             total_weight += weight
             loss = Metric(f'loss_{name.snake}', loss, weight)
             metrics += loss
     metrics += Metric('loss', total_loss, total_weight)
     return metrics
Ejemplo n.º 6
0
    def train_loop(self, train_langs: List[str]) -> Metrics:
        fold_data_loader = self.data_loader.select(train_langs, train_langs)
        metrics = Metrics()
        for batch_i, batch in enumerate(fold_data_loader):
            self.model.train()
            self.optimizer.zero_grad()

            output = self.model(batch)
            mse = (output - batch.dist) ** 2
            mse = Metric('mse', mse.sum(), len(batch))
            metrics += mse

            mse.mean.backward()
            self.optimizer.step()
        return metrics
Ejemplo n.º 7
0
 def train_loop(self) -> Metrics:
     self.model.train()
     self.optimizer.zero_grad()
     batch = next(self.iterator)
     ret = self.model(batch)
     bs = batch.feat_matrix.size('batch')
     breakpoint()  # DEBUG(j_luo)
     modified_log_probs = ret['sample_log_probs'] * self.concentration + (~ret['is_unique']).float() * (-999.9)
     sample_probs = modified_log_probs.log_softmax(dim='sample').exp()
     final_ret = ret['lm_score'] + ret['word_score'] * self.score_per_word
     score = (sample_probs * final_ret).sum()
     lm_score = Metric('lm_score', ret['lm_score'].sum(), bs)
     word_score = Metric('word_score', ret['word_score'].sum(), bs)
     score = Metric('score', score, bs)
     metrics = Metrics(score, lm_score, word_score)
     loss = -score.mean
     loss.backward()
     self.optimizer.step()
     return metrics
Ejemplo n.º 8
0
 def check_metrics(self, accum_metrics: Metrics):
     if self.track % self.check_interval == 0:
         logging.info(accum_metrics.get_table(f'Step: {self.track}'))
         accum_metrics.clear()