コード例 #1
0
    def __init__(self,
                 patience: Optional[int] = None,
                 metric_name: str = None,
                 should_decrease: bool = None) -> None:
        self._best_so_far: float = None
        self._patience = patience
        self._epochs_with_no_improvement = 0
        self._is_best_so_far = True
        self.best_epoch_metrics: Dict[str, float] = {}
        self._epoch_number = 0
        self.best_epoch: int = None

        # If the metric name starts with "+", we want it to increase.
        # If the metric name starts with "-", we want it to decrease.
        # We also allow you to not specify a metric name and just set `should_decrease` directly.
        if should_decrease is None and metric_name is None:
            raise ConfigurationError(
                "must specify either `should_decrease` or `metric_name` (but not both)"
            )
        elif should_decrease is not None and metric_name is not None:
            raise ConfigurationError(
                "must specify either `should_decrease` or `metric_name` (but not both)"
            )
        elif metric_name is not None:
            if metric_name[0] == "-":
                self._should_decrease = True
            elif metric_name[0] == "+":
                self._should_decrease = False
            else:
                raise ConfigurationError("metric_name must start with + or -")
        else:
            self._should_decrease = should_decrease
コード例 #2
0
 def __init__(self, top_k: int = 1, tie_break: bool = False) -> None:
     if top_k > 1 and tie_break:
         raise ConfigurationError(
             "Tie break in Categorical Accuracy "
             "can be done only for maximum (top_k = 1)")
     if top_k <= 0:
         raise ConfigurationError(
             "top_k passed to Categorical Accuracy must be > 0")
     self._top_k = top_k
     self._tie_break = tie_break
     self.correct_count = 0.
     self.total_count = 0.
コード例 #3
0
 def step(self, metric: float = None, epoch: int = None) -> None:
     if metric is None:
         raise ConfigurationError(
             "This learning rate scheduler requires "
             "a validation metric to compute the schedule and therefore "
             "must be used with a validation dataset.")
     self.lr_scheduler.step(metric, epoch)
コード例 #4
0
    def update_confusion_matrix(self, logits, labels):
        num_classes = len(self.class_list)
        assert logits.size(-1) == num_classes

        if labels.dim() != logits.dim() - 1:
            raise ConfigurationError(
                "labels must have dimension == logits.size() - 1 but "
                "found tensor of shape: {}".format(logits.size()))
        if (labels >= num_classes).any():
            raise ConfigurationError(
                "A label passed to Categorical Accuracy contains an id >= {}, "
                "the number of classes.".format(num_classes))
        logits = logits.view((-1, num_classes))
        labels = labels.view(-1).long()

        top_k = logits.max(-1)[1].unsqueeze(-1)
        labels = labels.unsqueeze(-1)

        for i, label in enumerate(labels):
            pred = top_k[i]
            label, pred = label.item(), pred.item()
            self.confusion_matrix[label][pred] += 1
コード例 #5
0
    def _get_prediction_device(self) -> int:
        """
        This method checks the device of the model parameters to determine the cuda_device
        this model should be run on for predictions.  If there are no parameters, it returns -1.

        Returns
        -------
        The cuda device this model should run on for predictions.
        """
        devices = {util.get_device_of(param) for param in self.parameters()}

        if len(devices) > 1:
            devices_string = ", ".join(str(x) for x in devices)
            raise ConfigurationError(
                f"Parameters have mismatching cuda_devices: {devices_string}")
        elif len(devices) == 1:
            return devices.pop()
        else:
            return -1
コード例 #6
0
    def __call__(self,
                 predictions: torch.Tensor,
                 gold_labels: torch.Tensor,
                 mask: Optional[torch.Tensor] = None):
        """
        Parameters
        ----------
        predictions : ``torch.Tensor``, required.
            A tensor of predictions of shape (batch_size, ..., num_classes).
        gold_labels : ``torch.Tensor``, required.
            A tensor of integer class label of shape (batch_size, ...). It must be the same
            shape as the ``predictions`` tensor without the ``num_classes`` dimension.
        mask: ``torch.Tensor``, optional (default = None).
            A masking tensor the same size as ``gold_labels``.
        """
        predictions, gold_labels, mask = self.unwrap_to_tensors(
            predictions, gold_labels, mask)

        # Some sanity checks.
        num_classes = predictions.size(-1)
        if gold_labels.dim() != predictions.dim() - 1:
            raise ConfigurationError(
                "gold_labels must have dimension == predictions.size() - 1 but "
                "found tensor of shape: {}".format(predictions.size()))
        if (gold_labels >= num_classes).any():
            raise ConfigurationError(
                "A gold label passed to Categorical Accuracy contains an id >= {}, "
                "the number of classes.".format(num_classes))

        predictions = predictions.view((-1, num_classes))
        gold_labels = gold_labels.view(-1).long()
        if not self._tie_break:
            # Top K indexes of the predictions (or fewer, if there aren't K of them).
            # Special case topk == 1, because it's common and .max() is much faster than .topk().
            if self._top_k == 1:
                top_k = predictions.max(-1)[1].unsqueeze(-1)
            else:
                top_k = predictions.topk(
                    min(self._top_k, predictions.shape[-1]), -1)[1]

            # This is of shape (batch_size, ..., top_k).
            correct = top_k.eq(gold_labels.unsqueeze(-1)).float()
        else:
            # prediction is correct if gold label falls on any of the max scores. distribute score by tie_counts
            max_predictions = predictions.max(-1)[0]
            max_predictions_mask = predictions.eq(
                max_predictions.unsqueeze(-1))
            # max_predictions_mask is (rows X num_classes) and gold_labels is (batch_size)
            # ith entry in gold_labels points to index (0-num_classes) for ith row in max_predictions
            # For each row check if index pointed by gold_label is was 1 or not (among max scored classes)
            correct = max_predictions_mask[
                torch.arange(gold_labels.numel()).long(), gold_labels].float()
            tie_counts = max_predictions_mask.sum(-1)
            correct /= tie_counts.float()
            correct.unsqueeze_(-1)

        if mask is not None:
            correct *= mask.view(-1, 1).float()
            self.total_count += mask.sum()
        else:
            self.total_count += gold_labels.numel()
        self.correct_count += correct.sum()
コード例 #7
0
    def __init__(
            self,
            model: Model,
            optimizer: torch.optim.Optimizer,
            train_dataloader: DataLoader,
            validation_dataloader: Optional[DataLoader] = None,
            patience: Optional[int] = None,
            validation_metric: str = "-loss",
            serialization_dir: Optional[str] = None,
            num_serialized_models_to_keep: int = 20,
            keep_serialized_model_every_num_seconds: int = None,
            checkpointer: Checkpointer = None,
            model_save_interval: float = None,
            cuda_device: int = -1,
            grad_norm: Optional[float] = None,
            grad_clipping: Optional[float] = None,
            learning_rate_scheduler: Optional[LearningRateScheduler] = None,
            momentum_scheduler: Optional[MomentumScheduler] = None,
            summary_interval: int = 100,
            histogram_interval: int = None,
            should_log_parameter_statistics: bool = True,
            should_log_learning_rate: bool = False,
            log_batch_size_period: Optional[int] = None,
            moving_average: Optional[MovingAverage] = None,
            num_epochs: int = 20,
            temperature_scheduler: Optional[TemperatureScheduler] = None,
            custom_logger: Optional[CustomLogger] = None) -> None:
        super().__init__(serialization_dir, cuda_device)

        # I am not calling move_to_gpu here, because if the model is
        # not already on the GPU then the optimizer is going to be wrong.
        self.model = model

        self.optimizer = optimizer
        self.train_dataloader = train_dataloader
        self._validation_dataloader = validation_dataloader

        if patience is None:  # no early stopping
            if validation_dataloader:
                logger.warning(
                    "You provided a validation dataset but patience was set to None, "
                    "meaning that early stopping is disabled")
        elif (not isinstance(patience, int)) or patience <= 0:
            raise ConfigurationError(
                '{} is an invalid value for "patience": it must be a positive integer '
                "or None (if you want to disable early stopping)".format(
                    patience))

        # For tracking is_best_so_far and should_stop_early
        self._metric_tracker = MetricTracker(patience, validation_metric)
        # Get rid of + or -
        self._validation_metric = validation_metric[1:]

        self._num_epochs = num_epochs

        if checkpointer is not None:
            # We can't easily check if these parameters were passed in, so check against their default values.
            # We don't check against serialization_dir since it is also used by the parent class.
            if (num_serialized_models_to_keep != 20
                    or keep_serialized_model_every_num_seconds is not None):
                raise ConfigurationError(
                    "When passing a custom Checkpointer, you may not also pass in separate checkpointer "
                    "args 'num_serialized_models_to_keep' or 'keep_serialized_model_every_num_seconds'."
                )
            self._checkpointer = checkpointer
        else:
            self._checkpointer = Checkpointer(
                serialization_dir,
                keep_serialized_model_every_num_seconds,
                num_serialized_models_to_keep,
            )

        self._model_save_interval = model_save_interval

        self._grad_norm = grad_norm
        self._grad_clipping = grad_clipping

        self._learning_rate_scheduler = learning_rate_scheduler
        self._momentum_scheduler = momentum_scheduler
        self._moving_average = moving_average
        self._temperature_scheduler = temperature_scheduler

        # We keep the total batch number as an instance variable because it
        # is used inside a closure for the hook which logs activations in
        # ``_enable_activation_logging``.
        self._batch_num_total = 0

        self._tensorboard = TensorboardWriter(
            get_batch_num_total=lambda: self._batch_num_total,
            serialization_dir=serialization_dir,
            summary_interval=summary_interval,
            histogram_interval=histogram_interval,
            should_log_parameter_statistics=should_log_parameter_statistics,
            should_log_learning_rate=should_log_learning_rate,
        )

        self._log_batch_size_period = log_batch_size_period

        self._last_log = 0.0  # time of last logging

        # Enable activation logging.
        if histogram_interval is not None:
            self._tensorboard.enable_activation_logging(self.model)

        self._custom_logger = custom_logger
コード例 #8
0
    def train(self) -> Dict[str, Any]:
        """
        Trains the supplied model with the supplied parameters.
        """
        try:
            epoch_counter = self._restore_checkpoint()
        except RuntimeError:
            traceback.print_exc()
            raise ConfigurationError(
                "Could not recover training from the checkpoint.  Did you mean to output to "
                "a different serialization directory or delete the existing serialization "
                "directory?")

        training_util.enable_gradient_clipping(self.model, self._grad_clipping)

        logger.info("Beginning training.")

        train_metrics: Dict[str, float] = {}
        val_metrics: Dict[str, float] = {}
        this_epoch_val_metric: float = None
        metrics: Dict[str, Any] = {}
        epochs_trained = 0
        training_start_time = time.time()

        metrics["best_epoch"] = self._metric_tracker.best_epoch
        for key, value in self._metric_tracker.best_epoch_metrics.items():
            metrics["best_validation_" + key] = value

        for epoch in range(epoch_counter, self._num_epochs):
            epoch_start_time = time.time()
            train_metrics = self._train_epoch(epoch)

            # get peak of memory usage
            if "cpu_memory_MB" in train_metrics:
                metrics["peak_cpu_memory_MB"] = max(
                    metrics.get("peak_cpu_memory_MB", 0),
                    train_metrics["cpu_memory_MB"])
            for key, value in train_metrics.items():
                if key.startswith("gpu_"):
                    metrics["peak_" + key] = max(metrics.get("peak_" + key, 0),
                                                 value)

            if self._validation_dataloader is not None:
                with torch.no_grad():
                    # We have a validation set, so compute all the metrics on it.
                    val_loss, num_batches = self._validation_loss()
                    val_metrics = training_util.get_metrics(self.model,
                                                            val_loss,
                                                            num_batches,
                                                            reset=True)

                    # Check validation metric for early stopping
                    this_epoch_val_metric = val_metrics[
                        self._validation_metric]
                    self._metric_tracker.add_metric(this_epoch_val_metric)

                    if self._metric_tracker.should_stop_early():
                        logger.info("Ran out of patience.  Stopping training.")
                        break

            self._tensorboard.log_metrics(
                train_metrics,
                val_metrics=val_metrics,
                log_to_console=True,
                epoch=epoch + 1,
            )  # +1 because tensorboard doesn't like 0

            # Create overall metrics dict
            training_elapsed_time = time.time() - training_start_time
            metrics["training_duration"] = str(
                datetime.timedelta(seconds=training_elapsed_time))
            metrics["training_start_epoch"] = epoch_counter
            metrics["training_epochs"] = epochs_trained
            metrics["epoch"] = epoch

            for key, value in train_metrics.items():
                metrics["training_" + key] = value
            for key, value in val_metrics.items():
                metrics["validation_" + key] = value

            if self._metric_tracker.is_best_so_far():
                # Update all the best_ metrics.
                # (Otherwise they just stay the same as they were.)
                metrics["best_epoch"] = epoch
                for key, value in val_metrics.items():
                    metrics["best_validation_" + key] = value

                self._metric_tracker.best_epoch_metrics = val_metrics

            if self._serialization_dir:
                dump_metrics(
                    os.path.join(self._serialization_dir,
                                 f"metrics_epoch_{epoch}.json"),
                    metrics,
                )

            # The Scheduler API is agnostic to whether your schedule requires a validation metric -
            # if it doesn't, the validation metric passed here is ignored.
            if self._learning_rate_scheduler:
                self._learning_rate_scheduler.step(this_epoch_val_metric,
                                                   epoch)
            if self._momentum_scheduler:
                self._momentum_scheduler.step(this_epoch_val_metric, epoch)

            self._save_checkpoint(epoch)

            epoch_elapsed_time = time.time() - epoch_start_time
            logger.info("Epoch duration: %s",
                        datetime.timedelta(seconds=epoch_elapsed_time))

            if epoch < self._num_epochs - 1:
                training_elapsed_time = time.time() - training_start_time
                estimated_time_remaining = training_elapsed_time * (
                    (self._num_epochs - epoch_counter) /
                    float(epoch - epoch_counter + 1) - 1)
                formatted_time = str(
                    datetime.timedelta(seconds=int(estimated_time_remaining)))
                logger.info("Estimated training time remaining: %s",
                            formatted_time)

            if self._custom_logger is not None:
                if self._temperature_scheduler is not None:
                    temperature = self._temperature_scheduler.get_temperature()
                else:
                    temperature = None
                self._custom_logger.log_epoch(epoch, temperature=temperature)

            # Update temperature after each epoch
            if self._temperature_scheduler is not None:
                self._temperature_scheduler.step()

            epochs_trained += 1

        # Load the best model state before returning
        best_model_state = self._checkpointer.best_model_state()
        if best_model_state:
            self.model.load_state_dict(best_model_state)

        return metrics