コード例 #1
0
    def save(self, path):
        path = utils.ensure_ext(path, 'tar')
        # create dir if needed
        dirname = os.path.dirname(path)
        if dirname and not os.path.isdir(dirname):
            os.makedirs(dirname)

        with tarfile.open(path, 'w') as tar:
            # serialize label_encoder
            string, path = json.dumps(
                self.label_encoder.jsonify()), 'label_encoder.zip'
            utils.add_gzip_to_tar(string, path, tar)

            # serialize parameters
            string, path = json.dumps(
                self.get_args_and_kwargs()), 'parameters.zip'
            utils.add_gzip_to_tar(string, path, tar)

            # serialize weights
            with utils.tmpfile() as tmppath:
                torch.save(self.state_dict(), tmppath)
                tar.add(tmppath, arcname='state_dict.pt')

            # serialize current pie commit
            string, path = pie.__commit__, 'pie-commit.zip'
            utils.add_gzip_to_tar(string, path, tar)
コード例 #2
0
    def load(cls, path):
        with tarfile.open(utils.ensure_ext(path, 'tar'), 'r') as tar:
            commit = utils.get_gzip_from_tar(tar, 'pie-commit.zip')
            if pie.__commit__ != commit:
                logging.warn(
                    ("Model {} was serialized with a previous "
                     "version of `pie`. This might result in issues. "
                     "Model commit is {}, whereas current `pie` commit is {}."
                     ).format(path, commit, pie.__commit__))

            # load label encoder
            le = pie.dataset.MultiLabelEncoder.load_from_string(
                utils.get_gzip_from_tar(tar, 'label_encoder.zip'))

            # load model parameters
            params = json.loads(utils.get_gzip_from_tar(tar, 'parameters.zip'))

            # instantiate model
            model = Encoder(le, *params['args'], **params['kwargs'])

            # load state_dict
            with utils.tmpfile() as tmppath:
                tar.extract('state_dict.pt', path=tmppath)
                dictpath = os.path.join(tmppath, 'state_dict.pt')
                model.load_state_dict(torch.load(dictpath, map_location='cpu'))

        model.eval()

        return model
コード例 #3
0
ファイル: base.py プロジェクト: PonteIneptique/tarte
 def load_settings(fpath):
     """
     Load settings from path
     """
     with tarfile.open(utils.ensure_ext(fpath, 'tar'), 'r') as tar:
         return Settings(
             json.loads(utils.get_gzip_from_tar(tar, 'settings.zip')))
コード例 #4
0
ファイル: base_model.py プロジェクト: Corentinfaye/pie
    def load(fpath):
        """
        Load model from path
        """
        import pie

        with tarfile.open(utils.ensure_ext(fpath, 'tar'), 'r') as tar:
            # check commit
            try:
                commit = utils.get_gzip_from_tar(tar, 'pie-commit.zip')
            except Exception:
                commit = None
            if (pie.__commit__ and commit) and pie.__commit__ != commit:
                logging.warn(
                    ("Model {} was serialized with a previous "
                     "version of `pie`. This might result in issues. "
                     "Model commit is {}, whereas current `pie` commit is {}."
                     ).format(fpath, commit, pie.__commit__))

            # load label encoder
            le = MultiLabelEncoder.load_from_string(
                utils.get_gzip_from_tar(tar, 'label_encoder.zip'))

            # load tasks
            tasks = json.loads(utils.get_gzip_from_tar(tar, 'tasks.zip'))

            # load model parameters
            params = json.loads(utils.get_gzip_from_tar(tar, 'parameters.zip'))

            # instantiate model
            model_type = getattr(pie.models,
                                 utils.get_gzip_from_tar(tar, 'class.zip'))
            with utils.shutup():
                model = model_type(le, tasks, *params['args'],
                                   **params['kwargs'])

            # load settings
            try:
                settings = Settings(
                    json.loads(utils.get_gzip_from_tar(tar, 'settings.zip')))
                model._settings = settings
            except Exception:
                logging.warn(
                    "Couldn't load settings for model {}!".format(fpath))

            # load state_dict
            with utils.tmpfile() as tmppath:
                tar.extract('state_dict.pt', path=tmppath)
                dictpath = os.path.join(tmppath, 'state_dict.pt')
                model.load_state_dict(torch.load(dictpath, map_location='cpu'))

        model.eval()

        return model
コード例 #5
0
    def tag_file(self, fpath: str, iterator: DataIterator, processor: ProcessorPrototype, no_tokenizer: bool = False):
        # Read content of the file
        with open(fpath) as f:
            data = f.read()

        _, ext = os.path.splitext(fpath)

        out_file = utils.ensure_ext(fpath, ext, 'pie')
        with open(out_file, 'w+') as f:
            for line in self.iter_tag(data, iterator, processor=processor, no_tokenizer=no_tokenizer):
                f.write(line)

        return out_file
コード例 #6
0
    def load(fpath):
        """
        Load model from path
        """
        import pie

        with tarfile.open(utils.ensure_ext(fpath, 'tar'), 'r') as tar:
            # check commit
            try:
                commit = get_gzip_from_tar(tar, 'pie-commit.zip')
            except Exception:
                # no commit in file
                commit = None
            if pie.__commit__ is not None and commit is not None \
               and pie.__commit__ != commit:
                logging.warn(
                    ("Model {} was serialized with a previous "
                     "version of `pie`. This might result in issues. "
                     "Model commit is {}, whereas current `pie` commit is {}."
                     ).format(fpath, commit, pie.__commit__))
            # load label encoder
            le = MultiLabelEncoder.load_from_string(
                get_gzip_from_tar(tar, 'label_encoder.zip'))
            # load model parameters
            params = json.loads(get_gzip_from_tar(tar, 'parameters.zip'))
            # instantiate model
            model_type = getattr(pie.models,
                                 get_gzip_from_tar(tar, 'class.zip'))
            with utils.shutup():
                model = model_type(le, *params['args'], **params['kwargs'])
            # (optional) load settings
            try:
                settings = Settings(
                    json.loads(get_gzip_from_tar(tar, 'settings.zip')))
                model._settings = settings
            except:
                pass
            # load state_dict
            tmppath = '/tmp/{}'.format(str(uuid.uuid1()))
            tar.extract('state_dict.pt', path=tmppath)
            model.load_state_dict(
                torch.load(os.path.join(tmppath, 'state_dict.pt')))
            shutil.rmtree(tmppath)

        model.eval()

        return model
コード例 #7
0
ファイル: base_model.py プロジェクト: Corentinfaye/pie
    def save(self, fpath, infix=None, settings=None):
        """
        Serialize model to path
        """
        import pie
        fpath = utils.ensure_ext(fpath, 'tar', infix)

        # create dir if necessary
        dirname = os.path.dirname(fpath)
        if not os.path.isdir(dirname):
            os.makedirs(dirname)

        with tarfile.open(fpath, 'w') as tar:
            # serialize label_encoder
            string = json.dumps(self.label_encoder.jsonify())
            path = 'label_encoder.zip'
            utils.add_gzip_to_tar(string, path, tar)

            # serialize tasks
            string, path = json.dumps(self.tasks), 'tasks.zip'
            utils.add_gzip_to_tar(string, path, tar)

            # serialize model class
            string, path = str(type(self).__name__), 'class.zip'
            utils.add_gzip_to_tar(string, path, tar)

            # serialize parameters
            string, path = json.dumps(
                self.get_args_and_kwargs()), 'parameters.zip'
            utils.add_gzip_to_tar(string, path, tar)

            # serialize weights
            with utils.tmpfile() as tmppath:
                torch.save(self.state_dict(), tmppath)
                tar.add(tmppath, arcname='state_dict.pt')

            # serialize current pie commit
            if pie.__commit__ is not None:
                string, path = pie.__commit__, 'pie-commit.zip'
                utils.add_gzip_to_tar(string, path, tar)

            # if passed, serialize settings
            if settings is not None:
                string, path = json.dumps(settings), 'settings.zip'
                utils.add_gzip_to_tar(string, path, tar)

        return fpath
コード例 #8
0
ファイル: tagger.py プロジェクト: PonteIneptique/pie
    def tag_file(self, fpath, sep='\t'):
        _, ext = os.path.splitext(fpath)
        header = False

        with open(utils.ensure_ext(fpath, ext, 'pie'), 'w+') as f:

            for chunk in utils.chunks(lines_from_file(fpath, self.lower),
                                      self.batch_size):
                sents, lengths = zip(*chunk)

                tagged, tasks = self.tag(sents, lengths)

                for sent in tagged:
                    if not header:
                        f.write(sep.join(['token'] + tasks) + '\n')
                        header = True
                    for token, tags in sent:
                        f.write(sep.join([token] + list(tags)) + '\n')

                    f.write('\n')
コード例 #9
0
ファイル: base.py プロジェクト: PonteIneptique/tarte
    def load(fpath):
        """
        Load model from path
        """
        import tarte.modules.models

        with tarfile.open(utils.ensure_ext(fpath, 'tar'), 'r') as tar:

            # load label encoder
            le = MultiEncoder.load(
                json.loads(utils.get_gzip_from_tar(tar, 'label_encoder.zip')))

            # load model parameters
            args, kwargs = json.loads(
                utils.get_gzip_from_tar(tar, 'parameters.zip'))

            # instantiate model
            model_type = getattr(tarte.modules.models,
                                 utils.get_gzip_from_tar(tar, 'class.zip'))
            with utils.shutup():
                model = model_type(le, *args, **kwargs)

            # load settings
            try:
                settings = Settings(
                    json.loads(utils.get_gzip_from_tar(tar, 'settings.zip')))
                model._settings = settings
            except Exception:
                logging.warn(
                    "Couldn't load settings for model {}!".format(fpath))

            # load state_dict
            with utils.tmpfile() as tmppath:
                tar.extract('state_dict.pt', path=tmppath)
                dictpath = os.path.join(tmppath, 'state_dict.pt')
                model.load_state_dict(torch.load(dictpath, map_location='cpu'))

        model.eval()

        return model
コード例 #10
0
 def load_from_pretrained_model(cls, path):
     with tarfile.open(utils.ensure_ext(path, 'tar'), 'r') as tar:
         return cls.load_from_string(utils.get_gzip_from_tar(tar, 'label_encoder'))