コード例 #1
0
class BertClassifier(Model):
    def __init__(self, args, out_sz: int,
                 vocab: Vocabulary):
        super().__init__(vocab)

        # init word embedding
        bert_embedder = PretrainedBertEmbedder(
            pretrained_model="bert-base-uncased",
            top_layer_only=True,  # conserve memory
        )
        self.word_embeddings = BasicTextFieldEmbedder({"tokens": bert_embedder},
                                                      # we'll be ignoring masks so we'll need to set this to True
                                                      allow_unmatched_keys=True)

        self.encoder = BertSentencePooler(vocab, self.word_embeddings.get_output_dim())
        self.projection = nn.Linear(self.encoder.get_output_dim(), out_sz)
        self.loss = nn.CrossEntropyLoss()

    def forward(self, tokens: Dict[str, torch.Tensor],
                id: Any, label: torch.Tensor) -> torch.Tensor:
        mask = get_text_field_mask(tokens)
        embeddings = self.word_embeddings(tokens)
        state = self.encoder(embeddings, mask)
        class_logits = self.projection(state)

        output = {"class_logits": class_logits}
        output["loss"] = self.loss(class_logits, label)

        return output
コード例 #2
0
def build_elmo_model(vocab: Vocabulary) -> Model:
    print("Building the model")
    vocab_size = vocab.get_vocab_size("tokens")
    embedding = ElmoTokenEmbedder()
    embedder = BasicTextFieldEmbedder(token_embedders={'bert_tokens': embedding})
    encoder = BagOfEmbeddingsEncoder(embedding_dim=embedder.get_output_dim(), averaged=True)
    
    return SimpleClassifier(vocab, embedder, encoder)
コード例 #3
0
def build_pool_transformer_model(vocab: Vocabulary, transformer_model: str) -> Model:
    print("Building the model")
    vocab_size = vocab.get_vocab_size("tokens")
    embedding = PretrainedTransformerEmbedder(model_name=transformer_model)
    embedder = BasicTextFieldEmbedder(token_embedders={'bert_tokens': embedding})
    encoder = BagOfEmbeddingsEncoder(embedding_dim=embedder.get_output_dim(), averaged=True)
    #encoder = ClsPooler(embedding_dim=embedder.get_output_dim())
    return SimpleClassifier(vocab, embedder, encoder)
コード例 #4
0
def multitask_learning():
    # load datasetreader 
    # Save logging to a local file
    # Multitasking
    log.getLogger().addHandler(log.FileHandler(directory+"/log.log"))

    lr = 0.00001
    batch_size = 2
    epochs = 10 
    max_seq_len = 512
    max_span_width = 30
    #token_indexer = BertIndexer(pretrained_model="bert-base-uncased", max_pieces=max_seq_len, do_lowercase=True,)
    token_indexer = PretrainedBertIndexer("bert-base-cased", do_lowercase=False)
    conll_reader = ConllCorefBertReader(max_span_width = max_span_width, token_indexers = {"tokens": token_indexer})
    swag_reader = SWAGDatasetReader(tokenizer=token_indexer.wordpiece_tokenizer,lazy=True, token_indexers=token_indexer)
    EMBEDDING_DIM = 1024
    HIDDEN_DIM = 200
    conll_datasets, swag_datasets = load_datasets(conll_reader, swag_reader, directory)
    conll_vocab = Vocabulary()
    swag_vocab = Vocabulary()
    conll_iterator = BasicIterator(batch_size=batch_size)
    conll_iterator.index_with(conll_vocab)

    swag_vocab = Vocabulary()
    swag_iterator = BasicIterator(batch_size=batch_size)
    swag_iterator.index_with(swag_vocab)


    from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder

    bert_embedder = PretrainedBertEmbedder(pretrained_model="bert-base-cased",top_layer_only=True, requires_grad=True)

    word_embedding = BasicTextFieldEmbedder({"tokens": bert_embedder}, allow_unmatched_keys=True)
    BERT_DIM = word_embedding.get_output_dim()

    seq2seq = PytorchSeq2SeqWrapper(torch.nn.LSTM(BERT_DIM, HIDDEN_DIM, batch_first=True, bidirectional=True))
    seq2vec = PytorchSeq2VecWrapper(torch.nn.LSTM(BERT_DIM, HIDDEN_DIM, batch_first=True, bidirectional=True))
    mention_feedforward = FeedForward(input_dim = 2336, num_layers = 2, hidden_dims = 150, activations = torch.nn.ReLU())
    antecedent_feedforward = FeedForward(input_dim = 7776, num_layers = 2, hidden_dims = 150, activations = torch.nn.ReLU())
    model1 = CoreferenceResolver(vocab=conll_vocab, text_field_embedder=word_embedding,context_layer= seq2seq, mention_feedforward=mention_feedforward,antecedent_feedforward=antecedent_feedforward , feature_size=768,max_span_width=max_span_width,spans_per_word=0.4,max_antecedents=250,lexical_dropout= 0.2)

    model2 = SWAGExampleModel(vocab=swag_vocab, text_field_embedder=word_embedding, phrase_encoder=seq2vec)
    optimizer1 = optim.Adam(model1.parameters(), lr=lr)
    optimizer2 = optim.Adam(model2.parameters(), lr=lr)

    swag_train_iterator = swag_iterator(swag_datasets[0], num_epochs=1, shuffle=True)
    conll_train_iterator = conll_iterator(conll_datasets[0], num_epochs=1, shuffle=True)
    swag_val_iterator = swag_iterator(swag_datasets[1], num_epochs=1, shuffle=True)
    conll_val_iterator:q = conll_iterator(conll_datasets[1], num_epochs=1, shuffle=True)
    task_infos = {"swag": {"model": model2, "optimizer": optimizer2, "loss": 0.0, "iterator": swag_iterator, "train_data": swag_datasets[0], "val_data": swag_datasets[1], "num_train": len(swag_datasets[0]), "num_val": len(swag_datasets[1]), "lr": lr, "score": {"accuracy":0.0}}, \
                    "conll": {"model": model1, "iterator": conll_iterator, "loss": 0.0, "val_data": conll_datasets[1], "train_data": conll_datasets[0], "optimizer": optimizer1, "num_train": len(conll_datasets[0]), "num_val": len(conll_datasets[1]),"lr": lr, "score": {"coref_prediction": 0.0, "coref_recall": 0.0, "coref_f1": 0.0,"mention_recall": 0.0}}}
    USE_GPU = 1
    trainer = MultiTaskTrainer(
        task_infos=task_infos, 
        num_epochs=epochs,
        serialization_dir=directory + "saved_models/multitask/"
    ) 
    metrics = trainer.train()
コード例 #5
0
def use_glove():
    embedding_dim = 300
    project_dim = 200

    train_reader = StanfordSentimentTreeBankDatasetReader()
    dev_reader = StanfordSentimentTreeBankDatasetReader(use_subtrees=False)
    train_dataset = train_reader.read('~/nlp/dataset/sst/trees/train.txt')
    dev_dataset = dev_reader.read('~/nlp/dataset/sst/trees/dev.txt')

    print(
        f"total train samples: {len(train_dataset)}, dev samples: {len(dev_dataset)}"
    )

    # 建立词汇表,从数据集中建立
    vocab = Vocabulary.from_instances(train_dataset + dev_dataset)
    vocab_dim = vocab.get_vocab_size('tokens')
    print("vocab: ", vocab.get_vocab_size('labels'), vocab_dim)

    glove_embeddings_file = '~/nlp/pretrainedEmbeddings/glove/glove.840B.300d.txt'

    # If you want to actually load a pretrained embedding file,
    # you currently need to do that by calling Embedding.from_params()
    # see https://github.com/allenai/allennlp/issues/2694
    token_embedding = Embedding.from_params(vocab=vocab,
                                            params=Params({
                                                'pretrained_file':
                                                glove_embeddings_file,
                                                'embedding_dim': embedding_dim,
                                                'trainable': False
                                            }))
    word_embeddings = BasicTextFieldEmbedder({"tokens": token_embedding})
    print(word_embeddings.get_output_dim())

    # use batch_to_ids to convert sentences to character ids
    sentence_lists = [["I", 'have', 'a', "dog"],
                      ["How", 'are', 'you', ',', 'today', 'is', "Monday"]]

    sentence_ids = batch_to_ids(sentence_lists, vocab)
    embeddings = token_embedding(sentence_ids)

    for sentence in sentence_lists:
        for text in sentence:
            indice = vocab.get_token_index(text)
            print(f"text: {text}, indice: {indice}")

    # calculate distance based on elmo embedding
    import scipy
    tokens = [["dog", "ate", "an", "apple", "for", "breakfast"]]
    tokens2 = [["cat", "ate", "an", "carrot", "for", "breakfast"]]
    token_ids = batch_to_ids(tokens, vocab)
    token_ids2 = batch_to_ids(tokens2, vocab)
    vectors = token_embedding(token_ids)
    vectors2 = token_embedding(token_ids2)

    print('embedding shape ', vectors.shape)
    print('\nvector ', vectors[0][0], vectors2[0][0])
    distance = scipy.spatial.distance.cosine(vectors[0][0], vectors2[0][0])
    print(f"embedding distance: {distance}")
コード例 #6
0
def build_model(options_file, weight_file):
    vocab = Vocabulary()
    iterator = BucketIterator(batch_size=config.batch_size, sorting_keys=[("tokens", "num_tokens")])
    iterator.index_with(vocab)

    elmo_embedder = ElmoTokenEmbedder(options_file, weight_file)
    word_embeddings = BasicTextFieldEmbedder({"tokens": elmo_embedder})
    encoder: Seq2VecEncoder = PytorchSeq2VecWrapper(nn.LSTM(word_embeddings.get_output_dim(), config.hidden_size, bidirectional=True, batch_first=True))
    model = BaselineModel(word_embeddings, encoder, vocab)

    return model, iterator, vocab
コード例 #7
0
def train_only_swag():
    # load datasetreader 
    # Save logging to a local file
    # Multitasking
    log.getLogger().addHandler(log.FileHandler(directory+"/log.log"))

    lr = 0.00001
    batch_size = 2
    epochs = 100
    max_seq_len = 512
    max_span_width = 30
    #token_indexer = BertIndexer(pretrained_model="bert-base-uncased", max_pieces=max_seq_len, do_lowercase=True,)
    token_indexer = PretrainedBertIndexer("bert-base-cased", do_lowercase=False)
    swag_reader = SWAGDatasetReader(tokenizer=token_indexer.wordpiece_tokenizer,lazy=True, token_indexers=token_indexer)
    EMBEDDING_DIM = 1024
    HIDDEN_DIM = 200
    swag_datasets = load_swag(swag_reader, directory)
    swag_vocab = Vocabulary()

    swag_vocab = Vocabulary()
    swag_iterator = BasicIterator(batch_size=batch_size)
    swag_iterator.index_with(swag_vocab)

    from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder

    bert_embedder = PretrainedBertEmbedder(pretrained_model="bert-base-cased",top_layer_only=True, requires_grad=True)

    word_embedding = BasicTextFieldEmbedder({"tokens": bert_embedder}, allow_unmatched_keys=True)
    BERT_DIM = word_embedding.get_output_dim()
    seq2vec = PytorchSeq2VecWrapper(torch.nn.LSTM(BERT_DIM, HIDDEN_DIM, batch_first=True, bidirectional=True))
    mention_feedforward = FeedForward(input_dim = 2336, num_layers = 2, hidden_dims = 150, activations = torch.nn.ReLU())
    antecedent_feedforward = FeedForward(input_dim = 7776, num_layers = 2, hidden_dims = 150, activations = torch.nn.ReLU())

    model = SWAGExampleModel(vocab=swag_vocab, text_field_embedder=word_embedding, phrase_encoder=seq2vec)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    USE_GPU =1 
    val_iterator = swag_iterator(swag_datasets[1], num_epochs=1, shuffle=True)
    trainer = Trainer(
        model=model,
        optimizer=optimizer,
        iterator=swag_iterator,
        validation_iterator = swag_iterator, 
        train_dataset=swag_datasets[0],
        validation_dataset = swag_datasets[1], 
        validation_metric = "+accuracy",
        cuda_device=0 if USE_GPU else -1,
        serialization_dir= directory + "saved_models/current_run_model_state_swag",
        num_epochs=epochs,
    )    

    metrics = trainer.train()
    # save the model
    with open(directory + "saved_models/current_run_model_state", 'wb') as f:
        torch.save(model.state_dict(), f)
def main():
    cuda_device = -1

    torch.manual_seed(SEED)

    elmo_embedder = ElmoTokenEmbedder(OPTION_FILE, WEIGHT_FILE)
    word_embeddings = BasicTextFieldEmbedder({"tokens": elmo_embedder})

    lstm = PytorchSeq2VecWrapper(
        torch.nn.LSTM(word_embeddings.get_output_dim(),
                      HIDDEN_DIM,
                      bidirectional=True,
                      batch_first=True))

    train_dataset, dev_dataset = dataset_reader(train=True, elmo=True)
    vocab = Vocabulary()

    model = BaseModel(word_embeddings=word_embeddings,
                      encoder=lstm,
                      vocabulary=vocab)

    if torch.cuda.is_available():
        cuda_device = 0
        model = model.cuda(cuda_device)

    iterator = data_iterator(vocab)
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=LEARNING_RATE,
                                 weight_decay=WEIGHT_DECAY)

    trainer = Trainer(model=model,
                      optimizer=optimizer,
                      iterator=iterator,
                      train_dataset=train_dataset,
                      validation_dataset=dev_dataset,
                      cuda_device=cuda_device,
                      num_epochs=EPOCHS,
                      patience=5)

    trainer.train()

    print("*******Save Model*******\n")

    output_elmo_model_file = os.path.join(PRETRAINED_ELMO,
                                          "lstm_elmo_model.bin")
    torch.save(model.state_dict(), output_elmo_model_file)
コード例 #9
0
def load_elmo_model():
    elmo_embedders = ElmoTokenEmbedder(OPTION_FILE, WEIGHT_FILE)
    word_embeddings = BasicTextFieldEmbedder({"tokens": elmo_embedders})

    encoder = PytorchSeq2VecWrapper(
        torch.nn.LSTM(word_embeddings.get_output_dim(),
                      HIDDEN_DIM,
                      bidirectional=True,
                      batch_first=True))

    vocabulary = Vocabulary()

    model = BaseModel(word_embeddings=word_embeddings,
                      encoder=encoder,
                      vocabulary=vocabulary)

    output_elmo_model_file = os.path.join(PRETRAINED_ELMO,
                                          "lstm_elmo_model.bin")
    model.load_state_dict(torch.load(output_elmo_model_file))
    return model
コード例 #10
0
ファイル: models.py プロジェクト: JakobBozic/nlp-ner
def get_model(vocab, params):
    emb_d = params["embedding_dim"]
    hidden_d = params["hidden_dim"]

    use_elmo_embeddings = params['use_elmo']
    use_lstm = params['use_lstm']
    n_layers = params["num_layers"]

    bidirectional = params['bidirectional']

    if use_elmo_embeddings:
        token_embedder = ElmoTokenEmbedder(ELMO_OPTIONS_FILE,
                                           ELMO_WEIGHTS_FILE)
    else:
        token_embedder = Embedding(
            num_embeddings=vocab.get_vocab_size('tokens'), embedding_dim=emb_d)

    word_embedder = BasicTextFieldEmbedder({"tokens": token_embedder})
    emb_d = word_embedder.get_output_dim()

    if use_lstm:
        encoder = PytorchSeq2SeqWrapper(
            torch.nn.LSTM(emb_d,
                          hidden_d,
                          num_layers=n_layers,
                          batch_first=True,
                          bidirectional=bidirectional))
    else:
        encoder = PytorchSeq2SeqWrapper(
            torch.nn.GRU(emb_d,
                         hidden_d,
                         num_layers=n_layers,
                         batch_first=True,
                         bidirectional=bidirectional))

    model = NerModel(word_embedder,
                     encoder,
                     vocab,
                     num_categories=(3 if params["dataset"] == "senti" else 4))
    return model
コード例 #11
0
class TempXCtxModel(ModelBase):
    def __init__(self, num_authors: int, out_sz: int, vocab: Vocabulary,
                 date_span: Any):
        super().__init__(vocab)

        # init word embedding
        bert_embedder = PretrainedBertEmbedder(
            pretrained_model="bert-base-uncased",
            top_layer_only=True,  # conserve memory
        )
        self.date_span = date_span
        self.word_embeddings = BasicTextFieldEmbedder(
            {"tokens": bert_embedder},
            # we'll be ignoring masks so we'll need to set this to True
            allow_unmatched_keys=True)

        self.encoder = BertSentencePooler(
            vocab, self.word_embeddings.get_output_dim())

        self.num_authors = num_authors

        # skills dim
        # self.num_sk, self.sk_dim, self.time_dim = 20, 768, 32
        self.num_sk, self.sk_dim, self.time_dim = 20, 768, 768

        # self.author_dim = self.sk_dim + self.time_dim
        self.author_dim = self.sk_dim

        self.author_embeddings = nn.Parameter(torch.randn(
            num_authors, self.author_dim),
                                              requires_grad=True)  # (m, d)

        # self.ctx_attention = MultiHeadCtxAttention(h=8, d_model=self.sk_dim + self.time_dim)
        self.temp_ctx_attention_ns = TempCtxAttentionNS(
            h=8,
            d_model=self.author_dim,
            d_query=self.sk_dim,
            d_time=self.time_dim)

        # temporal context
        self.time_encoder = TimeEncoder(self.time_dim,
                                        dropout=0.1,
                                        span=1,
                                        date_range=date_span)

        # layer_norm
        self.ctx_layer_norm = LayerNorm(self.author_dim)

        # loss related
        # self.cohere_loss = CoherenceLoss(self.encoder.get_output_dim(), out_sz)
        self.triplet_loss = TripletLoss(self.encoder.get_output_dim(), out_sz)
        self.htemp_loss = HTempLoss(self.encoder.get_output_dim(), out_sz)
        self.rank_loss = MarginRankLoss(self.encoder.get_output_dim(), out_sz)

        self.coherence_func = CoherenceInnerProd()

    def forward(self, tokens: Dict[str, torch.Tensor], id: Any, answerers: Any,
                date: Any, accept_usr: Any) -> torch.Tensor:

        # n -- batch number
        # m -- author number
        # d -- hidden dimension
        # k -- skill number
        # l -- text length
        # p -- pos/neg author number in one batch
        mask = get_text_field_mask(tokens)
        embeddings = self.word_embeddings(tokens)  # (n, l, d)
        token_hidden = self.encoder(embeddings,
                                    mask).transpose(-1, -2)  # (n, l, d)

        token_embed = torch.mean(token_hidden, 1).squeeze(1)  # (n, d)
        # token_embed = token_hidden[:, :, -1]

        # transfer the date into time embedding
        # TODO: use answer date for time embedding
        time_embed = gen_time_encoding(self.time_encoder, date)

        # token_temp_embed = torch.cat((token_embed, time_embed), 1)
        token_temp_embed = token_embed + time_embed
        author_tctx_embed = self.temp_ctx_attention_ns(token_embed,
                                                       self.author_embeddings,
                                                       self.author_embeddings,
                                                       time_embed)  # (n, m, d)

        # add layer norm for author context embedding
        author_tctx_embed = self.ctx_layer_norm(author_tctx_embed)  # (n, m, d)

        # generate loss
        loss, coherence = self.rank_loss(token_temp_embed, author_tctx_embed,
                                         answerers, accept_usr)

        output = {"loss": loss, "coherence": coherence}

        predict = np.argsort(-coherence.detach().cpu().numpy(), axis=1)
        truth = [[j[0] for j in i] for i in answerers]

        # self.rank_recall(predict, truth)
        # self.mrr(predict, truth)
        self.mrr(predict, accept_usr)

        return output
コード例 #12
0
class BertNoCtxRanker(ModelBase):
    def __init__(self, args, num_authors: int, out_sz: int,
                 vocab: Vocabulary):
        super().__init__(vocab)

        # init word embedding
        bert_embedder = PretrainedBertEmbedder(
            pretrained_model="bert-base-uncased",
            top_layer_only=True,  # conserve memory
        )
        self.word_embeddings = BasicTextFieldEmbedder({"tokens": bert_embedder},
                                                      # we'll be ignoring masks so we'll need to set this to True
                                                      allow_unmatched_keys=True)

        self.encoder = BertSentencePooler(vocab, self.word_embeddings.get_output_dim())

        self.num_authors = num_authors

        # skills dim
        self.num_sk, self.sk_dim = 20, 768
        self.author_embeddings = nn.Parameter(torch.randn(num_authors, self.sk_dim), requires_grad=True)  # (m, d)

        self.attention = nn.Parameter(torch.randn(self.word_embeddings.get_output_dim(), self.sk_dim), requires_grad=True)
        # nn.Linear(self.word_embeddings.get_output_dim(), self.sk_dim)

        self.tanh = nn.Tanh()
        self.softmax = nn.Softmax(dim=2)
        self.sigmoid = nn.Sigmoid()

        self.projection = nn.Linear(self.encoder.get_output_dim(), out_sz)
        # self.loss = nn.CrossEntropyLoss()

        # loss related
        # self.cohere_loss = CoherenceLoss(self.encoder.get_output_dim(), out_sz)
        self.triplet_loss = TripletLoss(self.encoder.get_output_dim(), out_sz)


    def build_coherence(self, token_hidden, author_embeds):

        # token_hidden (n, d, l)
        # author_embeds (m, d)

        n, _, l = token_hidden.shape
        m = author_embeds.shape[0]

        token_embed = torch.mean(token_hidden, 2)  # (n, d)

        coherence = torch.einsum('nd,md->nm', [token_embed, author_embeds])  # (n, m)

        return coherence


    def forward(self, tokens: Dict[str, torch.Tensor],
                id: Any, label: Any, date: Any) -> torch.Tensor:

        # n -- batch number
        # m -- author number
        # d -- hidden dimension
        # k -- skill number
        # l -- text length
        # p -- pos/neg author number in one batch

        mask = get_text_field_mask(tokens)
        embeddings = self.word_embeddings(tokens)
        token_hidden = self.encoder(embeddings, mask)  # (n, d, l)
        token_embed = torch.mean(token_hidden, 2)  # (n, d)

        # coherence = self.build_coherence(token_hidden, self.author_embeddings)

        # generate positive loss
        # all_labels = list(range(self.num_authors))
        # loss = 0
        # for i, pos_labels in enumerate(label):
        #
        #     pos_labels = torch.tensor(pos_labels)
        #     if torch.cuda.is_available(): pos_labels = pos_labels.cuda()
        #     pos_coherence = coherence[i, pos_labels]
        #     pos_loss = torch.sum(-torch.log(self.sigmoid(pos_coherence))) / len(pos_labels)
        #
        #     neg_labels = torch.tensor([item for item in all_labels if item not in pos_labels])
        #     if torch.cuda.is_available(): neg_labels = neg_labels.cuda()
        #     neg_coherence = coherence[i, neg_labels]
        #     neg_loss = torch.sum(-torch.log(self.sigmoid(-neg_coherence))) / len(neg_labels)
        #
        #     loss += (pos_loss + neg_loss)
        #     pass

        # generate negative loss

        # # positive author embeddings
        # pos_author_embeds, pos_size = self.gen_pos_author_embeds(label)  # (n, p, d, k)
        #
        # # negative author embeddings
        # neg_size = pos_size  # choose negative samples the same as positive size
        # neg_author_embeds = self.gen_neg_author_embeds(label, neg_size)
        #
        # pos_coherence = self.build_coherence(token_hidden, pos_author_embeds)
        # neg_coherence = self.build_coherence(token_hidden, neg_author_embeds)
        #
        # pos_loss = torch.sum(torch.sum(torch.log(self.sigmoid(-pos_coherence)))) / pos_size
        # neg_loss = torch.sum(torch.sum(torch.log(self.sigmoid(neg_coherence)))) / neg_size

        # loss = pos_loss + neg_loss

        # loss, coherence = self.cohere_loss(token_embed, self.author_embeddings, label, no_ctx=True)
        loss, coherence = self.triplet_loss(token_embed, self.author_embeddings, label, no_ctx=True)

        output = {"loss": loss, "coherence": coherence}
        predict = np.argsort(-coherence.detach().cpu().numpy(), axis=1)
        self.rank_recall(predict, label)

        return output
コード例 #13
0
class MultiSpanTempModel(ModelBase):
    def __init__(self,
                 num_authors: int,
                 out_sz: int,
                 vocab: Vocabulary,
                 date_span: Any,
                 num_shift: int,
                 spans: List,
                 encoder: Any,
                 max_vocab_size: int,
                 ignore_time: bool,
                 ns_mode: bool = False,
                 num_sk: int = 20):
        super().__init__(vocab)

        self.date_span = date_span

        self.num_authors = num_authors

        # skills dim
        self.num_sk, self.sk_dim = num_sk, 768
        self.ignore_time = ignore_time
        self.ns_mode = ns_mode
        if self.ns_mode:
            self.author_embeddings = nn.Parameter(torch.randn(
                num_authors, self.sk_dim),
                                                  requires_grad=True)  # (m, d)
        else:
            self.author_embeddings = nn.Parameter(
                torch.randn(num_authors, self.num_sk, self.sk_dim),
                requires_grad=True)  # (m, k, d)
        self.encode_type = encoder
        if self.encode_type == "bert":
            # init word embedding
            bert_embedder = PretrainedBertEmbedder(
                pretrained_model="bert-base-uncased",
                top_layer_only=True,  # conserve memory
            )
            self.word_embeddings = BasicTextFieldEmbedder(
                {"tokens": bert_embedder},
                # we'll be ignoring masks so we'll need to set this to True
                allow_unmatched_keys=True)
            self.encoder = BertSentencePooler(
                vocab, self.word_embeddings.get_output_dim())
        else:
            # prepare embeddings
            token_embedding = Embedding(num_embeddings=max_vocab_size + 2,
                                        embedding_dim=300,
                                        padding_index=0)
            self.word_embeddings: TextFieldEmbedder = BasicTextFieldEmbedder(
                {"tokens": token_embedding})

            self.encoder: Seq2VecEncoder = PytorchSeq2VecWrapper(
                nn.LSTM(self.word_embeddings.get_output_dim(),
                        hidden_size=int(self.sk_dim / 2),
                        bidirectional=True,
                        batch_first=True))

        self.ctx_attention = TempCtxAttention(h=8, d_model=self.sk_dim)
        self.ctx_layer_norm = nn.LayerNorm(self.sk_dim)  # layer_norm

        # shifted temporal attentions
        self.spans = spans
        self.span_temp_atts = nn.ModuleList()
        for span in self.spans:
            self.span_temp_atts.append(
                ShiftTempAttention(self.num_authors, self.sk_dim, date_span,
                                   num_shift, span, self.ignore_time))
        self.span_projection = nn.Linear(len(spans), 1)
        self.num_shift = num_shift

        # temporal encoder: used only for adding temporal information into token embedding
        self.time_encoder = TimeEncoder(self.sk_dim,
                                        dropout=0.1,
                                        span=spans[0],
                                        date_range=date_span)

        # loss
        # self.cohere_loss = CoherenceLoss(self.encoder.get_output_dim(), out_sz)
        # self.triplet_loss = TripletLoss(self.encoder.get_output_dim(), out_sz)
        self.temp_loss = TemporalLoss(self.encoder.get_output_dim(), out_sz)
        self.rank_loss = MarginRankLoss(self.encoder.get_output_dim(), out_sz)

        self.weight_temp = 0.3
        self.visual_id = 0

    def forward(self, tokens: Dict[str, torch.Tensor], id: Any, answerers: Any,
                date: Any, accept_usr: Any) -> torch.Tensor:

        # n -- batch number
        # m -- author number
        # d -- hidden dimension
        # k -- skill number
        # l -- text length
        # p -- pos/neg author number in one batch
        mask = get_text_field_mask(tokens)
        embeddings = self.word_embeddings(tokens)
        token_hidden2 = self.encoder(embeddings, mask)

        if self.encode_type == "bert":
            token_hidden = self.encoder(embeddings,
                                        mask).transpose(-1, -2)  # (n, d, l)
            token_embed = torch.mean(token_hidden, 1).squeeze(1)  # (n, d)
        else:
            token_embed = self.encoder(embeddings, mask)  # (n, d)

        time_embed = gen_time_encoding(self.time_encoder, date)

        token_temp_embed = token_embed if self.ignore_time else token_embed + time_embed
        # if self.ignore_time:
        #     token_temp_embed = token_embed
        # else:
        #     token_temp_embed = token_embed + time_embed  # add time embedding

        # generate the token_embed with temporal information
        # time_embed_zs = [self.time_encoder.get_time_encoding(d, num_shift=0) for d in date]
        # time_embed_zs = torch.stack(time_embed_zs, dim=0)  # (n, d)
        # token_temp_embed = token_embed + time_embed_zs

        if self.ns_mode:
            author_ctx_embed = self.author_embeddings.unsqueeze(0).expand(
                token_embed.size(0), -1, -1)  # (n, m, d)
        else:
            # token_embed = token_hidden[:, :, -1]
            author_ctx_embed = self.ctx_attention(
                token_temp_embed, self.author_embeddings,
                self.author_embeddings)  # (n, m, d)

            # add layer norm for author context embedding
            author_ctx_embed = self.ctx_layer_norm(author_ctx_embed)

        # multi-span shifted time attention layer
        span_temp_ctx_embeds, history_embeds = [], []
        for i in range(len(self.spans)):
            temp_ctx_embed, history_embed = self.span_temp_atts[i](
                token_embed, author_ctx_embed, date)  # (n, m, d)
            span_temp_ctx_embeds.append(temp_ctx_embed)
            history_embeds.append(history_embed)
        temp_ctx_embed_sp = torch.stack(span_temp_ctx_embeds, dim=-1)
        # temp_ctx_embed_sp = torch.transpose(torch.stack(temp_ctx_embed_splist), 0, -1)
        temp_ctx_embed = torch.squeeze(self.span_projection(temp_ctx_embed_sp),
                                       dim=-1)

        # print temporal context-aware embedding for visualization
        for i, answerer in enumerate(answerers):

            # generate the visualization embedding file
            if len(answerer) > 10:
                print("QID:", id[i], "Answerers:", len(answerer))
                embed_pq = temp_ctx_embed[i].cpu().numpy()
                qid = id[i]
                answerer_set = set([j[0] for j in answerer])

                with open("./exp_results/ve_" + str(qid), 'a') as f:
                    for j in range(embed_pq.shape[0]):
                        embed_pa = embed_pq[j]
                        embed_dump = "\t".join([str(i) for i in embed_pa])
                        category = 1 if j in answerer_set else 0
                        f.write(str(category) + "\t" + embed_dump + "\n")
                self.visual_id += 1

        # generate loss
        # loss, coherence = self.cohere_loss(token_embed, temp_ctx_embed, label)
        # triplet_loss, coherence = self.triplet_loss(token_embed, temp_ctx_embed, label)
        triplet_loss, coherence = self.rank_loss(token_embed, temp_ctx_embed,
                                                 answerers, accept_usr)

        truth = [[j[0] for j in i] for i in answerers]
        if self.num_shift > 2:  # no temporal loss between 1st and 2nd shifts
            temp_loss = sum([
                self.temp_loss(token_embed, history_embed, truth)
                for history_embed in history_embeds
            ])
        else:
            temp_loss = 0
        loss = triplet_loss + temp_loss * self.weight_temp
        output = {"loss": loss, "coherence": coherence}

        predict = np.argsort(-coherence.detach().cpu().numpy(), axis=1)

        #print("Truth:", accept_usr)
        self.mrr(predict, accept_usr)

        return output
コード例 #14
0
class TransformerQA(Model):
    """
    This class implements a reading comprehension model patterned after the proposed model in
    https://arxiv.org/abs/1810.04805 (Devlin et al), with improvements borrowed from the SQuAD model in the
    transformers project.

    It predicts start tokens and end tokens with a linear layer on top of word piece embeddings.

    Note that the metrics that the model produces are calculated on a per-instance basis only. Since there could
    be more than one instance per question, these metrics are not the official numbers on the SQuAD task. To get
    official numbers, run the script in scripts/transformer_qa_eval.py.

    Parameters
    ----------
    vocab : ``Vocabulary``
    transformer_model_name : ``str``, optional (default=``bert-base-cased``)
        This model chooses the embedder according to this setting. You probably want to make sure this is set to
        the same thing as the reader.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 transformer_model_name: str = "bert-base-cased",
                 **kwargs) -> None:
        super().__init__(vocab, **kwargs)
        self._text_field_embedder = BasicTextFieldEmbedder(
            {"tokens": PretrainedTransformerEmbedder(transformer_model_name)})
        self._linear_layer = nn.Linear(
            self._text_field_embedder.get_output_dim(), 2)

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._per_instance_metrics = SquadEmAndF1()

    def forward(  # type: ignore
        self,
        question_with_context: Dict[str, Dict[str, torch.LongTensor]],
        context_span: torch.IntTensor,
        answer_span: Optional[torch.IntTensor] = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:
        """
        Parameters
        ----------
        question_with_context : Dict[str, torch.LongTensor]
            From a ``TextField``. The model assumes that this text field contains the context followed by the
            question. It further assumes that the tokens have type ids set such that any token that can be part of
            the answer (i.e., tokens from the context) has type id 0, and any other token (including [CLS] and
            [SEP]) has type id 1.
        context_span : ``torch.IntTensor``
            From a ``SpanField``. This marks the span of word pieces in ``question`` from which answers can come.
        answer_span : ``torch.IntTensor``, optional
            From a ``SpanField``. This is the thing we are trying to predict - the span of text that marks the
            answer. If given, we compute a loss that gets included in the output directory.
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question id, and the original texts of context, question, tokenized
            version of both, and a list of possible answers. The length of the ``metadata`` list should be the
            batch size, and each dictionary should have the keys ``id``, ``question``, ``context``,
            ``question_tokens``, ``context_tokens``, and ``answers``.

        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        best_span_scores : torch.FloatTensor
            The score for each of the best spans.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """
        embedded_question = self._text_field_embedder(question_with_context)
        logits = self._linear_layer(embedded_question)
        span_start_logits, span_end_logits = logits.split(1, dim=-1)
        span_start_logits = span_start_logits.squeeze(-1)
        span_end_logits = span_end_logits.squeeze(-1)

        possible_answer_mask = torch.zeros_like(
            get_token_ids_from_text_field_tensors(question_with_context))
        for i, (start, end) in enumerate(context_span):
            possible_answer_mask[i, start:end + 1] = 1

        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       possible_answer_mask,
                                                       -1e32)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     possible_answer_mask,
                                                     -1e32)
        span_start_probs = torch.nn.functional.softmax(span_start_logits,
                                                       dim=-1)
        span_end_probs = torch.nn.functional.softmax(span_end_logits, dim=-1)
        best_spans = get_best_span(span_start_logits, span_end_logits)
        best_span_scores = torch.gather(
            span_start_logits, 1,
            best_spans[:, 0].unsqueeze(1)) + torch.gather(
                span_end_logits, 1, best_spans[:, 1].unsqueeze(1))
        best_span_scores = best_span_scores.squeeze(1)

        output_dict = {
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            "best_span": best_spans,
            "best_span_scores": best_span_scores,
        }

        # Compute the loss for training.
        if answer_span is not None:
            span_start = answer_span[:, 0]
            span_end = answer_span[:, 1]
            span_mask = span_start != -1
            self._span_accuracy(best_spans, answer_span,
                                span_mask.unsqueeze(-1).expand_as(best_spans))

            start_loss = cross_entropy(span_start_logits,
                                       span_start,
                                       ignore_index=-1)
            if torch.any(start_loss > 1e9):
                logger.critical("Start loss too high (%r)", start_loss)
                logger.critical("span_start_logits: %r", span_start_logits)
                logger.critical("span_start: %r", span_start)
                assert False

            end_loss = cross_entropy(span_end_logits,
                                     span_end,
                                     ignore_index=-1)
            if torch.any(end_loss > 1e9):
                logger.critical("End loss too high (%r)", end_loss)
                logger.critical("span_end_logits: %r", span_end_logits)
                logger.critical("span_end: %r", span_end)
                assert False

            loss = (start_loss + end_loss) / 2

            self._span_start_accuracy(span_start_logits, span_start, span_mask)
            self._span_end_accuracy(span_end_logits, span_end, span_mask)

            output_dict["loss"] = loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            best_spans = best_spans.detach().cpu().numpy()

            output_dict["best_span_str"] = []
            context_tokens = []
            for metadata_entry, best_span in zip(metadata, best_spans):
                context_tokens_for_question = metadata_entry["context_tokens"]
                context_tokens.append(context_tokens_for_question)

                best_span -= 1 + len(metadata_entry["question_tokens"]) + 2
                assert np.all(best_span >= 0)

                predicted_start, predicted_end = tuple(best_span)

                while (predicted_start >= 0
                       and context_tokens_for_question[predicted_start].idx is
                       None):
                    predicted_start -= 1
                if predicted_start < 0:
                    logger.warning(
                        f"Could not map the token '{context_tokens_for_question[best_span[0]].text}' at index "
                        f"'{best_span[0]}' to an offset in the original text.")
                    character_start = 0
                else:
                    character_start = context_tokens_for_question[
                        predicted_start].idx

                while (predicted_end < len(context_tokens_for_question) and
                       context_tokens_for_question[predicted_end].idx is None):
                    predicted_end += 1
                if predicted_end >= len(context_tokens_for_question):
                    logger.warning(
                        f"Could not map the token '{context_tokens_for_question[best_span[1]].text}' at index "
                        f"'{best_span[1]}' to an offset in the original text.")
                    character_end = len(metadata_entry["context"])
                else:
                    end_token = context_tokens_for_question[predicted_end]
                    character_end = end_token.idx + len(
                        sanitize_wordpiece(end_token.text))

                best_span_string = metadata_entry["context"][
                    character_start:character_end]
                output_dict["best_span_str"].append(best_span_string)

                answers = metadata_entry.get("answers")
                if len(answers) > 0:
                    self._per_instance_metrics(best_span_string, answers)
            output_dict["context_tokens"] = context_tokens
        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        exact_match, f1_score = self._per_instance_metrics.get_metric(reset)
        return {
            "start_acc": self._span_start_accuracy.get_metric(reset),
            "end_acc": self._span_end_accuracy.get_metric(reset),
            "span_acc": self._span_accuracy.get_metric(reset),
            "per_instance_em": exact_match,
            "per_instance_f1": f1_score,
        }
コード例 #15
0
            This feedforward network computes the output logits.
        dropout : ``float``, optional (default=0.5)
            Dropout percentage to use.
        initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
            Used to initialize the model parameters.
        regularizer : ``RegularizerApplicator``, optional (default=``None``)
            If provided, will be used to calculate the regularization penalty during training.
    """

    lstm = PytorchSeq2SeqWrapper(
        torch.nn.LSTM(EMBEDDING_DIM, HIDDEN_DIM, batch_first=True))
    inference = PytorchSeq2SeqWrapper(
        torch.nn.LSTM(EMBEDDING_DIM, HIDDEN_DIM, batch_first=True))
    # esim = PytorchSeq2SeqWrapper(torch.nn.ESIM(EMBEDDING_DIM, HIDDEN_DIM, batch_first=True))

    encoder_dim = word_embeddings.get_output_dim()

    projection_feedforward = FeedForward(encoder_dim * 4, 1,
                                         inference.get_input_dim(),
                                         Activation.by_name("elu")())

    # (batch_size, model_dim * 2 * 4)
    output_feedforward = FeedForward(lstm.get_output_dim() * 4, 1, 2,
                                     Activation.by_name("elu")())

    output_logit = torch.nn.Linear(in_features=2, out_features=2)

    simfunc = BilinearSimilarity(encoder_dim, encoder_dim)

    model = ESIM(vocab=vocab,
                 text_field_embedder=word_embeddings,
コード例 #16
0
    
    def forward(self,tokens,id,label):
        mask = get_text_field_mask(tokens)
        embeddings = self.word_embeddings(tokens)
        state = self.encoder(embeddings,mask)
        class_logits = self.projection(state)       
        out = {'class_logits':class_logits}
        out['loss'] = self.loss(class_logits,label)
        return out

from allennlp.modules.token_embedders import Embedding
from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder

token_embedding = Embedding(num_embeddings=config.max_vocab_size+2,embedding_dim=300,padding_index=0)
word_embeddings = BasicTextFieldEmbedder({'tokens':token_embedding})
encoder = PytorchSeq2VecWrapper(nn.LSTM(word_embeddings.get_output_dim(),config.batch_size,bidirectional=True,batch_first = True))

model = BaslinModel(word_embeddings,encoder)

if USE_CUDA:model.cuda()
else:
    model()

import allennlp.nn.util as nn_util

batch = nn_util.move_to_device(batch,0 if USE_CUDA else -1)

tokens = batch['tokens']
label = batch['label']

mask = get_text_field_mask(tokens)
コード例 #17
0
class TransformerQA(Model):
    """
    Registered as `"transformer_qa"`, this class implements a reading comprehension model patterned
    after the proposed model in [Devlin et al]([email protected]:huggingface/transformers.git),
    with improvements borrowed from the SQuAD model in the transformers project.

    It predicts start tokens and end tokens with a linear layer on top of word piece embeddings.

    If you want to use this model on SQuAD datasets, you can use it with the
    [`TransformerSquadReader`](../../dataset_readers/transformer_squad#transformersquadreader)
    dataset reader, registered as `"transformer_squad"`.

    Note that the metrics that the model produces are calculated on a per-instance basis only. Since there could
    be more than one instance per question, these metrics are not the official numbers on either SQuAD task.

    To get official numbers for SQuAD v1.1, for example, you can run

    ```
    python -m allennlp_models.rc.tools.transformer_qa_eval
    ```

    # Parameters

    vocab : `Vocabulary`

    transformer_model_name : `str`, optional (default=`'bert-base-cased'`)
        This model chooses the embedder according to this setting. You probably want to make sure this is set to
        the same thing as the reader.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 transformer_model_name: str = "bert-base-cased",
                 **kwargs) -> None:
        super().__init__(vocab, **kwargs)
        self._text_field_embedder = BasicTextFieldEmbedder(
            {"tokens": PretrainedTransformerEmbedder(transformer_model_name)})
        self._linear_layer = nn.Linear(
            self._text_field_embedder.get_output_dim(), 2)

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._per_instance_metrics = SquadEmAndF1()

    def forward(  # type: ignore
        self,
        question_with_context: Dict[str, Dict[str, torch.LongTensor]],
        context_span: torch.IntTensor,
        cls_index: torch.LongTensor = None,
        answer_span: torch.IntTensor = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:
        """
        # Parameters

        question_with_context : `Dict[str, torch.LongTensor]`
            From a `TextField`. The model assumes that this text field contains the context followed by the
            question. It further assumes that the tokens have type ids set such that any token that can be part of
            the answer (i.e., tokens from the context) has type id 0, and any other token (including
            `[CLS]` and `[SEP]`) has type id 1.

        context_span : `torch.IntTensor`
            From a `SpanField`. This marks the span of word pieces in `question` from which answers can come.

        cls_index : `torch.LongTensor`, optional
            A tensor of shape `(batch_size,)` that provides the index of the `[CLS]` token
            in the `question_with_context` for each instance.

            This is needed because the `[CLS]` token is used to indicate that the question
            is impossible.

            If this is `None`, it's assumed that the `[CLS]` token is at index 0 for each instance
            in the batch.

        answer_span : `torch.IntTensor`, optional
            From a `SpanField`. This is the thing we are trying to predict - the span of text that marks the
            answer. If given, we compute a loss that gets included in the output directory.

        metadata : `List[Dict[str, Any]]`, optional
            If present, this should contain the question id, and the original texts of context, question, tokenized
            version of both, and a list of possible answers. The length of the `metadata` list should be the
            batch size, and each dictionary should have the keys `id`, `question`, `context`,
            `question_tokens`, `context_tokens`, and `answers`.

        # Returns

        `Dict[str, torch.Tensor]` :
            An output dictionary with the following fields:

            - span_start_logits (`torch.FloatTensor`) :
              A tensor of shape `(batch_size, passage_length)` representing unnormalized log
              probabilities of the span start position.
            - span_end_logits (`torch.FloatTensor`) :
              A tensor of shape `(batch_size, passage_length)` representing unnormalized log
              probabilities of the span end position (inclusive).
            - best_span_scores (`torch.FloatTensor`) :
              The score for each of the best spans.
            - loss (`torch.FloatTensor`, optional) :
              A scalar loss to be optimised, evaluated against `answer_span`.
            - best_span (`torch.IntTensor`, optional) :
              Provided when not in train mode and sufficient metadata given for the instance.
              The result of a constrained inference over `span_start_logits` and
              `span_end_logits` to find the most probable span.  Shape is `(batch_size, 2)`
              and each offset is a token index, unless the best span for an instance
              was predicted to be the `[CLS]` token, in which case the span will be (-1, -1).
            - best_span_str (`List[str]`, optional) :
              Provided when not in train mode and sufficient metadata given for the instance.
              This is the string from the original passage that the model thinks is the best answer
              to the question.

        """
        embedded_question = self._text_field_embedder(question_with_context)
        # shape: (batch_size, sequence_length, 2)
        logits = self._linear_layer(embedded_question)
        # shape: (batch_size, sequence_length, 1)
        span_start_logits, span_end_logits = logits.split(1, dim=-1)
        # shape: (batch_size, sequence_length)
        span_start_logits = span_start_logits.squeeze(-1)
        # shape: (batch_size, sequence_length)
        span_end_logits = span_end_logits.squeeze(-1)

        # Create a mask for `question_with_context` to mask out tokens that are not part
        # of the context.
        # shape: (batch_size, sequence_length)
        possible_answer_mask = torch.zeros_like(
            get_token_ids_from_text_field_tensors(question_with_context),
            dtype=torch.bool)
        for i, (start, end) in enumerate(context_span):
            possible_answer_mask[i, start:end + 1] = True
            # Also unmask the [CLS] token since that token is used to indicate that
            # the question is impossible.
            possible_answer_mask[
                i, 0 if cls_index is None else cls_index[i]] = True

        # Replace the masked values with a very negative constant since we're in log-space.
        # shape: (batch_size, sequence_length)
        span_start_logits = replace_masked_values_with_big_negative_number(
            span_start_logits, possible_answer_mask)
        # shape: (batch_size, sequence_length)
        span_end_logits = replace_masked_values_with_big_negative_number(
            span_end_logits, possible_answer_mask)

        # Now calculate the best span.
        # shape: (batch_size, 2)
        best_spans = get_best_span(span_start_logits, span_end_logits)

        # Sum the span start score with the span end score to get an overall score for the span.
        # shape: (batch_size,)
        best_span_scores = torch.gather(
            span_start_logits, 1,
            best_spans[:, 0].unsqueeze(1)) + torch.gather(
                span_end_logits, 1, best_spans[:, 1].unsqueeze(1))
        best_span_scores = best_span_scores.squeeze(1)

        output_dict = {
            "span_start_logits": span_start_logits,
            "span_end_logits": span_end_logits,
            "best_span_scores": best_span_scores,
        }

        # Compute the loss.
        if answer_span is not None:
            output_dict["loss"] = self._evaluate_span(best_spans,
                                                      span_start_logits,
                                                      span_end_logits,
                                                      answer_span)

        # Gather the string of the best span and compute the EM and F1 against the gold span,
        # if given.
        if not self.training and metadata is not None:
            (
                output_dict["best_span_str"],
                output_dict["best_span"],
            ) = self._collect_best_span_strings(best_spans, context_span,
                                                metadata, cls_index)

        return output_dict

    def _evaluate_span(
        self,
        best_spans: torch.Tensor,
        span_start_logits: torch.Tensor,
        span_end_logits: torch.Tensor,
        answer_span: torch.Tensor,
    ) -> torch.Tensor:
        """
        Calculate the loss against the `answer_span` and also update the span metrics.
        """
        span_start = answer_span[:, 0]
        span_end = answer_span[:, 1]
        self._span_accuracy(best_spans, answer_span)

        start_loss = cross_entropy(span_start_logits,
                                   span_start,
                                   ignore_index=-1)
        big_constant = min(torch.finfo(start_loss.dtype).max, 1e9)
        assert not torch.any(start_loss > big_constant), "Start loss too high"

        end_loss = cross_entropy(span_end_logits, span_end, ignore_index=-1)
        assert not torch.any(end_loss > big_constant), "End loss too high"

        self._span_start_accuracy(span_start_logits, span_start)
        self._span_end_accuracy(span_end_logits, span_end)

        return (start_loss + end_loss) / 2

    def _collect_best_span_strings(
        self,
        best_spans: torch.Tensor,
        context_span: torch.IntTensor,
        metadata: List[Dict[str, Any]],
        cls_index: Optional[torch.LongTensor],
    ) -> Tuple[List[str], torch.Tensor]:
        """
        Collect the string of the best predicted span from the context metadata and
        update `self._per_instance_metrics`, which in the case of SQuAD v1.1 / v2.0
        includes the EM and F1 score.

        This returns a `Tuple[List[str], torch.Tensor]`, where the `List[str]` is the
        predicted answer for each instance in the batch, and the tensor is just the input
        tensor `best_spans` after adjustments so that each answer span corresponds to the
        context tokens only, and not the question tokens. Spans that correspond to the
        `[CLS]` token, i.e. the question was predicted to be impossible, will be set
        to `(-1, -1)`.
        """
        _best_spans = best_spans.detach().cpu().numpy()

        best_span_strings = []
        for (metadata_entry, best_span, cspan,
             cls_ind) in zip(metadata, _best_spans, context_span, cls_index
                             or (0 for _ in range(len(metadata)))):
            context_tokens_for_question = metadata_entry["context_tokens"]

            if best_span[0] == cls_ind:
                # Predicting [CLS] is interpreted as predicting the question as unanswerable.
                best_span_string = ""
                # NOTE: even though we've "detached" 'best_spans' above, this still
                # modifies the original tensor in-place.
                best_span[0], best_span[1] = -1, -1
            else:
                best_span -= int(cspan[0])
                assert np.all(best_span >= 0)

                predicted_start, predicted_end = tuple(best_span)

                while (predicted_start >= 0
                       and context_tokens_for_question[predicted_start].idx is
                       None):
                    predicted_start -= 1
                if predicted_start < 0:
                    logger.warning(
                        f"Could not map the token '{context_tokens_for_question[best_span[0]].text}' at index "
                        f"'{best_span[0]}' to an offset in the original text.")
                    character_start = 0
                else:
                    character_start = context_tokens_for_question[
                        predicted_start].idx

                while (predicted_end < len(context_tokens_for_question) and
                       context_tokens_for_question[predicted_end].idx is None):
                    predicted_end += 1
                if predicted_end >= len(context_tokens_for_question):
                    logger.warning(
                        f"Could not map the token '{context_tokens_for_question[best_span[1]].text}' at index "
                        f"'{best_span[1]}' to an offset in the original text.")
                    character_end = len(metadata_entry["context"])
                else:
                    end_token = context_tokens_for_question[predicted_end]
                    character_end = end_token.idx + len(
                        sanitize_wordpiece(end_token.text))

                best_span_string = metadata_entry["context"][
                    character_start:character_end]

            best_span_strings.append(best_span_string)

            answers = metadata_entry.get("answers")
            if answers:
                self._per_instance_metrics(best_span_string, answers)

        return best_span_strings, best_spans

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        output = {
            "start_acc": self._span_start_accuracy.get_metric(reset),
            "end_acc": self._span_end_accuracy.get_metric(reset),
            "span_acc": self._span_accuracy.get_metric(reset),
        }
        if not self.training:
            exact_match, f1_score = self._per_instance_metrics.get_metric(
                reset)
            output["per_instance_em"] = exact_match
            output["per_instance_f1"] = f1_score
        return output

    default_predictor = "transformer_qa"
コード例 #18
0
def multitask_learning():
    # load datasetreader 
    # Save logging to a local file
    # Multitasking
    log.getLogger().addHandler(log.FileHandler(directory+"/log.log"))

    lr = 0.00001
    batch_size = 2
    epochs = 10 
    max_seq_len = 512
    max_span_width = 30

    #import pdb
    #pdb.set_trace()    

    #token_indexer = BertIndexer(pretrained_model="bert-base-uncased", max_pieces=max_seq_len, do_lowercase=True,)
    #token_indexer = PretrainedBertIndexer("bert-base-cased", do_lowercase=False)
    from allennlp.data.token_indexers.elmo_indexer import ELMoTokenCharactersIndexer
    # the token indexer is responsible for mapping tokens to integers
    token_indexer = ELMoTokenCharactersIndexer()
    
    def tokenizer(x: str):
        return [w.text for w in SpacyWordSplitter(language='en_core_web_sm', pos_tags=False).split_words(x)[:max_seq_len]]


    #conll_reader = ConllCorefBertReader(max_span_width = max_span_width, token_indexers = {"tokens": token_indexer}) 
    conll_reader = ConllCorefReader(max_span_width = max_span_width, token_indexers = {"tokens": token_indexer})
    swag_reader = SWAGDatasetReader(tokenizer=tokenizer, token_indexers = token_indexer)
    EMBEDDING_DIM = 1024
    HIDDEN_DIM = 200
    conll_datasets, swag_datasets = load_datasets(conll_reader, swag_reader, directory)
    conll_vocab = Vocabulary()
    conll_iterator = BasicIterator(batch_size=batch_size)
    conll_iterator.index_with(conll_vocab)

    swag_vocab = Vocabulary()
    swag_iterator = BasicIterator(batch_size=batch_size)
    swag_iterator.index_with(swag_vocab)

    from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder
    from allennlp.modules.token_embedders import ElmoTokenEmbedder

    #bert_embedder = PretrainedBertEmbedder(pretrained_model="bert-base-cased",top_layer_only=True, requires_grad=True)

    options_file = 'https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x1024_128_2048cnn_1xhighway/elmo_2x1024_128_2048cnn_1xhighway_options.json'
    weight_file = 'https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x1024_128_2048cnn_1xhighway/elmo_2x1024_128_2048cnn_1xhighway_weights.hdf5'
 
    elmo_embedder = ElmoTokenEmbedder(options_file, weight_file)
    word_embedding = BasicTextFieldEmbedder({"tokens": elmo_embedder})#, allow_unmatched_keys=True)

    #word_embedding = BasicTextFieldEmbedder({"tokens": bert_embedder}, allow_unmatched_keys=True)
    #BERT_DIM = word_embedding.get_output_dim()
    ELMO_DIM = word_embedding.get_output_dim()

    seq2seq = PytorchSeq2SeqWrapper(torch.nn.LSTM(ELMO_DIM, HIDDEN_DIM, batch_first=True, bidirectional=True))
    seq2vec = PytorchSeq2VecWrapper(torch.nn.LSTM(ELMO_DIM, HIDDEN_DIM, batch_first=True, bidirectional=True))
    mention_feedforward = FeedForward(input_dim = 2336, num_layers = 2, hidden_dims = 150, activations = torch.nn.ReLU())
    antecedent_feedforward = FeedForward(input_dim = 7776, num_layers = 2, hidden_dims = 150, activations = torch.nn.ReLU())
    model1 = CoreferenceResolver(vocab=conll_vocab, text_field_embedder=word_embedding,context_layer= seq2seq, mention_feedforward=mention_feedforward,antecedent_feedforward=antecedent_feedforward , feature_size=768,max_span_width=max_span_width,spans_per_word=0.4,max_antecedents=250,lexical_dropout= 0.2)

    model2 = SWAGExampleModel(vocab=swag_vocab, text_field_embedder=word_embedding, phrase_encoder=seq2vec)
    optimizer1 = optim.Adam(model1.parameters(), lr=lr)
    optimizer2 = optim.Adam(model2.parameters(), lr=lr)

    swag_train_iterator = swag_iterator(swag_datasets[0], num_epochs=1, shuffle=True)
    conll_train_iterator = conll_iterator(conll_datasets[0], num_epochs=1, shuffle=True)
    swag_val_iterator = swag_iterator(swag_datasets[1], num_epochs=1, shuffle=True)
    conll_val_iterator:q = conll_iterator(conll_datasets[1], num_epochs=1, shuffle=True)
    task_infos = {"swag": {"model": model2, "optimizer": optimizer2, "loss": 0.0, "iterator": swag_iterator, "train_data": swag_datasets[0], "val_data": swag_datasets[1], "num_train": len(swag_datasets[0]), "num_val": len(swag_datasets[1]), "lr": lr, "score": {"accuracy":0.0}}, \
                    "conll": {"model": model1, "iterator": conll_iterator, "loss": 0.0, "val_data": conll_datasets[1], "train_data": conll_datasets[0], "optimizer": optimizer1, "num_train": len(conll_datasets[0]), "num_val": len(conll_datasets[1]),"lr": lr, "score": {"coref_prediction": 0.0, "coref_recall": 0.0, "coref_f1": 0.0,"mention_recall": 0.0}}}
    USE_GPU = 1
    trainer = MultiTaskTrainer(
        task_infos=task_infos, 
        num_epochs=epochs,
        serialization_dir=directory + "saved_models/multitask/"
    ) 
    metrics = trainer.train()
コード例 #19
0
class TransformerWorldTree(Model):
    """

    Parameters
    ----------
    vocab : ``Vocabulary``
    transformer_model : ``str``, optional (default=``"roberta-large"``)
        This model chooses the embedder according to this setting. You probably want to make sure this matches the
        setting in the reader.
    """

    def __init__(
        self,
        vocab: Vocabulary,
        transformer_model: str = "roberta-large",
        override_weights_file: Optional[str] = None,
        override_weights_strip_prefix: Optional[str] = None,
        **kwargs
    ) -> None:
        super().__init__(vocab, **kwargs)

        self._text_field_embedder = PretrainedTransformerEmbedder(
            transformer_model,
            override_weights_file=override_weights_file,
            override_weights_strip_prefix=override_weights_strip_prefix,
        )
        self._text_field_embedder = BasicTextFieldEmbedder(
            {"tokens": self._text_field_embedder})
        self._pooler = BertPooler(
            transformer_model,
            override_weights_file=override_weights_file,
            override_weights_strip_prefix=override_weights_strip_prefix,
            dropout=0.1,
        )

        self._linear_layer = torch.nn.Linear(
            self._text_field_embedder.get_output_dim(), 1)
        self._linear_layer.weight.data.normal_(mean=0.0, std=0.02)
        self._linear_layer.bias.data.zero_()

        self._loss = torch.nn.CrossEntropyLoss()
        self._accuracy = CategoricalAccuracy()

    def forward(
        self,  # type: ignore
        qc_pairs: TextFieldTensors,
        answer_idx: Optional[torch.IntTensor] = None,
        metadata: Optional[List[str]] = None,
    ) -> Dict[str, torch.Tensor]:
        """
        Parameters
        ----------
        qc_pairs : ``Dict[str, torch.LongTensor]``
            From a ``ListField[TextField]``. Contains a list of question-choice pairs to evaluate for every instance.
        answer_idx : ``Optional[torch.IntTensor]``
            From an ``IndexField``. Contains the index of the correct answer for every instance.
        metadata : `Optional[Dict[str, Any]]`
            The meta information for the questions, like question_id, original_text, and so on.
            
        Returns
        -------
        An output dictionary consisting of:
        loss : ``torch.FloatTensor``, optional
            A scalar loss to be optimised. This is only returned when `answer_idx` is not `None`.
        logits : ``torch.FloatTensor``
            The logits for every possible answer choice.
        prediction : ``List[int]``
            The index of the highest scoring choice for every instance in the batch.
        """

        # Bert embedding
        # Shape: (batch_size, num_choices, seq_length, embedding_dim)
        embedded_pairs = self._text_field_embedder(
            qc_pairs, num_wrapping_dims=1)
        batch_size, num_choices, seq_length, embedding_dim = embedded_pairs.size()

        # Flatten the choices
        # Shpae: (batch_size*num_choices, seq_length, embedding_dim)
        flattened = embedded_pairs.view(
            batch_size * num_choices,
            seq_length,
            embedding_dim,
        )

        # Get the embedding of the [CLS] for classification
        # Shpae: (batch_size*num_choices, embedding_dim)
        pooled = self._pooler(flattened)

        # Pass through a linear layer to predict the logits
        # Shpae: (batch_size*num_choices, 1)
        logits = self._linear_layer(pooled)

        # Restore the shapes
        # Shape: (batch_size, num_choices)
        logits = logits.view(
            batch_size, num_choices
        )
        prediction = logits.argmax(-1)

        # If answer_idx is passed, calculate the loss
        if answer_idx is not None:
            answer_idx = answer_idx.squeeze(1)
            loss = self._loss(
                logits, answer_idx)
            self._accuracy(logits, answer_idx)
        else:
            loss = None

        return {
            "logits": logits,
            "prediction": prediction,
            "loss": loss,
        }

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {
            "acc": self._accuracy.get_metric(reset),
        }
コード例 #20
0
vocab = Vocabulary.from_instances(train_dataset + validation_dataset)
# データセットを処理する際に作成した語彙を使うように設定
train_dataset.index_with(vocab)
validation_dataset.index_with(vocab)

# テキストの特徴ベクトルの作成
source_embedding = Embedding(
    num_embeddings=vocab.get_vocab_size(namespace="tokens"), embedding_dim=512)
source_text_embedder = BasicTextFieldEmbedder(
    token_embedders={"tokens": source_embedding})
target_embedding = Embedding(
    num_embeddings=vocab.get_vocab_size(namespace="target_tokens"),
    embedding_dim=512)

# Sequence-to-Sequence Model (LSTM, Transformer)
encoder = PytorchTransformer(input_dim=source_text_embedder.get_output_dim(),
                             feedforward_hidden_dim=512,
                             num_layers=4,
                             num_attention_heads=8)
decoder_net = StackedSelfAttentionDecoderNet(
    decoding_dim=target_embedding.get_output_dim(),
    target_embedding_dim=target_embedding.get_output_dim(),
    feedforward_hidden_dim=512,
    num_layers=4,
    num_attention_heads=8)
decoder = AutoRegressiveSeqDecoder(vocab=vocab,
                                   decoder_net=decoder_net,
                                   max_decoding_steps=128,
                                   target_embedder=target_embedding,
                                   beam_size=args.beam_size,
                                   target_namespace='target_tokens')
コード例 #21
0
    param_dict = {
        "pretrained_file":
        "(https://nlp.stanford.edu/data/glove.6B.zip)#glove.6B.300d.txt",
        "embedding_dim": 300
    }
    params = Params(params=param_dict)
    token_embedding = Embedding.from_params(vocab=vocab, params=params)
elif args.embedding_type == 'elmo':
    token_embedding = ElmoTokenEmbedder(args.options_file,
                                        args.weights_file,
                                        requires_grad=args.finetune_embeddings)

word_embeddings = BasicTextFieldEmbedder({"tokens": token_embedding})

if args.encoder_type == 'bag':
    encoder = BagOfEmbeddingsEncoder(word_embeddings.get_output_dim())
elif args.encoder_type == 'lstm':
    encoder = PytorchSeq2VecWrapper(
        torch.nn.LSTM(word_embeddings.get_output_dim(),
                      config.hidden_sz,
                      bidirectional=True,
                      batch_first=True))

num_classes = vocab.get_vocab_size("labels")
decoder_input_dim = encoder.get_output_dim()

if args.decoder_type == 'linear':
    decoder = torch.nn.Linear(decoder_input_dim, num_classes)

model = TextClassifier(word_embeddings, encoder, decoder, vocab)
コード例 #22
0
def train_only_lee():
    # This is WORKING!
    # load datasetreader
    # Save logging to a local file
    # Multitasking
    log.getLogger().addHandler(log.FileHandler(directory + "/log.log"))

    lr = 0.00001
    batch_size = 2
    epochs = 100
    max_seq_len = 512
    max_span_width = 30
    #token_indexer = BertIndexer(pretrained_model="bert-base-uncased", max_pieces=max_seq_len, do_lowercase=True,)
    token_indexer = PretrainedBertIndexer("bert-base-cased",
                                          do_lowercase=False)
    reader = ConllCorefBertReader(max_span_width=max_span_width,
                                  token_indexers={"tokens": token_indexer})

    EMBEDDING_DIM = 1024
    HIDDEN_DIM = 200
    processed_reader_dir = Path(directory + "processed/")

    train_ds = None
    if processed_reader_dir.is_dir():
        print("Loading indexed from checkpoints")
        train_path = Path(directory + "processed/train_d")
        if train_path.exists():
            train_ds = pickle.load(
                open(directory + "processed/conll/train_d", "rb"))
            val_ds = pickle.load(
                open(directory + "processed/conll/val_d", "rb"))
            test_ds = pickle.load(
                open(directory + "processed/conll/test_d", "rb"))
        else:
            print("checkpoints not found")
            train_ds, val_ds, test_ds = (
                reader.read(dataset_folder + fname) for fname in [
                    "train.english.v4_gold_conll", "dev.english.v4_gold_conll",
                    "test.english.v4_gold_conll"
                ])
            pickle.dump(train_ds, open(directory + "processed/train_d", "wb"))
            pickle.dump(val_ds, open(directory + "processed/val_d", "wb"))
            pickle.dump(test_ds, open(directory + "processed/test_d", "wb"))
            print("saved checkpoints")
    # restore checkpoint here

    #vocab = Vocabulary.from_instances(train_ds + val_ds)
    vocab = Vocabulary()
    iterator = BasicIterator(batch_size=batch_size)
    iterator.index_with(vocab)

    val_iterator = BasicIterator(batch_size=batch_size)
    val_iterator.index_with(vocab)
    from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder

    bert_embedder = PretrainedBertEmbedder(
        pretrained_model="bert-base-cased",
        top_layer_only=True,  # conserve memory
        requires_grad=True)
    # here, allow_unmatched_key = True since we dont pass in offsets since
    #we allow for word embedings of the bert-tokenized, wnot necessiarly the
    # original tokens
    # see the documetnation for offsets here for more info:
    # https://github.com/allenai/allennlp/blob/master/allennlp/modules/token_embedders/bert_token_embedder.py
    word_embedding = BasicTextFieldEmbedder({"tokens": bert_embedder},
                                            allow_unmatched_keys=True)
    BERT_DIM = word_embedding.get_output_dim()
    # at each batch, sample from the two, and load th eLSTM
    shared_layer = torch.nn.LSTM(BERT_DIM,
                                 HIDDEN_DIM,
                                 batch_first=True,
                                 bidirectional=True)
    seq2seq = PytorchSeq2SeqWrapper(shared_layer)
    mention_feedforward = FeedForward(input_dim=2336,
                                      num_layers=2,
                                      hidden_dims=150,
                                      activations=torch.nn.ReLU())
    antecedent_feedforward = FeedForward(input_dim=7776,
                                         num_layers=2,
                                         hidden_dims=150,
                                         activations=torch.nn.ReLU())

    model = CoreferenceResolver(vocab=vocab,
                                text_field_embedder=word_embedding,
                                context_layer=seq2seq,
                                mention_feedforward=mention_feedforward,
                                antecedent_feedforward=antecedent_feedforward,
                                feature_size=768,
                                max_span_width=max_span_width,
                                spans_per_word=0.4,
                                max_antecedents=250,
                                lexical_dropout=0.2)
    print(model)
    optimizer = optim.Adam(model.parameters(), lr=lr)

    # and then we can do the shared loss
    #
    # Get
    USE_GPU = 0
    trainer = Trainer(
        model=model,
        optimizer=optimizer,
        iterator=iterator,
        validation_iterator=val_iterator,
        train_dataset=train_ds,
        validation_dataset=val_ds,
        validation_metric="+coref_f1",
        cuda_device=0 if USE_GPU else -1,
        serialization_dir=directory + "saved_models/only_lee",
        num_epochs=epochs,
    )

    metrics = trainer.train()
    # save the model
    with open(directory + "saved_models/current_run_model_state", 'wb') as f:
        torch.save(model.state_dict(), f)
コード例 #23
0
def predict_only_lee():
    # load datasetreader
    # Save logging to a local file
    # Multitasking
    log.getLogger().addHandler(log.FileHandler(directory + "/log.log"))

    batch_size = 1
    epochs = 10
    max_seq_len = 512
    max_span_width = 30
    #token_indexer = BertIndexer(pretrained_model="bert-base-uncased", max_pieces=max_seq_len, do_lowercase=True,)
    token_indexer = PretrainedBertIndexer("bert-base-cased",
                                          do_lowercase=False)
    conll_reader = ConllCorefBertReader(
        max_span_width=max_span_width,
        token_indexers={"tokens": token_indexer})
    EMBEDDING_DIM = 1024
    HIDDEN_DIM = 200
    processed_reader_dir = Path(directory + "processed/")

    train_ds = None
    test_ds = None
    if processed_reader_dir.is_dir():
        print("Loading indexed from checkpoints")
        train_path = Path(directory + "processed/train_d")
        if train_path.exists():
            train_ds = pickle.load(
                open(directory + "processed/conll/train_d", "rb"))
            val_ds = pickle.load(
                open(directory + "processed/conll/val_d", "rb"))
            test_ds = pickle.load(
                open(directory + "processed/conll/test_d", "rb"))
        else:
            print("checkpoints not found")
            train_ds, val_ds, test_ds = (
                reader.read(dataset_folder + fname) for fname in [
                    "train.english.v4_gold_conll", "dev.english.v4_gold_conll",
                    "test.english.v4_gold_conll"
                ])
            pickle.dump(train_ds, open(directory + "processed/train_d", "wb"))
            pickle.dump(val_ds, open(directory + "processed/val_d", "wb"))
            pickle.dump(test_ds, open(directory + "processed/test_d", "wb"))
            print("saved checkpoints")

    vocab = Vocabulary()
    iterator = BasicIterator(batch_size=batch_size)
    iterator.index_with(vocab)

    from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder

    bert_embedder = PretrainedBertEmbedder(pretrained_model="bert-base-cased",
                                           top_layer_only=True,
                                           requires_grad=True)

    word_embedding = BasicTextFieldEmbedder({"tokens": bert_embedder},
                                            allow_unmatched_keys=True)
    BERT_DIM = word_embedding.get_output_dim()

    shared_layer = torch.nn.LSTM(BERT_DIM,
                                 HIDDEN_DIM,
                                 batch_first=True,
                                 bidirectional=True)

    seq2seq = PytorchSeq2SeqWrapper(shared_layer)
    #seq2vec = PytorchSeq2VecWrapper(torch.nn.LSTM(BERT_DIM, HIDDEN_DIM, batch_first=True, bidirectional=True))
    mention_feedforward = FeedForward(input_dim=2336,
                                      num_layers=2,
                                      hidden_dims=150,
                                      activations=torch.nn.ReLU())
    antecedent_feedforward = FeedForward(input_dim=7776,
                                         num_layers=2,
                                         hidden_dims=150,
                                         activations=torch.nn.ReLU())
    model1 = CoreferenceResolver(vocab=vocab,
                                 text_field_embedder=word_embedding,
                                 context_layer=seq2seq,
                                 mention_feedforward=mention_feedforward,
                                 antecedent_feedforward=antecedent_feedforward,
                                 feature_size=768,
                                 max_span_width=max_span_width,
                                 spans_per_word=0.4,
                                 max_antecedents=250,
                                 lexical_dropout=0.2)

    conll_test_iterator = iterator(test_ds, num_epochs=1, shuffle=False)
    USE_GPU = 1

    #serialization_dir=directory + "saved_models/multitask/"

    #TRAINED_MODEL_PATH = directory + "saved_models/multitask/conll/model_state_epoch_9.th"

    TRAINED_MODEL_PATH = directory + "saved_models/current_run_model_state/model_state_epoch_99.th"

    model1.eval()
    model1.load_state_dict(torch.load(TRAINED_MODEL_PATH))
    model1.eval()

    num_batches = len(test_ds)

    for i in range(20):
        batch = next(conll_test_iterator, None)
        output = model1.forward(**batch)

        #let us print out the predictions in the first document of this batch
        pairs = []
        for index, j in enumerate(output['predicted_antecedents'][0]):
            if j is not -1:
                i1 = output['top_spans'][0][index]
                i2 = output['top_spans'][0][output['antecedent_indices'][index]
                                            [j]]
                d0 = output['document'][0]
                pairs.append([d0[i1[0]:i1[1] + 1], d0[i2[0]:i2[1] + 1]])

        #pairs
        #print(pairs)
        metrics = model1.get_metrics()
        print(metrics['coref_f1'])
コード例 #24
0
        output["loss"] = self.loss(class_logits, label)

        return output


USE_CUDA = torch.cuda.is_available()
# token_embedding = Embedding(num_embeddings=config.max_vocab_size + 2,
#                             embedding_dim=300, padding_index=0)
# word_embeddings: TextFieldEmbedder = BasicTextFieldEmbedder({"tokens": token_embedding})

options_file = 'https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x1024_128_2048cnn_1xhighway/elmo_2x1024_128_2048cnn_1xhighway_options.json'
weight_file = 'https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x1024_128_2048cnn_1xhighway/elmo_2x1024_128_2048cnn_1xhighway_weights.hdf5'

elmo_embedder = ElmoTokenEmbedder(options_file, weight_file, requires_grad=True)
word_embeddings = BasicTextFieldEmbedder({"tokens": elmo_embedder})
encoder: Seq2VecEncoder = PytorchSeq2VecWrapper(nn.LSTM(word_embeddings.get_output_dim(),
                                                        config.hidden_sz, bidirectional=True, batch_first=True))
model = BaselineModel(word_embeddings, encoder)
if USE_CUDA:
    model.cuda()


optimizer = optim.Adam(model.parameters(), lr=config.lr)
trainer = Trainer(
    model=model,
    optimizer=optimizer,
    iterator=iterator,
    train_dataset=train_ds,
    cuda_device=0 if USE_CUDA else -1,
    num_epochs=config.epochs
)
コード例 #25
0
class BertCtxRanker(ModelBase):
    def __init__(self, args, num_authors: int, out_sz: int, vocab: Vocabulary):
        super().__init__(vocab)

        # init word embedding
        bert_embedder = PretrainedBertEmbedder(
            pretrained_model="bert-base-uncased",
            top_layer_only=True,  # conserve memory
        )
        self.word_embeddings = BasicTextFieldEmbedder(
            {"tokens": bert_embedder},
            # we'll be ignoring masks so we'll need to set this to True
            allow_unmatched_keys=True)

        self.encoder = BertSentencePooler(
            vocab, self.word_embeddings.get_output_dim())

        self.num_authors = num_authors

        # skills dim
        self.num_sk, self.sk_dim = 20, 768
        self.author_embeddings = nn.Parameter(torch.randn(
            num_authors, self.sk_dim, self.num_sk),
                                              requires_grad=True)  # (m, d, k)

        self.attention = nn.Parameter(torch.randn(
            self.word_embeddings.get_output_dim(), self.sk_dim),
                                      requires_grad=True)
        # nn.Linear(self.word_embeddings.get_output_dim(), self.sk_dim)

        self.tanh = nn.Tanh()
        self.softmax = nn.Softmax(dim=2)
        self.sigmoid = nn.Sigmoid()

        self.projection = nn.Linear(self.encoder.get_output_dim(), out_sz)
        # self.loss = nn.CrossEntropyLoss()

        # loss related
        # self.cohere_loss = CoherenceLoss(self.encoder.get_output_dim(), out_sz)
        self.triplet_loss = TripletLoss(self.encoder.get_output_dim(), out_sz)
        self.rank_loss = MarginRankLoss(self.encoder.get_output_dim(), out_sz)

    def build_author_ctx_embed(self, token_hidden, author_embeds):

        # token_hidden (n, d, l)
        # author_embeds (m, d, k)

        n, _, l = token_hidden.shape
        m = author_embeds.shape[0]

        F_sim = torch.einsum('ndl,de,mek->nmlk',
                             [token_hidden, self.attention, author_embeds])
        F_tanh = self.tanh(F_sim.contiguous().view(
            n * m, l, self.num_sk))  # (n * m, l, k)
        F_tanh = F_tanh.view(n, m, l, self.num_sk)  # (n, m, l, k)
        g_u = torch.mean(F_tanh, 2)  # (n, m, k)
        a_u = self.softmax(g_u)  # (n, m, k)

        author_ctx_embed = torch.einsum('mdk,nmk->nmd',
                                        [author_embeds, a_u])  # (n, m, d)

        return author_ctx_embed

    def forward(self,
                tokens: Dict[str, torch.Tensor],
                id: Any,
                answerers: Any,
                date: Any,
                accept_usr: Any,
                att_l=False) -> torch.Tensor:

        # n -- batch number
        # m -- author number
        # d -- hidden dimension
        # k -- skill number
        # l -- text length
        # p -- pos/neg author number in one batch

        mask = get_text_field_mask(tokens)
        embeddings = self.word_embeddings(tokens)
        token_hidden = self.encoder(embeddings, mask)  # (n, d, l)

        author_ctx_embed = self.build_author_ctx_embed(
            token_hidden, self.author_embeddings)  # (n, m, d)
        token_embed = torch.mean(token_hidden, 2)  # (n, d)

        # coherence = torch.einsum('nd,nmd->nm', [token_embed, author_ctx_embed])  # (n, m)
        # loss, coherence = self.cohere_loss(token_embed, author_ctx_embed, label)
        # loss, coherence = self.triplet_loss(token_embed, author_ctx_embed, label)
        loss, coherence = self.rank_loss(token_embed, author_ctx_embed,
                                         answerers, accept_usr)

        # generate positive loss
        # all_labels = list(range(self.num_authors))
        # loss = 0
        # for i, pos_labels in enumerate(label):
        #
        #     num_pos = len(pos_labels)
        #     if num_pos == 0:
        #         continue
        #
        #     # BR-DEV relation
        #     pos_labels = torch.tensor(pos_labels)
        #     if torch.cuda.is_available(): pos_labels = pos_labels.cuda()
        #     pos_coherence = coherence[i, pos_labels]
        #     pos_loss = torch.sum(-torch.log(self.sigmoid(pos_coherence))) / num_pos
        #
        #     neg_labels = torch.tensor([item for item in all_labels if item not in pos_labels])
        #     num_neg = len(neg_labels)
        #     if torch.cuda.is_available(): neg_labels = neg_labels.cuda()
        #     neg_coherence = coherence[i, neg_labels]
        #     neg_loss = torch.sum(-torch.log(self.sigmoid(-neg_coherence))) / num_neg
        #
        #     loss += (pos_loss + neg_loss)
        #
        #     # DEV-DEV relation
        #     pos_authors = author_ctx_embed[i, pos_labels]  # (pos, d)
        #     neg_authors = author_ctx_embed[i, neg_labels]  # (neg, d)
        #
        #     auth_pos_coherence = torch.einsum('pd,qd->pq', [pos_authors, pos_authors])  # (pos, pos)
        #     auth_neg_coherence = torch.einsum('pd,nd->pn', [pos_authors, neg_authors])  # (pos, neg)
        #
        #     log_sig_auth = -torch.log(self.sigmoid(auth_pos_coherence))
        #     auth_pos_loss = (torch.sum(log_sig_auth) - torch.sum(torch.diagonal(log_sig_auth, 0)))
        #     if num_pos > 1:
        #         auth_pos_loss /= (num_pos * num_pos - num_pos)
        #
        #     auth_neg_loss = torch.sum(-torch.log(self.sigmoid(-auth_neg_coherence))) / (num_pos * num_neg)
        #
        #     # loss += (auth_pos_loss + auth_neg_loss)
        #     loss += (auth_pos_loss)
        #
        #     if torch.isnan(loss):
        #         raise ValueError("nan loss encountered")

        output = {"loss": loss, "coherence": coherence}
        # output = {"class_logits": class_logits}
        # output["loss"] = self.loss(class_logits, label)

        predict = np.argsort(-coherence.detach().cpu().numpy(), axis=1)
        truth = [[j[0] for j in i] for i in answerers]

        self.mrr(predict, accept_usr)
        return output
コード例 #26
0
def run_experiment(use_similarity_targets, embedding_type, rnn_type, hparams):
    log = {}
    log["name"] = "{} {} {} {}".format(
        rnn_type, embedding_type, "similarity_target" if use_similarity_targets else "hard_target", hparams["update_targets"]
    )

    vocab = Vocabulary().from_files(hparams["vocab_path"])
    if embedding_type == "Chord":
        # data reader
        reader = CpmDatasetReader()

        # chord embedder
        token_embedding = Embedding(
            num_embeddings=vocab.get_vocab_size("tokens"),
            embedding_dim=hparams["chord_token_embedding_dim"],
        )
        chord_embedder = BasicTextFieldEmbedder({"tokens": token_embedding})

    elif embedding_type == "Note":
        # data reader
        note_tokenizer = NoteTokenizer()
        note_indexer = TokenCharactersIndexer(
            namespace="notes", min_padding_length=4, character_tokenizer=note_tokenizer
        )
        reader = CpmDatasetReader(
            token_indexers={"tokens": SingleIdTokenIndexer(),
                            "notes": note_indexer}
        )

        # chord embedder
        token_embedding = Embedding(
            num_embeddings=vocab.get_vocab_size("tokens"),
            embedding_dim=hparams["chord_token_embedding_dim"],
        )
        note_token_embedding = Embedding(
            vocab.get_vocab_size("notes"), hparams["note_embedding_dim"]
        )
        note_encoder = CnnEncoder(
            num_filters=hparams["cnn_encoder_num_filters"],
            ngram_filter_sizes=hparams["cnn_encoder_n_gram_filter_sizes"],
            embedding_dim=hparams["note_embedding_dim"],
            output_dim=hparams["note_level_embedding_dim"],
        )
        note_embedding = TokenCharactersEncoder(
            note_token_embedding, note_encoder)
        chord_embedder = BasicTextFieldEmbedder(
            {"tokens": token_embedding, "notes": note_embedding}
        )
    else:
        raise ValueError("Unknown embedding type:", embedding_type)

    # read data
    train_dataset = reader.read(os.path.join(
        hparams["data_path"], "train.txt"))
    val_dataset = reader.read(os.path.join(hparams["data_path"], "val.txt"))
    test_dataset = reader.read(os.path.join(hparams["data_path"], "test.txt"))

    # contextualizer
    contextual_input_dim = chord_embedder.get_output_dim()
    if rnn_type == "RNN":
        contextualizer = PytorchSeq2SeqWrapper(
            torch.nn.RNN(
                contextual_input_dim, hparams["rnn_hidden_dim"], batch_first=True, bidirectional=False
            )
        )
    elif rnn_type == "LSTM":
        contextualizer = PytorchSeq2SeqWrapper(
            torch.nn.LSTM(
                contextual_input_dim, hparams["lstm_hidden_dim"], batch_first=True, bidirectional=False
            )
        )
    elif rnn_type == "GRU":
        contextualizer = PytorchSeq2SeqWrapper(
            torch.nn.GRU(
                contextual_input_dim, hparams["gru_hidden_dim"], batch_first=True, bidirectional=False
            )
        )
    else:
        raise ValueError("Unknown rnn type:", rnn_type)

    if use_similarity_targets:
        vocab_size = vocab.get_vocab_size("tokens")
        similarity_targets = Embedding(
            num_embeddings=vocab_size,
            embedding_dim=vocab_size,
            weight=torch.load(hparams["similarity_target_path"]),
            trainable=False,
        )
    else:
        similarity_targets = None

    iterator = BucketIterator(
        batch_size=hparams["batch_size"], sorting_keys=[
            ("input_tokens", "num_tokens")]
    )
    iterator.index_with(vocab)

    batches_per_epoch = math.ceil(len(train_dataset) / hparams["batch_size"])

    model_hparams = {
        "dropout": None,
        "similarity_targets": similarity_targets,
        "update_targets": hparams["update_targets"],
        "T_initial": hparams["T_initial"],
        "decay_rate": hparams["decay_rate"],
        "batches_per_epoch": batches_per_epoch,
        "fc_hidden_dim": hparams["fc_hidden_dim"]
    }
    # chord progression model
    model = Cpm(
        vocab,
        chord_embedder,
        contextualizer,
        model_hparams
    )

    if torch.cuda.is_available():
        cuda_device = 0
        model = model.cuda(cuda_device)
        print("GPU available.")
    else:
        cuda_device = -1

    optimizer = optim.Adam(model.parameters(), lr=hparams["lr"])

    ts = time.gmtime()
    saved_model_path = os.path.join(
        hparams["saved_model_path"], time.strftime("%Y-%m-%d %H-%M-%S", ts))
    serialization_dir = os.path.join(saved_model_path, "checkpoints")

    trainer = Trainer(
        model=model,
        optimizer=optimizer,
        iterator=iterator,
        train_dataset=train_dataset,
        validation_dataset=val_dataset,
        serialization_dir=serialization_dir,
        patience=hparams["patience"],
        num_epochs=hparams["num_epochs"],
        cuda_device=cuda_device,
    )
    trainer.train()
    saved_model_path = os.path.join(
        saved_model_path, "{}.th".format(log["name"]))
    torch.save(model.state_dict(), saved_model_path)

    predictor = Predictor(model=model, iterator=iterator,
                          cuda_device=cuda_device)
    pred_metrics = predictor.predict(test_dataset)
    log["metrics"] = pred_metrics
    log["saved_mode_path"] = saved_model_path

    return log
コード例 #27
0
def train_only_lee():
    # This is WORKING! 
	# load datasetreader 
    # Save logging to a local file
    # Multitasking
    log.getLogger().addHandler(log.FileHandler(directory+"/log.log"))

    lr = 0.00001
    batch_size = 2
    epochs = 100
    max_seq_len = 512
    max_span_width = 30
    #token_indexer = BertIndexer(pretrained_model="bert-base-uncased", max_pieces=max_seq_len, do_lowercase=True,)
    token_indexer = PretrainedBertIndexer("bert-base-cased", do_lowercase=False)
    reader = ConllCorefBertReader(max_span_width = max_span_width, token_indexers = {"tokens": token_indexer})

    EMBEDDING_DIM = 1024
    HIDDEN_DIM = 200
    processed_reader_dir = Path(directory+"processed/")
    
    train_ds, val_ds, test_ds = load_lee(reader, directory)
    # restore checkpoint here
    from allennlp.modules.token_embedders import ElmoTokenEmbedder
    #vocab = Vocabulary.from_instances(train_ds + val_ds)
    vocab = Vocabulary()
    iterator = BasicIterator(batch_size=batch_size)
    iterator.index_with(vocab)

    val_iterator = BasicIterator(batch_size=batch_size)
    val_iterator.index_with(vocab)
    from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder
    # here, allow_unmatched_key = True since we dont pass in offsets since 
    #we allow for word embedings of the bert-tokenized, wnot necessiarly the 
    # original tokens
    # see the documetnation for offsets here for more info:
    # https://github.com/allenai/allennlp/blob/master/allennlp/modules/token_embedders/bert_token_embedder.py
    options_file = 'https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x1024_128_2048cnn_1xhighway/elmo_2x1024_128_2048cnn_1xhighway_options.json'
    weight_file = 'https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x1024_128_2048cnn_1xhighway/elmo_2x1024_128_2048cnn_1xhighway_weights.hdf5'
 
    elmo_embedder = ElmoTokenEmbedder(options_file, weight_file)
    word_embedding = BasicTextFieldEmbedder({"tokens": elmo_embedder})#, allow_unmatched_keys=True)

    #word_embedding = BasicTextFieldEmbedder({"tokens": bert_embedder}, allow_unmatched_keys=True)
    #BERT_DIM = word_embedding.get_output_dim()
    ELMO_DIM = word_embedding.get_output_dim()
    # at each batch, sample from the two, and load th eLSTM
    shared_layer = torch.nn.LSTM(ELMO_DIM, HIDDEN_DIM, batch_first=True, bidirectional=True)
    seq2seq = PytorchSeq2SeqWrapper(shared_layer)
    mention_feedforward = FeedForward(input_dim =512, num_layers = 2, hidden_dims = 150, activations = torch.nn.ReLU())
    antecedent_feedforward = FeedForward(input_dim =2304, num_layers = 2, hidden_dims = 150, activations = torch.nn.ReLU())

    model = CoreferenceResolver(vocab=vocab, text_field_embedder=word_embedding,context_layer= seq2seq, mention_feedforward=mention_feedforward,antecedent_feedforward=antecedent_feedforward , feature_size=768,max_span_width=max_span_width,spans_per_word=0.4,max_antecedents=250,lexical_dropout= 0.2)
    print(model)
    optimizer = optim.Adam(model.parameters(), lr=lr)

    # and then we can do the shared loss
    # 
    # Get 
    USE_GPU = 1
    trainer = Trainer(
        model=model.cuda(),
        optimizer=optimizer,
        iterator=iterator,
        validation_iterator = val_iterator, 
        train_dataset=train_ds,
        validation_dataset = val_ds, 
        validation_metric = "+coref_f1",
        cuda_device=0 if USE_GPU else -1,
        serialization_dir= directory + "saved_models/only_lee",
        num_epochs=epochs,
    )    

    metrics = trainer.train()
    # save the model
    with open(directory + "saved_models/current_run_model_state", 'wb') as f:
        torch.save(model.state_dict(), f)
コード例 #28
0
def train_and_evaluate(hparams):
    # vocabulary
    vocab = Vocabulary().from_files(hparams["vocab_path"])
    vocab_size = vocab.get_vocab_size("tokens")

    # chord embedding
    token_embedding = Embedding(
        num_embeddings=vocab_size,
        embedding_dim=hparams["chord_token_embedding_dim"])
    chord_embedder = BasicTextFieldEmbedder({"tokens": token_embedding})

    # data readers
    reader = CpmDatasetReader()
    train_dataset = reader.read(os.path.join(hparams["data_path"],
                                             "train.txt"))
    val_dataset = reader.read(os.path.join(hparams["data_path"], "val.txt"))
    test_dataset = reader.read(os.path.join(hparams["data_path"], "test.txt"))

    # contextualizer
    input_dim = chord_embedder.get_output_dim()
    hidden_dim = hparams["rnn_hidden_dim"]
    rnn_type = hparams["rnn_type"]
    contextualizer = get_contextualizer(rnn_type, input_dim, hidden_dim)

    # similarity matrix
    similarity_matrix_path = hparams["similarity_matrix_path"]
    if similarity_matrix_path is not None:
        similarity_matrix = Embedding(
            num_embeddings=vocab_size,
            embedding_dim=vocab_size,
            weight=torch.load(similarity_matrix_path),
            trainable=False,
        )
    else:
        similarity_matrix = None

    # training iterator
    batch_size = hparams["batch_size"]
    iterator = BucketIterator(batch_size=batch_size,
                              sorting_keys=[("input_tokens", "num_tokens")])
    iterator.index_with(vocab)
    batches_per_epoch = math.ceil(len(train_dataset) / batch_size)

    # model parameters
    model_hparams = {
        "dropout": None,
        "training_mode": hparams["training_mode"],
        "similarity_matrix": similarity_matrix,
        "T_initial": hparams["T_initial"],
        "decay_rate": hparams["decay_rate"],
        "batches_per_epoch": batches_per_epoch,
        "fc_hidden_dim": hparams["fc_hidden_dim"],
    }

    # chord progression model
    model = Cpm(vocab, chord_embedder, contextualizer, model_hparams)

    # check gpu available
    if torch.cuda.is_available():
        cuda_device = GPU_NO
        model = model.cuda(cuda_device)
    else:
        cuda_device = -1

    # optimizer
    optimizer = optim.Adam(model.parameters(), lr=hparams["lr"])

    # trainer
    ts = time.gmtime()
    saved_model_path = os.path.join(hparams["saved_model_path"],
                                    time.strftime("%Y-%m-%d %H-%M-%S", ts))
    serialization_dir = os.path.join(saved_model_path, "checkpoints")
    trainer = Trainer(
        model=model,
        optimizer=optimizer,
        iterator=iterator,
        train_dataset=train_dataset,
        validation_dataset=val_dataset,
        serialization_dir=serialization_dir,
        patience=hparams["patience"],
        num_epochs=hparams["num_epochs"],
        cuda_device=cuda_device,
    )
    trainer.train()
    torch.save(model.state_dict(), os.path.join(saved_model_path, "best.th"))
    with open(os.path.join(saved_model_path, "hparams.json"), "w") as f:
        json.dump(hparams, f, indent=4)

    predictor = Predictor(model=model,
                          iterator=iterator,
                          cuda_device=cuda_device)
    train_metrics = predictor.predict(train_dataset)
    test_metrics = predictor.predict(test_dataset)
    log(hparams, train_metrics, test_metrics, saved_model_path)
コード例 #29
0
)

vocab = Vocabulary.from_files(
    "/tmp/vocabulary")  #preloaded vocab, required to do lazy computations

token_embedder = Embedding.from_params(
    vocab=vocab,
    params=Params({
        'pretrained_file':
        '/home/dkeren/Documents/Spring2019/CIS520/project/glove.twitter.27B.50d.txt',
        'embedding_dim': 50
    }))
word_embeddings = BasicTextFieldEmbedder({"tokens": token_embedder})

lstm = PytorchSeq2VecWrapper(
    nn.LSTM(word_embeddings.get_output_dim(),
            config.hidden_sz,
            bidirectional=True,
            batch_first=True))

save_file = "model_v12.th"  ## models saved by lstm training
model2 = LSTM_Model(word_embeddings, lstm, 2)
with open(save_file, 'rb') as f:
    model2.load_state_dict(torch.load(f))

# iterate over the dataset without changing its order
seq_iterator = BasicIterator(config.batch_size)
seq_iterator.index_with(vocab)
predictor = Predictor(model2, seq_iterator)
prob, labels = predictor.predict(test_dataset)
test_preds = 1 * (prob > .525)  #optimal threshold
コード例 #30
0
class MultiHeadCtxModel(ModelBase):
    def __init__(self, num_authors: int, out_sz: int, vocab: Vocabulary):
        super().__init__(vocab)

        # init word embedding
        bert_embedder = PretrainedBertEmbedder(
            pretrained_model="bert-base-uncased",
            top_layer_only=True,  # conserve memory
        )
        self.word_embeddings = BasicTextFieldEmbedder(
            {"tokens": bert_embedder},
            # we'll be ignoring masks so we'll need to set this to True
            allow_unmatched_keys=True)

        self.encoder = BertSentencePooler(
            vocab, self.word_embeddings.get_output_dim())

        self.num_authors = num_authors

        # skills dim
        self.num_sk, self.sk_dim, self.time_dim = 20, 768, 32
        self.author_embeddings = nn.Parameter(torch.randn(
            num_authors, self.num_sk, self.sk_dim),
                                              requires_grad=True)  # (m, k, d)

        self.multihead_att = TempCtxAttention(h=8, d_model=self.sk_dim)

        self.attention = nn.Parameter(torch.randn(
            self.word_embeddings.get_output_dim(), self.sk_dim),
                                      requires_grad=True)
        # nn.Linear(self.word_embeddings.get_output_dim(), self.sk_dim)

        self.cohere_loss = CoherenceLoss(self.encoder.get_output_dim(), out_sz)

    def forward(self,
                tokens: Dict[str, torch.Tensor],
                id: Any,
                label: Any,
                date: Any,
                att_l=False) -> torch.Tensor:

        # n -- batch number
        # m -- author number
        # d -- hidden dimension
        # k -- skill number
        # l -- text length
        # p -- pos/neg author number in one batch

        mask = get_text_field_mask(tokens)
        embeddings = self.word_embeddings(tokens)
        token_hidden = self.encoder(embeddings,
                                    mask).transpose(-1, -2)  # (n, d, l)

        token_embed = torch.mean(token_hidden, 1).squeeze()  # (n, d)
        # token_embed = token_hidden[:, :, -1]
        if att_l:
            author_ctx_embed = self.multihead_att(token_hidden,
                                                  self.author_embeddings,
                                                  self.author_embeddings)
        else:
            author_ctx_embed = self.multihead_att(token_embed,
                                                  self.author_embeddings,
                                                  self.author_embeddings)

        # generate loss
        loss, coherence = self.cohere_loss(token_embed, author_ctx_embed,
                                           label)
        output = {"loss": loss, "coherence": coherence}

        predict = np.argsort(-coherence.detach().cpu().numpy(), axis=1)
        self.rank_recall(predict, label)

        return output