Пример #1
class TrainingOperator:
    """Abstract class to define training and validation state and logic.

    You must subclass this class and override the ``setup`` method to define
    your training components such as the model, optimizer, data, loss,
    and scheduler. When you pass this class to ``TorchTrainer``, a copy of
    this class will be made on each worker.

    .. code-block:: python

        class MyTrainingOperator(TrainingOperator):

            def setup(self, config):
                model = nn.Linear(1, 1)
                optimizer = torch.optim.SGD(
                    model.parameters(), lr=config.get("lr", 1e-4))
                loss = torch.nn.MSELoss()

                batch_size = config["batch_size"]
                train_data, val_data = LinearDataset(2, 5), LinearDataset(2, 5)
                train_loader = DataLoader(train_data, batch_size=batch_size)
                val_loader = DataLoader(val_data, batch_size=batch_size)

                self.model, self.optimizer = self.register(


        trainer = TorchTrainer(
            config={"batch_size": 32},
        for i in range(4):

    This class provides default implementations for training and validation.
    Set ``self.model``, ``self.optimizer``, and
    ``self.criterion`` to leverage the default training and validation loops.
    If ``self.scheduler`` is set, it will only be called at a batch or epoch
    frequency, depending on the user parameter. Set
    ``scheduler_step_freq`` in ``TorchTrainer`` to either "batch" or "epoch"
    to increment the scheduler correctly during training. If using a
    learning rate scheduler that depends on validation loss, you can use

    If you want to provide custom training and validation loops, you can do
    so using this class as well. There are two granularities that
    you can provide customization: per epoch or per batch.
    You do not need to override both.

    .. image:: raysgd-custom.jpg
        :scale: 80%
        :align: center

    If you are using multiple models, optimizers, or schedulers, you must
    implement custom training and validation.

            You are expected to either set ``self.model``,
            ``self.optimizer``, and ``self.criterion`` instance attributes in
            setup or implement custom training & validation.
    def __init__(self,
        # You are not expected to override this method.
        self._world_rank = world_rank
        self._config = config
        self._use_fp16 = use_fp16
        self._device_ids = device_ids
        self._use_gpu = use_gpu and torch.cuda.is_available()
        self._device = torch.device("cuda" if self._use_gpu else "cpu")
        if tqdm is None and use_tqdm:
            raise ValueError("tqdm must be installed to use tqdm in training.")
        self._use_tqdm = use_tqdm
        self.global_step = 0
        self._apex_args = apex_args if apex_args else {}
        self._wrap_ddp = wrap_ddp
        self._wrap_distributed_sampler = wrap_distributed_sampler
        self._add_dist_sampler = add_dist_sampler
        self._scheduler_step_freq = scheduler_step_freq

        self.timers = TimerCollection()

    def _set_timers(self, timers):
        """Passes in the timers from the Runner."""
        self.timers = timers

    def setup(self, config):
        """Override this method to implement operator setup.

        You should call self.register and self.register_data here to
        register training components and data loaders with Ray SGD.

            config (dict): Custom configuration value to be passed to
                all creator and operator constructors. Same as ``self.config``.
        raise NotImplementedError

    def register(self, *, models, optimizers, criterion=None, schedulers=None):
        """Registers parameters with Ray SGD and sets up training components.

        By calling this method to register your models, optimizers,
        criterion, and schedulers, Ray SGD will automatically handle
        necessary setup such as GPU/devices, Distributed Data Parallel, and
        Fp16. The registered components are returned and should be set as
        instance attributes to access during training/validation.

        If more than one model, optimizer, or scheduler is passed in,
        you should implement your own custom training loop.

        .. code-block:: python

            class MyTrainingOperator(TrainingOperator):
                def setup(self, config):
                    model = ...
                    optimizer = ...
                    train_loader = ...
                    val_loader = ...
                    loss = ...

                    self.model, self.optimizer, self.criterion = self.register(
                    models=model, optimizers=optimizer, criterion=loss)

                    # At this point DDP, Cuda, and Fp16
                    # are set up for all our components. We then use
                    # self.model, self.optimizer, etc. in our training loop.


            models (torch.nn.Module or Iterable[nn.Module]): Pytorch model or
                multiple Pytorch models to use for training. If
                `use_gpu=True` is passed into ``TorchTrainer``, and Cuda is
                available, models will automatically be placed on GPU.
                If ``wrap_ddp=True`` is passed into ``TorchTrainer``,
                models will be wrapped in DDP. If wrap_ddp is False,
                you should handle DDP for your models in setup.
            optimizers (torch.optim.Optimizer or Iterable[
                torch.optim.Optimizer]): Pytorch optimizer or multiple Pytorch
                optimizers to use for training.
            criterion (Callable, optional): Function to return loss
                metric given features and target. If not provided,
                must implement a custom training loop.
            schedulers (torch.optim.lr_scheduler or Iterable[
                torch.optim.lr_scheduler], optional): A learning rate
                scheduler or multiple learning rate schedulers.

            Tuple of model, optimizer, criterion if not None, and scheduler
            if not None.

        return_vals = []
        logger.debug("Registering models.")
        self._original_models = models
        if not isinstance(self._original_models, Iterable):
            self._original_models = [self._original_models]
        assert all(
            isinstance(model, nn.Module) for model in self._original_models), (
                f"All models must be PyTorch models: {self._original_models}.")
        if self.use_gpu and torch.cuda.is_available():
            self._original_models = [
                model.cuda() for model in self._original_models

        logger.debug("Registering optimizers.")
        self._optimizers = optimizers
        if not isinstance(self._optimizers, Iterable):
            self._optimizers = [self._optimizers]

        if schedulers:
            logger.debug("Registering scheduler.")
            self._schedulers = schedulers
            if not isinstance(self._schedulers, Iterable):
                self._schedulers = [self._schedulers]
            self._schedulers = None

        if criterion:
            logger.debug("Registering loss.")
            self._criterion = criterion
            if self.use_gpu and torch.cuda.is_available():
                if hasattr(self._criterion, "cuda"):
                    self._criterion = self._criterion.cuda()
            self._criterion = None

        if self.use_fp16 and amp:
            logger.debug("Setting up Apex.")
            self._original_models, self._optimizers = amp.initialize(
                self._original_models, self._optimizers, **self._apex_args)
            self._amp = amp

        if self._wrap_ddp:
            logging.debug("Setting up DDP for models.")
            self._models = [
                DistributedDataParallel(model, device_ids=self.device_ids)
                for model in self._original_models
            self._models = self._original_models

        if len(self._models) == 1:

        if len(self._optimizers) == 1:

        if self._criterion is not None:

        if self._schedulers is not None:
            if self.scheduler_step_freq is None:
                raise ValueError("scheduler_step_freq passed into "
                                 "TorchTrainer cannot be None if you "
                                 "are registering schedulers. Set this to "
                                 "'manual' if you will be manually stepping "
                                 "the schedulers.")
            if len(self._schedulers) == 1:

        return tuple(return_vals)

    def register_data(self, *, train_loader=None, validation_loader=None):
        """Registers data loaders with Ray SGD.

        Calling this method will automatically setup Distributed Sampler for
        these data loaders if add_dist_sampler=True is passed into the
        TorchTrainer. This method does not return the wrapped data loaders.
        You should use the iterators passed into train_epoch and validate

        .. code-block:: python

            class MyTrainingOperator(TrainingOperator):
                def setup(self, config):
                    model = ...
                    optimizer = ...
                    train_loader = ...
                    val_loader = ...
                    loss = ...

                    self.model, self.optimizer, self.criterion = self.register(
                    models=model, optimizers=optimizer, criterion=loss)


                    # At this point the data loaders are registered with
                    # Ray SGD and are wrapped with Distributed Samplers if
                    # applicable.

                def train_epoch(self, iterator, info):
                    # If providing custom training or validation methods,
                    # the registered data loaders are passed in through the
                    # iterator parameter.

            train_loader (Iterator): An iterator for training
                data. If None is explicitly passed in, a Ray SGD Dataset
                must be passed in through TorchTrainer.train. Ray SGD will
                automatically use a Distributed Sampler if TorchTrainer(...,
            validation_loader (Iterator): An iterator for validation
                data. Ray SGD will automatically use a Distributed Sampler
                if TorchTrainer(..., add_dist_sampler=True).

        logger.debug("Registering data loaders..")
        self._train_loader = train_loader
        self._validation_loader = validation_loader

        if self._wrap_distributed_sampler:
            logging.debug("Wrapping data loaders with DistributedSampler.")

            def with_sampler(loader):
                # Automatically set the DistributedSampler
                data_loader_args = {
                    "dataset": loader.dataset,
                    "batch_size": loader.batch_size,
                    "shuffle": False,
                    "num_workers": loader.num_workers,
                    "collate_fn": loader.collate_fn,
                    "pin_memory": loader.pin_memory,
                    "drop_last": loader.drop_last,
                    "timeout": loader.timeout,
                    "worker_init_fn": loader.worker_init_fn,
                    "sampler": DistributedSampler(loader.dataset)
                return DataLoader(**data_loader_args)

            def should_wrap_dataloader(loader):
                return (isinstance(loader, DataLoader)
                        and not isinstance(loader.dataset, IterableDataset))

            if should_wrap_dataloader(self._train_loader):
                if self._add_dist_sampler:
                    self._train_loader = with_sampler(self._train_loader)

            if self._validation_loader is not None and should_wrap_dataloader(
                if self._add_dist_sampler:
                    self._validation_loader = with_sampler(

    def train_epoch(self, iterator, info):
        """Runs one standard training pass over the training dataloader.

        By default, this method will iterate over the given iterator and
        call ``self.train_batch`` over each batch. If ``scheduler_step_freq``
        is set, this default method will also step the scheduler accordingly.

        You do not need to call ``train_batch`` in this method if you plan
        to implement a custom optimization/training routine here.

        You may find ``ray.util.sgd.utils.AverageMeterCollection`` useful
        when overriding this method. See example below:

        .. code-block:: python

            def train_epoch(self, ...):
                meter_collection = AverageMeterCollection()
                for batch in iterator:
                    # do some processing
                    metrics = {"metric_1": 1, "metric_2": 3} # dict of metrics

                    # This keeps track of all metrics across multiple batches
                    meter_collection.update(metrics, n=len(batch))

                # Returns stats of the meters.
                stats = meter_collection.summary()
                return stats

            iterator (iter): Iterator over the training data for the entire
                epoch. This iterator is expected to be entirely consumed.
            info (dict): Dictionary for information to be used for custom
                training operations.

            A dict of metrics from training.
        if not hasattr(self, "model"):
            raise RuntimeError("Either set self.model in setup function or "
                               "override this method to implement a custom "
                               "training loop.")
        model = self.model
        scheduler = None
        if hasattr(self, "scheduler"):
            scheduler = self.scheduler

        if self.use_tqdm and self.world_rank == 0:
            desc = ""
            if info is not None and "epoch_idx" in info:
                if "num_epochs" in info:
                    desc = f"{info['epoch_idx'] + 1}/{info['num_epochs']}e"
                    desc = f"{info['epoch_idx'] + 1}e"

            # TODO: Implement len for Dataset?
            total = info[NUM_STEPS]
            if total is None:
                if hasattr(iterator, "__len__"):
                    total = len(iterator)

            _progress_bar = tqdm(total=total,

        metric_meters = AverageMeterCollection()

        for batch_idx, batch in enumerate(iterator):
            batch_info = {
                "batch_idx": batch_idx,
                "global_step": self.global_step
            metrics = self.train_batch(batch, batch_info=batch_info)

            if self.use_tqdm and self.world_rank == 0:
                _progress_bar.n = batch_idx + 1
                postfix = {}
                if "train_loss" in metrics:

            if scheduler and self.scheduler_step_freq == SCHEDULER_STEP_BATCH:

            metric_meters.update(metrics, n=metrics.pop(NUM_SAMPLES, 1))
            self.global_step += 1

        if scheduler and self.scheduler_step_freq == SCHEDULER_STEP_EPOCH:

        return metric_meters.summary()

    def train_batch(self, batch, batch_info):
        """Computes loss and updates the model over one batch.

        This method is responsible for computing the loss and gradient and
        updating the model.

        By default, this method implementation assumes that batches
        are in (\\*features, labels) format. So we also support multiple inputs
        model. If using amp/fp16 training, it will also scale the loss

        You can provide custom loss metrics and training operations if you
        override this method.

        You do not need to override this method if you plan to
        override ``train_epoch``.

            batch: One item of the validation iterator.
            batch_info (dict): Information dict passed in from ``train_epoch``.

            A dictionary of metrics.
                By default, this dictionary contains "loss" and "num_samples".
                "num_samples" corresponds to number of datapoints in the batch.
                However, you can provide any number of other values.
                Consider returning "num_samples" in the metrics because
                by default, ``train_epoch`` uses "num_samples" to
                calculate averages.

        if not hasattr(self, "model"):
            raise RuntimeError("Either set self.model in setup function or "
                               "override this method to implement a custom "
                               "training loop.")
        if not hasattr(self, "optimizer"):
            raise RuntimeError("Either set self.optimizer in setup function "
                               "or override this method to implement a custom "
                               "training loop.")
        if not hasattr(self, "criterion"):
            raise RuntimeError("Either set self.criterion in setup function "
                               "or override this method to implement a custom "
                               "training loop.")
        model = self.model
        optimizer = self.optimizer
        criterion = self.criterion
        # unpack features into list to support multiple inputs model
        *features, target = batch
        # Create non_blocking tensors for distributed training
        if self.use_gpu:
            features = [
                feature.cuda(non_blocking=True) for feature in features
            target = target.cuda(non_blocking=True)

        # Compute output.
        with self.timers.record("fwd"):
            output = model(*features)
            loss = criterion(output, target)

        # Compute gradients in a backward pass.
        with self.timers.record("grad"):
            if self.use_fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:

        # Call step of optimizer to update model params.
        with self.timers.record("apply"):

        return {"train_loss": loss.item(), NUM_SAMPLES: features[0].size(0)}

    def validate(self, val_iterator, info):
        """Runs one standard validation pass over the val_iterator.

        This will call ``model.eval()`` and ``torch.no_grad`` when iterating
        over the validation dataloader.

        You also do not need to call ``validate_batch`` if overriding this

            val_iterator (iter): Iterable constructed from the
                validation dataloader.
            info: (dict): Dictionary for information to be used for custom
                validation operations.

            A dict of metrics from the evaluation.
                By default, returns "val_accuracy" and "val_loss"
                which is computed by aggregating "loss" and "correct" values
                from ``validate_batch`` and dividing it by the sum of
                ``num_samples`` from all calls to ``self.validate_batch``.
        if not hasattr(self, "model"):
            raise RuntimeError("Either set self.model in setup function or "
                               "override this method to implement a custom "
                               "validation loop.")
        model = self.model
        metric_meters = AverageMeterCollection()

        # switch to evaluate mode
        with torch.no_grad():
            for batch_idx, batch in enumerate(val_iterator):
                batch_info = {"batch_idx": batch_idx}
                metrics = self.validate_batch(batch, batch_info)
                metric_meters.update(metrics, n=metrics.pop(NUM_SAMPLES, 1))

        return metric_meters.summary()

    def validate_batch(self, batch, batch_info):
        """Calcuates the loss and accuracy over a given batch.

        You can override this method to provide arbitrary metrics.

        Same as ``train_batch``, this method implementation assumes that
        batches are in (\\*features, labels) format by default. So we also
        support multiple inputs model.

            batch: One item of the validation iterator.
            batch_info (dict): Contains information per batch from

            A dict of metrics.
                By default, returns "val_loss", "val_accuracy", and
                "num_samples". When overriding, consider returning
                "num_samples" in the metrics because
                by default, ``validate`` uses "num_samples" to
                calculate averages.
        if not hasattr(self, "model"):
            raise RuntimeError("Either set self.model in setup function or "
                               "override this method to implement a custom "
                               "training loop.")
        if not hasattr(self, "criterion"):
            raise RuntimeError("Either set self.criterion in setup function "
                               "or override this method to implement a custom "
                               "training loop.")
        model = self.model
        criterion = self.criterion
        # unpack features into list to support multiple inputs model
        *features, target = batch
        if self.use_gpu:
            features = [
                feature.cuda(non_blocking=True) for feature in features
            target = target.cuda(non_blocking=True)

        # compute output

        with self.timers.record("eval_fwd"):
            output = model(*features)
            loss = criterion(output, target)
            _, predicted = torch.max(output.data, 1)

        num_correct = (predicted == target).sum().item()
        num_samples = target.size(0)
        return {
            "val_loss": loss.item(),
            "val_accuracy": num_correct / num_samples,
            NUM_SAMPLES: num_samples

    def state_dict(self):
        """Override this to return a representation of the operator state.
        Any argument passed into self.register and self.register_data will
        automatically be saved.
        Use this method to save any additional state. If your TorchTrainer
        is on a CPU-only machine, make sure this method converts all state
        to be CPU-compatible.

            dict: The state dict of the operator."""

    def load_state_dict(self, state_dict):
        """Override this to load the representation of the operator state.
        Anything passed into self.register and self.register_data will
        automatically be loaded. Use this method to load any additional state.
            state_dict (dict): State dict as returned by the operator. """

    def from_creators(cls,
        """A utility method to create a custom TrainingOperator class from
        creator functions. This is useful for backwards compatibility with
        previous versions of Ray. To provide custom training and validation,
        you should subclass the class that is returned by this method instead
        of ``TrainingOperator``.

            model_creator (dict -> Model(s)): Constructor function that takes
                in config and returns the model(s) to be optimized. These
                must be ``torch.nn.Module`` objects. If multiple models are
                returned, a ``training_operator_cls`` must be specified.
                You do not need to handle GPU/devices in this function;
                RaySGD will do that under the hood.
            data_creator (dict -> Iterable(s)): Constructor function
                that takes in the passed config and returns one or
                two Iterable objects. Note that even though two Iterable
                objects can be returned, only one will be used for training,
                and the other will be used for validation. If not provided,
                you must pass in a Dataset to ``TorchTrainer.train``.
            optimizer_creator ((models, dict) -> optimizers): Constructor
                function that takes in the return values from
                ``model_creator`` and the passed config and returns One or
                more Torch optimizer objects. You do not need to handle
                GPU/devices in this function; ``RaySGD`` will do that for you.
            loss_creator (torch.nn.*Loss class | dict -> loss): A constructor
                function for the training loss. This can be either a function
                that takes in the provided config for customization or a
                subclass of ``torch.nn.modules.loss._Loss``, which is most
                Pytorch loss classes. For example,
                ``loss_creator=torch.nn.BCELoss``. If not provided, you must
                provide a custom TrainingOperator.
            scheduler_creator ((optimizers, dict) -> scheduler):
                A constructor function for the torch scheduler. This is
                a function that takes in the generated optimizers (from
                ``optimizer_creator``) provided config for customization.
                Be sure to set ``scheduler_step_freq`` to increment the
                scheduler correctly.
            serialize_data_creation (bool): A filelock will be used
                to ensure no race conditions in data downloading among
                different workers on the same node (using the local file
                system). Defaults to True.

            A TrainingOperator class with a ``setup`` method that utilizes
            the passed in creator functions.

        if not (callable(model_creator) and callable(optimizer_creator)):
            raise ValueError(
                "Must provide a callable model_creator and optimizer_creator.")

        class CustomCreatorOperator(CreatorOperator):
            _model_creator = model_creator
            _optimizer_creator = optimizer_creator
            _data_creator = data_creator
            _loss_creator = loss_creator
            _scheduler_creator = scheduler_creator
            _serialize_data_creation = serialize_data_creation

        return CustomCreatorOperator

    def device(self):
        """torch.device: The appropriate torch device, at your convenience."""
        return self._device

    def config(self):
        """dict: Provided into TorchTrainer."""
        return self._config

    def world_rank(self):
        """int: The rank of the parent runner. Always 0 if not distributed."""
        return self._world_rank

    def use_gpu(self):
        """Returns True if cuda is available and use_gpu is True."""
        return self._use_gpu

    def use_fp16(self):
        """bool: Whether the model and optimizer have been FP16 enabled."""
        return self._use_fp16

    def use_tqdm(self):
        """bool: Whether tqdm progress bars are enabled."""
        return self._use_tqdm

    def device_ids(self):
        """List[int]: Device IDs for the model.

        This is useful for using batch norm with DistributedDataParallel.
        return self._device_ids

    def scheduler_step_freq(self):
        """Optional[str]: The ``scheduler_step_freq`` passed into

        This is useful to determine when to call scheduler.step.
        return self._scheduler_step_freq
Пример #2
class TrainingOperator:
    """Abstract class for custom training or validation loops.

    The scheduler will only be called at a batch or epoch frequency, depending
    on the user parameter. Be sure to set ``scheduler_step_freq`` in
    ``TorchTrainer`` to either "batch" or "epoch" to increment the scheduler
    correctly during training. If using a learning rate scheduler
    that depends on validation loss, you can use ``trainer.update_scheduler``.

    For both training and validation, there are two granularities that
    you can provide customization: per epoch or per batch.
    You do not need to override both.

    .. image:: raysgd-custom.jpg
        :scale: 80%
        :align: center

        ValueError if multiple models/optimizers/schedulers are provided.
            You are expected to subclass this class if you wish
            to train over multiple models/optimizers/schedulers.
    def __init__(self,
        # You are not expected to override this method.
        self._models = models  # List of models
        assert isinstance(models, collections.Iterable), (
            "Components need to be iterable. Got: {}".format(type(models)))
        self._optimizers = optimizers  # List of optimizers
        assert isinstance(optimizers, collections.Iterable), (
            "Components need to be iterable. Got: {}".format(type(optimizers)))
        self._train_loader = train_loader
        self._validation_loader = validation_loader
        self._world_rank = world_rank
        self._criterion = criterion
        self._schedulers = schedulers
        if schedulers:
            assert isinstance(schedulers, collections.Iterable), (
                "Components need to be iterable. Got: {}".format(
        self._config = config
        self._use_fp16 = use_fp16
        self._device_ids = device_ids
        self._use_gpu = use_gpu and torch.cuda.is_available()
        self._device = torch.device("cuda" if self._use_gpu else "cpu")
        if tqdm is None and use_tqdm:
            raise ValueError("tqdm must be installed to use tqdm in training.")
        self._use_tqdm = use_tqdm
        self.global_step = 0

        if type(self) is TrainingOperator:
            for component in (models, schedulers, optimizers):
                if _is_multiple(component):
                    raise ValueError(
                        "Need to provide a custom operator subclassing "
                        "TrainingOperator if using multi-scheduler, "
                        "multi-model or multi-optimizer training/validation.")
        self.timers = TimerCollection()

    def _set_timers(self, timers):
        """Passes in the timers from the Runner."""
        self.timers = timers

    def setup(self, config):
        """Override this method to implement custom operator setup.

            config (dict): Custom configuration value to be passed to
                all creator and operator constructors. Same as ``self.config``.

    def train_epoch(self, iterator, info):
        """Runs one standard training pass over the training dataloader.

        By default, this method will iterate over the given iterator and
        call ``self.train_batch`` over each batch. If ``scheduler_step_freq``
        is set, this default method will also step the scheduler accordingly.

        You do not need to call ``train_batch`` in this method if you plan
        to implement a custom optimization/training routine here.

        You may find ``ray.util.sgd.utils.AverageMeterCollection`` useful
        when overriding this method. See example below:

        .. code-block:: python

            def train_epoch(self, ...):
                meter_collection = AverageMeterCollection()
                for batch in iterator:
                    # do some processing
                    metrics = {"metric_1": 1, "metric_2": 3} # dict of metrics

                    # This keeps track of all metrics across multiple batches
                    meter_collection.update(metrics, n=len(batch))

                # Returns stats of the meters.
                stats = meter_collection.summary()
                return stats

            iterator (iter): Iterator over the training data for the entire
                epoch. This iterator is expected to be entirely consumed.
            info (dict): Dictionary for information to be used for custom
                training operations.

            A dict of metrics from training.
        if self.use_tqdm and self.world_rank == 0:
            desc = ""
            if info is not None and "epoch_idx" in info:
                if "num_epochs" in info:
                    desc = "{}/{}e".format(info["epoch_idx"] + 1,
                    desc = "{}e".format(info["epoch_idx"] + 1)
            _progress_bar = tqdm(total=info[NUM_STEPS]
                                 or len(self.train_loader),

        metric_meters = AverageMeterCollection()

        for batch_idx, batch in enumerate(iterator):
            batch_info = {
                "batch_idx": batch_idx,
                "global_step": self.global_step
            metrics = self.train_batch(batch, batch_info=batch_info)

            if self.use_tqdm and self.world_rank == 0:
                _progress_bar.n = batch_idx + 1
                postfix = {}
                if "train_loss" in metrics:

            if self.scheduler and batch_info.get(

            metric_meters.update(metrics, n=metrics.pop(NUM_SAMPLES, 1))
            self.global_step += 1

        if self.scheduler and info.get(SCHEDULER_STEP) == SCHEDULER_STEP_EPOCH:

        return metric_meters.summary()

    def train_batch(self, batch, batch_info):
        """Computes loss and updates the model over one batch.

        This method is responsible for computing the loss and gradient and
        updating the model.

        By default, this method implementation assumes that batches
        are in (features, labels) format. If using amp/fp16
        training, it will also scale the loss automatically.

        You can provide custom loss metrics and training operations if you
        override this method. If overriding this method, you can access model,
        optimizer, criterion via ``self.model``, ``self.optimizer``,
        and ``self.criterion``.

        You do not need to override this method if you plan to
        override ``train_epoch``.

            batch: One item of the validation iterator.
            batch_info (dict): Information dict passed in from ``train_epoch``.

            A dictionary of metrics.
                By default, this dictionary contains "loss" and "num_samples".
                "num_samples" corresponds to number of datapoints in the batch.
                However, you can provide any number of other values.
                Consider returning "num_samples" in the metrics because
                by default, ``train_epoch`` uses "num_samples" to
                calculate averages.

        features, target = batch
        # Create non_blocking tensors for distributed training
        if torch.cuda.is_available():
            features = features.cuda(non_blocking=True)
            target = target.cuda(non_blocking=True)

        # Compute output.
        with self.timers.record("fwd"):
            output = self.model(features)
            loss = self.criterion(output, target)

        # Compute gradients in a backward pass.
        with self.timers.record("grad"):
            if self.use_fp16:
                with amp.scale_loss(loss, self.optimizer) as scaled_loss:

        # Call step of optimizer to update model params.
        with self.timers.record("apply"):

        return {"train_loss": loss.item(), NUM_SAMPLES: features.size(0)}

    def validate(self, val_iterator, info):
        """Runs one standard validation pass over the val_iterator.

        This will call ``model.eval()`` and ``torch.no_grad`` when iterating
        over the validation dataloader.

        If overriding this method, you can access model, criterion via
        ``self.model`` and ``self.criterion``. You also do not need to call
        ``validate_batch`` if overriding this method.

            val_iterator (iter): Iterable constructed from the
                validation dataloader.
            info: (dict): Dictionary for information to be used for custom
                validation operations.

            A dict of metrics from the evaluation.
                By default, returns "val_accuracy" and "val_loss"
                which is computed by aggregating "loss" and "correct" values
                from ``validate_batch`` and dividing it by the sum of
                ``num_samples`` from all calls to ``self.validate_batch``.
        metric_meters = AverageMeterCollection()

        # switch to evaluate mode
        with torch.no_grad():
            for batch_idx, batch in enumerate(val_iterator):
                batch_info = {"batch_idx": batch_idx}
                metrics = self.validate_batch(batch, batch_info)
                metric_meters.update(metrics, n=metrics.pop(NUM_SAMPLES, 1))

        return metric_meters.summary()

    def validate_batch(self, batch, batch_info):
        """Calcuates the loss and accuracy over a given batch.

        You can override this method to provide arbitrary metrics.

            batch: One item of the validation iterator.
            batch_info (dict): Contains information per batch from

            A dict of metrics.
                By default, returns "val_loss", "val_accuracy", and
                "num_samples". When overriding, consider returning
                "num_samples" in the metrics because
                by default, ``validate`` uses "num_samples" to
                calculate averages.
        features, target = batch
        if torch.cuda.is_available():
            features = features.cuda(non_blocking=True)
            target = target.cuda(non_blocking=True)

        # compute output

        with self.timers.record("eval_fwd"):
            output = self.model(features)
            loss = self.criterion(output, target)
            _, predicted = torch.max(output.data, 1)

        num_correct = (predicted == target).sum().item()
        num_samples = target.size(0)
        return {
            "val_loss": loss.item(),
            "val_accuracy": num_correct / num_samples,
            NUM_SAMPLES: num_samples

    def state_dict(self):
        """Override this to return a representation of the operator state.

            dict: The state dict of the operator."""

    def load_state_dict(self, state_dict):
        """Override this to load the representation of the operator state.

            state_dict (dict): State dict as returned by the operator. """

    def device(self):
        """torch.device: The appropriate torch device, at your convenience."""
        return self._device

    def config(self):
        """dict: Provided into TorchTrainer."""
        return self._config

    def model(self):
        """First or only model created by the provided ``model_creator``."""
        return self._models[0]

    def models(self):
        """List of models created by the provided ``model_creator``."""
        return self._models

    def optimizer(self):
        """First or only optimizer(s) created by the ``optimizer_creator``."""
        return self._optimizers[0]

    def optimizers(self):
        """List of optimizers created by the ``optimizer_creator``."""
        return self._optimizers

    def train_loader(self):
        """Iterable: 1st Dataloader from ``data_creator``.
        return self._train_loader

    def validation_loader(self):
        """Iterable: 2nd Dataloader from ``data_creator``."""
        return self._validation_loader

    def world_rank(self):
        """int: The rank of the parent runner. Always 0 if not distributed."""
        return self._world_rank

    def criterion(self):
        """Criterion created by the provided ``loss_creator``."""
        return self._criterion

    def scheduler(self):
        """First or only scheduler(s) created by the ``scheduler_creator``."""
        if self._schedulers:
            return self._schedulers[0]

    def schedulers(self):
        """List of schedulers created by the ``scheduler_creator``."""
        return self._schedulers

    def use_fp16(self):
        """bool: Whether the model and optimizer have been FP16 enabled."""
        return self._use_fp16

    def use_tqdm(self):
        """bool: Whether tqdm progress bars are enabled."""
        return self._use_tqdm

    def device_ids(self):
        """List[int]: Device IDs for the model.

        This is useful for using batch norm with DistributedDataParallel.
        return self._device_ids