コード例 #1
0
    def __init__(
        self,
        model: Model,
        meta_optimizer: torch.optim.Optimizer,
        optimizer_cls: str,
        optimizer_kwargs: Dict[str, Any],
        grad_norm: Optional[float] = None,
        grad_clipping: Optional[float] = None,
        update_hook: Callable = None,
        inherit: bool = False,
        loss_ratios_per_step: List[Dict[str, int]] = None,
    ):
        super(BaseWrapper, self).__init__()
        self.model = model
        self.meta_optimizer = meta_optimizer
        self._grad_clipping = grad_clipping
        self._grad_norm = grad_norm
        self._container = deepcopy(self.model)
        training_util.enable_gradient_clipping(self.model, self._grad_clipping)
        self.optimizer_cls = getattr(torch.optim, optimizer_cls)
        self.optimizer_kwargs = optimizer_kwargs
        self._update_hook = update_hook

        def forward_kwargs(step):
            ratios = {"dep": 1.0, "pos": 0.0}
            if loss_ratios_per_step is not None:
                ratios = loss_ratios_per_step[step]
            return {'return_metric': True, 'loss_ratios': ratios}

        self.forward_kwargs = forward_kwargs
        self._inherit = inherit
コード例 #2
0
    def train(self) -> Dict[str, Any]:
        """
        Trains the supplied model with the supplied parameters.
        """
        try:
            epoch_counter = self._restore_checkpoint()
        except RuntimeError:
            traceback.print_exc()
            raise ConfigurationError("Could not recover training from the checkpoint.  Did you mean to output to "
                                     "a different serialization directory or delete the existing serialization "
                                     "directory?")

        training_util.enable_gradient_clipping(self.model, self._grad_clipping)

        logger.info("Beginning training.")

        train_metrics: Dict[str, float] = {}
        val_metrics: Dict[str, float] = {}
        this_epoch_val_metric: float = None
        metrics: Dict[str, Any] = {}
        epochs_trained = 0
        training_start_time = time.time()

        metrics['best_epoch'] = self._metric_tracker.best_epoch
        for key, value in self._metric_tracker.best_epoch_metrics.items():
            metrics["best_validation_" + key] = value

        if self.callbacks is not None:
            with torch.no_grad():
                for callback in self.callbacks:
                    callback.on_train_begin()

        for epoch in range(epoch_counter, self._num_epochs):
            epoch_start_time = time.time()

            if self.callbacks is not None:
                with torch.no_grad():
                    for callback in self.callbacks:
                        callback.on_epoch_begin(epoch)

            train_metrics = self._train_epoch(epoch)
            if not self._early_stopping_by_batch:
                # get peak of memory usage
                if 'cpu_memory_MB' in train_metrics:
                    metrics['peak_cpu_memory_MB'] = max(metrics.get('peak_cpu_memory_MB', 0),
                                                        train_metrics['cpu_memory_MB'])
                for key, value in train_metrics.items():
                    if key.startswith('gpu_'):
                        metrics["peak_"+key] = max(metrics.get("peak_"+key, 0), value)

                if self._validation_data is not None:
                    with torch.no_grad():
                        val_metrics_temp = self._estimator.estimate(self._validation_data)
                        # We have a validation set, so compute all the metrics on it.
                        # val_loss, num_batches = self._validation_loss()
                        # val_metrics = training_util.get_metrics(self.model, val_loss, num_batches, reset=True)
                        val_metrics = {'loss': 0}
                        if 'sentiment_acc' in val_metrics_temp:
                            val_metrics['accuracy'] = val_metrics_temp['sentiment_acc']
                        if 'category_f1' in val_metrics_temp:
                            val_metrics['category_f1'] = val_metrics_temp['category_f1']['fscore']
                        if 'other_metrics' in val_metrics_temp and 'merge_micro_f1' in val_metrics_temp['other_metrics']:
                            val_metrics['merge_micro_f1'] = val_metrics_temp['other_metrics']['merge_micro_f1']
                        # Check validation metric for early stopping
                        val_metrics.update(val_metrics_temp)
                        this_epoch_val_metric = val_metrics[self._validation_metric]
                        self._metric_tracker.add_metric(this_epoch_val_metric)

                        if self._metric_tracker.should_stop_early():
                            logger.info("Ran out of patience.  Stopping training.")
                            break

                self._tensorboard.log_metrics(train_metrics,
                                              val_metrics=val_metrics,
                                              log_to_console=True,
                                              epoch=epoch + 1)  # +1 because tensorboard doesn't like 0

                # Create overall metrics dict
                training_elapsed_time = time.time() - training_start_time
                metrics["training_duration"] = str(datetime.timedelta(seconds=training_elapsed_time))
                metrics["training_start_epoch"] = epoch_counter
                metrics["training_epochs"] = epochs_trained
                metrics["epoch"] = epoch

                for key, value in train_metrics.items():
                    metrics["training_" + key] = value
                for key, value in val_metrics.items():
                    metrics["validation_" + key] = value

                if self._metric_tracker.is_best_so_far():
                    # Update all the best_ metrics.
                    # (Otherwise they just stay the same as they were.)
                    metrics['best_epoch'] = epoch
                    for key, value in val_metrics.items():
                        metrics["best_validation_" + key] = value

                    self._metric_tracker.best_epoch_metrics = val_metrics

                if self._serialization_dir:
                    dump_metrics(os.path.join(self._serialization_dir, f'metrics_epoch_{epoch}.json'), metrics)

                # The Scheduler API is agnostic to whether your schedule requires a validation metric -
                # if it doesn't, the validation metric passed here is ignored.
                if self._learning_rate_scheduler:
                    self._learning_rate_scheduler.step(this_epoch_val_metric, epoch)
                if self._momentum_scheduler:
                    self._momentum_scheduler.step(this_epoch_val_metric, epoch)

                self._save_checkpoint(epoch)
            else:
                if self._metric_tracker.should_stop_early():
                    logger.info("Ran out of patience.  Stopping training.")
                    break

            epoch_elapsed_time = time.time() - epoch_start_time
            logger.info("Epoch duration: %s", datetime.timedelta(seconds=epoch_elapsed_time))

            if epoch < self._num_epochs - 1:
                training_elapsed_time = time.time() - training_start_time
                estimated_time_remaining = training_elapsed_time * \
                    ((self._num_epochs - epoch_counter) / float(epoch - epoch_counter + 1) - 1)
                formatted_time = str(datetime.timedelta(seconds=int(estimated_time_remaining)))
                logger.info("Estimated training time remaining: %s", formatted_time)

            if self.callbacks is not None:
                with torch.no_grad():
                    for callback in self.callbacks:
                        callback.on_epoch_end(epoch)
            epochs_trained += 1

        # make sure pending events are flushed to disk and files are closed properly
        # self._tensorboard.close()

        # Load the best model state before returning
        best_model_state = self._checkpointer.best_model_state()
        if best_model_state:
            self.model.load_state_dict(best_model_state)

        return metrics
コード例 #3
0
 def enable_gradient_clipping(self, trainer: 'CallbackTrainer'):
     training_util.enable_gradient_clipping(trainer.model, self.grad_clipping)
コード例 #4
0
ファイル: trainer.py プロジェクト: sun-xiaoyu/allennlp
    def train(self) -> Dict[str, Any]:
        """
        Trains the supplied model with the supplied parameters.
        """
        try:
            epoch_counter = self._restore_checkpoint()
        except RuntimeError:
            traceback.print_exc()
            raise ConfigurationError(
                "Could not recover training from the checkpoint.  Did you mean to output to "
                "a different serialization directory or delete the existing serialization "
                "directory?")

        training_util.enable_gradient_clipping(self.model, self._grad_clipping)

        logger.info("Beginning training.")

        val_metrics: Dict[str, float] = {}
        this_epoch_val_metric: float = None
        metrics: Dict[str, Any] = {}
        epochs_trained = 0
        training_start_time = time.time()

        metrics["best_epoch"] = self._metric_tracker.best_epoch
        for key, value in self._metric_tracker.best_epoch_metrics.items():
            metrics["best_validation_" + key] = value

        for callback in self._epoch_callbacks:
            callback(self, metrics={}, epoch=-1)

        for epoch in range(epoch_counter, self._num_epochs):
            epoch_start_time = time.time()
            train_metrics = self._train_epoch(epoch)

            # get peak of memory usage
            if "cpu_memory_MB" in train_metrics:
                metrics["peak_cpu_memory_MB"] = max(
                    metrics.get("peak_cpu_memory_MB", 0),
                    train_metrics["cpu_memory_MB"])
            for key, value in train_metrics.items():
                if key.startswith("gpu_"):
                    metrics["peak_" + key] = max(metrics.get("peak_" + key, 0),
                                                 value)

            if self._validation_data_loader is not None:
                with torch.no_grad():
                    # We have a validation set, so compute all the metrics on it.
                    val_loss, val_reg_loss, num_batches = self._validation_loss(
                        epoch)

                    # It is safe again to wait till the validation is done. This is
                    # important to get the metrics right.
                    if self._distributed:
                        dist.barrier()

                    val_metrics = training_util.get_metrics(
                        self.model,
                        val_loss,
                        val_reg_loss,
                        num_batches,
                        reset=True,
                        world_size=self._world_size,
                        cuda_device=[self.cuda_device],
                    )

                    # Check validation metric for early stopping
                    this_epoch_val_metric = val_metrics[
                        self._validation_metric]
                    self._metric_tracker.add_metric(this_epoch_val_metric)

                    if self._metric_tracker.should_stop_early():
                        logger.info("Ran out of patience.  Stopping training.")
                        break

            if self._master:
                self._tensorboard.log_metrics(
                    train_metrics,
                    val_metrics=val_metrics,
                    log_to_console=True,
                    epoch=epoch + 1)  # +1 because tensorboard doesn't like 0

            # Create overall metrics dict
            training_elapsed_time = time.time() - training_start_time
            metrics["training_duration"] = str(
                datetime.timedelta(seconds=training_elapsed_time))
            metrics["training_start_epoch"] = epoch_counter
            metrics["training_epochs"] = epochs_trained
            metrics["epoch"] = epoch

            for key, value in train_metrics.items():
                metrics["training_" + key] = value
            for key, value in val_metrics.items():
                metrics["validation_" + key] = value

            if self._metric_tracker.is_best_so_far():
                # Update all the best_ metrics.
                # (Otherwise they just stay the same as they were.)
                metrics["best_epoch"] = epoch
                for key, value in val_metrics.items():
                    metrics["best_validation_" + key] = value

                self._metric_tracker.best_epoch_metrics = val_metrics

            if self._serialization_dir and self._master:
                common_util.dump_metrics(
                    os.path.join(self._serialization_dir,
                                 f"metrics_epoch_{epoch}.json"), metrics)

            # The Scheduler API is agnostic to whether your schedule requires a validation metric -
            # if it doesn't, the validation metric passed here is ignored.
            if self._learning_rate_scheduler:
                self._learning_rate_scheduler.step(this_epoch_val_metric)
            if self._momentum_scheduler:
                self._momentum_scheduler.step(this_epoch_val_metric)

            if self._master:
                self._checkpointer.save_checkpoint(
                    epoch,
                    self,
                    is_best_so_far=self._metric_tracker.is_best_so_far())

            # Wait for the master to finish saving the checkpoint
            if self._distributed:
                dist.barrier()

            for callback in self._epoch_callbacks:
                callback(self, metrics=metrics, epoch=epoch)

            epoch_elapsed_time = time.time() - epoch_start_time
            logger.info("Epoch duration: %s",
                        datetime.timedelta(seconds=epoch_elapsed_time))

            if epoch < self._num_epochs - 1:
                training_elapsed_time = time.time() - training_start_time
                estimated_time_remaining = training_elapsed_time * (
                    (self._num_epochs - epoch_counter) /
                    float(epoch - epoch_counter + 1) - 1)
                formatted_time = str(
                    datetime.timedelta(seconds=int(estimated_time_remaining)))
                logger.info("Estimated training time remaining: %s",
                            formatted_time)

            epochs_trained += 1

        # make sure pending events are flushed to disk and files are closed properly
        self._tensorboard.close()

        # Load the best model state before returning
        best_model_state = self._checkpointer.best_model_state()
        if best_model_state:
            self.model.load_state_dict(best_model_state)

        return metrics
コード例 #5
0
    def _try_train(self) -> Tuple[Dict[str, Any], int]:
        try:
            epoch_counter = self._restore_checkpoint()
        except RuntimeError:
            traceback.print_exc()
            raise ConfigurationError(
                "Could not recover training from the checkpoint.  Did you mean to output to "
                "a different serialization directory or delete the existing serialization "
                "directory?"
            )

        training_util.enable_gradient_clipping(self.model, self._grad_clipping)

        logger.info("Beginning training.")

        val_metrics: Dict[str, float] = {}
        metrics: Dict[str, Any] = {}
        epochs_trained = 0
        training_start_time = time.time()

        metrics["best_epoch"] = self._metric_tracker.best_epoch
        for key, value in self._metric_tracker.best_epoch_metrics.items():
            metrics["best_validation_" + key] = value

        for epoch in range(epoch_counter, self._num_epochs):
            epoch_start_time = time.time()
            train_metrics = self._train_epoch(epoch)

            # Back up the model now, in case something goes wrong later with the evaluation
            if self._primary and self._checkpointer is not None:
                self._checkpointer.shelve_model(epoch, self)
            # Wait for the primary process to finish saving the model checkpoint
            if self._distributed:
                dist.barrier()

            # get peak of memory usage
            for key, value in train_metrics.items():
                if key.startswith("gpu_") and key.endswith("_memory_MB"):
                    metrics["peak_" + key] = max(metrics.get("peak_" + key, 0), value)
                elif key.startswith("worker_") and key.endswith("_memory_MB"):
                    metrics["peak_" + key] = max(metrics.get("peak_" + key, 0), value)

            this_epoch_val_metric: float = 0.0
            if self._validation_data_loader is not None:
                with torch.no_grad():
                    # We have a validation set, so compute all the metrics on it.
                    val_loss, val_reg_loss, num_batches = self._validation_loss(epoch)

                    # It is safe again to wait till the validation is done. This is
                    # important to get the metrics right.
                    if self._distributed:
                        dist.barrier()

                    val_metrics = training_util.get_metrics(
                        self.model,
                        val_loss,
                        val_reg_loss,
                        batch_loss=None,
                        batch_reg_loss=None,
                        num_batches=num_batches,
                        reset=True,
                        world_size=self._world_size,
                        cuda_device=self.cuda_device,
                    )

                    # Check validation metric for early stopping
                    this_epoch_val_metric = self._metric_tracker.combined_score(val_metrics)
                    self._metric_tracker.add_metrics(val_metrics)

            # Create overall metrics dict
            training_elapsed_time = time.time() - training_start_time
            metrics["training_duration"] = str(datetime.timedelta(seconds=training_elapsed_time))
            metrics["training_start_epoch"] = epoch_counter
            metrics["training_epochs"] = epochs_trained
            metrics["epoch"] = epoch

            for key, value in train_metrics.items():
                metrics["training_" + key] = value
            for key, value in val_metrics.items():
                metrics["validation_" + key] = value

            if self._metric_tracker.is_best_so_far():
                # Update all the best_ metrics.
                # (Otherwise they just stay the same as they were.)
                metrics["best_epoch"] = epoch
                for key, value in val_metrics.items():
                    metrics["best_validation_" + key] = value

                self._metric_tracker.best_epoch_metrics = val_metrics

            if self._serialization_dir and self._primary:
                common_util.dump_metrics(
                    os.path.join(self._serialization_dir, f"metrics_epoch_{epoch}.json"),
                    metrics,
                )

            # The Scheduler API is agnostic to whether your schedule requires a validation metric -
            # if it doesn't, the validation metric passed here is ignored.
            if self._learning_rate_scheduler:
                self._learning_rate_scheduler.step(this_epoch_val_metric)
            if self._momentum_scheduler:
                self._momentum_scheduler.step(this_epoch_val_metric)

            # The checkpointer saves state from the learning rate scheduler and the momentum
            # scheduler, so we have to make sure those are updated before we save the checkpoint here.
            if self._primary and self._checkpointer is not None:
                self._checkpointer.save_checkpoint(
                    epoch, self, is_best_so_far=self._metric_tracker.is_best_so_far()
                )
            # Wait for the primary process to finish saving the checkpoint
            if self._distributed:
                dist.barrier()

            for callback in self._callbacks:
                callback.on_epoch(self, metrics=metrics, epoch=epoch, is_primary=self._primary)

            epoch_elapsed_time = time.time() - epoch_start_time
            logger.info("Epoch duration: %s", datetime.timedelta(seconds=epoch_elapsed_time))

            if epoch < self._num_epochs - 1:
                training_elapsed_time = time.time() - training_start_time
                estimated_time_remaining = training_elapsed_time * (
                        (self._num_epochs - epoch_counter) / float(epoch - epoch_counter + 1) - 1
                )
                formatted_time = str(datetime.timedelta(seconds=int(estimated_time_remaining)))
                logger.info("Estimated training time remaining: %s", formatted_time)

            epochs_trained += 1

            if self._metric_tracker.should_stop_early():
                logger.info("Ran out of patience. Stopping training.")
                break
        else:
            epoch = self._num_epochs - 1

        # Load the best model state before returning
        best_model_state = (
            None if self._checkpointer is None else self._checkpointer.best_model_state()
        )
        if best_model_state:
            self.model.load_state_dict(best_model_state)

        return metrics, epoch
コード例 #6
0
ファイル: trainer.py プロジェクト: zhifaceshi/allennlp_visdom
    def train(self) -> Dict[str, Any]:
        """
        Trains the supplied model with the supplied parameters.
        """
        try:
            epoch_counter = self._restore_checkpoint()
        except RuntimeError:
            traceback.print_exc()
            raise ConfigurationError(
                "Could not recover training from the checkpoint.  Did you mean to output to "
                "a different serialization directory or delete the existing serialization "
                "directory?")

        training_util.enable_gradient_clipping(self.model, self._grad_clipping)

        logger.info("Beginning training.")

        train_metrics: Dict[str, float] = {}
        val_metrics: Dict[str, float] = {}
        this_epoch_val_metric: float = None
        metrics: Dict[str, Any] = {}
        epochs_trained = 0
        training_start_time = time.time()

        metrics['best_epoch'] = self._metric_tracker.best_epoch
        for key, value in self._metric_tracker.best_epoch_metrics.items():
            metrics["best_validation_" + key] = value

        ####################################################################################################
        if self.visdom:

            def create_plot_window(vis, xlabel, ylabel, title):
                return vis.line(X=np.array([1]),
                                Y=np.array([np.nan]),
                                opts=dict(xlabel=xlabel,
                                          ylabel=ylabel,
                                          title=title))

            self.train_loss_window = create_plot_window(
                self.visdom, '#Iterations', 'Loss', 'Training Loss')
            self.consume_time_window = create_plot_window(
                self.visdom, "#Epochs", "Seconds", "Consuming time")
            self.left_time_window = self.visdom.text(
                "Waiting for training.......")
            metric_window = {}
        ##########################################################################################

        for epoch in range(epoch_counter, self._num_epochs):
            epoch_start_time = time.time()
            train_metrics = self._train_epoch(epoch)

            # get peak of memory usage
            if 'cpu_memory_MB' in train_metrics:
                metrics['peak_cpu_memory_MB'] = max(
                    metrics.get('peak_cpu_memory_MB', 0),
                    train_metrics['cpu_memory_MB'])
            for key, value in train_metrics.items():
                if key.startswith('gpu_'):
                    metrics["peak_" + key] = max(metrics.get("peak_" + key, 0),
                                                 value)

            if self._validation_data is not None:
                with torch.no_grad():
                    # We have a validation set, so compute all the metrics on it.
                    val_loss, num_batches = self._validation_loss()
                    val_metrics = training_util.get_metrics(self.model,
                                                            val_loss,
                                                            num_batches,
                                                            reset=True)

                    # Check validation metric for early stopping
                    this_epoch_val_metric = val_metrics[
                        self._validation_metric]
                    self._metric_tracker.add_metric(this_epoch_val_metric)

                    if self._metric_tracker.should_stop_early():
                        logger.info("Ran out of patience.  Stopping training.")
                        break

            self._tensorboard.log_metrics(
                train_metrics,
                val_metrics=val_metrics,
                log_to_console=True,
            )  # +1 because tensorboard doesn't like 0

            # Create overall metrics dict
            training_elapsed_time = time.time() - training_start_time
            metrics["training_duration"] = time.strftime(
                "%H:%M:%S", time.gmtime(training_elapsed_time))
            metrics["training_start_epoch"] = epoch_counter
            metrics["training_epochs"] = epochs_trained
            metrics["epoch"] = epoch

            # print(train_metrics.keys())
            # print(val_metrics.keys())

            ###############################################################################################
            if self.visdom:
                for key in train_metrics.keys():

                    newkey = 'training_' + key
                    if newkey in metric_window:
                        continue
                    else:
                        metric_window[newkey] = create_plot_window(
                            self.visdom, '#Epochs', key, newkey)

                for key in val_metrics.keys():

                    newkey = 'validation_' + key
                    if newkey in metric_window:
                        continue
                    else:
                        metric_window[newkey] = create_plot_window(
                            self.visdom, '#Epochs', key, newkey)
            #################################################################################################

            for key, value in train_metrics.items():
                metrics["training_" + key] = value

                ##########################################################
                if self.visdom:
                    self.visdom.line(X=np.array([epoch]),
                                     Y=np.array([value]),
                                     win=metric_window["training_" + key],
                                     update='append')
                #########################################################

            for key, value in val_metrics.items():
                metrics["validation_" + key] = value

                ##########################################################
                if self.visdom:
                    self.visdom.line(X=np.array([epoch]),
                                     Y=np.array([value]),
                                     win=metric_window["validation_" + key],
                                     update='append')
                ############################################################

            if self._metric_tracker.is_best_so_far():
                # Update all the best_ metrics.
                # (Otherwise they just stay the same as they were.)
                metrics['best_epoch'] = epoch
                for key, value in val_metrics.items():
                    metrics["best_validation_" + key] = value

                self._metric_tracker.best_epoch_metrics = val_metrics

            if self._serialization_dir:
                dump_metrics(
                    os.path.join(self._serialization_dir,
                                 f'metrics_epoch_{epoch}.json'), metrics)

            # The Scheduler API is agnostic to whether your schedule requires a validation metric -
            # if it doesn't, the validation metric passed here is ignored.
            if self._learning_rate_scheduler:
                self._learning_rate_scheduler.step(this_epoch_val_metric,
                                                   epoch)
            if self._momentum_scheduler:
                self._momentum_scheduler.step(this_epoch_val_metric, epoch)

            self._save_checkpoint(epoch)

            epoch_elapsed_time = time.time() - epoch_start_time
            logger.info(
                "Epoch duration: %s",
                time.strftime("%H:%M:%S", time.gmtime(epoch_elapsed_time)))
            #######################################################################################
            if self.visdom:
                self.visdom.line(X=np.array([epoch]),
                                 Y=np.array([epoch_elapsed_time / 60]),
                                 win=self.consume_time_window,
                                 update='append')
            ############################################################################################
            if epoch < self._num_epochs - 1:
                training_elapsed_time = time.time() - training_start_time
                estimated_time_remaining = training_elapsed_time * \
                    ((self._num_epochs - epoch_counter) / float(epoch - epoch_counter + 1) - 1)
                formatted_time = str(
                    datetime.timedelta(seconds=int(estimated_time_remaining)))
                logger.info("Estimated training time remaining: %s",
                            formatted_time)
                #######################################################################################
                if self.visdom:
                    self.visdom.text(
                        "Estimated training time remaining: {}".format(
                            formatted_time),
                        win=self.left_time_window,
                        append=True)
                ############################################################################################
            epochs_trained += 1

        # Load the best model state before returning
        best_model_state = self._checkpointer.best_model_state()
        if best_model_state:
            self.model.load_state_dict(best_model_state)

        return metrics
コード例 #7
0
ファイル: pt_dist_trainner.py プロジェクト: polixir/abl-sym
    def train(self) -> Dict[str, Any]:
        try:
            epoch_counter = self._restore_checkpoint()
        except RuntimeError:
            traceback.print_exc()
            raise ConfigurationError(
                "Could not recover training from the checkpoint.  Did you mean to output to "
                "a different serialization directory or delete the existing serialization "
                "directory?")

        training_util.enable_gradient_clipping(self.model, self._grad_clipping)

        logger.info("Beginning training.")

        train_metrics: Dict[str, float] = {}
        val_metrics: Dict[str, float] = {}
        this_epoch_val_metric: float = None
        metrics: Dict[str, Any] = {}
        epochs_trained = 0
        training_start_time = time.time()

        metrics['best_epoch'] = self._metric_tracker.best_epoch
        for key, value in self._metric_tracker.best_epoch_metrics.items():
            metrics["best_validation_" + key] = value
        for epoch in range(epoch_counter, self._num_epochs):
            epoch_start_time = time.time()
            train_metrics = self._train_epoch(epoch)
            if self._validation_data is not None:
                with torch.no_grad():
                    val_loss, num_batches = self._validation_loss()
                    val_metrics = training_util.get_metrics(self.get_model(),
                                                            val_loss,
                                                            num_batches,
                                                            reset=True)
                    this_epoch_val_metric = val_metrics[
                        self._validation_metric]
                    self._metric_tracker.add_metric(this_epoch_val_metric)

                    if self._metric_tracker.should_stop_early():
                        logger.info("Ran out of patience.  Stopping training.")
                        break

            # Create overall metrics dict
            training_elapsed_time = time.time() - training_start_time
            metrics["training_duration"] = str(
                datetime.timedelta(seconds=training_elapsed_time))
            metrics["training_start_epoch"] = epoch_counter
            metrics["training_epochs"] = epochs_trained
            metrics["epoch"] = epoch

            for key, value in train_metrics.items():
                metrics["training_" + key] = value
            for key, value in val_metrics.items():
                metrics["validation_" + key] = value

            if self._metric_tracker.is_best_so_far():
                metrics['best_epoch'] = epoch
                for key, value in val_metrics.items():
                    metrics["best_validation_" + key] = value

                self._metric_tracker.best_epoch_metrics = val_metrics

            if self._serialization_dir and is_master_rank():
                dump_metrics(
                    os.path.join(self._serialization_dir,
                                 f'metrics_epoch_{epoch}.json'), metrics)

            if self._learning_rate_scheduler:
                self._learning_rate_scheduler.step(this_epoch_val_metric,
                                                   epoch)
            if is_master_rank():
                self._save_checkpoint(epoch)

            epoch_elapsed_time = time.time() - epoch_start_time
            logger.info("Epoch duration: %s",
                        datetime.timedelta(seconds=epoch_elapsed_time))

            if epoch < self._num_epochs - 1:
                training_elapsed_time = time.time() - training_start_time
                estimated_time_remaining = training_elapsed_time * \
                                           ((self._num_epochs - epoch_counter) / float(epoch - epoch_counter + 1) - 1)
                formatted_time = str(
                    datetime.timedelta(seconds=int(estimated_time_remaining)))
                logger.info("Estimated training time remaining: %s",
                            formatted_time)

            epochs_trained += 1

        best_model_state = self._checkpointer.best_model_state()
        if best_model_state:
            self.model.load_state_dict(best_model_state)

        return metrics
コード例 #8
0
ファイル: trainer.py プロジェクト: tsiq-peyman/allennlp
    def train(self) -> Dict[str, Any]:
        """
        Trains the supplied model with the supplied parameters.
        """
        try:
            epoch_counter = self._restore_checkpoint()
        except RuntimeError:
            traceback.print_exc()
            raise ConfigurationError(
                "Could not recover training from the checkpoint.  Did you mean to output to "
                "a different serialization directory or delete the existing serialization "
                "directory?")

        training_util.enable_gradient_clipping(self.model, self._grad_clipping)

        logger.info("Beginning training.")

        train_metrics: Dict[str, float] = {}
        val_metrics: Dict[str, float] = {}
        this_epoch_val_metric: float = None
        metrics: Dict[str, Any] = {}
        epochs_trained = 0
        training_start_time = time.time()

        for epoch in range(epoch_counter, self._num_epochs):
            epoch_start_time = time.time()
            train_metrics = self._train_epoch(epoch)

            # get peak of memory usage
            if 'cpu_memory_MB' in train_metrics:
                metrics['peak_cpu_memory_MB'] = max(
                    metrics.get('peak_cpu_memory_MB', 0),
                    train_metrics['cpu_memory_MB'])
            for key, value in train_metrics.items():
                if key.startswith('gpu_'):
                    metrics["peak_" + key] = max(metrics.get("peak_" + key, 0),
                                                 value)

            if self._validation_data is not None:
                with torch.no_grad():
                    # We have a validation set, so compute all the metrics on it.
                    val_loss, num_batches = self._validation_loss()
                    val_metrics = training_util.get_metrics(self.model,
                                                            val_loss,
                                                            num_batches,
                                                            reset=True)

                    # Check validation metric for early stopping
                    this_epoch_val_metric = val_metrics[
                        self._validation_metric]
                    self._metric_tracker.add_metric(this_epoch_val_metric)

                    if self._metric_tracker.should_stop_early():
                        logger.info("Ran out of patience.  Stopping training.")
                        break

            self._tensorboard.log_metrics(train_metrics,
                                          val_metrics=val_metrics,
                                          log_to_console=True)

            # Create overall metrics dict
            training_elapsed_time = time.time() - training_start_time
            metrics["training_duration"] = time.strftime(
                "%H:%M:%S", time.gmtime(training_elapsed_time))
            metrics["training_start_epoch"] = epoch_counter
            metrics["training_epochs"] = epochs_trained
            metrics["epoch"] = epoch

            for key, value in train_metrics.items():
                metrics["training_" + key] = value
            for key, value in val_metrics.items():
                metrics["validation_" + key] = value

            if self._metric_tracker.is_best_so_far():
                # Update all the best_ metrics.
                # (Otherwise they just stay the same as they were.)
                metrics['best_epoch'] = epoch
                for key, value in val_metrics.items():
                    metrics["best_validation_" + key] = value

            if self._serialization_dir:
                dump_metrics(
                    os.path.join(self._serialization_dir,
                                 f'metrics_epoch_{epoch}.json'), metrics)

            if self._learning_rate_scheduler:
                # The LRScheduler API is agnostic to whether your schedule requires a validation metric -
                # if it doesn't, the validation metric passed here is ignored.
                self._learning_rate_scheduler.step(this_epoch_val_metric,
                                                   epoch)

            self._save_checkpoint(epoch)

            epoch_elapsed_time = time.time() - epoch_start_time
            logger.info(
                "Epoch duration: %s",
                time.strftime("%H:%M:%S", time.gmtime(epoch_elapsed_time)))

            if epoch < self._num_epochs - 1:
                training_elapsed_time = time.time() - training_start_time
                estimated_time_remaining = training_elapsed_time * \
                    ((self._num_epochs - epoch_counter) / float(epoch - epoch_counter + 1) - 1)
                formatted_time = str(
                    datetime.timedelta(seconds=int(estimated_time_remaining)))
                logger.info("Estimated training time remaining: %s",
                            formatted_time)

            epochs_trained += 1

        # Load the best model state before returning
        best_model_state = self._checkpointer.best_model_state()
        if best_model_state:
            self.model.load_state_dict(best_model_state)

        return metrics
コード例 #9
0
 def _enable_gradient_clipping(self) -> None:
     training_util.enable_gradient_clipping(self._model, self._grad_clipping)
コード例 #10
0
    def train(self, experiment: Optional[Experiment] = None) -> Dict[str, Any]:
        """
        Trains the supplied model with the supplied parameters.
        """
        try:
            epoch_counter = self._restore_checkpoint()
        except RuntimeError:
            traceback.print_exc()
            raise ConfigurationError(
                "Could not recover training from the checkpoint.  Did you mean to output to "
                "a different serialization directory or delete the existing serialization "
                "directory?")

        training_util.enable_gradient_clipping(self.model, self._grad_clipping)

        self.experiment = experiment

        logger.info("Beginning training.")

        self.val_metrics: Dict[str, float] = {}
        this_epoch_val_metric: float = None
        self.metrics: Dict[str, Any] = {}
        epochs_trained = 0
        training_start_time = time.time()

        self.metrics["best_epoch"] = self._metric_tracker.best_epoch
        for key, value in self._metric_tracker.best_epoch_metrics.items():
            self.metrics["best_validation_" + key] = value

        for callback in self._epoch_callbacks:
            callback(self, metrics={}, epoch=-1, is_master=self._master)

        for epoch in range(epoch_counter, self._num_epochs):
            self.epoch = epoch
            epoch_start_time = time.time()
            train_metrics = self._train_epoch(epoch)

            if experiment:
                with experiment.train():
                    experiment.log_metrics(
                        {
                            k: v
                            for k, v in train_metrics.items() if np.isscalar(v)
                        },
                        step=epoch)

            # get peak of memory usage
            for key, value in train_metrics.items():
                if key.startswith("gpu_") and key.endswith("_memory_MB"):
                    self.metrics["peak_" + key] = max(
                        self.metrics.get("peak_" + key, 0), value)
                elif key.startswith("worker_") and key.endswith("_memory_MB"):
                    self.metrics["peak_" + key] = max(
                        self.metrics.get("peak_" + key, 0), value)

            if self._validation_data_loader is not None and epoch >= self.epochs_before_validate:
                with torch.no_grad():
                    try:
                        if self.external_callbacks:
                            self.external_callbacks.call_if_registered(
                                CallbackName.BEFORE_VALIDATION,
                                annotator=self.annotator,
                                model=self.model,
                                trainer=self,
                                experiment=experiment)

                        # We have a validation set, so compute all the metrics on it.
                        val_loss, val_reg_loss, num_batches, preds = self._validation_loss(
                            epoch)

                        # It is safe again to wait till the validation is done. This is
                        # important to get the metrics right.
                        if self._distributed:
                            dist.barrier()

                        self.val_metrics = training_util.get_metrics(
                            self.model,
                            val_loss,
                            val_reg_loss,
                            num_batches,
                            reset=True,
                            world_size=self._world_size,
                            cuda_device=self.cuda_device,
                        )

                        if self.dataset_writer:
                            if self.decoder:
                                preds = self.decoder.decode_batch(
                                    self.model.vocab, preds)
                            filename = self._serialization_dir + f"/pred_epoch_{epoch}.txt"
                            with open(filename, "w") as f:
                                self.dataset_writer.write_to_file(
                                    self.model.vocab,
                                    OrderedDatasetReader.restore_order(preds),
                                    f)

                            if self.validation_command:
                                self.val_metrics.update(
                                    self.validation_command.evaluate(filename))

                        if self.external_callbacks:
                            self.external_callbacks.call_if_registered(
                                CallbackName.AFTER_VALIDATION,
                                annotator=self.annotator,
                                model=self.model,
                                trainer=self,
                                experiment=experiment)

                        # Check validation metric for early stopping
                        this_epoch_val_metric = self.val_metrics[
                            self._validation_metric]
                        self._metric_tracker.add_metric(this_epoch_val_metric)

                        if self._metric_tracker.should_stop_early():
                            logger.info(
                                "Ran out of patience.  Stopping training.")
                            break

                    except Exception as ex:
                        print("An exception occured:")
                        print(ex)
                        self._checkpointer.save_checkpoint("validation-failed",
                                                           trainer=self)
                        raise

            if self._master:
                self._tensorboard.log_metrics(
                    train_metrics,
                    val_metrics=self.val_metrics,
                    log_to_console=True,
                    epoch=epoch + 1)  # +1 because tensorboard doesn't like 0

            # Create overall metrics dict
            training_elapsed_time = time.time() - training_start_time
            self.metrics["training_duration"] = str(
                datetime.timedelta(seconds=training_elapsed_time))
            self.metrics["training_start_epoch"] = epoch_counter
            self.metrics["training_epochs"] = epochs_trained
            self.metrics["epoch"] = epoch

            for key, value in train_metrics.items():
                self.metrics["training_" + key] = value
            for key, value in self.val_metrics.items():
                self.metrics["validation_" + key] = value

            if experiment:
                with experiment.validate():
                    experiment.log_metrics(
                        {
                            k: v
                            for k, v in self.metrics.items() if np.isscalar(v)
                        },
                        step=epoch)

            if self._metric_tracker.is_best_so_far():
                # Update all the best_ metrics.
                # (Otherwise they just stay the same as they were.)
                self.metrics["best_epoch"] = epoch
                for key, value in self.val_metrics.items():
                    self.metrics["best_validation_" + key] = value

                self._metric_tracker.best_epoch_metrics = self.val_metrics

            if self._serialization_dir and self._master:
                common_util.dump_metrics(
                    os.path.join(self._serialization_dir,
                                 f"metrics_epoch_{epoch}.json"), self.metrics)

            # The Scheduler API is agnostic to whether your schedule requires a validation metric -
            # if it doesn't, the validation metric passed here is ignored.
            if self._learning_rate_scheduler:
                self._learning_rate_scheduler.step(this_epoch_val_metric)
            if self._momentum_scheduler:
                self._momentum_scheduler.step(this_epoch_val_metric)

            if self._master:
                self._checkpointer.save_checkpoint(
                    epoch,
                    self,
                    is_best_so_far=self._metric_tracker.is_best_so_far())

            # Wait for the master to finish saving the checkpoint
            if self._distributed:
                dist.barrier()

            for callback in self._epoch_callbacks:
                callback(self,
                         metrics=self.metrics,
                         epoch=epoch,
                         is_master=self._master)

            epoch_elapsed_time = time.time() - epoch_start_time
            logger.info("Epoch duration: %s",
                        datetime.timedelta(seconds=epoch_elapsed_time))

            if epoch < self._num_epochs - 1:
                training_elapsed_time = time.time() - training_start_time
                estimated_time_remaining = training_elapsed_time * (
                    (self._num_epochs - epoch_counter) /
                    float(epoch - epoch_counter + 1) - 1)
                formatted_time = str(
                    datetime.timedelta(seconds=int(estimated_time_remaining)))
                logger.info("Estimated training time remaining: %s",
                            formatted_time)

            epochs_trained += 1

        # make sure pending events are flushed to disk and files are closed properly
        self._tensorboard.close()

        # Load the best model state before returning
        best_model_state = self._checkpointer.best_model_state()
        if best_model_state:
            self.model.load_state_dict(best_model_state)

        if self.external_callbacks:
            self.external_callbacks.call_if_registered(
                CallbackName.AFTER_TRAINING,
                annotator=self.annotator,
                model=self.model,
                trainer=self,
                experiment=experiment)

        return self.metrics
コード例 #11
0
ファイル: gan_trainer.py プロジェクト: djin31/loss-landscape
    def custom_train(self) -> Dict[str, Any]:
        """
        Trains the supplied model with the supplied parameters.
        """
        logger.info("GAN TRAINER HM START")
        try:
            epoch_counter = self.trainer._restore_checkpoint()
        except RuntimeError:
            traceback.print_exc()
            raise ConfigurationError(
                "Could not recover training from the checkpoint.  Did you mean to output to "
                "a different serialization directory or delete the existing serialization "
                "directory?")

        # TODO - gradient clipping?
        training_util.enable_gradient_clipping(self.trainer.model,
                                               self.trainer._grad_clipping)
        #HACK:
        #self.trainer._metric_tracker._patience = 30
        logger.info("Beginning training.")

        train_metrics: Dict[str, float] = {}
        val_metrics: Dict[str, float] = {}
        this_epoch_val_metric: float = None
        metrics: Dict[str, Any] = {}
        epochs_trained = 0
        training_start_time = time.time()

        metrics['best_epoch'] = self.trainer._metric_tracker.best_epoch
        for key, value in self.trainer._metric_tracker.best_epoch_metrics.items(
        ):
            metrics["best_validation_" + key] = value

        for epoch in range(epoch_counter, self.trainer._num_epochs):

            # Start tracemalloc
            # tracemalloc.start()

            epoch_start_time = time.time()
            train_metrics = self.semi_train_epoch(epoch)

            # get peak of memory usage
            if 'cpu_memory_MB' in train_metrics:
                metrics['peak_cpu_memory_MB'] = max(
                    metrics.get('peak_cpu_memory_MB', 0),
                    train_metrics['cpu_memory_MB'])
            for key, value in train_metrics.items():
                if key.startswith('gpu_'):
                    metrics["peak_" + key] = max(metrics.get("peak_" + key, 0),
                                                 value)
            """
            if self.unlabelled_dataset is not None:
                unlabelled_metrics = unlabelled_train_epoch(self.trainer, self.unlabelled_dataset, epoch)
                for key, value in unlabelled_metrics.items():
                    if key.startswith('gpu_'):
                        metrics["peak_"+'un_'+key] = max(unlabelled_metrics.get("peak_"+key, 0), value)
                    else:
                        metrics['un_'+key] = value
            """

            if self.trainer._validation_data is not None and (
                (epoch - epoch_counter) % self.calc_valid_freq
                    == (self.calc_valid_freq - 1)):
                with torch.no_grad():
                    # We have a validation set, so compute all the metrics on it.
                    val_loss, num_batches = self.trainer._validation_loss()
                    val_metrics = training_util.get_metrics(self.trainer.model,
                                                            val_loss,
                                                            num_batches,
                                                            reset=True)

                    # Check validation metric for early stopping
                    this_epoch_val_metric = val_metrics[
                        self.trainer._validation_metric]
                    self.trainer._metric_tracker.add_metric(
                        this_epoch_val_metric)

                    if self.trainer._metric_tracker.should_stop_early():
                        logger.info("Ran out of patience.  Stopping training.")
                        break

            self.trainer._tensorboard.log_metrics(train_metrics,
                                                  val_metrics=val_metrics,
                                                  log_to_console=True)

            # Create overall metrics dict
            training_elapsed_time = time.time() - training_start_time
            metrics["training_duration"] = time.strftime(
                "%H:%M:%S", time.gmtime(training_elapsed_time))
            metrics["training_start_epoch"] = epoch_counter
            metrics["training_epochs"] = epochs_trained
            metrics["epoch"] = epoch

            for key, value in train_metrics.items():
                metrics["training_" + key] = value
            for key, value in val_metrics.items():
                metrics["validation_" + key] = value

            is_best_so_far = False
            if self.trainer._metric_tracker.is_best_so_far():
                is_best_so_far = True
                # Update all the best_ metrics.
                # (Otherwise they just stay the same as they were.)
                metrics['best_epoch'] = epoch
                for key, value in val_metrics.items():
                    metrics["best_validation_" + key] = value

                self.trainer._metric_tracker.best_epoch_metrics = val_metrics

            if self.trainer._serialization_dir:
                dump_metrics(
                    os.path.join(self.trainer._serialization_dir,
                                 f'metrics_epoch_{epoch}.json'), metrics)

            #Pdb().set_trace()
            if self.trainer._learning_rate_scheduler:
                # The LRScheduler API is agnostic to whether your schedule requires a validation metric -
                # if it doesn't, the validation metric passed here is ignored.
                self.trainer._learning_rate_scheduler.step(
                    this_epoch_val_metric, epoch)

            self.trainer._save_checkpoint(epoch)
            if self.constraints_model is not None:
                spath = self.save_constraints_model(epoch)
                if is_best_so_far:
                    shutil.copyfile(
                        spath,
                        os.path.join(self.trainer._serialization_dir,
                                     'best_dd_checkpoint.pth'))

                # Start saving checkpoint models after checkpoint_begin after every checkpoint_interval
                #if (self.trainer._checkpointer._save_intermediate_checkpoints) and (epoch >= self.trainer._checkpointer._checkpoint_begin) and (epoch%self.trainer._checkpointer._checkpoint_interval == 0):
                #    shutil.copyfile(spath,os.path.join(self.trainer._serialization_dir,'dd_checkpoint_epoch_'+str(epoch)+'.cpoint'))

            epoch_elapsed_time = time.time() - epoch_start_time
            logger.info(
                "Epoch duration: %s",
                time.strftime("%H:%M:%S", time.gmtime(epoch_elapsed_time)))

            if epoch < self.trainer._num_epochs - 1:
                training_elapsed_time = time.time() - training_start_time
                estimated_time_remaining = training_elapsed_time * \
                    ((self.trainer._num_epochs - epoch_counter) / float(epoch - epoch_counter + 1) - 1)
                formatted_time = str(
                    datetime.timedelta(seconds=int(estimated_time_remaining)))
                logger.info("Estimated training time remaining: %s",
                            formatted_time)

            self.trainer.model.train()
            epochs_trained += 1

            # Take snapshot and reveal top memory allocation
            # snapshot = tracemalloc.take_snapshot()
            # top_stats = snapshot.statistics('lineno')

            # print("[ Top 10 ]")
            # for stat in top_stats[:10]:
            #     logger.info(stat)

        # Load the best model state before returning
        best_model_state = self.trainer._checkpointer.best_model_state()
        if best_model_state:
            self.trainer.model.load_state_dict(best_model_state)

        return metrics
コード例 #12
0
ファイル: trainer.py プロジェクト: zhangbo2008/pachong2
    def train(self) -> Dict[str, Any]:
        """
        Trains the supplied model with the supplied parameters.
        """
        try:
            epoch_counter = self._restore_checkpoint()
        except RuntimeError:
            traceback.print_exc()
            raise ConfigurationError(
                "Could not recover training from the checkpoint.  Did you mean to output to "
                "a different serialization directory or delete the existing serialization "
                "directory?")

        training_util.enable_gradient_clipping(self.model, self._grad_clipping)

        logger.info("Beginning training.")

        train_metrics: Dict[str, float] = {}
        val_metrics: Dict[str, float] = {}
        this_epoch_val_metric: float = None
        metrics: Dict[str, Any] = {}
        epochs_trained = 0
        training_start_time = time.time()

        if self.cold_step_count > 0:  # 冷启动几个step, 这些step不用更新权重.对于embed网路都freeeze上.只更新后面自己简历的分类器网络.
            base_lr = self.optimizer.param_groups[0]['lr']
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = self.cold_lr
            self.model.text_field_embedder._token_embedders[
                'bert'].set_weights(freeze=True)

        metrics["best_epoch"] = self._metric_tracker.best_epoch
        for key, value in self._metric_tracker.best_epoch_metrics.items():
            metrics["best_validation_" + key] = value

        for epoch in range(epoch_counter,
                           self._num_epochs):  # 把之前学完的epoch直接跳过.
            if epoch == self.cold_step_count and epoch != 0:  # 冷启动完毕,开始恢复学习率
                for param_group in self.optimizer.param_groups:
                    param_group['lr'] = base_lr
                self.model.text_field_embedder._token_embedders[
                    'bert'].set_weights(freeze=False)  #并且把embed网络解除冻结

            epoch_start_time = time.time()
            train_metrics = self._train_epoch(epoch)

            # get peak of memory usage
            if "cpu_memory_MB" in train_metrics:
                metrics["peak_cpu_memory_MB"] = max(
                    metrics.get("peak_cpu_memory_MB", 0),
                    train_metrics["cpu_memory_MB"])
            for key, value in train_metrics.items():
                if key.startswith("gpu_"):
                    metrics["peak_" + key] = max(metrics.get("peak_" + key, 0),
                                                 value)

            # clear cache before validation
            torch.cuda.empty_cache()
            if self._validation_data is not None:
                with torch.no_grad():
                    # We have a validation set, so compute all the metrics on it.
                    val_loss, num_batches = self._validation_loss()
                    val_metrics = training_util.get_metrics(self.model,
                                                            val_loss,
                                                            num_batches,
                                                            reset=True)

                    # Check validation metric for early stopping
                    this_epoch_val_metric = val_metrics[
                        self._validation_metric]
                    self._metric_tracker.add_metric(this_epoch_val_metric)

                    if self._metric_tracker.should_stop_early():
                        logger.info("Ran out of patience.  Stopping training.")
                        break

            self._tensorboard.log_metrics(
                train_metrics,
                val_metrics=val_metrics,
                log_to_console=True,
                epoch=epoch + 1)  # +1 because tensorboard doesn't like 0

            # Create overall metrics dict
            training_elapsed_time = time.time() - training_start_time
            metrics["training_duration"] = str(
                datetime.timedelta(seconds=training_elapsed_time))
            metrics["training_start_epoch"] = epoch_counter
            metrics["training_epochs"] = epochs_trained
            metrics["epoch"] = epoch

            for key, value in train_metrics.items():
                metrics["training_" + key] = value
            for key, value in val_metrics.items():
                metrics["validation_" + key] = value

            # if self.cold_step_count <= epoch:
            self.scheduler.step(metrics['validation_loss'])

            if self._metric_tracker.is_best_so_far():
                # Update all the best_ metrics.
                # (Otherwise they just stay the same as they were.)
                metrics["best_epoch"] = epoch
                for key, value in val_metrics.items():
                    metrics["best_validation_" + key] = value

                self._metric_tracker.best_epoch_metrics = val_metrics

            if self._serialization_dir:
                dump_metrics(
                    os.path.join(self._serialization_dir,
                                 f"metrics_epoch_{epoch}.json"), metrics)

            # The Scheduler API is agnostic to whether your schedule requires a validation metric -
            # if it doesn't, the validation metric passed here is ignored.
            if self._learning_rate_scheduler:
                self._learning_rate_scheduler.step(this_epoch_val_metric,
                                                   epoch)
            if self._momentum_scheduler:
                self._momentum_scheduler.step(this_epoch_val_metric, epoch)
#保存model, 只需要给定这个文件夹,那么算法自动会读取里面最新的模型.来进行finetune.很方便.
            self._save_checkpoint(epoch)

            epoch_elapsed_time = time.time() - epoch_start_time
            logger.info("Epoch duration: %s",
                        datetime.timedelta(seconds=epoch_elapsed_time))

            if epoch < self._num_epochs - 1:
                training_elapsed_time = time.time() - training_start_time
                estimated_time_remaining = training_elapsed_time * (
                    (self._num_epochs - epoch_counter) /
                    float(epoch - epoch_counter + 1) - 1)
                formatted_time = str(
                    datetime.timedelta(seconds=int(estimated_time_remaining)))
                logger.info("Estimated training time remaining: %s",
                            formatted_time)

            epochs_trained += 1

        # make sure pending events are flushed to disk and files are closed properly
        # self._tensorboard.close()

        # Load the best model state before returning # 根据路径目录找到里面存的最好的模型.
        best_model_state = self._checkpointer.best_model_state()
        if best_model_state:
            self.model.load_state_dict(best_model_state)

        return metrics
コード例 #13
0
 def enable_gradient_clipping(self, trainer):
     training_util.enable_gradient_clipping(trainer.model,
                                            self.grad_clipping)
コード例 #14
0
def multi_task_training(
        main_trainer_name: Tuple[Trainer, str],
        aux_trainers_names: Tuple[List[Trainer], List[str]]) -> Dict[str, Any]:
    '''
    Performs as many epochs as the main task requires and if early stopping 
    is set then it is defined by the main task. The way that multi task is run
    it runs the auxiliary task for one epoch and then the main task for one epoch 
    and then it evaluates the main tasks validation dataset to see if early stopping 
    needs to happen and if so then no more training else it goes for another 
    epoch on auxiliary then main task.

    :param main_trainer_name: A tuple of 1. Trainer and 2. name of task.
    :param aux_trainers_names: A tuple of 1. A list of auxiliary trainers and 
                               2. A list of names associated to those trainers.
    :returns: Metrics for both auxiliary and main tasks 
                
    '''
    main_trainer = main_trainer_name[0]
    main_task_name = main_trainer_name[1]
    training_util.enable_gradient_clipping(main_trainer.model,
                                           main_trainer._grad_clipping)
    for aux_trainer in aux_trainers_names[0]:
        training_util.enable_gradient_clipping(aux_trainer.model,
                                               aux_trainer._grad_clipping)

    all_metrics: Dict[str, Any] = {}
    # need to deal with the metrics the format could be `split name, auxiliary or not, task name,`

    for epoch in range(main_trainer._num_epochs):
        aux_name_validation_metrics: Dict[str, float] = {}
        for aux_trainer, aux_name in zip(*aux_trainers_names):
            logger.warning(f'Training Auxiliary task {aux_name}')
            aux_metrics = train_one_epoch(aux_trainer, epoch)
            all_metrics[f'training_aux_{aux_name}'] = aux_metrics[0]
            all_metrics[f'validation_aux_{aux_name}'] = aux_metrics[1]
            aux_name_validation_metrics[aux_name] = aux_metrics[1]
        logger.warning(f'Training Main task {main_task_name}')
        main_train_metrics, main_val_metrics = train_one_epoch(
            main_trainer, epoch)
        all_metrics[f'training_main_{main_task_name}'] = main_train_metrics
        all_metrics[f'validation_main_{main_task_name}'] = main_val_metrics
        # Early stopping if applicable (main task) and tracking the best metric
        main_validation_metric_name = main_trainer._validation_metric
        main_validation_metric = main_val_metrics[main_validation_metric_name]
        main_trainer._metric_tracker.add_metric(main_validation_metric)

        for aux_trainer in aux_trainers_names[0]:
            multi_task_checkpoint_saver(
                aux_trainer, main_trainer._metric_tracker.is_best_so_far(),
                epoch)
        multi_task_checkpoint_saver(
            main_trainer, main_trainer._metric_tracker.is_best_so_far(), epoch)

        if main_trainer._metric_tracker.should_stop_early():
            logger.info("Ran out of patience.  Stopping training.")
            break
        # Getting the best metrics for the main task
        if main_trainer._metric_tracker.is_best_so_far():
            # Update all the best_ metrics.
            # (Otherwise they just stay the same as they were.)
            all_metrics['best_epoch'] = epoch
            for key, value in main_val_metrics.items():
                all_metrics["best_validation_" + key] = value
            main_trainer._metric_tracker.best_epoch_metrics = main_val_metrics
    # Load the best model state before returning
    main_best_model_state = main_trainer._checkpointer.best_model_state()
    if main_best_model_state:
        main_trainer.model.load_state_dict(main_best_model_state)

    for aux_trainer in aux_trainers_names[0]:
        aux_best_model_state = aux_trainer._checkpointer.best_model_state()
        if aux_best_model_state:
            aux_trainer.model.load_state_dict(aux_best_model_state)

    return all_metrics
コード例 #15
0
ファイル: trainer.py プロジェクト: valueable/GEC
    def train(self) -> Dict[str, Any]:
        """
        Trains the supplied model with the supplied parameters.
        相关的metric字典记录的信息都在训练时产生的json文件中
        """
        try:
            epoch_counter = self._restore_checkpoint()
        except RuntimeError:
            traceback.print_exc()
            raise ConfigurationError(
                "Could not recover training from the checkpoint.  Did you mean to output to "
                "a different serialization directory or delete the existing serialization "
                "directory?")
        # 梯度剪裁 防止梯度爆炸跳过最优解
        training_util.enable_gradient_clipping(self.model, self._grad_clipping)

        logger.info("Beginning training.")

        train_metrics: Dict[str, float] = {}
        val_metrics: Dict[str, float] = {}
        this_epoch_val_metric: float = None
        metrics: Dict[str, Any] = {}
        epochs_trained = 0
        # ------训练开始-------
        training_start_time = time.time()
        # cold_step_count为只训练最后一层线性层的epoch数
        # 训练阶段一,二
        # 在前 cold_step_count个epoch
        # 不需要训练原来的预训练模型,之后需要训练
        # 阶段三直接训练预训练模型参数, 因为预训练模型的参数过多
        # 同时需要注意,在cold step阶段也要使用cold lr,
        # 此阶段结束后,使用base lr
        if self.cold_step_count > 0:
            # 1e-5
            base_lr = self.optimizer.param_groups[0]['lr']
            for param_group in self.optimizer.param_groups:
                # 1e-3
                param_group['lr'] = self.cold_lr
            self.model.text_field_embedder._token_embedders[
                'bert'].set_weights(freeze=True)

        metrics["best_epoch"] = self._metric_tracker.best_epoch
        for key, value in self._metric_tracker.best_epoch_metrics.items():
            metrics["best_validation_" + key] = value
        # epoch_counter = 0 if restore_checkpoint is none else continue training
        for epoch in range(epoch_counter, self._num_epochs):
            # 恢复正常
            if epoch == self.cold_step_count and epoch != 0:
                for param_group in self.optimizer.param_groups:
                    param_group['lr'] = base_lr
                self.model.text_field_embedder._token_embedders[
                    'bert'].set_weights(freeze=False)
            # --开始当前epoch--
            epoch_start_time = time.time()
            # **训练**
            train_metrics = self._train_epoch(epoch)

            # get peak of memory usage
            if "cpu_memory_MB" in train_metrics:
                metrics["peak_cpu_memory_MB"] = max(
                    metrics.get("peak_cpu_memory_MB", 0),
                    train_metrics["cpu_memory_MB"])
            for key, value in train_metrics.items():
                if key.startswith("gpu_"):
                    metrics["peak_" + key] = max(metrics.get("peak_" + key, 0),
                                                 value)

            # clear cache before validation
            torch.cuda.empty_cache()
            # evaluate的函数说了, 不是一定需要进行验证,所以这里要做判断
            if self._validation_data is not None:
                # 常规操作,验证时不计算梯度,不更新参数
                with torch.no_grad():
                    # We have a validation set, so compute all the metrics on it.
                    val_loss, num_batches = self._validation_loss()
                    val_metrics = training_util.get_metrics(self.model,
                                                            val_loss,
                                                            num_batches,
                                                            reset=True)

                    # Check validation metric for early stopping
                    # 获取性能指标--loss
                    this_epoch_val_metric = val_metrics[
                        self._validation_metric]
                    self._metric_tracker.add_metric(this_epoch_val_metric)

                    if self._metric_tracker.should_stop_early():
                        # 这就是为什么有的时候ckpt不足epoch个数,是因为patience耗光
                        # patience是配合早停机制的阈值,patience次在验证集的性能下降时,停止训练
                        logger.info("Ran out of patience.  Stopping training.")
                        break

            self._tensorboard.log_metrics(
                train_metrics,
                val_metrics=val_metrics,
                log_to_console=True,
                epoch=epoch + 1)  # +1 because tensorboard doesn't like 0

            # Create overall metrics dict
            # **epoch结束**
            training_elapsed_time = time.time() - training_start_time
            metrics["training_duration"] = str(
                datetime.timedelta(seconds=training_elapsed_time))
            metrics["training_start_epoch"] = epoch_counter
            metrics["training_epochs"] = epochs_trained
            metrics["epoch"] = epoch
            # 将train, evaluate阶段的metric记录都汇总
            for key, value in train_metrics.items():
                metrics["training_" + key] = value
            for key, value in val_metrics.items():
                metrics["validation_" + key] = value

            # if self.cold_step_count <= epoch:
            # step操作
            self.scheduler.step(metrics['validation_loss'])
            # 这些更新都在119服务器的pretraingectors目录下
            if self._metric_tracker.is_best_so_far():
                # Update all the best_ metrics.
                # (Otherwise they just stay the same as they were.)
                metrics["best_epoch"] = epoch
                for key, value in val_metrics.items():
                    metrics["best_validation_" + key] = value

                self._metric_tracker.best_epoch_metrics = val_metrics
            # 以json形式存储metrics
            if self._serialization_dir:
                dump_metrics(
                    os.path.join(self._serialization_dir,
                                 f"metrics_epoch_{epoch}.json"), metrics)

            # The Scheduler API is agnostic to whether your schedule requires a validation metric -
            # if it doesn't, the validation metric passed here is ignored.
            if self._learning_rate_scheduler:
                # step操作
                self._learning_rate_scheduler.step(this_epoch_val_metric,
                                                   epoch)
            if self._momentum_scheduler:
                # step操作
                self._momentum_scheduler.step(this_epoch_val_metric, epoch)
            # 保存ckpt
            self._save_checkpoint(epoch)

            epoch_elapsed_time = time.time() - epoch_start_time
            logger.info("Epoch duration: %s",
                        datetime.timedelta(seconds=epoch_elapsed_time))

            if epoch < self._num_epochs - 1:
                training_elapsed_time = time.time() - training_start_time
                estimated_time_remaining = training_elapsed_time * (
                    (self._num_epochs - epoch_counter) /
                    float(epoch - epoch_counter + 1) - 1)
                formatted_time = str(
                    datetime.timedelta(seconds=int(estimated_time_remaining)))
                logger.info("Estimated training time remaining: %s",
                            formatted_time)
            # 一个epoch结束
            epochs_trained += 1

        # make sure pending events are flushed to disk and files are closed properly
        # self._tensorboard.close()

        # Load the best model state before returning
        best_model_state = self._checkpointer.best_model_state()
        if best_model_state:
            self.model.load_state_dict(best_model_state)

        return metrics
コード例 #16
0
ファイル: trainer.py プロジェクト: yf1291/nlp4
    def train(self, oldmodel) -> Dict[str, Any]:
        """
        Trains the supplied model with the supplied parameters.
        """#---------------通过oldmodel来加载官方模型,从而进行finetune
        try:
            # epoch_counter = self._restore_checkpoint()
            epoch_counter = 0  # 直接进行finetune. 加速训练.
            # 下面看这个地方怎么修改,保证之前的参数都赋值上去.?????????????????????????????????????之前没遇到过这个问题,需要自己克服看看.
            print(oldmodel,
                  '旧的模型路径是这个')  # 这个model里面已经涵盖了xlnet网络的参数和最后2层linear的参数.
            tmp = torch.load(oldmodel, map_location=torch.device('cpu'))
            out_shape = self.model.tag_labels_projection_layer._module.out_features  # 把下面的东西扩充到out_shape
            now_shape = tmp[
                'tag_labels_projection_layer._module.weight'].shape[0]
            fix_shape = out_shape - now_shape
            # 通过concat
            tmp['tag_labels_projection_layer._module.weight'] = torch.cat(
                (tmp['tag_labels_projection_layer._module.weight'],
                 torch.zeros(fix_shape, 768)), 0)
            tmp['tag_labels_projection_layer._module.bias'] = torch.cat(
                (tmp['tag_labels_projection_layer._module.bias'],
                 torch.zeros(fix_shape)), 0)
            # 需要补充到的数据大小:

            # tmp只是一个字典而已,随便玩.

            self.model.load_state_dict(
                tmp)  # 这次的收敛速度飞快!!!!!!!!!!!!!!!!!!!# 初步的打算是,补充shape到我们需要的,大小.

        except RuntimeError:
            traceback.print_exc()
            raise ConfigurationError(
                "Could not recover training from the checkpoint.  Did you mean to output to "
                "a different serialization directory or delete the existing serialization "
                "directory?")

        training_util.enable_gradient_clipping(self.model, self._grad_clipping)

        logger.info("Beginning training.")

        train_metrics: Dict[str, float] = {}
        val_metrics: Dict[str, float] = {}
        this_epoch_val_metric: float = None
        metrics: Dict[str, Any] = {}
        epochs_trained = 0
        training_start_time = time.time()

        if self.cold_step_count > 0:  # 冷启动几个step, 这些step不用更新权重.对于embed网路都freeeze上.只更新后面自己简历的分类器网络.
            base_lr = self.optimizer.param_groups[0]['lr']
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = self.cold_lr
            self.model.text_field_embedder._token_embedders[
                'bert'].set_weights(freeze=True)

        metrics["best_epoch"] = self._metric_tracker.best_epoch
        for key, value in self._metric_tracker.best_epoch_metrics.items():
            metrics["best_validation_" + key] = value

        for epoch in range(epoch_counter,
                           self._num_epochs):  # 把之前学完的epoch直接跳过.
            if epoch == self.cold_step_count and epoch != 0:  # 冷启动完毕,开始恢复学习率
                for param_group in self.optimizer.param_groups:
                    param_group['lr'] = base_lr
                self.model.text_field_embedder._token_embedders[
                    'bert'].set_weights(freeze=False)  #并且把embed网络解除冻结

            epoch_start_time = time.time()  # 下行是训练代码
            train_metrics = self._train_epoch(epoch)

            # get peak of memory usage
            if "cpu_memory_MB" in train_metrics:
                metrics["peak_cpu_memory_MB"] = max(
                    metrics.get("peak_cpu_memory_MB", 0),
                    train_metrics["cpu_memory_MB"])
            for key, value in train_metrics.items():
                if key.startswith("gpu_"):
                    metrics["peak_" + key] = max(metrics.get("peak_" + key, 0),
                                                 value)

            # clear cache before validation
            torch.cuda.empty_cache()

            # 在验证集上评测效果,防止过拟合.
            if self._validation_data is not None:
                with torch.no_grad():
                    # We have a validation set, so compute all the metrics on it.
                    val_loss, num_batches = self._validation_loss()
                    val_metrics = training_util.get_metrics(self.model,
                                                            val_loss,
                                                            num_batches,
                                                            reset=True)

                    # Check validation metric for early stopping
                    this_epoch_val_metric = val_metrics[
                        self._validation_metric]
                    self._metric_tracker.add_metric(this_epoch_val_metric)

                    if self._metric_tracker.should_stop_early():
                        logger.info("Ran out of patience.  Stopping training.")
                        break

            self._tensorboard.log_metrics(
                train_metrics,
                val_metrics=val_metrics,
                log_to_console=True,
                epoch=epoch + 1)  # +1 because tensorboard doesn't like 0

            # Create overall metrics dict
            training_elapsed_time = time.time() - training_start_time
            metrics["training_duration"] = str(
                datetime.timedelta(seconds=training_elapsed_time))
            metrics["training_start_epoch"] = epoch_counter
            metrics["training_epochs"] = epochs_trained
            metrics["epoch"] = epoch

            for key, value in train_metrics.items():
                metrics["training_" + key] = value
            for key, value in val_metrics.items():
                metrics["validation_" + key] = value

            # if self.cold_step_count <= epoch:
            self.scheduler.step(metrics['validation_loss'])

            if self._metric_tracker.is_best_so_far():
                # Update all the best_ metrics.
                # (Otherwise they just stay the same as they were.)
                metrics["best_epoch"] = epoch
                for key, value in val_metrics.items():
                    metrics["best_validation_" + key] = value

                self._metric_tracker.best_epoch_metrics = val_metrics

            if self._serialization_dir:
                dump_metrics(
                    os.path.join(self._serialization_dir,
                                 f"metrics_epoch_{epoch}.json"), metrics)

            # The Scheduler API is agnostic to whether your schedule requires a validation metric -
            # if it doesn't, the validation metric passed here is ignored.
            if self._learning_rate_scheduler:
                self._learning_rate_scheduler.step(this_epoch_val_metric,
                                                   epoch)
            if self._momentum_scheduler:
                self._momentum_scheduler.step(this_epoch_val_metric, epoch)
#保存model, 只需要给定这个文件夹,那么算法自动会读取里面最新的模型.来进行finetune.很方便.

            if self._num_epochs == epoch + 1:  # 只存最后一个.
                self._save_checkpoint(epoch)  # 每一个epoch 都存,最后的磁盘占用很大.

            epoch_elapsed_time = time.time() - epoch_start_time
            logger.info("Epoch duration: %s",
                        datetime.timedelta(seconds=epoch_elapsed_time))

            if epoch < self._num_epochs - 1:
                training_elapsed_time = time.time() - training_start_time
                estimated_time_remaining = training_elapsed_time * (
                    (self._num_epochs - epoch_counter) /
                    float(epoch - epoch_counter + 1) - 1)
                formatted_time = str(
                    datetime.timedelta(seconds=int(estimated_time_remaining)))
                logger.info("Estimated training time remaining: %s",
                            formatted_time)

            epochs_trained += 1

        # make sure pending events are flushed to disk and files are closed properly
        # self._tensorboard.close()

        # Load the best model state before returning # 根据路径目录找到里面存的最好的模型.
        best_model_state = self._checkpointer.best_model_state()
        if best_model_state:
            self.model.load_state_dict(best_model_state)

        return metrics