def build_optimizer(model, config): optimizer_config = config.optimizer if not hasattr(optimizer_config, "type"): raise ValueError( "Optimizer attributes must have a 'type' key " "specifying the type of optimizer. " "(Custom or PyTorch)" ) optimizer_type = optimizer_config.type if not hasattr(optimizer_config, "params"): warnings.warn("optimizer attributes has no params defined, defaulting to {}.") params = getattr(optimizer_config, "params", {}) if hasattr(torch.optim, optimizer_type): optimizer_class = getattr(torch.optim, optimizer_type) else: optimizer_class = registry.get_optimizer_class(optimizer_type) if optimizer_class is None: raise ValueError( "No optimizer class of type {} present in " "either torch or registered to registry" ) parameters = get_optimizer_parameters(model, config) if optimizer_config.get("enable_state_sharding", False): # TODO(vedanuj): Remove once OSS is moved to PT upstream try: from fairscale.optim.oss import OSS except ImportError: print( "Optimizer state sharding requires fairscale. " + "Install using pip install fairscale." ) raise assert ( is_dist_initialized() ), "Optimizer state sharding can only be used in distributed mode." optimizer = OSS(params=parameters, optim=optimizer_class, **params) else: optimizer = optimizer_class(parameters, **params) return optimizer
def seed_sampler(self, epoch: int): if is_dist_initialized(): for sampler in self.samplers.values(): if sampler is not None and hasattr(sampler, "set_epoch"): sampler.set_epoch(epoch)