Ejemplo n.º 1
0
    def test_identical(self):
        model_args = dict(num_classes=3, )
        model_class = nupic.research.frameworks.pytorch.models.resnets.resnet50
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        model = create_model(
            model_class=model_class,
            model_args=model_args,
            init_batch_norm=False,
            device=device,
        )

        state = {}
        with io.BytesIO() as buffer:
            serialize_state_dict(buffer, model.state_dict())
            state["model"] = buffer.getvalue()

        with tempfile.NamedTemporaryFile(delete=True) as checkpoint_file:
            pickle.dump(state, checkpoint_file)
            checkpoint_file.flush()

            model2 = create_model(model_class=model_class,
                                  model_args=model_args,
                                  init_batch_norm=False,
                                  device=device,
                                  checkpoint_file=checkpoint_file.name)

            self.assertTrue(compare_models(model, model2, (3, 224, 224)))
    def create_model(cls, config, device):
        """
        Creates the BlockModel. The config should specify the model class as
        BlockModel, and then the model_args should contain a "module_args" parameter
        which contains a list of dictionaries that specify submodules. A simple
        example might look like this:

        """
        if config["model_class"] != BlockModel:
            return super().create_model(config, device)
        model_args = config.get("model_args", {})
        module_args = model_args.get("module_args", [])
        modules = []
        for module_dict in module_args:
            modules.append(
                create_model(
                    model_class=module_dict["model_class"],
                    model_args=module_dict.get("model_args", {}),
                    init_batch_norm=module_dict.get("init_batch_norm", False),
                    device=device,
                    checkpoint_file=module_dict.get("checkpoint_file", None),
                    load_checkpoint_args=module_dict.get("load_checkpoint_args", {}),
                )
            )
        model_args["modules"] = modules
        return create_model(
            model_class=config["model_class"],
            model_args=model_args,
            init_batch_norm=config.get("init_batch_norm", False),
            device=device,
            checkpoint_file=config.get("checkpoint_file", None),
            load_checkpoint_args=config.get("load_checkpoint_args", {}),
        )
Ejemplo n.º 3
0
 def create_model(cls, config, device):
     """
     Create `torch.nn.Module` model from an experiment config
     :param config:
         - model_class: Model class. Must inherit from "torch.nn.Module"
         - model_args: model model class arguments passed to the constructor
         - init_batch_norm: Whether or not to Initialize running batch norm
                            mean to 0.
         - checkpoint_file: if not None, will start from this model. The
                            model must have the same model_args and
                            model_class as the current experiment.
         - load_checkpoint_args: args to be passed to `load_state_from_checkpoint`
     :param device:
         Pytorch device
     :return:
             Model instance
     """
     return create_model(
         model_class=config["model_class"],
         model_args=config.get("model_args", {}),
         init_batch_norm=config.get("init_batch_norm", False),
         device=device,
         checkpoint_file=config.get("checkpoint_file", None),
         load_checkpoint_args=config.get("load_checkpoint_args", {}),
     )
def _create_test_model(model_class):
    model_args = dict(config=dict(num_classes=3, defaults_sparse=True))
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = create_model(
        model_class=model_class,
        model_args=model_args,
        init_batch_norm=False,
        device=device,
    )
    return model
Ejemplo n.º 5
0
    def __init__(self, config):
        super().__init__()

        self.config = config
        self._loss_function = config.get("loss_function",
                                         torch.nn.functional.cross_entropy)

        self.model = create_model(
            model_class=config["model_class"],
            model_args=config.get("model_args", {}),
            init_batch_norm=config.get("init_batch_norm", False),
            checkpoint_file=config.get("checkpoint_file", None),
            load_checkpoint_args=config.get("load_checkpoint_args", {}),
        )

        self.epochs = config["epochs"]
Ejemplo n.º 6
0
def _create_test_model(checkpoint_file=None):
    """
    Create standard resnet50 model to be used in tests.
    """
    model = create_model(model_class=resnet50, model_args=TEST_MODEL_ARGS,
                         init_batch_norm=False, checkpoint_file=checkpoint_file,
                         device="cpu")

    # Simulate imagenet experiment by changing the weights
    def init(m):
        if hasattr(m, "weight") and m.weight is not None:
            m.weight.data.fill_(0.042)
        if hasattr(m, "bias") and m.bias is not None:
            m.bias.data.fill_(0.0)

    if checkpoint_file is None:
        model.apply(init)
        model.apply(rezero_weights)

    return model
Ejemplo n.º 7
0
    def test_creaate_model_from_checkpoint(self):
        model1 = _create_test_model()

        # Save model checkpoint only, ignoring optimizer and other imagenet
        # experiment objects state. See ImagenetExperiment.get_state
        state = {}
        with io.BytesIO() as buffer:
            serialize_state_dict(buffer, model1.state_dict())
            state["model"] = buffer.getvalue()

        with tempfile.NamedTemporaryFile() as checkpoint_file:
            # Ray save checkpoints as pickled dicts
            pickle.dump(state, checkpoint_file)
            checkpoint_file.file.flush()

            # Load model from checkpoint
            model2 = create_model(
                model_class=resnet50, model_args=TEST_MODEL_ARGS,
                init_batch_norm=False, device="cpu",
                checkpoint_file=checkpoint_file.name)

        self.assertTrue(compare_models(model1, model2, (3, 32, 32)))