def test_creaate_model_from_checkpoint(self): model1 = create_model(model_class=resnet50, model_args={}, init_batch_norm=False, device="cpu") # Simulate imagenet experiment by changing the weights def init(m): if hasattr(m, "weight") and m.weight is not None: m.weight.data.fill_(0.042) model1.apply(init) # Save model checkpoint only, ignoring optimizer and other imagenet # experiment objects state. See ImagenetExperiment.get_state state = {} with io.BytesIO() as buffer: serialize_state_dict(buffer, model1.state_dict()) state["model"] = buffer.getvalue() with tempfile.NamedTemporaryFile() as checkpoint_file: # Ray save checkpoints as pickled dicts pickle.dump(state, checkpoint_file) checkpoint_file.file.flush() # Load model from checkpoint model2 = create_model(model_class=resnet50, model_args={}, init_batch_norm=False, device="cpu", checkpoint_file=checkpoint_file.name) self.assertTrue(compare_models(model1, model2, (3, 32, 32)))
def test_identical(self): model_args = dict(config=dict( num_classes=3, defaults_sparse=True, )) model_class = nupic.research.frameworks.pytorch.models.resnets.resnet50 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = create_model( model_class=model_class, model_args=model_args, init_batch_norm=False, device=device, ) state = {} with io.BytesIO() as buffer: serialize_state_dict(buffer, model.state_dict()) state["model"] = buffer.getvalue() with tempfile.NamedTemporaryFile(delete=True) as checkpoint_file: pickle.dump(state, checkpoint_file) checkpoint_file.flush() model2 = create_model(model_class=model_class, model_args=model_args, init_batch_norm=False, device=device, checkpoint_file=checkpoint_file.name) self.assertTrue(compare_models(model, model2, (3, 224, 224)))
def setup_experiment(self, config): """ Configure the experiment for training :param config: Dictionary containing the configuration parameters - distributed: Whether or not to use Pytorch Distributed training - backend: Pytorch Distributed backend ("nccl", "gloo") Default: nccl - world_size: Total number of processes participating - rank: Rank of the current process - data: Dataset path - train_dir: Dataset training data relative path - batch_size: Training batch size - val_dir: Dataset validation data relative path - val_batch_size: Validation batch size - workers: how many data loading processes to use - num_classes: Limit the dataset size to the given number of classes - model_class: Model class. Must inherit from "torch.nn.Module" - model_args: model model class arguments passed to the constructor - init_batch_norm: Whether or not to Initialize running batch norm mean to 0. - optimizer_class: Optimizer class. Must inherit from "torch.optim.Optimizer" - optimizer_args: Optimizer class class arguments passed to the constructor - batch_norm_weight_decay: Whether or not to apply weight decay to batch norm modules parameters - lr_scheduler_class: Learning rate scheduler class. Must inherit from "_LRScheduler" - lr_scheduler_args: Learning rate scheduler class class arguments passed to the constructor - loss_function: Loss function. See "torch.nn.functional" - local_dir: Results path - epochs: Number of epochs to train - batches_in_epoch: Number of batches per epoch. Useful for debugging - progress: Show progress during training - profile: Whether or not to enable torch.autograd.profiler.profile during training - name: Experiment name. Used as logger name - log_level: Python Logging level - log_format: Python Logging format - seed: the seed to be used for pytorch, python, and numpy - mixed_precision: Whether or not to enable apex mixed precision - mixed_precision_args: apex mixed precision arguments. See "amp.initialize" - create_train_dataloader: Optional user defined function to create the training data loader. See below for input params. - create_validation_dataloader: Optional user defined function to create the validation data loader. See below for input params. - train_model_func: Optional user defined function to train the model, expected to behave similarly to `train_model` in terms of input parameters and return values - evaluate_model_func: Optional user defined function to validate the model expected to behave similarly to `evaluate_model` in terms of input parameters and return values - init_hooks: list of hooks (functions) to call on the model just following its initialization - post_epoch_hooks: list of hooks (functions) to call on the model following each epoch of training - checkpoint_file: if not None, will start from this model. The model must have the same model_args and model_class as the current experiment. - checkpoint_at_init: boolean argument for whether to create a checkpoint of the initialized model. this differs from `checkpoint_at_start` for which the checkpoint occurs after the first epoch of training as opposed to before it - epochs_to_validate: list of epochs to run validate(). A -1 asks to run validate before any training occurs. Default: last three epochs. - launch_time: time the config was created (via time.time). Used to report wall clock time until the first batch is done. Default: time.time() in this setup_experiment(). """ # Configure logging related stuff log_format = config.get("log_format", logging.BASIC_FORMAT) log_level = getattr(logging, config.get("log_level", "INFO").upper()) console = logging.StreamHandler() console.setFormatter(logging.Formatter(log_format)) self.logger = logging.getLogger(config.get("name", type(self).__name__)) self.logger.setLevel(log_level) self.logger.addHandler(console) self.progress = config.get("progress", False) self.launch_time = config.get("launch_time", time.time()) # Configure seed self.seed = config.get("seed", self.seed) set_random_seed(self.seed, False) # Configure distribute pytorch self.distributed = config.get("distributed", False) self.rank = config.get("rank", 0) if self.distributed: dist_url = config.get("dist_url", "tcp://127.0.0.1:54321") backend = config.get("backend", "nccl") world_size = config.get("world_size", 1) dist.init_process_group( backend=backend, init_method=dist_url, rank=self.rank, world_size=world_size, ) # Only enable logs from first process self.logger.disabled = self.rank != 0 self.progress = self.progress and self.rank == 0 # Configure model model_class = config["model_class"] model_args = config.get("model_args", {}) init_batch_norm = config.get("init_batch_norm", False) init_hooks = config.get("init_hooks", None) self.model = create_model(model_class=model_class, model_args=model_args, init_batch_norm=init_batch_norm, device=self.device, init_hooks=init_hooks, checkpoint_file=config.get( "checkpoint_file", None)) if self.rank == 0: self.logger.debug(self.model) params_sparse, nonzero_params_sparse2 = count_nonzero_params( self.model) self.logger.debug("Params total/nnz %s / %s = %s ", params_sparse, nonzero_params_sparse2, float(nonzero_params_sparse2) / params_sparse) # Configure optimizer optimizer_class = config.get("optimizer_class", torch.optim.SGD) optimizer_args = config.get("optimizer_args", {}) batch_norm_weight_decay = config.get("batch_norm_weight_decay", True) self.optimizer = create_optimizer( model=self.model, optimizer_class=optimizer_class, optimizer_args=optimizer_args, batch_norm_weight_decay=batch_norm_weight_decay, ) # Validate mixed precision requirements self.mixed_precision = config.get("mixed_precision", False) if self.mixed_precision and amp is None: self.mixed_precision = False self.logger.error( "Mixed precision requires NVIDA APEX." "Please install apex from https://www.github.com/nvidia/apex" "Disabling mixed precision training.") # Configure mixed precision training if self.mixed_precision: amp_args = config.get("mixed_precision_args", {}) self.model, self.optimizer = amp.initialize( self.model, self.optimizer, **amp_args) self.logger.info("Using mixed precision") # Apply DistributedDataParallel after all other model mutations if self.distributed: self.model = DistributedDataParallel(self.model) else: self.model = DataParallel(self.model) self.loss_function = config.get("loss_function", torch.nn.functional.cross_entropy) # Configure data loaders self.epochs = config.get("epochs", 1) self.batches_in_epoch = config.get("batches_in_epoch", sys.maxsize) self.epochs_to_validate = config.get( "epochs_to_validate", range(self.epochs - 3, self.epochs + 1)) workers = config.get("workers", 0) data_dir = config["data"] train_dir = config.get("train_dir", "train") num_classes = config.get("num_classes", 1000) # Get initial batch size self.batch_size = config.get("batch_size", 1) # CUDA runtime does not support the fork start method. # See https://pytorch.org/docs/stable/notes/multiprocessing.html if torch.cuda.is_available(): multiprocessing.set_start_method("spawn") # Configure Training data loader self.create_train_dataloader = config.get("create_train_dataloader", create_train_dataloader) self.train_loader = self.create_train_dataloader( data_dir=data_dir, train_dir=train_dir, batch_size=self.batch_size, workers=workers, distributed=self.distributed, num_classes=num_classes, use_auto_augment=config.get("use_auto_augment", False), ) self.total_batches = len(self.train_loader) # Configure Validation data loader val_dir = config.get("val_dir", "val") val_batch_size = config.get("val_batch_size", self.batch_size) self.create_validation_dataloader = config.get( "create_validation_dataloader", create_validation_dataloader) self.val_loader = self.create_validation_dataloader( data_dir=data_dir, val_dir=val_dir, batch_size=val_batch_size, workers=workers, num_classes=num_classes, ) # Configure learning rate scheduler lr_scheduler_class = config.get("lr_scheduler_class", None) if lr_scheduler_class is not None: lr_scheduler_args = config.get("lr_scheduler_args", {}) self.logger.info("LR Scheduler args:") self.logger.info(pformat(lr_scheduler_args)) self.logger.info("steps_per_epoch=%s", self.total_batches) self.lr_scheduler = create_lr_scheduler( optimizer=self.optimizer, lr_scheduler_class=lr_scheduler_class, lr_scheduler_args=lr_scheduler_args, steps_per_epoch=self.total_batches) # Only profile from rank 0 self.profile = config.get("profile", False) and self.rank == 0 # Set train and validate methods. self.train_model = config.get("train_model_func", train_model) self.evaluate_model = config.get("evaluate_model_func", evaluate_model) # Register post-epoch hooks. To be used as `self.model.apply(post_epoch_hook)` self.post_epoch_hooks = config.get("post_epoch_hooks", [])