コード例 #1
0
    def __init__(self,
                 bert_path: Path,
                 encoder: Seq2SeqEncoder,
                 vocab: Vocabulary,
                 hidden_dim: int = 100,
                 encoder_dropout: float = 0.0,
                 train_bert: bool = False) -> None:
        # We have to pass the vocabulary to the constructor.
        super().__init__(vocab)
        self.word_embeddings = bert_embeddings(pretrained_model=bert_path,
                                               training=train_bert)

        self.encoder_dropout: torch.nn.Module
        if encoder_dropout > 0:
            self.encoder_dropout = torch.nn.Dropout(p=encoder_dropout)
        else:
            self.encoder_dropout = torch.nn.Identity()

        self.pooler = BertPooler(pretrained_model=str(bert_path))
        self.dense1 = torch.nn.Linear(in_features=self.pooler.get_output_dim(),
                                      out_features=hidden_dim)
        self.encoder = encoder
        self.self_attn = LinearSelfAttention(
            input_dim=self.encoder.get_output_dim(), bias=True)
        self.dense2 = torch.nn.Linear(
            in_features=self.encoder.get_output_dim(), out_features=1)
コード例 #2
0
ファイル: common.py プロジェクト: oyarsa/literate-lamp
def get_word_embeddings(vocabulary: Vocabulary) -> TextFieldEmbedder:
    "Instatiates the word embeddings based on config."
    if ARGS.EMBEDDING_TYPE == 'glove':
        return glove_embeddings(vocabulary,
                                ARGS.GLOVE_PATH,
                                ARGS.GLOVE_EMBEDDING_DIM,
                                training=True)
    if ARGS.EMBEDDING_TYPE == 'bert':
        return bert_embeddings(pretrained_model=ARGS.BERT_PATH,
                               training=ARGS.finetune_embeddings)
    if ARGS.EMBEDDING_TYPE == 'xlnet':
        return xlnet_embeddings(config_path=ARGS.xlnet_config_path,
                                model_path=ARGS.xlnet_model_path,
                                training=ARGS.finetune_embeddings,
                                window_size=ARGS.xlnet_window_size)
    raise ValueError(f'Invalid word embedding type: {ARGS.EMBEDDING_TYPE}')
コード例 #3
0
ファイル: common.py プロジェクト: oyarsa/literate-lamp
def build_hierarchical_bert(vocab: Vocabulary) -> Model:
    """
    Builds the HierarchicalBert.

    Parameters
    ---------
    vocab : Vocabulary built from the problem dataset.

    Returns
    -------
    A `HierarchicalBert` model ready to be trained.
    """
    if ARGS.ENCODER_TYPE == 'lstm':
        encoder_fn = lstm_encoder
    elif ARGS.ENCODER_TYPE == 'gru':
        encoder_fn = gru_encoder
    else:
        raise ValueError('Invalid RNN type')

    bert = bert_embeddings(ARGS.BERT_PATH)
    embedding_dim = bert.get_output_dim()

    # To prevent the warning on single-layer, as the dropout is only
    # between layers of the stacked RNN.
    dropout = ARGS.RNN_DROPOUT if ARGS.RNN_LAYERS > 1 else 0

    if ARGS.ENCODER_TYPE in ['lstm', 'gru']:
        sentence_encoder = encoder_fn(input_dim=embedding_dim,
                                      output_dim=ARGS.HIDDEN_DIM,
                                      num_layers=ARGS.RNN_LAYERS,
                                      bidirectional=ARGS.BIDIRECTIONAL,
                                      dropout=dropout)
        document_encoder = encoder_fn(
            input_dim=sentence_encoder.get_output_dim(),
            output_dim=ARGS.HIDDEN_DIM,
            num_layers=1,
            bidirectional=ARGS.BIDIRECTIONAL,
            dropout=dropout)

    # Instantiate modele with our embedding, encoder and vocabulary
    model = models.HierarchicalBert(bert_path=ARGS.BERT_PATH,
                                    sentence_encoder=sentence_encoder,
                                    document_encoder=document_encoder,
                                    vocab=vocab,
                                    encoder_dropout=ARGS.RNN_DROPOUT)

    return model
コード例 #4
0
    def __init__(self,
                 bert_path: Path,
                 vocab: Vocabulary,
                 train_bert: bool = False
                 ) -> None:
        # We have to pass the vocabulary to the constructor.
        super().__init__(vocab)
        self.word_embeddings = bert_embeddings(pretrained_model=bert_path,
                                               training=train_bert)

        self.pooler = BertPooler(pretrained_model=str(bert_path))

        hidden_dim = self.pooler.get_output_dim()
        self.hidden2logit = torch.nn.Linear(
            in_features=hidden_dim,
            out_features=1
        )
コード例 #5
0
    def __init__(self,
                 bert_path: Path,
                 sentence_encoder: Seq2VecEncoder,
                 document_encoder: Seq2VecEncoder,
                 vocab: Vocabulary,
                 encoder_dropout: float = 0.0,
                 train_bert: bool = False
                 ) -> None:
        # We have to pass the vocabulary to the constructor.
        super().__init__(vocab)
        self.word_embeddings = bert_embeddings(pretrained_model=bert_path,
                                               training=train_bert)

        if encoder_dropout > 0:
            self.encoder_dropout = torch.nn.Dropout(p=encoder_dropout)
        else:
            self.encoder_dropout = lambda x: x

        self.sentence_encoder = sentence_encoder
        self.document_encoder = document_encoder
        self.dense = torch.nn.Linear(
            in_features=document_encoder.get_output_dim(),
            out_features=1
        )