def test_can_construct_from_params(self):
     params = Params({"embedding_dim": 5})
     encoder = BagOfEmbeddingsEncoder.from_params(params)
     assert encoder.get_input_dim() == 5
     assert encoder.get_output_dim() == 5
     params = Params({"embedding_dim": 12, "averaged": True})
     encoder = BagOfEmbeddingsEncoder.from_params(params)
     assert encoder.get_input_dim() == 12
     assert encoder.get_output_dim() == 12
Beispiel #2
0
 def __init__(self,
              word_embedder: BasicTextFieldEmbedder,
              word_embedding_dropout: float = 0.05):
     super(ChiveEntityEncoder, self).__init__()
     self.sec2vec_for_title = BagOfEmbeddingsEncoder(embedding_dim=300,
                                                     averaged=True)
     self.sec2vec_for_ent_desc = BagOfEmbeddingsEncoder(embedding_dim=300,
                                                        averaged=True)
     # LstmSeq2VecEncoder(input_size=300, hidden_size=300, num_layers=1, bidirectional=True)
     self.linear = nn.Linear(600, 300)
     self.linear2 = nn.Linear(300, 300)
     self.word_embedder = word_embedder
     self.word_embedding_dropout = nn.Dropout(word_embedding_dropout)
 def test_get_dimension_is_correct(self):
     encoder = BagOfEmbeddingsEncoder(embedding_dim=5)
     assert encoder.get_input_dim() == 5
     assert encoder.get_output_dim() == 5
     encoder = BagOfEmbeddingsEncoder(embedding_dim=12)
     assert encoder.get_input_dim() == 12
     assert encoder.get_output_dim() == 12
 def test_can_construct_from_params(self):
     params = Params({
             'embedding_dim': 5,
             })
     encoder = BagOfEmbeddingsEncoder.from_params(params)
     assert encoder.get_input_dim() == 5
     assert encoder.get_output_dim() == 5
     params = Params({
             'embedding_dim': 12,
             'averaged': True
             })
     encoder = BagOfEmbeddingsEncoder.from_params(params)
     assert encoder.get_input_dim() == 12
     assert encoder.get_output_dim() == 12
    def __init__(
        self,
        backbone: ModelBackbone,
        labels: List[str],
        tokens_pooler: Optional[Seq2VecEncoderConfiguration] = None,
        sentences_encoder: Optional[Seq2SeqEncoderConfiguration] = None,
        sentences_pooler: Seq2VecEncoderConfiguration = None,
        feedforward: Optional[FeedForwardConfiguration] = None,
        multilabel: bool = False,
    ) -> None:

        super(DocumentClassification, self).__init__(
            backbone, labels=labels, multilabel=multilabel
        )

        self.backbone.encoder = TimeDistributedEncoder(backbone.encoder)

        # layers
        self.tokens_pooler = TimeDistributedEncoder(
            BagOfEmbeddingsEncoder(embedding_dim=self.backbone.encoder.get_output_dim())
            if not tokens_pooler
            else tokens_pooler.input_dim(
                self.backbone.encoder.get_output_dim()
            ).compile()
        )
        self.sentences_encoder = (
            PassThroughEncoder(self.tokens_pooler.get_output_dim())
            if not sentences_encoder
            else sentences_encoder.input_dim(
                self.tokens_pooler.get_output_dim()
            ).compile()
        )
        self.sentences_pooler = (
            BagOfEmbeddingsEncoder(self.sentences_encoder.get_output_dim())
            if not sentences_pooler
            else sentences_pooler.input_dim(
                self.sentences_encoder.get_output_dim()
            ).compile()
        )
        self.feedforward = (
            None
            if not feedforward
            else feedforward.input_dim(self.sentences_pooler.get_output_dim()).compile()
        )

        self._classification_layer = torch.nn.Linear(
            (self.feedforward or self.sentences_pooler).get_output_dim(),
            self.num_labels,
        )
Beispiel #6
0
    def __init__(
        self,
        backbone: ModelBackbone,
        labels: List[str],
        entities_embedder: EmbeddingConfiguration,
        pooler: Optional[Seq2VecEncoderConfiguration] = None,
        feedforward: Optional[FeedForwardConfiguration] = None,
        multilabel: bool = False,
        entity_encoding: Optional[str] = "BIOUL"
        # self_attention: Optional[MultiheadSelfAttentionEncoder] = None
    ) -> None:

        super(RelationClassification, self).__init__(
            backbone,
            labels,
            pooler=pooler,
            feedforward=feedforward,
            multilabel=multilabel,
        )

        self._label_encoding = entity_encoding
        self._entity_tags_namespace = "entities"

        self.entities_embedder = entities_embedder.compile()

        encoding_output_dim = (self.backbone.encoder.get_output_dim() +
                               self.entities_embedder.get_output_dim())
        self.pooler = (pooler.input_dim(encoding_output_dim).compile() if
                       pooler else BagOfEmbeddingsEncoder(encoding_output_dim))

        self.feedforward = (None if not feedforward else feedforward.input_dim(
            self.pooler.get_output_dim()).compile())
        self._classification_layer = torch.nn.Linear(
            (self.feedforward
             or self.pooler).get_output_dim(), self.num_labels)
def get_encoder(input_dim, output_dim, encoder_type, args):
    if encoder_type == "bag":
        return BagOfEmbeddingsEncoder(input_dim)
    if encoder_type == "bilstm":
        return PytorchSeq2VecWrapper(
            AllenNLPSequential(torch.nn.ModuleList(
                [get_encoder(input_dim, output_dim, "bilstm-unwrapped",
                             args)]),
                               input_dim,
                               output_dim,
                               bidirectional=True,
                               residual_connection=args.residual_connection,
                               dropout=args.dropout))
    if encoder_type == "bilstm-unwrapped":
        return torch.nn.LSTM(
            input_dim,
            output_dim,
            batch_first=True,
            bidirectional=True,
            dropout=args.dropout,
        )
    if encoder_type == "cnn":
        return CnnEncoder(embedding_dim=input_dim, num_filters=output_dim)
    if encoder_type == "cnn_highway":
        filter_size: int = output_dim // 4
        return CnnHighwayEncoder(
            embedding_dim=input_dim,
            filters=[(2, filter_size), (3, filter_size), (4, filter_size),
                     (5, filter_size)],
            projection_dim=output_dim,
            num_highway=3,
            do_layer_norm=True,
        )
    raise RuntimeError(f"Unknown encoder type={encoder_type}")
Beispiel #8
0
def build_model(vocab: Vocabulary) -> Model:
    print("Building the model")
    vocab_size = vocab.get_vocab_size("tokens")
    embedder = BasicTextFieldEmbedder(
        {"tokens": Embedding(embedding_dim=10, num_embeddings=vocab_size)})
    encoder = BagOfEmbeddingsEncoder(embedding_dim=10)
    return SimpleClassifier(vocab, embedder, encoder)
Beispiel #9
0
    def __init__(
        self,
        backbone: ModelBackbone,
        labels: List[str],
        pooler: Optional[Seq2VecEncoderConfiguration] = None,
        dropout: float = 0.0,
        feedforward: Optional[FeedForwardConfiguration] = None,
        multilabel: bool = False,
        label_weights: Optional[Union[List[float], Dict[str, float]]] = None,
    ) -> None:

        super().__init__(backbone,
                         labels,
                         multilabel,
                         label_weights=label_weights)

        self._empty_prediction = TextClassificationPrediction(labels=[],
                                                              probabilities=[])

        self.pooler = (pooler.input_dim(
            self.backbone.encoder.get_output_dim()).compile()
                       if pooler else BagOfEmbeddingsEncoder(
                           self.backbone.encoder.get_output_dim()))
        self.dropout = torch.nn.Dropout(dropout)
        self.feedforward = (None if not feedforward else feedforward.input_dim(
            self.pooler.get_output_dim()).compile())
        self._classification_layer = torch.nn.Linear(
            (self.feedforward
             or self.pooler).get_output_dim(), self.num_labels)
 def test_forward_does_correct_computation(self):
     encoder = BagOfEmbeddingsEncoder(embedding_dim=2)
     input_tensor = torch.FloatTensor([[[.7, .8], [.1, 1.5], [.3, .6]], [[.5, .3], [1.4, 1.1], [.3, .9]]])
     mask = torch.ByteTensor([[1, 1, 1], [1, 1, 0]])
     encoder_output = encoder(input_tensor, mask)
     assert_almost_equal(encoder_output.data.numpy(),
                         numpy.asarray([[.7 + .1 + .3, .8 + 1.5 + .6], [.5 + 1.4, .3 + 1.1]]))
Beispiel #11
0
 def build_model(vocab: Vocabulary) -> Model:
     vocab_size = vocab.get_vocab_size(
         "tokens")  # "tokens" from data_reader.token_indexers ??
     embedder = BasicTextFieldEmbedder(
         {"tokens": Embedding(embedding_dim=10, num_embeddings=vocab_size)})
     encoder = BagOfEmbeddingsEncoder(embedding_dim=10)
     return SimpleClassifier(vocab, embedder, encoder)
Beispiel #12
0
    def __init__(
        self,
        vocab: Vocabulary,
        text_field_embedder: TextFieldEmbedder,
        seq2vec_encoder: Optional[Seq2VecEncoder] = None,
        feedforward: Optional[FeedForward] = None,
        miner: Optional[PyTorchMetricLearningMiner] = None,
        loss: Optional[PyTorchMetricLearningLoss] = None,
        initializer: InitializerApplicator = InitializerApplicator(),
        **kwargs,
    ) -> None:

        super().__init__(vocab, **kwargs)
        self._text_field_embedder = text_field_embedder
        # (HACK): Prevents the user from having to specify the tokenizer / masked language modeling
        # objective. In the future it would be great to come up with something more elegant.
        token_embedder = self._text_field_embedder._token_embedders["tokens"]
        self._masked_language_modeling = token_embedder.masked_language_modeling
        if self._masked_language_modeling:
            self._tokenizer = token_embedder.tokenizer

        # Default to mean BOW pooler. This performs well and so it serves as a sensible default.
        self._seq2vec_encoder = seq2vec_encoder or BagOfEmbeddingsEncoder(
            text_field_embedder.get_output_dim(), averaged=True)
        self._feedforward = feedforward

        self._miner = miner
        self._loss = loss
        if self._loss is None and not self._masked_language_modeling:
            raise ValueError((
                "No loss function provided. You must provide a contrastive loss (DeCLUTR.loss)"
                " and/or specify `masked_language_modeling=True` in the config when training."
            ))
        initializer(self)
Beispiel #13
0
def build_model(
        vocab: Vocabulary,
        embedding_dim: int,
        pretrained_file: str = None,
        initializer: InitializerApplicator = None,
        regularizer: RegularizerApplicator = None
        ) -> Model:
    print("Building the model")
    vocab_size = vocab.get_vocab_size("tokens")
    word_vec = Embedding(embedding_dim=embedding_dim,
                          num_embeddings=vocab_size,
                          pretrained_file=pretrained_file,
                          vocab=vocab)
    embedding = BasicTextFieldEmbedder({"tokens": word_vec})

    # Use ELMo
    # options_file = 'https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x1024_128_2048cnn_1xhighway/elmo_2x1024_128_2048cnn_1xhighway_options.json'
    # weight_file = 'https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x1024_128_2048cnn_1xhighway/elmo_2x1024_128_2048cnn_1xhighway_weights.hdf5'
    # elmo_embedder = ElmoTokenEmbedder(options_file, weight_file)
    # embedding = BasicTextFieldEmbedder({"tokens": elmo_embedder})

    # Use BERT
    # bert_embedder = PretrainedTransformerEmbedder(
    #     model_name='bert-base-uncased',
    #     max_length=512,
    #     train_parameters=False
    # )
    # embedding = BasicTextFieldEmbedder({"tokens": bert_embedder})

    encoder = BagOfEmbeddingsEncoder(embedding_dim=embedding_dim)
    return SimpleClassifier(vocab, embedding, encoder, initializer, regularizer=regularizer)
Beispiel #14
0
def build_model(vocab: Vocabulary, bert_model: str = None) -> Model:
    if bert_model:
        embedder = BasicTextFieldEmbedder({"bert": PretrainedTransformerEmbedder(model_name=bert_model,
                                                                                 train_parameters=True)})
        encoder = BertPooler(pretrained_model=bert_model, requires_grad=True)
    else:
        # (3) How to get vectors for each Token ID:
        # (3.1) embed each token
        token_embedding = Embedding(embedding_dim=10, num_embeddings=vocab.get_vocab_size("token_vocab"))
        # pretrained_file='https://allennlp.s3.amazonaws.com/datasets/glove/glove.6B.50d.txt.gz'

        # (3.2) embed each character in each token
        character_embedding = Embedding(embedding_dim=3, num_embeddings=vocab.get_vocab_size("character_vocab"))
        cnn_encoder = CnnEncoder(embedding_dim=3, num_filters=4, ngram_filter_sizes=[3,])
        token_encoder = TokenCharactersEncoder(character_embedding, cnn_encoder)
        # (3.3) embed the POS of each token
        pos_tag_embedding = Embedding(embedding_dim=10, num_embeddings=vocab.get_vocab_size("pos_tag_vocab"))

        # Each TokenEmbedders embeds its input, and the result is concatenated in an arbitrary (but consistent) order
        # cf: https://docs.allennlp.org/master/api/modules/text_field_embedders/basic_text_field_embedder/
        embedder = BasicTextFieldEmbedder(
            token_embedders={"tokens": token_embedding,
                             "token_characters": token_encoder,
                             "pos_tags": pos_tag_embedding}
        )  # emb_dim = 10 + 4 + 10 = 24
        encoder = BagOfEmbeddingsEncoder(embedding_dim=24, averaged=True)
        #                                                  ^
        # average the embeddings across time, rather than simply summing
        # (ie. we will divide the summed embeddings by the length of the sentence).
    return SimpleClassifier(vocab, embedder, encoder)
    def __init__(
        self,
        backbone: ModelBackbone,
        labels: List[str],
        token_pooler: Optional[Seq2VecEncoderConfiguration] = None,
        sentence_encoder: Optional[Seq2SeqEncoderConfiguration] = None,
        sentence_pooler: Seq2VecEncoderConfiguration = None,
        feedforward: Optional[FeedForwardConfiguration] = None,
        dropout: float = 0.0,
        multilabel: bool = False,
        label_weights: Optional[Union[List[float], Dict[str, float]]] = None,
    ) -> None:

        super().__init__(
            backbone,
            labels=labels,
            multilabel=multilabel,
            label_weights=label_weights,
        )

        self._empty_prediction = DocumentClassificationPrediction(
            labels=[], probabilities=[])

        self.backbone.encoder = TimeDistributedEncoder(backbone.encoder)

        # layers
        self.token_pooler = TimeDistributedEncoder(
            BagOfEmbeddingsEncoder(
                embedding_dim=self.backbone.encoder.get_output_dim(
                )) if not token_pooler else token_pooler.
            input_dim(self.backbone.encoder.get_output_dim()).compile())
        self.sentence_encoder = (
            PassThroughEncoder(self.token_pooler.get_output_dim())
            if not sentence_encoder else sentence_encoder.input_dim(
                self.token_pooler.get_output_dim()).compile())
        self.sentence_pooler = (
            BagOfEmbeddingsEncoder(self.sentence_encoder.get_output_dim())
            if not sentence_pooler else sentence_pooler.input_dim(
                self.sentence_encoder.get_output_dim()).compile())
        self.feedforward = (None if not feedforward else feedforward.input_dim(
            self.sentence_pooler.get_output_dim()).compile())
        self.dropout = torch.nn.Dropout(dropout)

        self._classification_layer = torch.nn.Linear(
            (self.feedforward or self.sentence_pooler).get_output_dim(),
            self.num_labels,
        )
def build_pool_transformer_model(vocab: Vocabulary, transformer_model: str) -> Model:
    print("Building the model")
    vocab_size = vocab.get_vocab_size("tokens")
    embedding = PretrainedTransformerEmbedder(model_name=transformer_model)
    embedder = BasicTextFieldEmbedder(token_embedders={'bert_tokens': embedding})
    encoder = BagOfEmbeddingsEncoder(embedding_dim=embedder.get_output_dim(), averaged=True)
    #encoder = ClsPooler(embedding_dim=embedder.get_output_dim())
    return SimpleClassifier(vocab, embedder, encoder)
def build_elmo_model(vocab: Vocabulary) -> Model:
    print("Building the model")
    vocab_size = vocab.get_vocab_size("tokens")
    embedding = ElmoTokenEmbedder()
    embedder = BasicTextFieldEmbedder(token_embedders={'bert_tokens': embedding})
    encoder = BagOfEmbeddingsEncoder(embedding_dim=embedder.get_output_dim(), averaged=True)
    
    return SimpleClassifier(vocab, embedder, encoder)
 def test_forward_does_correct_computation_with_average_no_mask(self):
     encoder = BagOfEmbeddingsEncoder(embedding_dim=2, averaged=True)
     input_tensor = torch.FloatTensor([
             [[.7, .8], [.1, 1.5], [.3, .6]], [[.5, .3], [1.4, 1.1], [.3, .9]]
     ])
     encoder_output = encoder(input_tensor)
     assert_almost_equal(encoder_output.data.numpy(),
                         numpy.asarray([[(.7 + .1 + .3)/3, (.8 + 1.5 + .6)/3],
                                        [(.5 + 1.4 + .3)/3, (.3 + 1.1 + .9)/3]]))
 def test_forward_does_correct_computation_with_average(self):
     encoder = BagOfEmbeddingsEncoder(embedding_dim=2, averaged=True)
     input_tensor = Variable(
         torch.FloatTensor([[[.7, .8], [.1, 1.5], [.3, .6]],
                            [[.5, .3], [1.4, 1.1], [.3, .9]],
                            [[.4, .3], [.4, .3], [1.4, 1.7]]]))
     mask = Variable(torch.ByteTensor([[1, 1, 1], [1, 1, 0], [0, 0, 0]]))
     encoder_output = encoder(input_tensor, mask)
     assert_almost_equal(
         encoder_output.data.numpy(),
         numpy.asarray([[(.7 + .1 + .3) / 3, (.8 + 1.5 + .6) / 3],
                        [(.5 + 1.4) / 2, (.3 + 1.1) / 2], [0., 0.]]))
Beispiel #20
0
    def __init__(
        self,
        vocab: Vocabulary,
        text_field_embedder: TextFieldEmbedder,
        seq2vec_encoder: Seq2VecEncoder = None,
        seq2seq_encoder: Seq2SeqEncoder = None,
        feedforward: Optional[FeedForward] = None,
        dropout: float = None,
        num_labels: int = None,
        label_namespace: str = "labels",
        namespace: str = "tokens",
        threshold: float = 0.5,
        initializer: InitializerApplicator = InitializerApplicator(),
        **kwargs,
    ) -> None:

        super().__init__(vocab, **kwargs)
        self._text_field_embedder = text_field_embedder

        self._seq2seq_encoder = seq2seq_encoder

        # Default to mean BOW pooler. This performs well and so it serves as a sensible default.
        self._seq2vec_encoder = seq2vec_encoder or BagOfEmbeddingsEncoder(
            text_field_embedder.get_output_dim(), averaged=True)

        self._feedforward = feedforward
        if self._feedforward is not None:
            self._classifier_input_dim = self._feedforward.get_output_dim()
        else:
            self._classifier_input_dim = self._seq2vec_encoder.get_output_dim()

        if dropout:
            self._dropout = torch.nn.Dropout(dropout)
        else:
            self._dropout = None
        self._label_namespace = label_namespace
        self._namespace = namespace

        if num_labels:
            self._num_labels = num_labels
        else:
            self._num_labels = vocab.get_vocab_size(
                namespace=self._label_namespace)

        self._classification_layer = torch.nn.Linear(
            self._classifier_input_dim, self._num_labels)
        self._threshold = threshold
        self._micro_f1 = F1MultiLabelMeasure(average="micro",
                                             threshold=self._threshold)
        self._macro_f1 = F1MultiLabelMeasure(average="macro",
                                             threshold=self._threshold)
        self._loss = torch.nn.BCEWithLogitsLoss()
        initializer(self)
 def test_get_dimension_is_correct(self):
     encoder = BagOfEmbeddingsEncoder(embedding_dim=5)
     assert encoder.get_input_dim() == 5
     assert encoder.get_output_dim() == 5
     encoder = BagOfEmbeddingsEncoder(embedding_dim=12)
     assert encoder.get_input_dim() == 12
     assert encoder.get_output_dim() == 12
Beispiel #22
0
    def __init__(
        self,
        *,
        vocab: Vocabulary,
        hidden_dim: int = 1000,
        n_hidden_layers: int = 1,
        emb_dim: int = 300,
        dropout: float = 0.5,
        pool: str = "avg",
        label_namespace: str = "page_labels",
    ):
        super().__init__(vocab=vocab,
                         dropout=dropout,
                         label_namespace=label_namespace,
                         hidden_dim=hidden_dim)
        self._embedder = BasicTextFieldEmbedder({
            "text":
            Embedding(
                num_embeddings=vocab.get_vocab_size(),
                embedding_dim=emb_dim,
                trainable=True,
                pretrained_file=
                f"https://allennlp.s3.amazonaws.com/datasets/glove/glove.840B.{emb_dim}d.txt.gz",
            )
        })
        self._pool = pool
        if pool == "avg":
            averaged = True
        elif pool == "sum":
            averaged = False
        else:
            raise ValueError("Invalid value for pool type")
        self._boe = BagOfEmbeddingsEncoder(emb_dim, averaged=averaged)
        encoder_layers = []
        for i in range(n_hidden_layers):
            if i == 0:
                input_dim = emb_dim
            else:
                input_dim = hidden_dim

            encoder_layers.extend([
                nn.Linear(input_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.GELU(),
                nn.Dropout(dropout),
            ])
        self._encoder = nn.Sequential(*encoder_layers)
 def test_forward_does_correct_computation_with_average(self):
     encoder = BagOfEmbeddingsEncoder(embedding_dim=2, averaged=True)
     input_tensor = torch.FloatTensor([
         [[0.7, 0.8], [0.1, 1.5], [0.3, 0.6]],
         [[0.5, 0.3], [1.4, 1.1], [0.3, 0.9]],
         [[0.4, 0.3], [0.4, 0.3], [1.4, 1.7]],
     ])
     mask = torch.ByteTensor([[1, 1, 1], [1, 1, 0], [0, 0, 0]])
     encoder_output = encoder(input_tensor, mask)
     assert_almost_equal(
         encoder_output.data.numpy(),
         numpy.asarray([
             [(0.7 + 0.1 + 0.3) / 3, (0.8 + 1.5 + 0.6) / 3],
             [(0.5 + 1.4) / 2, (0.3 + 1.1) / 2],
             [0.0, 0.0],
         ]),
     )
Beispiel #24
0
 def __init__(self,
              vocab: Vocabulary,
              text_field_embedder: TextFieldEmbedder,
              shared_encoder: Seq2SeqEncoder,
              private_encoder: Seq2SeqEncoder,
              with_domain_embedding: bool = True,
              domain_embeddings: Embedding = None,
              input_dropout: float = 0.0,
              regularizer: RegularizerApplicator = None) -> None:
     super(RNNEncoder, self).__init__(vocab, regularizer)
     self._text_field_embedder = text_field_embedder
     self._shared_encoder = shared_encoder
     self._private_encoder = private_encoder
     self._domain_embeddings = domain_embeddings
     self._with_domain_embedding = with_domain_embedding
     self._seq2vec = BagOfEmbeddingsEncoder(
         embedding_dim=self.get_output_dim())
     self._input_dropout = Dropout(input_dropout)
Beispiel #25
0
    def __init__(self, args, input_dim, hidden_dim, word_embedder):
        super(DefinitionSentenceEncoder, self).__init__()
        self.config = args
        self.args = args
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

        self.projection_dim = input_dim
        self.feedforward_hidden_dim = input_dim
        self.num_layers = self.args.num_layers_for_stackatt
        self.num_attention_heads = self.args.num_atthead_for_stackatt

        self.word_embedder = word_embedder
        self.word_embedding_dropout = nn.Dropout(
            self.args.word_embedding_dropout)

        self.mentiontransformer = StackedSelfAttentionEncoder(
            input_dim=input_dim,
            hidden_dim=hidden_dim,
            projection_dim=self.projection_dim,
            feedforward_hidden_dim=self.feedforward_hidden_dim,
            num_layers=self.num_layers,
            num_attention_heads=self.num_attention_heads)

        self.senttransformer = StackedSelfAttentionEncoder(
            input_dim=input_dim,
            hidden_dim=hidden_dim,
            projection_dim=self.projection_dim,
            feedforward_hidden_dim=self.feedforward_hidden_dim,
            num_layers=self.num_layers,
            num_attention_heads=self.num_attention_heads)

        self.ff_seq2vecs = nn.Linear(input_dim, input_dim)

        self.rnn = PytorchSeq2VecWrapper(
            nn.LSTM(bidirectional=True,
                    num_layers=2,
                    input_size=input_dim,
                    hidden_size=hidden_dim // 2,
                    batch_first=True,
                    dropout=self.args.lstmdropout))

        self.bow = BagOfEmbeddingsEncoder(input_dim, self.args.bow_avg)
Beispiel #26
0
    def __init__(
        self,
        backbone: ModelBackbone,
        labels: List[str],
        pooler: Optional[Seq2VecEncoderConfiguration] = None,
        feedforward: Optional[FeedForwardConfiguration] = None,
        multilabel: bool = False,
    ) -> None:

        super(TextClassification, self).__init__(backbone, labels, multilabel)

        self.pooler = (pooler.input_dim(
            self.backbone.encoder.get_output_dim()).compile()
                       if pooler else BagOfEmbeddingsEncoder(
                           self.backbone.encoder.get_output_dim()))
        self.feedforward = (None if not feedforward else feedforward.input_dim(
            self.pooler.get_output_dim()).compile())
        self._classification_layer = torch.nn.Linear(
            (self.feedforward
             or self.pooler).get_output_dim(), self.num_labels)
Beispiel #27
0
    def __init__(
        self,
        vocab: Vocabulary,
        text_field_embedder: TextFieldEmbedder,
        seq2vec_encoder: Optional[Seq2VecEncoder] = None,
        feedforward: Optional[FeedForward] = None,
        miner: Optional[PyTorchMetricLearningMiner] = None,
        loss: Optional[PyTorchMetricLearningLoss] = None,
        scale_fix: bool = True,
        initializer: InitializerApplicator = InitializerApplicator(),
        **kwargs,
    ) -> None:

        super().__init__(vocab, **kwargs)
        self._text_field_embedder = text_field_embedder
        # Prevents the user from having to specify the tokenizer / masked language modeling
        # objective. In the future it would be great to come up with something more elegant.
        token_embedder = self._text_field_embedder._token_embedders["tokens"]
        self._masked_language_modeling = token_embedder.masked_language_modeling
        if self._masked_language_modeling:
            self._tokenizer = token_embedder.tokenizer

        # Default to mean BOW pooler. This performs well and so it serves as a sensible default.
        self._seq2vec_encoder = seq2vec_encoder or BagOfEmbeddingsEncoder(
            text_field_embedder.get_output_dim(), averaged=True)
        self._feedforward = feedforward

        self._miner = miner
        self._loss = loss
        if self._loss is None and not self._masked_language_modeling:
            raise ValueError((
                "No loss function provided. You must provide a contrastive loss (DeCLUTR.loss)"
                " and/or specify `masked_language_modeling=True` in the config when training."
            ))
        # There was a small bug in the original implementation that caused gradients derived from
        # the contrastive loss to be scaled by 1/N, where N is the number of GPUs used during
        # training. This has been fixed. To reproduce results from the paper, set `model.scale_fix`
        # to `False` in your config. Note that this will have no effect if you are not using
        # distributed training with more than 1 GPU.
        self._scale_fix = scale_fix
        initializer(self)
Beispiel #28
0
    def __init__(self,
                 embedding_dim: int,
                 hidden_dim: Optional[List[int]] = None) -> None:
        super(BoWMaxAndMeanEncoder, self).__init__()
        self._embedding_dim = embedding_dim
        self.maxer = BoWMaxEncoder(self._embedding_dim)
        self.meaner = BagOfEmbeddingsEncoder(self._embedding_dim, True)
        self._hidden_dim = hidden_dim
        if self._hidden_dim is not None:
            layers = [
                torch.nn.LeakyReLU(),
                torch.nn.Linear(self._embedding_dim * 2, self._hidden_dim[0])
            ]

            for i, hid_dim in enumerate(self._hidden_dim[1:]):
                layers.append(torch.nn.LeakyReLU())
                layers.append(torch.nn.Linear(self._hidden_dim[i], hid_dim))

            self.linear = torch.nn.Sequential(*layers)
        else:
            self.linear = None
Beispiel #29
0
    def __init__(
        self,
        backbone: ModelBackbone,
        labels: List[str],
        entities_embedder: EmbeddingConfiguration,
        entity_encoding: Optional[str] = "BIOUL",
        pooler: Optional[Seq2VecEncoderConfiguration] = None,
        feedforward: Optional[FeedForwardConfiguration] = None,
        multilabel: bool = False,
        label_weights: Optional[Union[List[float], Dict[str, float]]] = None,
        # self_attention: Optional[MultiheadSelfAttentionEncoder] = None
    ) -> None:

        super().__init__(
            backbone=backbone,
            labels=labels,
            multilabel=multilabel,
            label_weights=label_weights,
        )

        self._empty_prediction = RelationClassificationPrediction(
            labels=[], probabilities=[])

        self._label_encoding = entity_encoding
        self._entity_tags_namespace = "entities"

        self.entities_embedder = entities_embedder.compile()

        encoding_output_dim = (self.backbone.encoder.get_output_dim() +
                               self.entities_embedder.get_output_dim())
        self.pooler = (pooler.input_dim(encoding_output_dim).compile() if
                       pooler else BagOfEmbeddingsEncoder(encoding_output_dim))

        self.feedforward = (None if not feedforward else feedforward.input_dim(
            self.pooler.get_output_dim()).compile())

        self._classification_layer = torch.nn.Linear(
            (self.feedforward
             or self.pooler).get_output_dim(), self.num_labels)
    def _get_initial_state_and_scores(self,
                                      question,
                                      table,
                                      world,
                                      actions,
                                      example_lisp_string=None,
                                      add_world_to_initial_state=False,
                                      checklist_states=None):
        u"""
        Does initial preparation and creates an intiial state for both the semantic parsers. Note
        that the checklist state is optional, and the ``WikiTablesMmlParser`` is not expected to
        pass it.
        """
        table_text = table[u'text']
        # (batch_size, question_length, embedding_dim)
        embedded_question = self._question_embedder(question)
        question_mask = util.get_text_field_mask(question).float()
        # (batch_size, num_entities, num_entity_tokens, embedding_dim)
        embedded_table = self._question_embedder(table_text,
                                                 num_wrapping_dims=1)
        table_mask = util.get_text_field_mask(table_text,
                                              num_wrapping_dims=1).float()

        batch_size, num_entities, num_entity_tokens, _ = embedded_table.size()
        num_question_tokens = embedded_question.size(1)

        # (batch_size, num_entities, embedding_dim)
        encoded_table = self._entity_encoder(embedded_table, table_mask)
        # (batch_size, num_entities, num_neighbors)
        neighbor_indices = self._get_neighbor_indices(world, num_entities,
                                                      encoded_table)

        # Neighbor_indices is padded with -1 since 0 is a potential neighbor index.
        # Thus, the absolute value needs to be taken in the index_select, and 1 needs to
        # be added for the mask since that method expects 0 for padding.
        # (batch_size, num_entities, num_neighbors, embedding_dim)
        embedded_neighbors = util.batched_index_select(
            encoded_table, torch.abs(neighbor_indices))

        neighbor_mask = util.get_text_field_mask(
            {
                u'ignored': neighbor_indices + 1
            }, num_wrapping_dims=1).float()

        # Encoder initialized to easily obtain a masked average.
        neighbor_encoder = TimeDistributed(
            BagOfEmbeddingsEncoder(self._embedding_dim, averaged=True))
        # (batch_size, num_entities, embedding_dim)
        embedded_neighbors = neighbor_encoder(embedded_neighbors,
                                              neighbor_mask)

        # entity_types: one-hot tensor with shape (batch_size, num_entities, num_types)
        # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index
        # These encode the same information, but for efficiency reasons later it's nice
        # to have one version as a tensor and one that's accessible on the cpu.
        entity_types, entity_type_dict = self._get_type_vector(
            world, num_entities, encoded_table)

        entity_type_embeddings = self._type_params(entity_types.float())
        projected_neighbor_embeddings = self._neighbor_params(
            embedded_neighbors.float())
        # (batch_size, num_entities, embedding_dim)
        entity_embeddings = torch.tanh(entity_type_embeddings +
                                       projected_neighbor_embeddings)

        # Compute entity and question word similarity.  We tried using cosine distance here, but
        # because this similarity is the main mechanism that the model can use to push apart logit
        # scores for certain actions (like "n -> 1" and "n -> -1"), this needs to have a larger
        # output range than [-1, 1].
        question_entity_similarity = torch.bmm(
            embedded_table.view(batch_size, num_entities * num_entity_tokens,
                                self._embedding_dim),
            torch.transpose(embedded_question, 1, 2))

        question_entity_similarity = question_entity_similarity.view(
            batch_size, num_entities, num_entity_tokens, num_question_tokens)

        # (batch_size, num_entities, num_question_tokens)
        question_entity_similarity_max_score, _ = torch.max(
            question_entity_similarity, 2)

        # (batch_size, num_entities, num_question_tokens, num_features)
        linking_features = table[u'linking']

        linking_scores = question_entity_similarity_max_score

        if self._use_neighbor_similarity_for_linking:
            # The linking score is computed as a linear projection of two terms. The first is the
            # maximum similarity score over the entity's words and the question token. The second
            # is the maximum similarity over the words in the entity's neighbors and the question
            # token.
            #
            # The second term, projected_question_neighbor_similarity, is useful when a column
            # needs to be selected. For example, the question token might have no similarity with
            # the column name, but is similar with the cells in the column.
            #
            # Note that projected_question_neighbor_similarity is intended to capture the same
            # information as the related_column feature.
            #
            # Also note that this block needs to be _before_ the `linking_params` block, because
            # we're overwriting `linking_scores`, not adding to it.

            # (batch_size, num_entities, num_neighbors, num_question_tokens)
            question_neighbor_similarity = util.batched_index_select(
                question_entity_similarity_max_score,
                torch.abs(neighbor_indices))
            # (batch_size, num_entities, num_question_tokens)
            question_neighbor_similarity_max_score, _ = torch.max(
                question_neighbor_similarity, 2)
            projected_question_entity_similarity = self._question_entity_params(
                question_entity_similarity_max_score.unsqueeze(-1)).squeeze(-1)
            projected_question_neighbor_similarity = self._question_neighbor_params(
                question_neighbor_similarity_max_score.unsqueeze(-1)).squeeze(
                    -1)
            linking_scores = projected_question_entity_similarity + projected_question_neighbor_similarity

        feature_scores = None
        if self._linking_params is not None:
            feature_scores = self._linking_params(linking_features).squeeze(3)
            linking_scores = linking_scores + feature_scores

        # (batch_size, num_question_tokens, num_entities)
        linking_probabilities = self._get_linking_probabilities(
            world, linking_scores.transpose(1, 2), question_mask,
            entity_type_dict)

        # (batch_size, num_question_tokens, embedding_dim)
        link_embedding = util.weighted_sum(entity_embeddings,
                                           linking_probabilities)
        encoder_input = torch.cat([link_embedding, embedded_question], 2)

        # (batch_size, question_length, encoder_output_dim)
        encoder_outputs = self._dropout(
            self._encoder(encoder_input, question_mask))

        # This will be our initial hidden state and memory cell for the decoder LSTM.
        final_encoder_output = util.get_final_encoder_states(
            encoder_outputs, question_mask, self._encoder.is_bidirectional())
        memory_cell = encoder_outputs.new_zeros(batch_size,
                                                self._encoder.get_output_dim())

        initial_score = embedded_question.data.new_zeros(batch_size)

        action_embeddings, output_action_embeddings, action_biases, action_indices = self._embed_actions(
            actions)

        _, num_entities, num_question_tokens = linking_scores.size()
        flattened_linking_scores, actions_to_entities = self._map_entity_productions(
            linking_scores, world, actions)
        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, question_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(question_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        question_mask_list = [question_mask[i] for i in range(batch_size)]
        initial_rnn_state = []
        for i in range(batch_size):
            initial_rnn_state.append(
                RnnState(final_encoder_output[i], memory_cell[i],
                         self._first_action_embedding,
                         self._first_attended_question, encoder_output_list,
                         question_mask_list))
        initial_grammar_state = [
            self._create_grammar_state(world[i], actions[i])
            for i in range(batch_size)
        ]
        initial_state_world = world if add_world_to_initial_state else None
        initial_state = WikiTablesDecoderState(
            batch_indices=range(batch_size),
            action_history=[[] for _ in range(batch_size)],
            score=initial_score_list,
            rnn_state=initial_rnn_state,
            grammar_state=initial_grammar_state,
            action_embeddings=action_embeddings,
            output_action_embeddings=output_action_embeddings,
            action_biases=action_biases,
            action_indices=action_indices,
            possible_actions=actions,
            flattened_linking_scores=flattened_linking_scores,
            actions_to_entities=actions_to_entities,
            entity_types=entity_type_dict,
            world=initial_state_world,
            example_lisp_string=example_lisp_string,
            checklist_state=checklist_states,
            debug_info=None)
        return {
            u"initial_state": initial_state,
            u"linking_scores": linking_scores,
            u"feature_scores": feature_scores,
            u"similarity_scores": question_entity_similarity_max_score
        }
Beispiel #31
0
    def _get_initial_rnn_and_grammar_state(
        self,
        question: Dict[str, torch.LongTensor],
        table: Dict[str, torch.LongTensor],
        world: List[WikiTablesLanguage],
        actions: List[List[ProductionRuleArray]],
        outputs: Dict[str, Any],
    ) -> Tuple[List[RnnStatelet], List[GrammarStatelet]]:
        """
        Encodes the question and table, computes a linking between the two, and constructs an
        initial RnnStatelet and GrammarStatelet for each batch instance to pass to the
        decoder.

        We take ``outputs`` as a parameter here and `modify` it, adding things that we want to
        visualize in a demo.
        """
        table_text = table["text"]
        # (batch_size, question_length, embedding_dim)
        embedded_question = self._question_embedder(question)
        question_mask = util.get_text_field_mask(question)
        # (batch_size, num_entities, num_entity_tokens, embedding_dim)
        embedded_table = self._question_embedder(table_text,
                                                 num_wrapping_dims=1)
        table_mask = util.get_text_field_mask(table_text, num_wrapping_dims=1)

        batch_size, num_entities, num_entity_tokens, _ = embedded_table.size()
        num_question_tokens = embedded_question.size(1)

        # (batch_size, num_entities, embedding_dim)
        encoded_table = self._entity_encoder(embedded_table, table_mask)

        # entity_types: tensor with shape (batch_size, num_entities), where each entry is the
        # entity's type id.
        # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index
        # These encode the same information, but for efficiency reasons later it's nice
        # to have one version as a tensor and one that's accessible on the cpu.
        entity_types, entity_type_dict = self._get_type_vector(
            world, num_entities, encoded_table)

        entity_type_embeddings = self._entity_type_encoder_embedding(
            entity_types)

        # (batch_size, num_entities, num_neighbors) or None
        neighbor_indices = self._get_neighbor_indices(world, num_entities,
                                                      encoded_table)

        if neighbor_indices is not None:
            # Neighbor_indices is padded with -1 since 0 is a potential neighbor index.
            # Thus, the absolute value needs to be taken in the index_select, and 1 needs to
            # be added for the mask since that method expects 0 for padding.
            # (batch_size, num_entities, num_neighbors, embedding_dim)
            embedded_neighbors = util.batched_index_select(
                encoded_table, torch.abs(neighbor_indices))

            neighbor_mask = util.get_text_field_mask(
                {
                    "ignored": {
                        "ignored": neighbor_indices + 1
                    }
                },
                num_wrapping_dims=1).float()

            # Encoder initialized to easily obtain a masked average.
            neighbor_encoder = TimeDistributed(
                BagOfEmbeddingsEncoder(self._embedding_dim, averaged=True))
            # (batch_size, num_entities, embedding_dim)
            embedded_neighbors = neighbor_encoder(embedded_neighbors,
                                                  neighbor_mask)
            projected_neighbor_embeddings = self._neighbor_params(
                embedded_neighbors.float())

            # (batch_size, num_entities, embedding_dim)
            entity_embeddings = torch.tanh(entity_type_embeddings +
                                           projected_neighbor_embeddings)
        else:
            # (batch_size, num_entities, embedding_dim)
            entity_embeddings = torch.tanh(entity_type_embeddings)

        # Compute entity and question word similarity.  We tried using cosine distance here, but
        # because this similarity is the main mechanism that the model can use to push apart logit
        # scores for certain actions (like "n -> 1" and "n -> -1"), this needs to have a larger
        # output range than [-1, 1].
        question_entity_similarity = torch.bmm(
            embedded_table.view(batch_size, num_entities * num_entity_tokens,
                                self._embedding_dim),
            torch.transpose(embedded_question, 1, 2),
        )

        question_entity_similarity = question_entity_similarity.view(
            batch_size, num_entities, num_entity_tokens, num_question_tokens)

        # (batch_size, num_entities, num_question_tokens)
        question_entity_similarity_max_score, _ = torch.max(
            question_entity_similarity, 2)

        # (batch_size, num_entities, num_question_tokens, num_features)
        linking_features = table["linking"]

        linking_scores = question_entity_similarity_max_score

        if self._use_neighbor_similarity_for_linking:
            # The linking score is computed as a linear projection of two terms. The first is the
            # maximum similarity score over the entity's words and the question token. The second
            # is the maximum similarity over the words in the entity's neighbors and the question
            # token.
            #
            # The second term, projected_question_neighbor_similarity, is useful when a column
            # needs to be selected. For example, the question token might have no similarity with
            # the column name, but is similar with the cells in the column.
            #
            # Note that projected_question_neighbor_similarity is intended to capture the same
            # information as the related_column feature.
            #
            # Also note that this block needs to be _before_ the `linking_params` block, because
            # we're overwriting `linking_scores`, not adding to it.

            # (batch_size, num_entities, num_neighbors, num_question_tokens)
            question_neighbor_similarity = util.batched_index_select(
                question_entity_similarity_max_score,
                torch.abs(neighbor_indices))
            # (batch_size, num_entities, num_question_tokens)
            question_neighbor_similarity_max_score, _ = torch.max(
                question_neighbor_similarity, 2)
            projected_question_entity_similarity = self._question_entity_params(
                question_entity_similarity_max_score.unsqueeze(-1)).squeeze(-1)
            projected_question_neighbor_similarity = self._question_neighbor_params(
                question_neighbor_similarity_max_score.unsqueeze(-1)).squeeze(
                    -1)
            linking_scores = (projected_question_entity_similarity +
                              projected_question_neighbor_similarity)

        feature_scores = None
        if self._linking_params is not None:
            feature_scores = self._linking_params(linking_features).squeeze(3)
            linking_scores = linking_scores + feature_scores

        # (batch_size, num_question_tokens, num_entities)
        linking_probabilities = self._get_linking_probabilities(
            world, linking_scores.transpose(1, 2), question_mask,
            entity_type_dict)

        # (batch_size, num_question_tokens, embedding_dim)
        link_embedding = util.weighted_sum(entity_embeddings,
                                           linking_probabilities)
        encoder_input = torch.cat([link_embedding, embedded_question], 2)

        # (batch_size, question_length, encoder_output_dim)
        encoder_outputs = self._dropout(
            self._encoder(encoder_input, question_mask))

        # This will be our initial hidden state and memory cell for the decoder LSTM.
        final_encoder_output = util.get_final_encoder_states(
            encoder_outputs, question_mask, self._encoder.is_bidirectional())
        memory_cell = encoder_outputs.new_zeros(batch_size,
                                                self._encoder.get_output_dim())

        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, question_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(question_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        question_mask_list = [question_mask[i] for i in range(batch_size)]
        initial_rnn_state = []
        for i in range(batch_size):
            initial_rnn_state.append(
                RnnStatelet(
                    final_encoder_output[i],
                    memory_cell[i],
                    self._first_action_embedding,
                    self._first_attended_question,
                    encoder_output_list,
                    question_mask_list,
                ))
        initial_grammar_state = [
            self._create_grammar_state(world[i], actions[i], linking_scores[i],
                                       entity_types[i])
            for i in range(batch_size)
        ]
        if not self.training:
            # We add a few things to the outputs that will be returned from `forward` at evaluation
            # time, for visualization in a demo.
            outputs["linking_scores"] = linking_scores
            if feature_scores is not None:
                outputs["feature_scores"] = feature_scores
            outputs["similarity_scores"] = question_entity_similarity_max_score
        return initial_rnn_state, initial_grammar_state
Beispiel #32
0
    def _get_initial_state(
            self, utterance: Dict[str, torch.LongTensor],
            worlds: List[SpiderWorld], schema: Dict[str, torch.LongTensor],
            actions: List[List[ProductionRule]]) -> GrammarBasedState:
        schema_text = schema['text']
        """KAIMARY"""
        # TextFieldEmbedder needs a "token" key in the Dict
        """
        embedded_schema:torch.Size([batch_size, num_entities, max_num_entity_tokens, embedding_dim])
        schema_mask:torch.Size([batch_size, num_entities, max_num_entity_tokens])
        embedded_utterance:torch.Size([batch_size, max_utterance_size, embedding_dim])
        entity_type_embeddings:torch.Size([batch_size, num_entities, embedding_dim])
        """
        embedded_schema = self._question_embedder(schema_text,
                                                  num_wrapping_dims=1)
        schema_mask = util.get_text_field_mask(schema_text,
                                               num_wrapping_dims=1).float()

        embedded_utterance = self._question_embedder(utterance)
        utterance_mask = util.get_text_field_mask(utterance).float()

        batch_size, num_entities, num_entity_tokens, _ = embedded_schema.size()
        num_entities = max([
            len(world.db_context.knowledge_graph.entities) for world in worlds
        ])
        num_question_tokens = utterance['tokens'].size(1)

        # entity_types: tensor with shape (batch_size, num_entities), where each entry is the
        # entity's type id.
        # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index
        # These encode the same information, but for efficiency reasons later it's nice
        # to have one version as a tensor and one that's accessible on the cpu.
        entity_types, entity_type_dict = self._get_type_vector(
            worlds, num_entities, embedded_schema.device)

        entity_type_embeddings = self._entity_type_encoder_embedding(
            entity_types)

        # Compute entity and question word similarity.  We tried using cosine distance here, but
        # because this similarity is the main mechanism that the model can use to push apart logit
        # scores for certain actions (like "n -> 1" and "n -> -1"), this needs to have a larger
        # output range than [-1, 1].
        question_entity_similarity = torch.bmm(
            embedded_schema.view(batch_size, num_entities * num_entity_tokens,
                                 self._embedding_dim),
            torch.transpose(embedded_utterance, 1, 2))

        question_entity_similarity = question_entity_similarity.view(
            batch_size, num_entities, num_entity_tokens, num_question_tokens)
        # (batch_size, num_entities, num_question_tokens)
        question_entity_similarity_max_score, _ = torch.max(
            question_entity_similarity, 2)
        """KAIMARY"""
        # Variable: linking_scores
        # The entitiy linking score s(e, i) in the Krishnamurthy 2017
        # (batch_size, num_entities, num_question_tokens, num_features)
        linking_features = schema['linking']

        linking_scores = question_entity_similarity_max_score

        feature_scores = self._linking_params(linking_features).squeeze(3)

        linking_scores = linking_scores + feature_scores
        """KAIMARY"""
        # linking_probabilities
        # The scores s(e,i) are then fed into a softmax layer over all entities e of the same type
        # (batch_size, num_question_tokens, num_entities)
        linking_probabilities = self._get_linking_probabilities(
            worlds, linking_scores.transpose(1, 2), utterance_mask,
            entity_type_dict)

        # (batch_size, num_entities, num_neighbors) or None
        neighbor_indices = self._get_neighbor_indices(worlds, num_entities,
                                                      linking_scores.device)

        if self._use_neighbor_similarity_for_linking and neighbor_indices is not None:
            """KAIMARY"""
            # Seq2VecEncoder get the hidden state of the last step as the unique output
            # (batch_size, num_entities, embedding_dim)
            encoded_table = self._entity_encoder(embedded_schema, schema_mask)

            # Neighbor_indices is padded with -1 since 0 is a potential neighbor index.
            # Thus, the absolute value needs to be taken in the index_select, and 1 needs to
            # be added for the mask since that method expects 0 for padding.
            # (batch_size, num_entities, num_neighbors, embedding_dim)
            embedded_neighbors = util.batched_index_select(
                encoded_table, torch.abs(neighbor_indices))

            neighbor_mask = util.get_text_field_mask(
                {
                    'ignored': neighbor_indices + 1
                }, num_wrapping_dims=1).float()

            # Encoder initialized to easily obtain a masked average.
            neighbor_encoder = TimeDistributed(
                BagOfEmbeddingsEncoder(self._embedding_dim, averaged=True))
            # (batch_size, num_entities, embedding_dim)
            embedded_neighbors = neighbor_encoder(embedded_neighbors,
                                                  neighbor_mask)
            projected_neighbor_embeddings = self._neighbor_params(
                embedded_neighbors.float())
            """KAIMARY"""
            # Variable: entity_embedding
            # Rv in B Bogin 2019
            # Is a learned embedding for the schema item v, which base the embedding on the type of v and its schema neighbors only
            # (batch_size, num_entities, embedding_dim)
            entity_embeddings = torch.tanh(entity_type_embeddings +
                                           projected_neighbor_embeddings)
        else:
            # (batch_size, num_entities, embedding_dim)
            entity_embeddings = torch.tanh(entity_type_embeddings)
        """KAIMARY"""
        # Variable: link_embedding
        # Li in B Bogin 2019
        # Is an average of entity vectors weighted by the resulting distribution
        link_embedding = util.weighted_sum(entity_embeddings,
                                           linking_probabilities)
        """KAIMARY"""
        # Variable: encoder_input
        # [Wi, Li] in B Bogin 2019
        encoder_input = torch.cat([link_embedding, embedded_utterance], 2)

        # (batch_size, utterance_length, encoder_output_dim)
        encoder_outputs = self._dropout(
            self._encoder(encoder_input, utterance_mask))
        """KAIMARY"""
        # Variable: max_entities_relevance
        # ρv = maxi plink(v | xi) in B Bogin 2019
        # Is the maximum probability of v for any word xi
        max_entities_relevance = linking_probabilities.max(dim=1)[0]
        entities_relevance = max_entities_relevance.unsqueeze(-1).detach()
        """KAIMARY"""
        # entity_type_embeddings ???
        # Variable: graph_initial_embedding
        # hv(0) in B Bogin 2019
        # Is an initial embedding conditioned on the relevance score, and then used to be fed into GNN
        graph_initial_embedding = entity_type_embeddings * entities_relevance

        encoder_output_dim = self._encoder.get_output_dim()
        if self._gnn:
            """KAIMARY"""
            # Variable: entities_graph_encoding
            # φv in  B Bogin 2019
            # Is the final representation of each schema item after L steps
            entities_graph_encoding = self._get_schema_graph_encoding(
                worlds, graph_initial_embedding)
            """KAIMARY"""
            # Variable: graph_link_embedding
            # Lφ,i in B Bogin 2019
            graph_link_embedding = util.weighted_sum(entities_graph_encoding,
                                                     linking_probabilities)
            encoder_outputs = torch.cat(
                (encoder_outputs, graph_link_embedding), dim=-1)
            encoder_output_dim = self._action_embedding_dim + self._encoder.get_output_dim(
            )
        else:
            entities_graph_encoding = None

        if self._self_attend:
            # linked_actions_linking_scores = self._get_linked_actions_linking_scores(actions, entities_graph_encoding)
            entities_ff = self._ent2ent_ff(entities_graph_encoding)
            linked_actions_linking_scores = torch.bmm(
                entities_ff, entities_ff.transpose(1, 2))
        else:
            linked_actions_linking_scores = [None] * batch_size

        # This will be our initial hidden state and memory cell for the decoder LSTM.
        final_encoder_output = util.get_final_encoder_states(
            encoder_outputs, utterance_mask, self._encoder.is_bidirectional())
        memory_cell = encoder_outputs.new_zeros(batch_size, encoder_output_dim)
        initial_score = embedded_utterance.data.new_zeros(batch_size)

        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, utterance_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(utterance_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        utterance_mask_list = [utterance_mask[i] for i in range(batch_size)]
        initial_rnn_state = []
        for i in range(batch_size):
            initial_rnn_state.append(
                RnnStatelet(final_encoder_output[i], memory_cell[i],
                            self._first_action_embedding,
                            self._first_attended_utterance,
                            encoder_output_list, utterance_mask_list))

        initial_grammar_state = [
            self._create_grammar_state(
                worlds[i], actions[i], linking_scores[i],
                linked_actions_linking_scores[i], entity_types[i],
                entities_graph_encoding[i]
                if entities_graph_encoding is not None else None)
            for i in range(batch_size)
        ]

        initial_sql_state = [
            SqlState(actions[i], self.parse_sql_on_decoding)
            for i in range(batch_size)
        ]

        initial_state = GrammarBasedState(
            batch_indices=list(range(batch_size)),
            action_history=[[] for _ in range(batch_size)],
            score=initial_score_list,
            rnn_state=initial_rnn_state,
            grammar_state=initial_grammar_state,
            sql_state=initial_sql_state,
            possible_actions=actions,
            action_entity_mapping=[
                w.get_action_entity_mapping() for w in worlds
            ])

        return initial_state