Esempio n. 1
0
    def _train(self):
        if not self.use_tqdm:
            from fastNLP.core.utils import pseudo_tqdm as inner_tqdm
        else:
            inner_tqdm = tqdm
        self.step = 0
        start = time.time()
        data_iterator = Batch(self.train_data,
                              batch_size=self.batch_size,
                              sampler=self.sampler,
                              as_numpy=False)
        total_steps = data_iterator.num_batches * self.n_epochs
        with inner_tqdm(total=total_steps,
                        postfix='loss:{0:<6.5f}',
                        leave=False,
                        dynamic_ncols=True) as pbar:
            avg_loss = 0
            for epoch in range(1, self.n_epochs + 1):
                pbar.set_description_str(
                    desc="Epoch {}/{}".format(epoch, self.n_epochs))
                # early stopping
                self.callback_manager.before_epoch(epoch, self.n_epochs)
                for batch_x, batch_y in data_iterator:
                    indices = data_iterator.get_batch_indices()
                    # negative sampling; replace unknown; re-weight batch_y
                    self.callback_manager.before_batch(batch_x, batch_y,
                                                       indices)
                    _move_dict_value_to_device(batch_x,
                                               batch_y,
                                               device=self._model_device)
                    prediction = self._data_forward(self.model, batch_x)

                    # edit prediction
                    self.callback_manager.before_loss(batch_y, prediction)
                    loss = self._compute_loss(prediction, batch_y)
                    avg_loss += loss.item()

                    # Is loss NaN or inf? requires_grad = False
                    self.callback_manager.before_backward(loss, self.model)
                    self._grad_backward(loss)
                    # gradient clipping
                    self.callback_manager.after_backward(self.model)

                    self._update()
                    # lr scheduler; lr_finder; one_cycle
                    self.callback_manager.after_step(self.optimizer)

                    self._summary_writer.add_scalar("loss",
                                                    loss.item(),
                                                    global_step=self.step)
                    for name, param in self.model.named_parameters():
                        if param.requires_grad:
                            self._summary_writer.add_scalar(
                                name + "_mean",
                                param.mean(),
                                global_step=self.step)
                            # self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step)
                            # self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step)
                    if (self.step + 1) % self.print_every == 0:
                        if self.use_tqdm:
                            print_output = "loss:{0:<6.5f}".format(
                                avg_loss / self.print_every)
                            pbar.update(self.print_every)
                        else:
                            end = time.time()
                            diff = timedelta(seconds=round(end - start))
                            print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format(
                                epoch, self.step, avg_loss, diff)
                        pbar.set_postfix_str(print_output)
                        avg_loss = 0
                    self.step += 1
                    # do nothing
                    self.callback_manager.after_batch()

                    if ((self.validate_every > 0 and self.step % self.validate_every == 0) or
                        (self.validate_every < 0 and self.step % len(data_iterator)) == 0) \
                            and self.dev_data is not None:
                        eval_res = self._do_validation(epoch=epoch,
                                                       step=self.step)
                        eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step,
                                                                                    total_steps) + \
                                   self.tester._format_eval_results(eval_res)
                        pbar.write(eval_str)

                # if self.validate_every < 0 and self.dev_data:
                #     eval_res = self._do_validation(epoch=epoch, step=self.step)
                #     eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \
                #                self.tester._format_eval_results(eval_res)
                #     pbar.write(eval_str)
                if epoch != self.n_epochs:
                    data_iterator = Batch(self.train_data,
                                          batch_size=self.batch_size,
                                          sampler=self.sampler,
                                          as_numpy=False)
                # lr decay; early stopping
                self.callback_manager.after_epoch(epoch, self.n_epochs,
                                                  self.optimizer)
            pbar.close()
Esempio n. 2
0
    def train_shared(self, pbar=None, max_step=None, dag=None):
        """Train the language model for 400 steps of minibatches of 64
        examples.

        Args:
            max_step: Used to run extra training steps as a warm-up.
            dag: If not None, is used instead of calling sample().

        BPTT is truncated at 35 timesteps.

        For each weight update, gradients are estimated by sampling M models
        from the fixed controller policy, and averaging their gradients
        computed on a batch of training data.
        """
        model = self.shared
        model.train()
        self.controller.eval()

        hidden = self.shared.init_hidden(self.batch_size)

        abs_max_grad = 0
        abs_max_hidden_norm = 0
        step = 0
        raw_total_loss = 0
        total_loss = 0
        train_idx = 0
        avg_loss = 0
        data_iterator = Batch(self.train_data,
                              batch_size=self.batch_size,
                              sampler=self.sampler,
                              as_numpy=False,
                              prefetch=self.prefetch)

        for batch_x, batch_y in data_iterator:
            _move_dict_value_to_device(batch_x,
                                       batch_y,
                                       device=self._model_device)
            indices = data_iterator.get_batch_indices()
            # negative sampling; replace unknown; re-weight batch_y
            self.callback_manager.on_batch_begin(batch_x, batch_y, indices)
            # prediction = self._data_forward(self.model, batch_x)

            dags = self.controller.sample(1)
            inputs, targets = batch_x, batch_y
            # self.callback_manager.on_loss_begin(batch_y, prediction)
            loss, hidden, extra_out = self.get_loss(inputs, targets, hidden,
                                                    dags)
            hidden.detach_()

            avg_loss += loss.item()

            # Is loss NaN or inf? requires_grad = False
            self.callback_manager.on_backward_begin(loss, self.model)
            self._grad_backward(loss)
            self.callback_manager.on_backward_end(self.model)

            self._update()
            self.callback_manager.on_step_end(self.optimizer)

            if (self.step + 1) % self.print_every == 0:
                if self.use_tqdm:
                    print_output = "loss:{0:<6.5f}".format(avg_loss /
                                                           self.print_every)
                    pbar.update(self.print_every)
                else:
                    end = time.time()
                    diff = timedelta(seconds=round(end - start))
                    print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format(
                        epoch, self.step, avg_loss, diff)
                pbar.set_postfix_str(print_output)
                avg_loss = 0
            self.step += 1
            step += 1
            self.shared_step += 1
            self.callback_manager.on_batch_end()
Esempio n. 3
0
    def _train(self):
        if not self.use_tqdm:
            from fastNLP.core.utils import pseudo_tqdm as inner_tqdm
        else:
            inner_tqdm = tqdm
        self.step = 0
        start = time.time()
        total_steps = (len(self.train_data) // self.batch_size + int(
            len(self.train_data) % self.batch_size != 0)) * self.n_epochs
        with inner_tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar:
            avg_loss = 0
            data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False,
                                  prefetch=self.prefetch)
            for epoch in range(1, self.n_epochs+1):
                pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs))
                # early stopping
                self.callback_manager.on_epoch_begin(epoch, self.n_epochs)
                for batch_x, batch_y in data_iterator:
                    _move_dict_value_to_device(batch_x, batch_y, device=self._model_device)
                    indices = data_iterator.get_batch_indices()
                    # negative sampling; replace unknown; re-weight batch_y
                    self.callback_manager.on_batch_begin(batch_x, batch_y, indices)
                    prediction = self._data_forward(self.model, batch_x)

                    # edit prediction
                    self.callback_manager.on_loss_begin(batch_y, prediction)
                    loss = self._compute_loss(prediction, batch_y)
                    avg_loss += loss.item()

                    # Is loss NaN or inf? requires_grad = False
                    self.callback_manager.on_backward_begin(loss, self.model)
                    self._grad_backward(loss)
                    self.callback_manager.on_backward_end(self.model)

                    self._update()
                    self.callback_manager.on_step_end(self.optimizer)

                    if (self.step+1) % self.print_every == 0:
                        if self.use_tqdm:
                            print_output = "loss:{0:<6.5f}".format(avg_loss / self.print_every)
                            pbar.update(self.print_every)
                        else:
                            end = time.time()
                            diff = timedelta(seconds=round(end - start))
                            print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format(
                                epoch, self.step, avg_loss, diff)
                        pbar.set_postfix_str(print_output)
                        avg_loss = 0
                    self.step += 1
                    self.callback_manager.on_batch_end()

                    if ((self.validate_every > 0 and self.step % self.validate_every == 0) or
                        (self.validate_every < 0 and self.step % len(data_iterator) == 0)) \
                            and self.dev_data is not None:
                        eval_res = self._do_validation(epoch=epoch, step=self.step)
                        eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step,
                                                                                    total_steps) + \
                                   self.tester._format_eval_results(eval_res)
                        pbar.write(eval_str)

                # ================= mini-batch end ==================== #

                # lr decay; early stopping
                self.callback_manager.on_epoch_end(epoch, self.n_epochs, self.optimizer)
            # =============== epochs end =================== #
            pbar.close()