def test_identical(self): model_args = dict(num_classes=3, ) model_class = nupic.research.frameworks.pytorch.models.resnets.resnet50 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = create_model( model_class=model_class, model_args=model_args, init_batch_norm=False, device=device, ) state = {} with io.BytesIO() as buffer: serialize_state_dict(buffer, model.state_dict()) state["model"] = buffer.getvalue() with tempfile.NamedTemporaryFile(delete=True) as checkpoint_file: pickle.dump(state, checkpoint_file) checkpoint_file.flush() model2 = create_model(model_class=model_class, model_args=model_args, init_batch_norm=False, device=device, checkpoint_file=checkpoint_file.name) self.assertTrue(compare_models(model, model2, (3, 224, 224)))
def create_model(cls, config, device): """ Create imagenet model from an ImagenetExperiment config :param config: - model_class: Model class. Must inherit from "torch.nn.Module" - model_args: model model class arguments passed to the constructor - init_batch_norm: Whether or not to Initialize running batch norm mean to 0. - checkpoint_file: if not None, will start from this model. The model must have the same model_args and model_class as the current experiment. - load_checkpoint_args: args to be passed to `load_state_from_checkpoint` :param device: Pytorch device :return: Model instance """ return create_model( model_class=config["model_class"], model_args=config.get("model_args", {}), init_batch_norm=config.get("init_batch_norm", False), device=device, checkpoint_file=config.get("checkpoint_file", None), load_checkpoint_args=config.get("load_checkpoint_args", {}), )
def create_model(cls, config, device): """ Create imagenet model from an ImagenetExperiment config :param config: - model_class: Model class. Must inherit from "torch.nn.Module" - model_args: model model class arguments passed to the constructor - init_batch_norm: Whether or not to Initialize running batch norm mean to 0. - checkpoint_file: if not None, will start from this model. The model must have the same model_args and model_class as the current experiment. - resize_buffers_for_checkpoint: if True, this will resize the model buffers to match those in the checkpoint. This is helpful for loading buffers with sparse levels not matching the model_args :param device: Pytorch device :return: Model instance """ return create_model( model_class=config["model_class"], model_args=config.get("model_args", {}), init_batch_norm=config.get("init_batch_norm", False), device=device, checkpoint_file=config.get("checkpoint_file", None), resize_buffers_for_checkpoint=config.get( "resize_buffers_for_checkpoint", False), )
def _create_test_model(model_class): model_args = dict(config=dict(num_classes=3, defaults_sparse=True)) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = create_model( model_class=model_class, model_args=model_args, init_batch_norm=False, device=device, ) return model
def _create_test_model(checkpoint_file=None): """ Create standard resnet50 model to be used in tests. """ model = create_model(model_class=resnet50, model_args=TEST_MODEL_ARGS, init_batch_norm=False, checkpoint_file=checkpoint_file, device="cpu") # Simulate imagenet experiment by changing the weights def init(m): if hasattr(m, "weight") and m.weight is not None: m.weight.data.fill_(0.042) if hasattr(m, "bias") and m.bias is not None: m.bias.data.fill_(0.0) if checkpoint_file is None: model.apply(init) model.apply(rezero_weights) return model
def test_creaate_model_from_checkpoint(self): model1 = _create_test_model() # Save model checkpoint only, ignoring optimizer and other imagenet # experiment objects state. See ImagenetExperiment.get_state state = {} with io.BytesIO() as buffer: serialize_state_dict(buffer, model1.state_dict()) state["model"] = buffer.getvalue() with tempfile.NamedTemporaryFile() as checkpoint_file: # Ray save checkpoints as pickled dicts pickle.dump(state, checkpoint_file) checkpoint_file.file.flush() # Load model from checkpoint model2 = create_model(model_class=resnet50, model_args=TEST_MODEL_ARGS, init_batch_norm=False, device="cpu", checkpoint_file=checkpoint_file.name) self.assertTrue(compare_models(model1, model2, (3, 32, 32)))