class TrainingOperator: """Abstract class to define training and validation state and logic. You must subclass this class and override the ``setup`` method to define your training components such as the model, optimizer, data, loss, and scheduler. When you pass this class to ``TorchTrainer``, a copy of this class will be made on each worker. .. code-block:: python class MyTrainingOperator(TrainingOperator): def setup(self, config): model = nn.Linear(1, 1) optimizer = torch.optim.SGD( model.parameters(), lr=config.get("lr", 1e-4)) loss = torch.nn.MSELoss() batch_size = config["batch_size"] train_data, val_data = LinearDataset(2, 5), LinearDataset(2, 5) train_loader = DataLoader(train_data, batch_size=batch_size) val_loader = DataLoader(val_data, batch_size=batch_size) self.model, self.optimizer = self.register( models=model, optimizers=optimizer, criterion=loss) self.register_data( train_loader=train_loader, validation_loader=val_loader) trainer = TorchTrainer( training_operator_cls=MyTrainingOperator, config={"batch_size": 32}, use_gpu=True ) for i in range(4): trainer.train() This class provides default implementations for training and validation. Set ``self.model``, ``self.optimizer``, and ``self.criterion`` to leverage the default training and validation loops. If ``self.scheduler`` is set, it will only be called at a batch or epoch frequency, depending on the user parameter. Set ``scheduler_step_freq`` in ``TorchTrainer`` to either "batch" or "epoch" to increment the scheduler correctly during training. If using a learning rate scheduler that depends on validation loss, you can use ``trainer.update_scheduler``. If you want to provide custom training and validation loops, you can do so using this class as well. There are two granularities that you can provide customization: per epoch or per batch. You do not need to override both. .. image:: raysgd-custom.jpg :scale: 80% :align: center If you are using multiple models, optimizers, or schedulers, you must implement custom training and validation. Raises: ValueError You are expected to either set ``self.model``, ``self.optimizer``, and ``self.criterion`` instance attributes in setup or implement custom training & validation. """ def __init__(self, config, world_rank, device_ids=None, use_gpu=False, use_fp16=False, use_tqdm=False, apex_args=None, wrap_ddp=False, wrap_distributed_sampler=False, add_dist_sampler=False, scheduler_step_freq=None): # You are not expected to override this method. self._world_rank = world_rank self._config = config self._use_fp16 = use_fp16 self._device_ids = device_ids self._use_gpu = use_gpu and torch.cuda.is_available() self._device = torch.device("cuda" if self._use_gpu else "cpu") if tqdm is None and use_tqdm: raise ValueError("tqdm must be installed to use tqdm in training.") self._use_tqdm = use_tqdm self.global_step = 0 self._apex_args = apex_args if apex_args else {} self._wrap_ddp = wrap_ddp self._wrap_distributed_sampler = wrap_distributed_sampler self._add_dist_sampler = add_dist_sampler self._scheduler_step_freq = scheduler_step_freq self.timers = TimerCollection() self.setup(config) def _set_timers(self, timers): """Passes in the timers from the Runner.""" self.timers = timers def setup(self, config): """Override this method to implement operator setup. You should call self.register and self.register_data here to register training components and data loaders with Ray SGD. Args: config (dict): Custom configuration value to be passed to all creator and operator constructors. Same as ``self.config``. """ raise NotImplementedError def register(self, *, models, optimizers, criterion=None, schedulers=None): """Registers parameters with Ray SGD and sets up training components. By calling this method to register your models, optimizers, criterion, and schedulers, Ray SGD will automatically handle necessary setup such as GPU/devices, Distributed Data Parallel, and Fp16. The registered components are returned and should be set as instance attributes to access during training/validation. If more than one model, optimizer, or scheduler is passed in, you should implement your own custom training loop. .. code-block:: python class MyTrainingOperator(TrainingOperator): def setup(self, config): model = ... optimizer = ... train_loader = ... val_loader = ... loss = ... self.model, self.optimizer, self.criterion = self.register( models=model, optimizers=optimizer, criterion=loss) # At this point DDP, Cuda, and Fp16 # are set up for all our components. We then use # self.model, self.optimizer, etc. in our training loop. self.register_data(train_loader=train_loader, validation_loader=val_loader) Args: models (torch.nn.Module or Iterable[nn.Module]): Pytorch model or multiple Pytorch models to use for training. If `use_gpu=True` is passed into ``TorchTrainer``, and Cuda is available, models will automatically be placed on GPU. If ``wrap_ddp=True`` is passed into ``TorchTrainer``, models will be wrapped in DDP. If wrap_ddp is False, you should handle DDP for your models in setup. optimizers (torch.optim.Optimizer or Iterable[ torch.optim.Optimizer]): Pytorch optimizer or multiple Pytorch optimizers to use for training. criterion (Callable, optional): Function to return loss metric given features and target. If not provided, must implement a custom training loop. schedulers (torch.optim.lr_scheduler or Iterable[ torch.optim.lr_scheduler], optional): A learning rate scheduler or multiple learning rate schedulers. Returns: Tuple of model, optimizer, criterion if not None, and scheduler if not None. """ return_vals = [] logger.debug("Registering models.") self._original_models = models if not isinstance(self._original_models, Iterable): self._original_models = [self._original_models] assert all( isinstance(model, nn.Module) for model in self._original_models), ( f"All models must be PyTorch models: {self._original_models}.") if self.use_gpu and torch.cuda.is_available(): self._original_models = [ model.cuda() for model in self._original_models ] logger.debug("Registering optimizers.") self._optimizers = optimizers if not isinstance(self._optimizers, Iterable): self._optimizers = [self._optimizers] if schedulers: logger.debug("Registering scheduler.") self._schedulers = schedulers if not isinstance(self._schedulers, Iterable): self._schedulers = [self._schedulers] else: self._schedulers = None if criterion: logger.debug("Registering loss.") self._criterion = criterion if self.use_gpu and torch.cuda.is_available(): if hasattr(self._criterion, "cuda"): self._criterion = self._criterion.cuda() else: self._criterion = None if self.use_fp16 and amp: logger.debug("Setting up Apex.") self._original_models, self._optimizers = amp.initialize( self._original_models, self._optimizers, **self._apex_args) self._amp = amp if self._wrap_ddp: logging.debug("Setting up DDP for models.") self._models = [ DistributedDataParallel(model, device_ids=self.device_ids) for model in self._original_models ] else: self._models = self._original_models if len(self._models) == 1: return_vals.append(self._models[0]) else: return_vals.append(self._models) if len(self._optimizers) == 1: return_vals.append(self._optimizers[0]) else: return_vals.append(self._optimizers) if self._criterion is not None: return_vals.append(self._criterion) if self._schedulers is not None: if self.scheduler_step_freq is None: raise ValueError("scheduler_step_freq passed into " "TorchTrainer cannot be None if you " "are registering schedulers. Set this to " "'manual' if you will be manually stepping " "the schedulers.") if len(self._schedulers) == 1: return_vals.append(self._schedulers[0]) else: return_vals.append(self._schedulers) return tuple(return_vals) def register_data(self, *, train_loader=None, validation_loader=None): """Registers data loaders with Ray SGD. Calling this method will automatically setup Distributed Sampler for these data loaders if add_dist_sampler=True is passed into the TorchTrainer. This method does not return the wrapped data loaders. You should use the iterators passed into train_epoch and validate instead. .. code-block:: python class MyTrainingOperator(TrainingOperator): def setup(self, config): model = ... optimizer = ... train_loader = ... val_loader = ... loss = ... self.model, self.optimizer, self.criterion = self.register( models=model, optimizers=optimizer, criterion=loss) self.register_data(train_loader=train_loader, validation_loader=val_loader) # At this point the data loaders are registered with # Ray SGD and are wrapped with Distributed Samplers if # applicable. def train_epoch(self, iterator, info): # If providing custom training or validation methods, # the registered data loaders are passed in through the # iterator parameter. ... Args: train_loader (Iterator): An iterator for training data. If None is explicitly passed in, a Ray SGD Dataset must be passed in through TorchTrainer.train. Ray SGD will automatically use a Distributed Sampler if TorchTrainer(..., add_dist_sampler=True). validation_loader (Iterator): An iterator for validation data. Ray SGD will automatically use a Distributed Sampler if TorchTrainer(..., add_dist_sampler=True). """ logger.debug("Registering data loaders..") self._train_loader = train_loader self._validation_loader = validation_loader if self._wrap_distributed_sampler: logging.debug("Wrapping data loaders with DistributedSampler.") def with_sampler(loader): # Automatically set the DistributedSampler data_loader_args = { "dataset": loader.dataset, "batch_size": loader.batch_size, "shuffle": False, "num_workers": loader.num_workers, "collate_fn": loader.collate_fn, "pin_memory": loader.pin_memory, "drop_last": loader.drop_last, "timeout": loader.timeout, "worker_init_fn": loader.worker_init_fn, "sampler": DistributedSampler(loader.dataset) } return DataLoader(**data_loader_args) def should_wrap_dataloader(loader): return (isinstance(loader, DataLoader) and not isinstance(loader.dataset, IterableDataset)) if should_wrap_dataloader(self._train_loader): if self._add_dist_sampler: self._train_loader = with_sampler(self._train_loader) if self._validation_loader is not None and should_wrap_dataloader( self._validation_loader): if self._add_dist_sampler: self._validation_loader = with_sampler( self._validation_loader) def train_epoch(self, iterator, info): """Runs one standard training pass over the training dataloader. By default, this method will iterate over the given iterator and call ``self.train_batch`` over each batch. If ``scheduler_step_freq`` is set, this default method will also step the scheduler accordingly. You do not need to call ``train_batch`` in this method if you plan to implement a custom optimization/training routine here. You may find ``ray.util.sgd.utils.AverageMeterCollection`` useful when overriding this method. See example below: .. code-block:: python def train_epoch(self, ...): meter_collection = AverageMeterCollection() self.model.train() for batch in iterator: # do some processing metrics = {"metric_1": 1, "metric_2": 3} # dict of metrics # This keeps track of all metrics across multiple batches meter_collection.update(metrics, n=len(batch)) # Returns stats of the meters. stats = meter_collection.summary() return stats Args: iterator (iter): Iterator over the training data for the entire epoch. This iterator is expected to be entirely consumed. info (dict): Dictionary for information to be used for custom training operations. Returns: A dict of metrics from training. """ if not hasattr(self, "model"): raise RuntimeError("Either set self.model in setup function or " "override this method to implement a custom " "training loop.") model = self.model scheduler = None if hasattr(self, "scheduler"): scheduler = self.scheduler if self.use_tqdm and self.world_rank == 0: desc = "" if info is not None and "epoch_idx" in info: if "num_epochs" in info: desc = f"{info['epoch_idx'] + 1}/{info['num_epochs']}e" else: desc = f"{info['epoch_idx'] + 1}e" # TODO: Implement len for Dataset? total = info[NUM_STEPS] if total is None: if hasattr(iterator, "__len__"): total = len(iterator) _progress_bar = tqdm(total=total, desc=desc, unit="batch", leave=False) metric_meters = AverageMeterCollection() model.train() for batch_idx, batch in enumerate(iterator): batch_info = { "batch_idx": batch_idx, "global_step": self.global_step } batch_info.update(info) metrics = self.train_batch(batch, batch_info=batch_info) if self.use_tqdm and self.world_rank == 0: _progress_bar.n = batch_idx + 1 postfix = {} if "train_loss" in metrics: postfix.update(loss=metrics["train_loss"]) _progress_bar.set_postfix(postfix) if scheduler and self.scheduler_step_freq == SCHEDULER_STEP_BATCH: scheduler.step() metric_meters.update(metrics, n=metrics.pop(NUM_SAMPLES, 1)) self.global_step += 1 if scheduler and self.scheduler_step_freq == SCHEDULER_STEP_EPOCH: scheduler.step() return metric_meters.summary() def train_batch(self, batch, batch_info): """Computes loss and updates the model over one batch. This method is responsible for computing the loss and gradient and updating the model. By default, this method implementation assumes that batches are in (\\*features, labels) format. So we also support multiple inputs model. If using amp/fp16 training, it will also scale the loss automatically. You can provide custom loss metrics and training operations if you override this method. You do not need to override this method if you plan to override ``train_epoch``. Args: batch: One item of the validation iterator. batch_info (dict): Information dict passed in from ``train_epoch``. Returns: A dictionary of metrics. By default, this dictionary contains "loss" and "num_samples". "num_samples" corresponds to number of datapoints in the batch. However, you can provide any number of other values. Consider returning "num_samples" in the metrics because by default, ``train_epoch`` uses "num_samples" to calculate averages. """ if not hasattr(self, "model"): raise RuntimeError("Either set self.model in setup function or " "override this method to implement a custom " "training loop.") if not hasattr(self, "optimizer"): raise RuntimeError("Either set self.optimizer in setup function " "or override this method to implement a custom " "training loop.") if not hasattr(self, "criterion"): raise RuntimeError("Either set self.criterion in setup function " "or override this method to implement a custom " "training loop.") model = self.model optimizer = self.optimizer criterion = self.criterion # unpack features into list to support multiple inputs model *features, target = batch # Create non_blocking tensors for distributed training if self.use_gpu: features = [ feature.cuda(non_blocking=True) for feature in features ] target = target.cuda(non_blocking=True) # Compute output. with self.timers.record("fwd"): output = model(*features) loss = criterion(output, target) # Compute gradients in a backward pass. with self.timers.record("grad"): optimizer.zero_grad() if self.use_fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() # Call step of optimizer to update model params. with self.timers.record("apply"): optimizer.step() return {"train_loss": loss.item(), NUM_SAMPLES: features[0].size(0)} def validate(self, val_iterator, info): """Runs one standard validation pass over the val_iterator. This will call ``model.eval()`` and ``torch.no_grad`` when iterating over the validation dataloader. You also do not need to call ``validate_batch`` if overriding this method. Args: val_iterator (iter): Iterable constructed from the validation dataloader. info: (dict): Dictionary for information to be used for custom validation operations. Returns: A dict of metrics from the evaluation. By default, returns "val_accuracy" and "val_loss" which is computed by aggregating "loss" and "correct" values from ``validate_batch`` and dividing it by the sum of ``num_samples`` from all calls to ``self.validate_batch``. """ if not hasattr(self, "model"): raise RuntimeError("Either set self.model in setup function or " "override this method to implement a custom " "validation loop.") model = self.model metric_meters = AverageMeterCollection() # switch to evaluate mode model.eval() with torch.no_grad(): for batch_idx, batch in enumerate(val_iterator): batch_info = {"batch_idx": batch_idx} batch_info.update(info) metrics = self.validate_batch(batch, batch_info) metric_meters.update(metrics, n=metrics.pop(NUM_SAMPLES, 1)) return metric_meters.summary() def validate_batch(self, batch, batch_info): """Calcuates the loss and accuracy over a given batch. You can override this method to provide arbitrary metrics. Same as ``train_batch``, this method implementation assumes that batches are in (\\*features, labels) format by default. So we also support multiple inputs model. Args: batch: One item of the validation iterator. batch_info (dict): Contains information per batch from ``validate()``. Returns: A dict of metrics. By default, returns "val_loss", "val_accuracy", and "num_samples". When overriding, consider returning "num_samples" in the metrics because by default, ``validate`` uses "num_samples" to calculate averages. """ if not hasattr(self, "model"): raise RuntimeError("Either set self.model in setup function or " "override this method to implement a custom " "training loop.") if not hasattr(self, "criterion"): raise RuntimeError("Either set self.criterion in setup function " "or override this method to implement a custom " "training loop.") model = self.model criterion = self.criterion # unpack features into list to support multiple inputs model *features, target = batch if self.use_gpu: features = [ feature.cuda(non_blocking=True) for feature in features ] target = target.cuda(non_blocking=True) # compute output with self.timers.record("eval_fwd"): output = model(*features) loss = criterion(output, target) _, predicted = torch.max(output.data, 1) num_correct = (predicted == target).sum().item() num_samples = target.size(0) return { "val_loss": loss.item(), "val_accuracy": num_correct / num_samples, NUM_SAMPLES: num_samples } def state_dict(self): """Override this to return a representation of the operator state. Any argument passed into self.register and self.register_data will automatically be saved. Use this method to save any additional state. If your TorchTrainer is on a CPU-only machine, make sure this method converts all state to be CPU-compatible. Returns: dict: The state dict of the operator.""" pass def load_state_dict(self, state_dict): """Override this to load the representation of the operator state. Anything passed into self.register and self.register_data will automatically be loaded. Use this method to load any additional state. Args: state_dict (dict): State dict as returned by the operator. """ pass @classmethod def from_creators(cls, model_creator, optimizer_creator, data_creator=None, loss_creator=None, scheduler_creator=None, serialize_data_creation=True): """A utility method to create a custom TrainingOperator class from creator functions. This is useful for backwards compatibility with previous versions of Ray. To provide custom training and validation, you should subclass the class that is returned by this method instead of ``TrainingOperator``. Args: model_creator (dict -> Model(s)): Constructor function that takes in config and returns the model(s) to be optimized. These must be ``torch.nn.Module`` objects. If multiple models are returned, a ``training_operator_cls`` must be specified. You do not need to handle GPU/devices in this function; RaySGD will do that under the hood. data_creator (dict -> Iterable(s)): Constructor function that takes in the passed config and returns one or two Iterable objects. Note that even though two Iterable objects can be returned, only one will be used for training, and the other will be used for validation. If not provided, you must pass in a Dataset to ``TorchTrainer.train``. optimizer_creator ((models, dict) -> optimizers): Constructor function that takes in the return values from ``model_creator`` and the passed config and returns One or more Torch optimizer objects. You do not need to handle GPU/devices in this function; ``RaySGD`` will do that for you. loss_creator (torch.nn.*Loss class | dict -> loss): A constructor function for the training loss. This can be either a function that takes in the provided config for customization or a subclass of ``torch.nn.modules.loss._Loss``, which is most Pytorch loss classes. For example, ``loss_creator=torch.nn.BCELoss``. If not provided, you must provide a custom TrainingOperator. scheduler_creator ((optimizers, dict) -> scheduler): A constructor function for the torch scheduler. This is a function that takes in the generated optimizers (from ``optimizer_creator``) provided config for customization. Be sure to set ``scheduler_step_freq`` to increment the scheduler correctly. serialize_data_creation (bool): A filelock will be used to ensure no race conditions in data downloading among different workers on the same node (using the local file system). Defaults to True. Returns: A TrainingOperator class with a ``setup`` method that utilizes the passed in creator functions. """ if not (callable(model_creator) and callable(optimizer_creator)): raise ValueError( "Must provide a callable model_creator and optimizer_creator.") class CustomCreatorOperator(CreatorOperator): _model_creator = model_creator _optimizer_creator = optimizer_creator _data_creator = data_creator _loss_creator = loss_creator _scheduler_creator = scheduler_creator _serialize_data_creation = serialize_data_creation return CustomCreatorOperator @property def device(self): """torch.device: The appropriate torch device, at your convenience.""" return self._device @property def config(self): """dict: Provided into TorchTrainer.""" return self._config @property def world_rank(self): """int: The rank of the parent runner. Always 0 if not distributed.""" return self._world_rank @property def use_gpu(self): """Returns True if cuda is available and use_gpu is True.""" return self._use_gpu @property def use_fp16(self): """bool: Whether the model and optimizer have been FP16 enabled.""" return self._use_fp16 @property def use_tqdm(self): """bool: Whether tqdm progress bars are enabled.""" return self._use_tqdm @property def device_ids(self): """List[int]: Device IDs for the model. This is useful for using batch norm with DistributedDataParallel. """ return self._device_ids @property def scheduler_step_freq(self): """Optional[str]: The ``scheduler_step_freq`` passed into ``TorchTrainer`` This is useful to determine when to call scheduler.step. """ return self._scheduler_step_freq
class TrainingOperator: """Abstract class for custom training or validation loops. The scheduler will only be called at a batch or epoch frequency, depending on the user parameter. Be sure to set ``scheduler_step_freq`` in ``TorchTrainer`` to either "batch" or "epoch" to increment the scheduler correctly during training. If using a learning rate scheduler that depends on validation loss, you can use ``trainer.update_scheduler``. For both training and validation, there are two granularities that you can provide customization: per epoch or per batch. You do not need to override both. .. image:: raysgd-custom.jpg :scale: 80% :align: center Raises: ValueError if multiple models/optimizers/schedulers are provided. You are expected to subclass this class if you wish to train over multiple models/optimizers/schedulers. """ def __init__(self, config, models, optimizers, train_loader, validation_loader, world_rank, criterion=None, schedulers=None, device_ids=None, use_gpu=False, use_fp16=False, use_tqdm=False): # You are not expected to override this method. self._models = models # List of models assert isinstance(models, collections.Iterable), ( "Components need to be iterable. Got: {}".format(type(models))) self._optimizers = optimizers # List of optimizers assert isinstance(optimizers, collections.Iterable), ( "Components need to be iterable. Got: {}".format(type(optimizers))) self._train_loader = train_loader self._validation_loader = validation_loader self._world_rank = world_rank self._criterion = criterion self._schedulers = schedulers if schedulers: assert isinstance(schedulers, collections.Iterable), ( "Components need to be iterable. Got: {}".format( type(schedulers))) self._config = config self._use_fp16 = use_fp16 self._device_ids = device_ids self._use_gpu = use_gpu and torch.cuda.is_available() self._device = torch.device("cuda" if self._use_gpu else "cpu") if tqdm is None and use_tqdm: raise ValueError("tqdm must be installed to use tqdm in training.") self._use_tqdm = use_tqdm self.global_step = 0 if type(self) is TrainingOperator: for component in (models, schedulers, optimizers): if _is_multiple(component): raise ValueError( "Need to provide a custom operator subclassing " "TrainingOperator if using multi-scheduler, " "multi-model or multi-optimizer training/validation.") self.timers = TimerCollection() self.setup(config) def _set_timers(self, timers): """Passes in the timers from the Runner.""" self.timers = timers def setup(self, config): """Override this method to implement custom operator setup. Args: config (dict): Custom configuration value to be passed to all creator and operator constructors. Same as ``self.config``. """ pass def train_epoch(self, iterator, info): """Runs one standard training pass over the training dataloader. By default, this method will iterate over the given iterator and call ``self.train_batch`` over each batch. If ``scheduler_step_freq`` is set, this default method will also step the scheduler accordingly. You do not need to call ``train_batch`` in this method if you plan to implement a custom optimization/training routine here. You may find ``ray.util.sgd.utils.AverageMeterCollection`` useful when overriding this method. See example below: .. code-block:: python def train_epoch(self, ...): meter_collection = AverageMeterCollection() self.model.train() for batch in iterator: # do some processing metrics = {"metric_1": 1, "metric_2": 3} # dict of metrics # This keeps track of all metrics across multiple batches meter_collection.update(metrics, n=len(batch)) # Returns stats of the meters. stats = meter_collection.summary() return stats Args: iterator (iter): Iterator over the training data for the entire epoch. This iterator is expected to be entirely consumed. info (dict): Dictionary for information to be used for custom training operations. Returns: A dict of metrics from training. """ if self.use_tqdm and self.world_rank == 0: desc = "" if info is not None and "epoch_idx" in info: if "num_epochs" in info: desc = "{}/{}e".format(info["epoch_idx"] + 1, info["num_epochs"]) else: desc = "{}e".format(info["epoch_idx"] + 1) _progress_bar = tqdm(total=info[NUM_STEPS] or len(self.train_loader), desc=desc, unit="batch", leave=False) metric_meters = AverageMeterCollection() self.model.train() for batch_idx, batch in enumerate(iterator): batch_info = { "batch_idx": batch_idx, "global_step": self.global_step } batch_info.update(info) metrics = self.train_batch(batch, batch_info=batch_info) if self.use_tqdm and self.world_rank == 0: _progress_bar.n = batch_idx + 1 postfix = {} if "train_loss" in metrics: postfix.update(loss=metrics["train_loss"]) _progress_bar.set_postfix(postfix) if self.scheduler and batch_info.get( SCHEDULER_STEP) == SCHEDULER_STEP_BATCH: self.scheduler.step() metric_meters.update(metrics, n=metrics.pop(NUM_SAMPLES, 1)) self.global_step += 1 if self.scheduler and info.get(SCHEDULER_STEP) == SCHEDULER_STEP_EPOCH: self.scheduler.step() return metric_meters.summary() def train_batch(self, batch, batch_info): """Computes loss and updates the model over one batch. This method is responsible for computing the loss and gradient and updating the model. By default, this method implementation assumes that batches are in (features, labels) format. If using amp/fp16 training, it will also scale the loss automatically. You can provide custom loss metrics and training operations if you override this method. If overriding this method, you can access model, optimizer, criterion via ``self.model``, ``self.optimizer``, and ``self.criterion``. You do not need to override this method if you plan to override ``train_epoch``. Args: batch: One item of the validation iterator. batch_info (dict): Information dict passed in from ``train_epoch``. Returns: A dictionary of metrics. By default, this dictionary contains "loss" and "num_samples". "num_samples" corresponds to number of datapoints in the batch. However, you can provide any number of other values. Consider returning "num_samples" in the metrics because by default, ``train_epoch`` uses "num_samples" to calculate averages. """ features, target = batch # Create non_blocking tensors for distributed training if torch.cuda.is_available(): features = features.cuda(non_blocking=True) target = target.cuda(non_blocking=True) # Compute output. with self.timers.record("fwd"): output = self.model(features) loss = self.criterion(output, target) # Compute gradients in a backward pass. with self.timers.record("grad"): self.optimizer.zero_grad() if self.use_fp16: with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() # Call step of optimizer to update model params. with self.timers.record("apply"): self.optimizer.step() return {"train_loss": loss.item(), NUM_SAMPLES: features.size(0)} def validate(self, val_iterator, info): """Runs one standard validation pass over the val_iterator. This will call ``model.eval()`` and ``torch.no_grad`` when iterating over the validation dataloader. If overriding this method, you can access model, criterion via ``self.model`` and ``self.criterion``. You also do not need to call ``validate_batch`` if overriding this method. Args: val_iterator (iter): Iterable constructed from the validation dataloader. info: (dict): Dictionary for information to be used for custom validation operations. Returns: A dict of metrics from the evaluation. By default, returns "val_accuracy" and "val_loss" which is computed by aggregating "loss" and "correct" values from ``validate_batch`` and dividing it by the sum of ``num_samples`` from all calls to ``self.validate_batch``. """ metric_meters = AverageMeterCollection() # switch to evaluate mode self.model.eval() with torch.no_grad(): for batch_idx, batch in enumerate(val_iterator): batch_info = {"batch_idx": batch_idx} batch_info.update(info) metrics = self.validate_batch(batch, batch_info) metric_meters.update(metrics, n=metrics.pop(NUM_SAMPLES, 1)) return metric_meters.summary() def validate_batch(self, batch, batch_info): """Calcuates the loss and accuracy over a given batch. You can override this method to provide arbitrary metrics. Args: batch: One item of the validation iterator. batch_info (dict): Contains information per batch from ``validate()``. Returns: A dict of metrics. By default, returns "val_loss", "val_accuracy", and "num_samples". When overriding, consider returning "num_samples" in the metrics because by default, ``validate`` uses "num_samples" to calculate averages. """ features, target = batch if torch.cuda.is_available(): features = features.cuda(non_blocking=True) target = target.cuda(non_blocking=True) # compute output with self.timers.record("eval_fwd"): output = self.model(features) loss = self.criterion(output, target) _, predicted = torch.max(output.data, 1) num_correct = (predicted == target).sum().item() num_samples = target.size(0) return { "val_loss": loss.item(), "val_accuracy": num_correct / num_samples, NUM_SAMPLES: num_samples } def state_dict(self): """Override this to return a representation of the operator state. Returns: dict: The state dict of the operator.""" pass def load_state_dict(self, state_dict): """Override this to load the representation of the operator state. Args: state_dict (dict): State dict as returned by the operator. """ pass @property def device(self): """torch.device: The appropriate torch device, at your convenience.""" return self._device @property def config(self): """dict: Provided into TorchTrainer.""" return self._config @property def model(self): """First or only model created by the provided ``model_creator``.""" return self._models[0] @property def models(self): """List of models created by the provided ``model_creator``.""" return self._models @property def optimizer(self): """First or only optimizer(s) created by the ``optimizer_creator``.""" return self._optimizers[0] @property def optimizers(self): """List of optimizers created by the ``optimizer_creator``.""" return self._optimizers @property def train_loader(self): """Iterable: 1st Dataloader from ``data_creator``. """ return self._train_loader @property def validation_loader(self): """Iterable: 2nd Dataloader from ``data_creator``.""" return self._validation_loader @property def world_rank(self): """int: The rank of the parent runner. Always 0 if not distributed.""" return self._world_rank @property def criterion(self): """Criterion created by the provided ``loss_creator``.""" return self._criterion @property def scheduler(self): """First or only scheduler(s) created by the ``scheduler_creator``.""" if self._schedulers: return self._schedulers[0] @property def schedulers(self): """List of schedulers created by the ``scheduler_creator``.""" return self._schedulers @property def use_fp16(self): """bool: Whether the model and optimizer have been FP16 enabled.""" return self._use_fp16 @property def use_tqdm(self): """bool: Whether tqdm progress bars are enabled.""" return self._use_tqdm @property def device_ids(self): """List[int]: Device IDs for the model. This is useful for using batch norm with DistributedDataParallel. """ return self._device_ids