Esempio n. 1
0
    def test_creaate_model_from_checkpoint(self):
        model1 = create_model(model_class=resnet50,
                              model_args={},
                              init_batch_norm=False,
                              device="cpu")

        # Simulate imagenet experiment by changing the weights
        def init(m):
            if hasattr(m, "weight") and m.weight is not None:
                m.weight.data.fill_(0.042)

        model1.apply(init)

        # Save model checkpoint only, ignoring optimizer and other imagenet
        # experiment objects state. See ImagenetExperiment.get_state
        state = {}
        with io.BytesIO() as buffer:
            serialize_state_dict(buffer, model1.state_dict())
            state["model"] = buffer.getvalue()

        with tempfile.NamedTemporaryFile() as checkpoint_file:
            # Ray save checkpoints as pickled dicts
            pickle.dump(state, checkpoint_file)
            checkpoint_file.file.flush()

            # Load model from checkpoint
            model2 = create_model(model_class=resnet50,
                                  model_args={},
                                  init_batch_norm=False,
                                  device="cpu",
                                  checkpoint_file=checkpoint_file.name)

        self.assertTrue(compare_models(model1, model2, (3, 32, 32)))
    def test_identical(self):
        model_args = dict(config=dict(
            num_classes=3,
            defaults_sparse=True,
        ))
        model_class = nupic.research.frameworks.pytorch.models.resnets.resnet50
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        model = create_model(
            model_class=model_class,
            model_args=model_args,
            init_batch_norm=False,
            device=device,
        )

        state = {}
        with io.BytesIO() as buffer:
            serialize_state_dict(buffer, model.state_dict())
            state["model"] = buffer.getvalue()

        with tempfile.NamedTemporaryFile(delete=True) as checkpoint_file:
            pickle.dump(state, checkpoint_file)
            checkpoint_file.flush()

            model2 = create_model(model_class=model_class,
                                  model_args=model_args,
                                  init_batch_norm=False,
                                  device=device,
                                  checkpoint_file=checkpoint_file.name)

            self.assertTrue(compare_models(model, model2, (3, 224, 224)))
    def setup_experiment(self, config):
        """
        Configure the experiment for training

        :param config: Dictionary containing the configuration parameters

            - distributed: Whether or not to use Pytorch Distributed training
            - backend: Pytorch Distributed backend ("nccl", "gloo")
                    Default: nccl
            - world_size: Total number of processes participating
            - rank: Rank of the current process
            - data: Dataset path
            - train_dir: Dataset training data relative path
            - batch_size: Training batch size
            - val_dir: Dataset validation data relative path
            - val_batch_size: Validation batch size
            - workers: how many data loading processes to use
            - num_classes: Limit the dataset size to the given number of classes
            - model_class: Model class. Must inherit from "torch.nn.Module"
            - model_args: model model class arguments passed to the constructor
            - init_batch_norm: Whether or not to Initialize running batch norm
                               mean to 0.
            - optimizer_class: Optimizer class.
                               Must inherit from "torch.optim.Optimizer"
            - optimizer_args: Optimizer class class arguments passed to the
                              constructor
            - batch_norm_weight_decay: Whether or not to apply weight decay to
                                       batch norm modules parameters
            - lr_scheduler_class: Learning rate scheduler class.
                                 Must inherit from "_LRScheduler"
            - lr_scheduler_args: Learning rate scheduler class class arguments
                                 passed to the constructor
            - loss_function: Loss function. See "torch.nn.functional"
            - local_dir: Results path
            - epochs: Number of epochs to train
            - batches_in_epoch: Number of batches per epoch.
                                Useful for debugging
            - progress: Show progress during training
            - profile: Whether or not to enable torch.autograd.profiler.profile
                       during training
            - name: Experiment name. Used as logger name
            - log_level: Python Logging level
            - log_format: Python Logging format
            - seed: the seed to be used for pytorch, python, and numpy
            - mixed_precision: Whether or not to enable apex mixed precision
            - mixed_precision_args: apex mixed precision arguments.
                                    See "amp.initialize"
            - create_train_dataloader: Optional user defined function to create
                                       the training data loader. See below for
                                       input params.
            - create_validation_dataloader: Optional user defined function to create
                                            the validation data loader. See below for
                                            input params.
            - train_model_func: Optional user defined function to train the model,
                                expected to behave similarly to `train_model`
                                in terms of input parameters and return values
            - evaluate_model_func: Optional user defined function to validate the model
                                   expected to behave similarly to `evaluate_model`
                                   in terms of input parameters and return values
            - init_hooks: list of hooks (functions) to call on the model
                          just following its initialization
            - post_epoch_hooks: list of hooks (functions) to call on the model
                                following each epoch of training
            - checkpoint_file: if not None, will start from this model. The model
                               must have the same model_args and model_class as the
                               current experiment.
            - checkpoint_at_init: boolean argument for whether to create a checkpoint
                                  of the initialized model. this differs from
                                  `checkpoint_at_start` for which the checkpoint occurs
                                  after the first epoch of training as opposed to
                                  before it
            - epochs_to_validate: list of epochs to run validate(). A -1 asks
                                  to run validate before any training occurs.
                                  Default: last three epochs.
            - launch_time: time the config was created (via time.time). Used to report
                           wall clock time until the first batch is done.
                           Default: time.time() in this setup_experiment().
        """
        # Configure logging related stuff
        log_format = config.get("log_format", logging.BASIC_FORMAT)
        log_level = getattr(logging, config.get("log_level", "INFO").upper())
        console = logging.StreamHandler()
        console.setFormatter(logging.Formatter(log_format))
        self.logger = logging.getLogger(config.get("name",
                                                   type(self).__name__))
        self.logger.setLevel(log_level)
        self.logger.addHandler(console)
        self.progress = config.get("progress", False)
        self.launch_time = config.get("launch_time", time.time())

        # Configure seed
        self.seed = config.get("seed", self.seed)
        set_random_seed(self.seed, False)

        # Configure distribute pytorch
        self.distributed = config.get("distributed", False)
        self.rank = config.get("rank", 0)
        if self.distributed:
            dist_url = config.get("dist_url", "tcp://127.0.0.1:54321")
            backend = config.get("backend", "nccl")
            world_size = config.get("world_size", 1)
            dist.init_process_group(
                backend=backend,
                init_method=dist_url,
                rank=self.rank,
                world_size=world_size,
            )
            # Only enable logs from first process
            self.logger.disabled = self.rank != 0
            self.progress = self.progress and self.rank == 0

        # Configure model
        model_class = config["model_class"]
        model_args = config.get("model_args", {})
        init_batch_norm = config.get("init_batch_norm", False)
        init_hooks = config.get("init_hooks", None)
        self.model = create_model(model_class=model_class,
                                  model_args=model_args,
                                  init_batch_norm=init_batch_norm,
                                  device=self.device,
                                  init_hooks=init_hooks,
                                  checkpoint_file=config.get(
                                      "checkpoint_file", None))
        if self.rank == 0:
            self.logger.debug(self.model)
            params_sparse, nonzero_params_sparse2 = count_nonzero_params(
                self.model)
            self.logger.debug("Params total/nnz %s / %s = %s ", params_sparse,
                              nonzero_params_sparse2,
                              float(nonzero_params_sparse2) / params_sparse)

        # Configure optimizer
        optimizer_class = config.get("optimizer_class", torch.optim.SGD)
        optimizer_args = config.get("optimizer_args", {})
        batch_norm_weight_decay = config.get("batch_norm_weight_decay", True)
        self.optimizer = create_optimizer(
            model=self.model,
            optimizer_class=optimizer_class,
            optimizer_args=optimizer_args,
            batch_norm_weight_decay=batch_norm_weight_decay,
        )

        # Validate mixed precision requirements
        self.mixed_precision = config.get("mixed_precision", False)
        if self.mixed_precision and amp is None:
            self.mixed_precision = False
            self.logger.error(
                "Mixed precision requires NVIDA APEX."
                "Please install apex from https://www.github.com/nvidia/apex"
                "Disabling mixed precision training.")

        # Configure mixed precision training
        if self.mixed_precision:
            amp_args = config.get("mixed_precision_args", {})
            self.model, self.optimizer = amp.initialize(
                self.model, self.optimizer, **amp_args)
            self.logger.info("Using mixed precision")

        # Apply DistributedDataParallel after all other model mutations
        if self.distributed:
            self.model = DistributedDataParallel(self.model)
        else:
            self.model = DataParallel(self.model)

        self.loss_function = config.get("loss_function",
                                        torch.nn.functional.cross_entropy)

        # Configure data loaders
        self.epochs = config.get("epochs", 1)
        self.batches_in_epoch = config.get("batches_in_epoch", sys.maxsize)
        self.epochs_to_validate = config.get(
            "epochs_to_validate", range(self.epochs - 3, self.epochs + 1))
        workers = config.get("workers", 0)
        data_dir = config["data"]
        train_dir = config.get("train_dir", "train")
        num_classes = config.get("num_classes", 1000)

        # Get initial batch size
        self.batch_size = config.get("batch_size", 1)

        # CUDA runtime does not support the fork start method.
        # See https://pytorch.org/docs/stable/notes/multiprocessing.html
        if torch.cuda.is_available():
            multiprocessing.set_start_method("spawn")

        # Configure Training data loader
        self.create_train_dataloader = config.get("create_train_dataloader",
                                                  create_train_dataloader)
        self.train_loader = self.create_train_dataloader(
            data_dir=data_dir,
            train_dir=train_dir,
            batch_size=self.batch_size,
            workers=workers,
            distributed=self.distributed,
            num_classes=num_classes,
            use_auto_augment=config.get("use_auto_augment", False),
        )
        self.total_batches = len(self.train_loader)

        # Configure Validation data loader
        val_dir = config.get("val_dir", "val")
        val_batch_size = config.get("val_batch_size", self.batch_size)
        self.create_validation_dataloader = config.get(
            "create_validation_dataloader", create_validation_dataloader)
        self.val_loader = self.create_validation_dataloader(
            data_dir=data_dir,
            val_dir=val_dir,
            batch_size=val_batch_size,
            workers=workers,
            num_classes=num_classes,
        )

        # Configure learning rate scheduler
        lr_scheduler_class = config.get("lr_scheduler_class", None)
        if lr_scheduler_class is not None:
            lr_scheduler_args = config.get("lr_scheduler_args", {})
            self.logger.info("LR Scheduler args:")
            self.logger.info(pformat(lr_scheduler_args))
            self.logger.info("steps_per_epoch=%s", self.total_batches)
            self.lr_scheduler = create_lr_scheduler(
                optimizer=self.optimizer,
                lr_scheduler_class=lr_scheduler_class,
                lr_scheduler_args=lr_scheduler_args,
                steps_per_epoch=self.total_batches)

        # Only profile from rank 0
        self.profile = config.get("profile", False) and self.rank == 0

        # Set train and validate methods.
        self.train_model = config.get("train_model_func", train_model)
        self.evaluate_model = config.get("evaluate_model_func", evaluate_model)

        # Register post-epoch hooks. To be used as `self.model.apply(post_epoch_hook)`
        self.post_epoch_hooks = config.get("post_epoch_hooks", [])