Esempio n. 1
0
    def _train(self):
        if not self.use_tqdm:
            from fastNLP.core.utils import _pseudo_tqdm as inner_tqdm
        else:
            inner_tqdm = tqdm
        self.step = 0
        start = time.time()
        total_steps = (len(self.train_data) // self.batch_size + int(
            len(self.train_data) % self.batch_size != 0)) * self.n_epochs
        with inner_tqdm(total=total_steps,
                        postfix='loss:{0:<6.5f}',
                        leave=False,
                        dynamic_ncols=True) as pbar:
            avg_loss = 0
            data_iterator = Batch(self.train_data,
                                  batch_size=self.batch_size,
                                  sampler=self.sampler,
                                  as_numpy=False,
                                  prefetch=self.prefetch)
            for epoch in range(1, self.n_epochs + 1):
                pbar.set_description_str(
                    desc="Epoch {}/{}".format(epoch, self.n_epochs))
                last_stage = (epoch > self.n_epochs + 1 - self.final_epochs)
                if epoch == self.n_epochs + 1 - self.final_epochs:
                    print(
                        'Entering the final stage. (Only train the selected structure)'
                    )
                # early stopping
                self.callback_manager.on_epoch_begin(epoch, self.n_epochs)

                # 1. Training the shared parameters omega of the child models
                self.train_shared(pbar)

                # 2. Training the controller parameters theta
                if not last_stage:
                    self.train_controller()

                if ((self.validate_every > 0 and self.step % self.validate_every == 0) or
                    (self.validate_every < 0 and self.step % len(data_iterator) == 0)) \
                        and self.dev_data is not None:
                    if not last_stage:
                        self.derive()
                    eval_res = self._do_validation(epoch=epoch, step=self.step)
                    eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step,
                                                                                total_steps) + \
                                self.tester._format_eval_results(eval_res)
                    pbar.write(eval_str)

                # lr decay; early stopping
                self.callback_manager.on_epoch_end(epoch, self.n_epochs,
                                                   self.optimizer)
            # =============== epochs end =================== #
            pbar.close()
Esempio n. 2
0
 def _train(self):
     if not self.use_tqdm:
         from fastNLP.core.utils import _pseudo_tqdm as inner_tqdm
     else:
         inner_tqdm = tqdm
     self.step = 0
     self.epoch = 0
     start = time.time()
     if isinstance(self.model, nn.DataParallel):
         self._forward_func = self.model.module.forward
     else:
         self._forward_func = self.model.forward
     with inner_tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar:
         self.pbar = pbar
         avg_loss = 0
         data_iterator = self.data_iterator
         self.batch_per_epoch = data_iterator.num_batches
         for epoch in range(1, self.n_epochs + 1):
             self.epoch = epoch
             pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs))
             # early stopping
             self.callback_manager.on_epoch_begin()
             for batch_x, batch_y in data_iterator:
                 self.step += 1
                 _move_dict_value_to_device(batch_x, batch_y, device=self._model_device)
                 indices = data_iterator.get_batch_indices()
                 # negative sampling; replace unknown; re-weight batch_y
                 self.callback_manager.on_batch_begin(batch_x, batch_y, indices)
                 prediction = self._data_forward(self.model, batch_x)
                 
                 # edit prediction
                 self.callback_manager.on_loss_begin(batch_y, prediction)
                 loss = self._compute_loss(prediction, batch_y).mean()
                 avg_loss += loss.item()
                 loss = loss / self.update_every
                 
                 # Is loss NaN or inf? requires_grad = False
                 self.callback_manager.on_backward_begin(loss)
                 self._grad_backward(loss)
                 self.callback_manager.on_backward_end()
                 
                 self._update()
                 self.callback_manager.on_step_end()
                 
                 if self.step % self.print_every == 0:
                     avg_loss = float(avg_loss) / self.print_every
                     if self.use_tqdm:
                         print_output = "loss:{:<6.5f}".format(avg_loss)
                         pbar.update(self.print_every)
                     else:
                         end = time.time()
                         diff = timedelta(seconds=round(end - start))
                         print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format(
                             epoch, self.step, avg_loss, diff)
                     pbar.set_postfix_str(print_output)
                     avg_loss = 0
                 self.callback_manager.on_batch_end()
                 
                 if ((self.validate_every > 0 and self.step % self.validate_every == 0) or
                     (self.validate_every < 0 and self.step % len(data_iterator) == 0)) \
                         and self.dev_data is not None:
                     eval_res = self._do_validation(epoch=epoch, step=self.step)
                     eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step,
                                                                                 self.n_steps) + \
                                self.tester._format_eval_results(eval_res)
                     pbar.write(eval_str + '\n')
             
             # ================= mini-batch end ==================== #
             
             # lr decay; early stopping
             self.callback_manager.on_epoch_end()
         # =============== epochs end =================== #
         pbar.close()
         self.pbar = None
Esempio n. 3
0
    def _train(self):
        if not self.use_tqdm:
            from fastNLP.core.utils import pseudo_tqdm as inner_tqdm
        else:
            inner_tqdm = tqdm
        self.step = 0
        start = time.time()
        data_iterator = Batch(self.train_data,
                              batch_size=self.batch_size,
                              sampler=self.sampler,
                              as_numpy=False)
        total_steps = data_iterator.num_batches * self.n_epochs
        with inner_tqdm(total=total_steps,
                        postfix='loss:{0:<6.5f}',
                        leave=False,
                        dynamic_ncols=True) as pbar:
            avg_loss = 0
            for epoch in range(1, self.n_epochs + 1):
                pbar.set_description_str(
                    desc="Epoch {}/{}".format(epoch, self.n_epochs))
                # early stopping
                self.callback_manager.before_epoch(epoch, self.n_epochs)
                for batch_x, batch_y in data_iterator:
                    indices = data_iterator.get_batch_indices()
                    # negative sampling; replace unknown; re-weight batch_y
                    self.callback_manager.before_batch(batch_x, batch_y,
                                                       indices)
                    _move_dict_value_to_device(batch_x,
                                               batch_y,
                                               device=self._model_device)
                    prediction = self._data_forward(self.model, batch_x)

                    # edit prediction
                    self.callback_manager.before_loss(batch_y, prediction)
                    loss = self._compute_loss(prediction, batch_y)
                    avg_loss += loss.item()

                    # Is loss NaN or inf? requires_grad = False
                    self.callback_manager.before_backward(loss, self.model)
                    self._grad_backward(loss)
                    # gradient clipping
                    self.callback_manager.after_backward(self.model)

                    self._update()
                    # lr scheduler; lr_finder; one_cycle
                    self.callback_manager.after_step(self.optimizer)

                    self._summary_writer.add_scalar("loss",
                                                    loss.item(),
                                                    global_step=self.step)
                    for name, param in self.model.named_parameters():
                        if param.requires_grad:
                            self._summary_writer.add_scalar(
                                name + "_mean",
                                param.mean(),
                                global_step=self.step)
                            # self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step)
                            # self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step)
                    if (self.step + 1) % self.print_every == 0:
                        if self.use_tqdm:
                            print_output = "loss:{0:<6.5f}".format(
                                avg_loss / self.print_every)
                            pbar.update(self.print_every)
                        else:
                            end = time.time()
                            diff = timedelta(seconds=round(end - start))
                            print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format(
                                epoch, self.step, avg_loss, diff)
                        pbar.set_postfix_str(print_output)
                        avg_loss = 0
                    self.step += 1
                    # do nothing
                    self.callback_manager.after_batch()

                    if ((self.validate_every > 0 and self.step % self.validate_every == 0) or
                        (self.validate_every < 0 and self.step % len(data_iterator)) == 0) \
                            and self.dev_data is not None:
                        eval_res = self._do_validation(epoch=epoch,
                                                       step=self.step)
                        eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step,
                                                                                    total_steps) + \
                                   self.tester._format_eval_results(eval_res)
                        pbar.write(eval_str)

                # if self.validate_every < 0 and self.dev_data:
                #     eval_res = self._do_validation(epoch=epoch, step=self.step)
                #     eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \
                #                self.tester._format_eval_results(eval_res)
                #     pbar.write(eval_str)
                if epoch != self.n_epochs:
                    data_iterator = Batch(self.train_data,
                                          batch_size=self.batch_size,
                                          sampler=self.sampler,
                                          as_numpy=False)
                # lr decay; early stopping
                self.callback_manager.after_epoch(epoch, self.n_epochs,
                                                  self.optimizer)
            pbar.close()