예제 #1
0
    def load(cls, path):
        with tarfile.open(utils.ensure_ext(path, 'tar'), 'r') as tar:
            commit = utils.get_gzip_from_tar(tar, 'pie-commit.zip')
            if pie.__commit__ != commit:
                logging.warn(
                    ("Model {} was serialized with a previous "
                     "version of `pie`. This might result in issues. "
                     "Model commit is {}, whereas current `pie` commit is {}."
                     ).format(path, commit, pie.__commit__))

            # load label encoder
            le = pie.dataset.MultiLabelEncoder.load_from_string(
                utils.get_gzip_from_tar(tar, 'label_encoder.zip'))

            # load model parameters
            params = json.loads(utils.get_gzip_from_tar(tar, 'parameters.zip'))

            # instantiate model
            model = Encoder(le, *params['args'], **params['kwargs'])

            # load state_dict
            with utils.tmpfile() as tmppath:
                tar.extract('state_dict.pt', path=tmppath)
                dictpath = os.path.join(tmppath, 'state_dict.pt')
                model.load_state_dict(torch.load(dictpath, map_location='cpu'))

        model.eval()

        return model
예제 #2
0
 def load_settings(fpath):
     """
     Load settings from path
     """
     with tarfile.open(utils.ensure_ext(fpath, 'tar'), 'r') as tar:
         return Settings(
             json.loads(utils.get_gzip_from_tar(tar, 'settings.zip')))
예제 #3
0
    def load(fpath):
        """
        Load model from path
        """
        import tarte.modules.models

        with tarfile.open(utils.ensure_ext(fpath, 'tar'), 'r') as tar:

            # load label encoder
            le = MultiEncoder.load(
                json.loads(utils.get_gzip_from_tar(tar, 'label_encoder.zip')))

            # load model parameters
            args, kwargs = json.loads(
                utils.get_gzip_from_tar(tar, 'parameters.zip'))

            # instantiate model
            model_type = getattr(tarte.modules.models,
                                 utils.get_gzip_from_tar(tar, 'class.zip'))
            with utils.shutup():
                model = model_type(le, *args, **kwargs)

            # load settings
            try:
                settings = Settings(
                    json.loads(utils.get_gzip_from_tar(tar, 'settings.zip')))
                model._settings = settings
            except Exception:
                logging.warn(
                    "Couldn't load settings for model {}!".format(fpath))

            # load state_dict
            with utils.tmpfile() as tmppath:
                tar.extract('state_dict.pt', path=tmppath)
                dictpath = os.path.join(tmppath, 'state_dict.pt')
                model.load_state_dict(torch.load(dictpath, map_location='cpu'))

        model.eval()

        return model
예제 #4
0
    def load(fpath):
        """
        Load model from path
        """
        import pie

        with tarfile.open(utils.ensure_ext(fpath, 'tar'), 'r') as tar:
            # check commit
            try:
                commit = utils.get_gzip_from_tar(tar, 'pie-commit.zip')
            except Exception:
                commit = None
            if (pie.__commit__ and commit) and pie.__commit__ != commit:
                logging.warn(
                    ("Model {} was serialized with a previous "
                     "version of `pie`. This might result in issues. "
                     "Model commit is {}, whereas current `pie` commit is {}."
                     ).format(fpath, commit, pie.__commit__))

            # load label encoder
            le = MultiLabelEncoder.load_from_string(
                utils.get_gzip_from_tar(tar, 'label_encoder.zip'))

            # load tasks
            tasks = json.loads(utils.get_gzip_from_tar(tar, 'tasks.zip'))

            # load model parameters
            params = json.loads(utils.get_gzip_from_tar(tar, 'parameters.zip'))

            # instantiate model
            model_type = getattr(pie.models,
                                 utils.get_gzip_from_tar(tar, 'class.zip'))
            with utils.shutup():
                model = model_type(le, tasks, *params['args'],
                                   **params['kwargs'])

            # load settings
            try:
                settings = Settings(
                    json.loads(utils.get_gzip_from_tar(tar, 'settings.zip')))
                model._settings = settings
            except Exception:
                logging.warn(
                    "Couldn't load settings for model {}!".format(fpath))

            # load state_dict
            with utils.tmpfile() as tmppath:
                tar.extract('state_dict.pt', path=tmppath)
                dictpath = os.path.join(tmppath, 'state_dict.pt')
                model.load_state_dict(torch.load(dictpath, map_location='cpu'))

        model.eval()

        return model
예제 #5
0
 def load_from_pretrained_model(cls, path):
     with tarfile.open(utils.ensure_ext(path, 'tar'), 'r') as tar:
         return cls.load_from_string(utils.get_gzip_from_tar(tar, 'label_encoder'))