def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, num_highway_layers: int, phrase_layer: Seq2SeqEncoder, modeling_layer: Seq2SeqEncoder, span_end_encoder: Seq2SeqEncoder, dropout: float = 0.2, mask_lstms: bool = True, initializer: InitializerApplicator = InitializerApplicator(), regularizer: RegularizerApplicator = RegularizerApplicator()): super(BidirectionalAttentionFlow, self).__init__(vocab, regularizer) self._text_field_embedder = text_field_embedder self._highway_layer = TimeDistributed( Highway(text_field_embedder.get_output_dim(), num_highway_layers)) self._phrase_layer = phrase_layer self._matrix_attention = CosineMatrixAttention() self._modeling_layer = modeling_layer self._span_end_encoder = span_end_encoder encoding_dim = phrase_layer.get_output_dim() modeling_dim = modeling_layer.get_output_dim() span_start_input_dim = encoding_dim * 4 + modeling_dim self._span_start_predictor = TimeDistributed( torch.nn.Linear(span_start_input_dim, 1)) span_end_encoding_dim = span_end_encoder.get_output_dim() span_end_input_dim = encoding_dim * 4 + span_end_encoding_dim self._span_end_predictor = TimeDistributed( torch.nn.Linear(span_end_input_dim, 1)) # Bidaf has lots of layer dimensions which need to match up - these aren't necessarily # obvious from the configuration files, so we check here. check_dimensions_match(modeling_layer.get_input_dim(), 4 * encoding_dim, "modeling layer input dim", "4 * encoding dim") check_dimensions_match(text_field_embedder.get_output_dim(), phrase_layer.get_input_dim(), "text field embedder output dim", "phrase layer input dim") check_dimensions_match(span_end_encoder.get_input_dim(), 4 * encoding_dim + 3 * modeling_dim, "span end encoder input dim", "4 * encoding dim + 3 * modeling dim") self._span_start_accuracy = CategoricalAccuracy() self._span_end_accuracy = CategoricalAccuracy() self._span_accuracy = BooleanAccuracy() self._squad_metrics = SquadEmAndF1() if dropout > 0: self._dropout = torch.nn.Dropout(p=dropout) else: self._dropout = lambda x: x self._mask_lstms = mask_lstms initializer(self)
def __init__( self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, seq2vec_encoder: Seq2VecEncoder, seq2seq_encoder: Seq2SeqEncoder = None, dropout: float = None, num_labels: int = None, label_namespace: str = "labels", regularizer: RegularizerApplicator = RegularizerApplicator(), initializer: InitializerApplicator = InitializerApplicator() ) -> None: super().__init__(vocab, regularizer) self._text_field_embedder = text_field_embedder if seq2seq_encoder: self._seq2seq_encoder = seq2seq_encoder else: self._seq2seq_encoder = None self._seq2vec_encoder = seq2vec_encoder self._classifier_input_dim = self._seq2vec_encoder.get_output_dim() if dropout: self._dropout = torch.nn.Dropout(dropout) else: self._dropout = None self._label_namespace = label_namespace if num_labels: self._num_labels = num_labels else: self._num_labels = vocab.get_vocab_size( namespace=self._label_namespace) self._classification_layer = torch.nn.Linear( self._classifier_input_dim, self._num_labels) self._accuracy = CategoricalAccuracy() self._loss = torch.nn.CrossEntropyLoss() initializer(self)