Exemplo n.º 1
0
    def __init__(self, **kwargs):
        super().__init__()
        # You dont actually have to pass this if you are using the `load_bert_vocab` call from your
        # tokenizer.  In this case, a singleton variable will contain the vocab and it will be returned
        # by `load_bert_vocab`
        # If you trained your model with MEAD/Baseline, you will have a `*.json` file which would want to
        # reference here
        vocab_file = kwargs.get('vocab_file')
        if vocab_file and os.path.exists(vocab_file):
            if vocab_file.endswith('.json'):
                self.vocab = read_config_stream(kwargs.get('vocab_file'))
            else:
                self.vocab = load_bert_vocab(kwargs.get('vocab_file'))
        else:
            self.vocab = kwargs.get('vocab', kwargs.get('known_vocab'))
            if self.vocab is None or isinstance(self.vocab,
                                                collections.Counter):
                self.vocab = load_bert_vocab(None)
        # When we reload, allows skipping restoration of these embeddings
        # If the embedding wasnt trained with token types, this allows us to add them later
        self.skippable = set(listify(kwargs.get('skip_restore_embeddings',
                                                [])))

        self.cls_index = self.vocab.get('[CLS]', self.vocab.get('<s>'))
        self.vsz = max(self.vocab.values()) + 1
        self.d_model = int(kwargs.get('dsz', kwargs.get('d_model', 768)))
        self.init_embed(**kwargs)
        self.proj_to_dsz = pytorch_linear(
            self.dsz, self.d_model) if self.dsz != self.d_model else _identity
        self.init_transformer(**kwargs)
        self.return_mask = kwargs.get('return_mask', False)
Exemplo n.º 2
0
    def __init__(self, name=None, **kwargs):
        super().__init__(name=name)
        # You dont actually have to pass this if you are using the `load_bert_vocab` call from your
        # tokenizer.  In this case, a singleton variable will contain the vocab and it will be returned
        # by `load_bert_vocab`
        # If you trained your model with MEAD/Baseline, you will have a `*.json` file which would want to
        # reference here
        vocab_file = kwargs.get('vocab_file')
        if vocab_file and vocab_file.endswith('.json'):
            self.vocab = read_json(vocab_file)
        else:
            self.vocab = load_bert_vocab(kwargs.get('vocab_file'))

        # When we reload, allows skipping restoration of these embeddings
        # If the embedding wasnt trained with token types, this allows us to add them later
        self.skippable = set(listify(kwargs.get('skip_restore_embeddings',
                                                [])))

        self.cls_index = self.vocab.get('[CLS]', self.vocab.get('<s>'))
        self.vsz = max(self.vocab.values()) + 1
        self.d_model = int(kwargs.get('dsz', kwargs.get('d_model', 768)))
        self.init_embed(**kwargs)
        self.proj_to_dsz = tf.keras.layers.Dense(
            self.dsz, self.d_model) if self.dsz != self.d_model else _identity
        self.init_transformer(**kwargs)
Exemplo n.º 3
0
 def __init__(self, name, **kwargs):
     super().__init__(name=name, **kwargs)
     self.dsz = kwargs.get('dsz')
     self.model = BertModel.from_pretrained(kwargs.get('handle'))
     self.vocab = load_bert_vocab(None)
     self.vsz = len(
         self.vocab
     )  # 30522 self.model.embeddings.word_embeddings.num_embeddings
Exemplo n.º 4
0
    def __init__(self, name, **kwargs):
        super().__init__(name=name, **kwargs)
        self.handle = BERTHubModel.TF_HUB_URL + kwargs.get('embed_file')
        if 'vocab' in kwargs:
            self.vocab = kwargs['vocab']
        else:
            self.vocab = load_bert_vocab(kwargs.get('vocab_file'))

        self.vsz = len(self.vocab)
        self.dsz = kwargs.get('dsz')
        self.trainable = kwargs.get('trainable', False)
Exemplo n.º 5
0
    def __init__(self, name, **kwargs):
        super().__init__(name=name)

        self.dsz = kwargs.get('dsz')
        self.bert_config = BertConfig.from_json_file(kwargs.get('bert_config'))
        self.vocab = load_bert_vocab(kwargs.get('vocab_file'))
        self.vsz = self.bert_config.vocab_size
        assert self.vsz == len(self.vocab)
        self.use_one_hot_embeddings = kwargs.get('use_one_hot_embeddings', False)
        self.layer_indices = kwargs.get('layers', [-1, -2, -3, -4])
        self.operator = kwargs.get('operator', 'concat')
Exemplo n.º 6
0
 def __init__(self, **kwargs):
     super().__init__()
     # You dont actually have to pass this if you are using the `load_bert_vocab` call from your
     # tokenizer.  In this case, a singleton variable will contain the vocab and it will be returned
     # by `load_bert_vocab`
     # If you trained your model with MEAD/Baseline, you will have a `*.json` file which would want to
     # reference here
     vocab_file = kwargs.get('vocab_file')
     if vocab_file and vocab_file.endswith('.json'):
         self.vocab = read_config_stream(kwargs.get('vocab_file'))
     else:
         self.vocab = load_bert_vocab(kwargs.get('vocab_file'))
     self.cls_index = self.vocab['[CLS]']
     self.vsz = max(self.vocab.values()) + 1
     self.d_model = int(kwargs.get('dsz', kwargs.get('d_model', 768)))
     self.init_embed(**kwargs)
     self.proj_to_dsz = pytorch_linear(
         self.dsz, self.d_model) if self.dsz != self.d_model else _identity
     self.init_transformer(**kwargs)