Esempio n. 1
0
    def load(cls, basename, **kwargs):
        """Reload the model from a graph file and a checkpoint
        The model that is loaded is independent of the pooling and stacking layers, making this class reusable
        by sub-classes.

        :param basename: The base directory to load from
        :param kwargs: See below

        :Keyword Arguments:
        * *sess* -- An optional tensorflow session.  If not passed, a new session is
            created

        :return: A restored model
        """
        _state = read_json("{}.state".format(basename))

        embeddings_info = _state.pop('embeddings')
        embeddings = reload_embeddings(embeddings_info, basename)

        for k in embeddings_info:
            if k in kwargs:
                _state[k] = kwargs[k]
        # TODO: convert labels into just another vocab and pass number of labels to models.
        labels = read_json("{}.labels".format(basename))
        # FIXME: referring to the `constraint_mask` in the base where its not mentioned isnt really clean
        if _state.get('constraint_mask') is not None:
            # Dummy constraint values that will be filled in by the check pointing
            _state['constraint_mask'] = [
                np.zeros((len(labels), len(labels))) for _ in range(2)
            ]
        model = cls.create(embeddings, labels, **_state)
        model._state = _state
        model.load_weights(f"{basename}.wgt")
        return model
Esempio n. 2
0
    def load(cls, basename, **kwargs):
        """Reload the model from a graph file and a checkpoint

        The model that is loaded is independent of the pooling and stacking layers, making this class reusable
        by sub-classes.

        :param basename: The base directory to load from
        :param kwargs: See below

        :Keyword Arguments:
        * *sess* -- An optional tensorflow session.  If not passed, a new session is
            created

        :return: A restored model
        """
        _state = read_json(basename + '.state')
        _state['model_type'] = kwargs.get('model_type', 'default')
        embeddings = {}
        embeddings_dict = _state.pop("embeddings")

        for key, class_name in embeddings_dict.items():
            md = read_json('{}-{}-md.json'.format(basename, key))
            embed_args = dict({'vsz': md['vsz'], 'dsz': md['dsz']})
            Constructor = eval(class_name)
            embeddings[key] = Constructor(key, **embed_args)

        model = cls.create(embeddings, **_state)
        model._state = _state
        model.load_weights(f"{basename}.wgt")

        return model
Esempio n. 3
0
    def load(cls, basename: str, **kwargs) -> 'DependencyParserModelBase':
        """Reload the model from a graph file and a checkpoint

        The model that is loaded is independent of the pooling and stacking layers, making this class reusable
        by sub-classes.

        :param basename: The base directory to load from
        :param kwargs: See below

        :Keyword Arguments:
        * *sess* -- An optional tensorflow session.  If not passed, a new session is
            created

        :return: A restored model
        """
        _state = read_json("{}.state".format(basename))
        if __version__ != _state['version']:
            logger.warning(
                "Loaded model is from baseline version %s, running version is %s",
                _state['version'], __version__)

        embeddings_info = _state.pop('embeddings')
        embeddings = reload_embeddings(embeddings_info, basename)
        # If there is a kwarg that is the same name as an embedding object that
        # is taken to be the input of that layer. This allows for passing in
        # subgraphs like from a tf.split (for data parallel) or preprocessing
        # graphs that convert text to indices
        for k in embeddings_info:
            if k in kwargs:
                _state[k] = kwargs[k]
        labels = {"labels": read_json("{}.labels".format(basename))}
        model = cls.create(embeddings, labels, **_state)
        model._state = _state
        model.load_weights(f"{basename}.wgt")
        return model
Esempio n. 4
0
 def download(self):
     if is_file_correct(self.embedding_file):
         logger.info("embedding file location: {}".format(
             self.embedding_file))
         return self.embedding_file
     dcache_path = os.path.join(self.data_download_cache, DATA_CACHE_CONF)
     dcache = read_json(dcache_path)
     if self.embedding_file in dcache and not self.cache_ignore:
         download_loc = dcache[self.embedding_file]
         logger.info("files for {} found in cache".format(
             self.embedding_file))
         return self._get_embedding_file(download_loc, self.embedding_key)
     else:  # try to download the bundle and unzip
         url = self.embedding_file
         if not validate_url(url):
             raise RuntimeError("can not download from the given url")
         else:
             cache_dir = self.data_download_cache
             temp_file = web_downloader(url)
             unzip_fn = Downloader.ZIPD.get(
                 mime_type(temp_file)) if self.unzip_file else None
             download_loc = extractor(filepath=temp_file,
                                      cache_dir=cache_dir,
                                      extractor_func=unzip_fn)
             if self.sha1 is not None:
                 if os.path.split(download_loc)[-1] != self.sha1:
                     raise RuntimeError(
                         "The sha1 of the downloaded file does not match with the provided one"
                     )
             dcache.update({url: download_loc})
             write_json(
                 dcache,
                 os.path.join(self.data_download_cache, DATA_CACHE_CONF))
             return self._get_embedding_file(download_loc,
                                             self.embedding_key)
Esempio n. 5
0
 def download(self):
     file_loc = self.dataset_file
     if is_file_correct(file_loc):
         return file_loc
     elif validate_url(
             file_loc):  # is it a web URL? check if exists in cache
         url = file_loc
         dcache_path = os.path.join(self.data_download_cache,
                                    DATA_CACHE_CONF)
         dcache = read_json(dcache_path)
         if url in dcache and is_file_correct(
                 dcache[url], self.data_download_cache,
                 url) and not self.cache_ignore:
             logger.info(
                 "file for {} found in cache, not downloading".format(url))
             return dcache[url]
         else:  # download the file in the cache, update the json
             cache_dir = self.data_download_cache
             logger.info(
                 "using {} as data/embeddings cache".format(cache_dir))
             temp_file = web_downloader(url)
             dload_file = extractor(filepath=temp_file,
                                    cache_dir=cache_dir,
                                    extractor_func=Downloader.ZIPD.get(
                                        mime_type(temp_file), None))
             dcache.update({url: dload_file})
             write_json(
                 dcache,
                 os.path.join(self.data_download_cache, DATA_CACHE_CONF))
             return dload_file
     raise RuntimeError(
         "the file [{}] is not in cache and can not be downloaded".format(
             file_loc))
Esempio n. 6
0
    def __init__(self, name=None, **kwargs):
        super().__init__(name=name)
        # You dont actually have to pass this if you are using the `load_bert_vocab` call from your
        # tokenizer.  In this case, a singleton variable will contain the vocab and it will be returned
        # by `load_bert_vocab`
        # If you trained your model with MEAD/Baseline, you will have a `*.json` file which would want to
        # reference here
        vocab_file = kwargs.get('vocab_file')
        if vocab_file and vocab_file.endswith('.json'):
            self.vocab = read_json(vocab_file)
        else:
            self.vocab = load_bert_vocab(kwargs.get('vocab_file'))

        # When we reload, allows skipping restoration of these embeddings
        # If the embedding wasnt trained with token types, this allows us to add them later
        self.skippable = set(listify(kwargs.get('skip_restore_embeddings',
                                                [])))

        self.cls_index = self.vocab.get('[CLS]', self.vocab.get('<s>'))
        self.vsz = max(self.vocab.values()) + 1
        self.d_model = int(kwargs.get('dsz', kwargs.get('d_model', 768)))
        self.init_embed(**kwargs)
        self.proj_to_dsz = tf.keras.layers.Dense(
            self.dsz, self.d_model) if self.dsz != self.d_model else _identity
        self.init_transformer(**kwargs)
Esempio n. 7
0
def reload_embeddings_from_state(embeddings_dict, basename):
    embeddings = {}
    for key, class_name in embeddings_dict.items():
        embed_args = read_json('{}-{}-md.json'.format(basename, key))
        module = embed_args.pop('module')
        name = embed_args.pop('name', None)
        assert name is None or name == key
        mod = import_user_module(module)
        Constructor = getattr(mod, class_name)
        embeddings[key] = Constructor(key, **embed_args)
    return embeddings
Esempio n. 8
0
    def download(self):
        dload_bundle = self.dataset_desc.get("download", None)
        if dload_bundle is not None:  # download a zip/tar/tar.gz directory, look for train, dev test files inside that.
            dcache_path = os.path.join(self.data_download_cache,
                                       DATA_CACHE_CONF)
            dcache = read_json(dcache_path)
            if dload_bundle in dcache and \
                    is_dir_correct(dcache[dload_bundle], self.dataset_desc, self.data_download_cache, dload_bundle,
                                   self.enc_dec) and not self.cache_ignore:
                download_dir = dcache[dload_bundle]
                logger.info(
                    "files for {} found in cache, not downloading".format(
                        dload_bundle))
                updated = _update_md(self.dataset_desc, download_dir)
                return updated
            else:  # try to download the bundle and unzip
                if not validate_url(dload_bundle):
                    raise RuntimeError("can not download from the given url")
                else:
                    cache_dir = self.data_download_cache
                    temp_file = web_downloader(dload_bundle)

                    download_dir = extractor(
                        filepath=temp_file,
                        cache_dir=cache_dir,
                        extractor_func=Downloader.ZIPD.get(
                            mime_type(temp_file), None))
                    if "sha1" in self.dataset_desc:
                        if os.path.split(
                                download_dir)[-1] != self.dataset_desc["sha1"]:
                            raise RuntimeError(
                                "The sha1 of the downloaded file does not match with the provided one"
                            )
                    dcache.update({dload_bundle: download_dir})
                    write_json(
                        dcache,
                        os.path.join(self.data_download_cache,
                                     DATA_CACHE_CONF))
                    updated = _update_md(self.dataset_desc, download_dir)
                    return updated
        else:  # we have download links to every file or they exist
            updated = _update_md(self.dataset_desc, None)
            if not self.enc_dec:
                updated.update({
                    k:
                    SingleFileDownloader(self.dataset_desc[k],
                                         self.data_download_cache).download()
                    for k in self.dataset_desc
                    if k.endswith("_file") and self.dataset_desc[k]
                })
            return updated
Esempio n. 9
0
    def __init__(self, name=None, **kwargs):
        super().__init__(name=name)
        # You dont actually have to pass this if you are using the `load_bert_vocab` call from your
        # tokenizer.  In this case, a singleton variable will contain the vocab and it will be returned
        # by `load_bert_vocab`
        # If you trained your model with MEAD/Baseline, you will have a `*.json` file which would want to
        # reference here
        vocab_file = kwargs.get('vocab_file')
        if vocab_file and vocab_file.endswith('.json'):
            self.vocab = read_json(kwargs.get('vocab_file'))
        else:
            self.vocab = load_bert_vocab(kwargs.get('vocab_file'))

        self.cls_index = self.vocab['[CLS]']
        self.vsz = max(self.vocab.values()) + 1
        self.d_model = int(kwargs.get('dsz', kwargs.get('d_model', 768)))
        self.init_embed(**kwargs)
        self.proj_to_dsz = tf.keras.layers.Dense(
            self.dsz, self.d_model) if self.dsz != self.d_model else _identity
        self.init_transformer(**kwargs)
Esempio n. 10
0
def load_model_for(activity, filename, **kwargs):
    # Sniff state to see if we need to import things
    state = read_json('{}.state'.format(filename))
    if 'hub_modules' in state:
        for hub_module in state['hub_modules']:
            import_user_module(hub_module)
    # There won't be a module for pytorch (there is no state file to load).
    if 'module' in state:
        import_user_module(state['module'])
    # Allow user to override model type (for back compat with old api), backoff
    # to the model type in the state file or to default.
    # TODO: Currently in pytorch all models are always reloaded with the load
    # classmethod with a default model class. This is fine given how simple pyt
    # loading is but it could cause problems if a model has a custom load
    model_type = kwargs.get(
        'type',
        kwargs.get('model_type',
                   state.get('type', state.get('model_type', 'default'))))
    creator_fn = BASELINE_LOADERS[activity][model_type]
    logger.info('Calling model %s', creator_fn)
    return creator_fn(filename, **kwargs)
Esempio n. 11
0
 def download(self):
     file_loc = self.dataset_file
     if is_file_correct(file_loc):
         return file_loc
     elif validate_url(
             file_loc):  # is it a web URL? check if exists in cache
         url = file_loc
         dcache_path = os.path.join(self.data_download_cache,
                                    DATA_CACHE_CONF)
         dcache = read_json(dcache_path)
         # If the file already exists in the cache
         if url in dcache and is_file_correct(
                 dcache[url], self.data_download_cache,
                 url) and not self.cache_ignore:
             logger.info(
                 "file for {} found in cache, not downloading".format(url))
             return dcache[url]
         # Otherwise, we want it to be placed in ~/.bl-cache/addons
         else:  # download the file in the cache, update the json
             cache_dir = self.data_download_cache
             addon_path = os.path.join(cache_dir,
                                       AddonDownloader.ADDON_SUBPATH)
             if not os.path.exists(addon_path):
                 os.makedirs(addon_path)
             path_to_save = os.path.join(addon_path,
                                         os.path.basename(file_loc))
             logger.info("using {} as data/addons cache".format(cache_dir))
             web_downloader(url, path_to_save)
             dcache.update({url: path_to_save})
             write_json(
                 dcache,
                 os.path.join(self.data_download_cache, DATA_CACHE_CONF))
             return path_to_save
     raise RuntimeError(
         "the file [{}] is not in cache and can not be downloaded".format(
             file_loc))
Esempio n. 12
0
def test_read_json_strict():
    with pytest.raises(IOError):
        read_json(os.path.join("not_there.json"), strict=True)
Esempio n. 13
0
def test_read_json_given_default():
    gold_default = "default"
    data = read_json(os.path.join(data_loc, "not_there.json"), gold_default)
    assert data == gold_default
Esempio n. 14
0
def test_read_json_default_value():
    gold_default = {}
    data = read_json(os.path.join(data_loc, "not_there.json"))
    assert data == gold_default
Esempio n. 15
0
def test_read_json(gold_data):
    data = read_json(os.path.join(data_loc, "test_json.json"))
    assert data == gold_data
Esempio n. 16
0
def run():
    parser = ArgumentParser()
    parser.add_argument("--basedir", type=str)
    parser.add_argument("--checkpoint",
                        type=str,
                        help='Checkpoint name or directory to load')
    parser.add_argument("--sample",
                        type=str2bool,
                        help='Sample from the decoder?  Defaults to `true`',
                        default=1)
    parser.add_argument("--vocab",
                        type=str,
                        help='Vocab file to load',
                        required=False)
    parser.add_argument("--query", type=str, default='hello how are you ?')
    parser.add_argument("--dataset_cache",
                        type=str,
                        default=os.path.expanduser('~/.bl-data'),
                        help="Path or url of the dataset cache")
    parser.add_argument("--d_model",
                        type=int,
                        default=512,
                        help="Model dimension (and embedding dsz)")
    parser.add_argument("--d_ff", type=int, default=2048, help="FFN dimension")
    parser.add_argument(
        "--d_k",
        type=int,
        default=None,
        help="Dimension per head.  Use if num_heads=1 to reduce dims")
    parser.add_argument("--num_heads",
                        type=int,
                        default=8,
                        help="Number of heads")
    parser.add_argument("--num_layers",
                        type=int,
                        default=8,
                        help="Number of layers")
    parser.add_argument(
        "--nctx",
        type=int,
        default=256,
        help="Max context length (for both encoder and decoder)")
    parser.add_argument(
        "--embed_type",
        type=str,
        default='default',
        help=
        "register label of the embeddings, so far support positional or learned-positional"
    )
    parser.add_argument("--subword_model_file", type=str, required=True)
    parser.add_argument("--subword_vocab_file", type=str, required=True)
    parser.add_argument("--activation", type=str, default='relu')
    parser.add_argument(
        '--rpr_k',
        help=
        'Relative attention positional sizes pass 0 if you dont want relative attention',
        type=int,
        default=[48] * 8,
        nargs='+')
    parser.add_argument("--use_cls", type=str2bool, default=True)
    parser.add_argument("--go_token", default="<GO>")
    parser.add_argument("--device",
                        type=str,
                        default="cuda" if torch.cuda.is_available() else "cpu",
                        help="Device (cuda or cpu)")
    args = parser.parse_args()

    if torch.cuda.device_count() == 1:
        torch.cuda.set_device(0)
        args.device = torch.device("cuda", 0)

    vocab_file = args.vocab

    if os.path.isdir(args.checkpoint):
        if not vocab_file:
            vocab_file = os.path.join(args.checkpoint, 'vocabs.json')
        checkpoint, _ = find_latest_checkpoint(args.checkpoint)
        logger.warning("Found latest checkpoint %s", checkpoint)
    else:
        checkpoint = args.checkpoint
        if not vocab_file:
            vocab_file = os.path.join(os.path.dirname(checkpoint),
                                      'vocabs.json')

    vocab = read_json(vocab_file)
    # If we are not using chars, then use 'x' for both input and output
    preproc_data = baseline.embeddings.load_embeddings(
        'x',
        dsz=args.d_model,
        counts=False,
        known_vocab=vocab,
        embed_type=args.embed_type)
    embeddings = preproc_data['embeddings']
    vocab = preproc_data['vocab']
    model = create_model(embeddings,
                         d_model=args.d_model,
                         d_ff=args.d_ff,
                         num_heads=args.num_heads,
                         num_layers=args.num_layers,
                         rpr_k=args.rpr_k,
                         d_k=args.d_k,
                         checkpoint_name=checkpoint,
                         activation=args.activation)
    model.to(args.device)

    cls = None if not args.use_cls else '[CLS]'
    vectorizer = BPEVectorizer1D(model_file=args.subword_model_file,
                                 vocab_file=args.subword_vocab_file,
                                 mxlen=args.nctx,
                                 emit_begin_tok=cls)

    index2word = revlut(vocab)
    print('[Query]', args.query)
    print(
        '[Response]', ' '.join(
            decode_sentence(model,
                            vectorizer,
                            args.query.split(),
                            vocab,
                            index2word,
                            args.device,
                            max_response_length=args.nctx,
                            sou_token=args.go_token,
                            sample=args.sample)))
Esempio n. 17
0
def update_cache(key, data_download_cache):
    dcache = read_json(os.path.join(data_download_cache, DATA_CACHE_CONF))
    if key not in dcache:
        return
    del dcache[key]
    write_json(dcache, os.path.join(data_download_cache, DATA_CACHE_CONF))