Exemple #1
0
    def train_epoch(self, iterator, info):
        """Runs one standard training pass over the train_iterator.

        By default, this method will iterate over the given iterator and
        call ``self.train_batch`` over each batch.

        If ``scheduler_step_freq`` is set, this class will also step the
        scheduler accordingly.

        You do not need to call ``train_batch`` in this method if you plan
        to implement a custom optimization/training routine here.

        Args:
            iterator (iter): Iterator over the training data for the entire
                epoch. This iterator is expected to be entirely consumed.
            info (dict): Dictionary for information to be used for custom
                training operations.

        Returns:
            A dict of metrics from training.
        """
        self._losses = AverageMeter()

        self.model.train()
        with self.timers["epoch_time"]:
            for batch_idx, batch in enumerate(iterator):
                batch_info = {
                    "batch_idx": batch_idx,
                    "global_step": self.global_step
                }
                batch_info.update(info)
                metrics = self.train_batch(batch, batch_info=batch_info)

                if self.scheduler and batch_info.get(
                        SCHEDULER_STEP) == SCHEDULER_STEP_BATCH:
                    self.scheduler.step()

                if "loss" in metrics:
                    self._losses.update(metrics["loss"],
                                        n=metrics.get("num_samples", 1))
                self.global_step += 1

        if self.scheduler and info.get(SCHEDULER_STEP) == SCHEDULER_STEP_EPOCH:
            self.scheduler.step()

        stats = {
            BATCH_COUNT: batch_idx + 1,
            "mean_train_loss": self._losses.avg,
            "last_train_loss": self._losses.val,
            "epoch_time": self.timers["epoch_time"].last
        }
        stats.update({
            timer_tag: timer.mean
            for timer_tag, timer in self.timers.items()
        })
        return stats
Exemple #2
0
    def validate(self, val_iterator, info):
        """Runs one standard validation pass over the val_iterator.

        This will call ``model.eval()`` and ``torch.no_grad`` when iterating
        over the validation dataset.

        If overriding this method, you can access model, criterion via
        ``self.model`` and ``self.criterion``. You also do not need to call
        ``validate_batch`` if overriding this method.

        Args:
            val_iterator (iter): Iterable constructed over the
                validation dataset.
            info: (dict): Dictionary for information to be used for custom
                validation operations.

        Returns:
            A dict of metrics from the evaluation.
                By default, returns "mean_accuracy" and "mean_validation_loss"
                which is computed by aggregating "loss" and "correct" values
                from ``validate_batch`` and dividing it by the sum of
                ``num_samples`` from all calls to ``self.validate_batch``.
        """
        losses = AverageMeter()
        total_correct = 0

        # switch to evaluate mode
        self.model.eval()
        with torch.no_grad():
            for batch_idx, batch in enumerate(val_iterator):
                batch_info = {"batch_idx": batch_idx}
                batch_info.update(info)
                metrics = self.validate_batch(batch, batch_info)
                if "loss" in metrics:
                    losses.update(metrics["loss"],
                                  n=metrics.get("num_samples", 1))

                if "num_correct" in metrics:
                    total_correct += metrics["num_correct"]

        stats = {
            "batch_count": batch_idx + 1,
            "mean_validation_loss": losses.avg,
            "mean_accuracy": total_correct / losses.count
        }
        return stats
Exemple #3
0
    def train_epoch(self, iterator, info):
        self._losses = AverageMeter()

        self.model.train()
        try:
            with self.timers["epoch_time"]:
                for batch_idx, batch in enumerate(iterator):
                    batch_info = {
                        "batch_idx": batch_idx,
                        "global_step": self.global_step
                    }
                    batch_info.update(info)
                    metrics = self.train_batch(batch, batch_info=batch_info)

                    if self.scheduler and batch_info.get(
                            SCHEDULER_STEP) == SCHEDULER_STEP_BATCH:
                        self.scheduler.step()

                    if "loss" in metrics:
                        self._losses.update(metrics["loss"],
                                            n=metrics.get("num_samples", 1))
                    self.global_step += 1
        except Exception as e:
            print("CAUGHT EXCEPTION", e)
        finally:
            if self.index == 0:
                print("[TRAINER] FINALIZING EPOCH AND SAVING")
                self.save()  # save no matter what

        if self.scheduler and info.get(SCHEDULER_STEP) == SCHEDULER_STEP_EPOCH:
            self.scheduler.step()

        stats = {
            BATCH_COUNT: batch_idx + 1,
            "mean_train_loss": self._losses.avg,
            "last_train_loss": self._losses.val,
            "epoch_time": self.timers["epoch_time"].last
        }
        stats.update({
            timer_tag: timer.mean
            for timer_tag, timer in self.timers.items()
        })
        return stats
Exemple #4
0
class TrainingOperator:
    """Abstract class for custom training or validation loops.

    The scheduler will only be called at a batch or epoch frequency, depending
    on the user parameter. Be sure to set ``scheduler_step_freq`` in
    ``TorchTrainer`` to either "batch" or "epoch" to increment the scheduler
    correctly during training. If using a learning rate scheduler
    that depends on validation loss, you can use ``trainer.update_scheduler``.

    For both training and validation, there are two granularities that
    you can provide customization: per epoch or per batch.
    You do not need to override both.

    .. image:: raysgd-custom.jpg
        :scale: 80%
        :align: center

    Raises:
        ValueError if multiple models/optimizers/schedulers are provided.
            You are expected to subclass this class if you wish
            to train over multiple models/optimizers/schedulers.
    """
    def __init__(self,
                 config,
                 models,
                 optimizers,
                 criterion,
                 schedulers=None,
                 use_fp16=False):
        # You are not expected to override this method.
        self.timers = {
            k: TimerStat()
            for k in ["fwd", "grad", "apply", "epoch_time"]
        }
        self._validated_customization = False
        self._models = models  # List of models
        assert isinstance(models, collections.Iterable), (
            "Components need to be iterable. Got: {}".format(type(models)))
        self._optimizers = optimizers  # List of optimizers
        assert isinstance(optimizers, collections.Iterable), (
            "Components need to be iterable. Got: {}".format(type(optimizers)))
        self._criterion = criterion
        self._schedulers = schedulers
        if schedulers:
            assert isinstance(schedulers, collections.Iterable), (
                "Components need to be iterable. Got: {}".format(
                    type(schedulers)))
        self._config = config
        self._use_fp16 = use_fp16
        self.global_step = 0

        if type(self) is TrainingOperator:
            for component in (models, schedulers, optimizers):
                if _is_multiple(component):
                    raise ValueError(
                        "Need to provide a custom operator subclassing "
                        "TrainingOperator if using multi-scheduler, "
                        "multi-model or multi-optimizer training/validation.")

        self.setup(config)

    def setup(self, config):
        """Override this method to implement custom operator setup.

        Args:
            config (dict): Custom configuration value to be passed to
                all creator and operator constructors. Same as ``self.config``.
        """
        pass

    def train_epoch(self, iterator, info):
        """Runs one standard training pass over the train_iterator.

        By default, this method will iterate over the given iterator and
        call ``self.train_batch`` over each batch.

        If ``scheduler_step_freq`` is set, this class will also step the
        scheduler accordingly.

        You do not need to call ``train_batch`` in this method if you plan
        to implement a custom optimization/training routine here.

        Args:
            iterator (iter): Iterator over the training data for the entire
                epoch. This iterator is expected to be entirely consumed.
            info (dict): Dictionary for information to be used for custom
                training operations.

        Returns:
            A dict of metrics from training.
        """
        self._losses = AverageMeter()

        self.model.train()
        with self.timers["epoch_time"]:
            for batch_idx, batch in enumerate(iterator):
                batch_info = {
                    "batch_idx": batch_idx,
                    "global_step": self.global_step
                }
                batch_info.update(info)
                metrics = self.train_batch(batch, batch_info=batch_info)

                if self.scheduler and batch_info.get(
                        SCHEDULER_STEP) == SCHEDULER_STEP_BATCH:
                    self.scheduler.step()

                if "loss" in metrics:
                    self._losses.update(metrics["loss"],
                                        n=metrics.get("num_samples", 1))
                self.global_step += 1

        if self.scheduler and info.get(SCHEDULER_STEP) == SCHEDULER_STEP_EPOCH:
            self.scheduler.step()

        stats = {
            BATCH_COUNT: batch_idx + 1,
            "mean_train_loss": self._losses.avg,
            "last_train_loss": self._losses.val,
            "epoch_time": self.timers["epoch_time"].last
        }
        stats.update({
            timer_tag: timer.mean
            for timer_tag, timer in self.timers.items()
        })
        return stats

    def train_batch(self, batch, batch_info):
        """Computes loss and updates the model over one batch.

        This method is responsible for computing the loss and gradient and
        updating the model.

        By default, this method implementation assumes that batches
        are in (features, labels) format. If using amp/fp16
        training, it will also scale the loss automatically.

        You can provide custom loss metrics and training operations if you
        override this method. If overriding this method, you can access model,
        optimizer, criterion via ``self.model``, ``self.optimizer``,
        and ``self.criterion``.

        You do not need to override this method if you plan to
        override ``train_epoch``.

        Args:
            batch: One item of the validation iterator.
            batch_info (dict): Information dict passed in from ``train_epoch``.

        Returns:
            A dictionary of metrics.
                By default, this dictionary contains "loss" and "num_samples".
                "num_samples" corresponds to number of datapoints in the batch.
                However, you can provide any number of other values.

        """
        features, target = batch
        # Create non_blocking tensors for distributed training
        if torch.cuda.is_available():
            features = features.cuda(non_blocking=True)
            target = target.cuda(non_blocking=True)

        # Compute output.
        with self.timers["fwd"]:
            output = self.model(features)
            loss = self.criterion(output, target)

        # Compute gradients in a backward pass.
        with self.timers["grad"]:
            self.optimizer.zero_grad()
            if self.use_fp16:
                with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

        # Call step of optimizer to update model params.
        with self.timers["apply"]:
            self.optimizer.step()
        return {"loss": loss.item(), "num_samples": features.size(0)}

    def validate(self, val_iterator, info):
        """Runs one standard validation pass over the val_iterator.

        This will call ``model.eval()`` and ``torch.no_grad`` when iterating
        over the validation dataset.

        If overriding this method, you can access model, criterion via
        ``self.model`` and ``self.criterion``. You also do not need to call
        ``validate_batch`` if overriding this method.

        Args:
            val_iterator (iter): Iterable constructed over the
                validation dataset.
            info: (dict): Dictionary for information to be used for custom
                validation operations.

        Returns:
            A dict of metrics from the evaluation.
                By default, returns "mean_accuracy" and "mean_validation_loss"
                which is computed by aggregating "loss" and "correct" values
                from ``validate_batch`` and dividing it by the sum of
                ``num_samples`` from all calls to ``self.validate_batch``.
        """
        losses = AverageMeter()
        total_correct = 0

        # switch to evaluate mode
        self.model.eval()
        with torch.no_grad():
            for batch_idx, batch in enumerate(val_iterator):
                batch_info = {"batch_idx": batch_idx}
                batch_info.update(info)
                metrics = self.validate_batch(batch, batch_info)
                if "loss" in metrics:
                    losses.update(metrics["loss"],
                                  n=metrics.get("num_samples", 1))

                if "num_correct" in metrics:
                    total_correct += metrics["num_correct"]

        stats = {
            "batch_count": batch_idx + 1,
            "mean_validation_loss": losses.avg,
            "mean_accuracy": total_correct / losses.count
        }
        return stats

    def validate_batch(self, batch, batch_info):
        """Calcuates the loss and accuracy over a given batch.

        You can override this method to provide arbitrary metrics.

        Args:
            batch: One item of the validation iterator.
            batch_info (dict): Contains information per batch from
                ``validate()``.

        Returns:
            A dict of metrics.
                By default, returns "loss", "num_correct", and "num_samples".
        """
        features, target = batch
        if torch.cuda.is_available():
            features = features.cuda(non_blocking=True)
            target = target.cuda(non_blocking=True)

        # compute output
        output = self.model(features)
        loss = self.criterion(output, target)
        _, predicted = torch.max(output.data, 1)

        return {
            "loss": loss.item(),
            "num_correct": (predicted == target).sum().item(),
            "num_samples": target.size(0)
        }

    def state_dict(self):
        """Returns a serializable representation of the operator state."""
        pass

    def load_state_dict(self, state_dict):
        """Loads a serializable representation of the operator state."""
        pass

    @property
    def config(self):
        """Dictionary as provided into TorchTrainer."""
        return self._config

    @property
    def model(self):
        """First or only model created by the provided ``model_creator``."""
        return self._models[0]

    @property
    def models(self):
        """List of models created by the provided ``model_creator``."""
        return self._models

    @property
    def optimizer(self):
        """First or only optimizer(s) created by the ``optimizer_creator``."""
        return self._optimizers[0]

    @property
    def optimizers(self):
        """List of optimizers created by the ``optimizer_creator``."""
        return self._optimizers

    @property
    def criterion(self):
        """Criterion created by the provided ``loss_creator``."""
        return self._criterion

    @property
    def scheduler(self):
        """First or only scheduler(s) created by the ``scheduler_creator``."""
        if self._schedulers:
            return self._schedulers[0]

    @property
    def schedulers(self):
        """List of schedulers created by the ``scheduler_creator``."""
        return self._schedulers

    @property
    def use_fp16(self):
        """Whether the model and optimizer have been FP16 enabled."""
        return self._use_fp16
Exemple #5
0
class GNN1TrainingOperator(TrainingOperator):
    def __init__(self, *args, **kwargs):
        try:
            self.ckpt_dir = kwargs.pop("ckpt_dir")
            self.ckpt_freq = kwargs.pop("ckpt_freq")
            self.index = kwargs.pop("index")
            self.save_counter = 0
            self.device = torch.device(
                "cuda" if torch.cuda.is_available() else "cpu")
            self.lr = kwargs.pop("lr")
        except KeyError:
            raise Exception("DING DONG!")

        super(GNN1TrainingOperator, self).__init__(*args, **kwargs)
        self.batcher = Batcher(device=self.device)
        if self.index == 0:
            self.logger = TrainLogger(logdir=args[0].get("logdir"))
            self.logger.write_log("SOME HYPERPARAMETERS:")
            for key, value in self.config.items():
                if key == "lr" or key == "optimizer" or key == "batch_size" or key == "ckpt_freq" or key == "data_dir":
                    self.logger.write_log(f"{key}: {value}")  # oops
        self.optim = self.optimizers[0]
        try:
            latest = get_latest_ckpt(self.ckpt_dir)
            print(f"LOADING WEIGHTS FROM {latest}")
            state = torch.load(latest)
            self.model.load_state_dict(state["models"][0], strict=False)
            self.optim.load_state_dict(state["optimizers"][0])
            if self.lr is not None:
                print(f"FORCING LEARNING RATE: {self.lr}")
                for g in self.optim.param_groups:
                    g["lr"] = self.lr
            self.save_counter = state["SAVE_COUNTER"]
            self.global_step = state["GLOBAL_STEP"]
            print(f"SUCCESSFULLY LOADED CHECKPOINT FROM {latest}")
            self.save_counter += 1
        except Exception as e:
            print("WARNING: EXCEPTION ", e)
            pass

    def save(self):
        if self.index == 0:
            ckpt_path = os.path.join(self.ckpt_dir,
                                     f"ckpt_{self.save_counter}.pth")
            torch.save(
                {
                    "models":
                    [{k: v.cpu()
                      for k, v in self.model.state_dict().items()}],
                    "optimizers": [self.optim.state_dict()],
                    "GLOBAL_STEP":
                    self.global_step,
                    "SAVE_COUNTER":
                    self.save_counter
                }, ckpt_path)

            update_index(ckpt_path)
            self.logger.write_log(
                f"[TRAIN LOOP -- HEAD WORKER] SAVED CHECKPOINT TO {ckpt_path}")

    def train_epoch(self, iterator, info):
        self._losses = AverageMeter()

        self.model.train()
        try:
            with self.timers["epoch_time"]:
                for batch_idx, batch in enumerate(iterator):
                    batch_info = {
                        "batch_idx": batch_idx,
                        "global_step": self.global_step
                    }
                    batch_info.update(info)
                    metrics = self.train_batch(batch, batch_info=batch_info)

                    if self.scheduler and batch_info.get(
                            SCHEDULER_STEP) == SCHEDULER_STEP_BATCH:
                        self.scheduler.step()

                    if "loss" in metrics:
                        self._losses.update(metrics["loss"],
                                            n=metrics.get("num_samples", 1))
                    self.global_step += 1
        except Exception as e:
            print("CAUGHT EXCEPTION", e)
        finally:
            if self.index == 0:
                print("[TRAINER] FINALIZING EPOCH AND SAVING")
                self.save()  # save no matter what

        if self.scheduler and info.get(SCHEDULER_STEP) == SCHEDULER_STEP_EPOCH:
            self.scheduler.step()

        stats = {
            BATCH_COUNT: batch_idx + 1,
            "mean_train_loss": self._losses.avg,
            "last_train_loss": self._losses.val,
            "epoch_time": self.timers["epoch_time"].last
        }
        stats.update({
            timer_tag: timer.mean
            for timer_tag, timer in self.timers.items()
        })
        return stats

    def train_batch(self, batch, batch_info):
        num_samples = len(batch)
        drat_loss, core_loss, core_clause_loss, loss, x, l2_loss = train_step(
            self.model,
            self.batcher,
            self.optim,
            batch,
            use_NMSDP_to_sparse2=True,
            CUDA_FLAG=torch.cuda.is_available(),
            device=self.device,
            use_glue_counts=True)
        GLOBAL_STEP_COUNT = batch_info.get("global_step")

        if self.index == 0:
            self.logger.write_log(
                f"[TRAIN LOOP] Finished global step {GLOBAL_STEP_COUNT}. Loss: {loss}."
            )
            self.logger.write_scalar("drat_loss", drat_loss, GLOBAL_STEP_COUNT)
            self.logger.write_scalar("core_loss", core_loss, GLOBAL_STEP_COUNT)
            self.logger.write_scalar("core_clause_loss", core_clause_loss,
                                     GLOBAL_STEP_COUNT)
            self.logger.write_scalar("loss", loss, GLOBAL_STEP_COUNT)
            if (GLOBAL_STEP_COUNT + 1) % self.ckpt_freq == 0:
                if self.save_counter == 0:
                    util.check_make_path(self.ckpt_dir)
                self.save()
                self.save_counter += 1

        return {
            "drat_loss": drat_loss,
            "core_loss": core_loss,
            "loss": loss,
            "num_samples": num_samples
        }