Esempio n. 1
0
def build_gnn_parsing_model2(
    flags,
    data_reader: DatasetReader,
    vocab: Vocabulary,
    is_test: bool = False,
    source_namespace: str = 'source_tokens',
    target_namespace: str = 'target_tokens',
    segment_namespace: str = 'segment_tokens',
) -> Model:
    metric = SequenceAccuracy()
    model = GNNCopyTransformer2(
        vocab=vocab,
        source_namespace=source_namespace,
        target_namespace=target_namespace,
        segment_namespace=segment_namespace,
        max_decoding_step=flags.max_decode_length,
        token_based_metric=metric,
        source_embedding_dim=flags.source_embedding_dim,
        target_embedding_dim=flags.target_embedding_dim,
        encoder_d_model=flags.transformer_encoder_hidden_dim,
        decoder_d_model=flags.transformer_decoder_hidden_dim,
        encoder_nhead=flags.transformer_encoder_nhead,
        decoder_nhead=flags.transformer_decoder_nhead,
        num_decoder_layers=flags.transformer_num_decoder_layers,
        num_encoder_layers=flags.transformer_num_encoder_layers,
        encoder_dim_feedforward=flags.transformer_encoder_feedforward_dim,
        decoder_dim_feedforward=flags.transformer_decoder_feedforward_dim,
        dropout=flags.dropout,
        beam_size=1,
        nlabels=flags.gnn_transformer_num_edge_labels,
        max_decode_clip_range=flags.gnn_max_decode_clip_range,
        encode_edge_label_with_matrix=flags.gnn_encode_edge_label_with_matrix,
        is_test=is_test)
    return model
Esempio n. 2
0
def build_parsing_seq2seq_model(
        flags,
        data_reader,
        vocab: Vocabulary,
        source_namespace: str = 'source_tokens',
        target_namespace: str = 'target_tokens') -> Model:
    source_embedding = Embedding(
        vocab.get_vocab_size(namespace=source_namespace),
        embedding_dim=flags.source_embedding_dim)
    source_embedder = BasicTextFieldEmbedder({'tokens': source_embedding})
    lstm_encoder = PytorchSeq2SeqWrapper(
        torch.nn.LSTM(flags.source_embedding_dim,
                      flags.encoder_hidden_dim,
                      batch_first=True,
                      bidirectional=flags.encoder_bidirectional))
    attention = DotProductAttention()
    metric = SequenceAccuracy()
    model = Seq2SeqModel(vocab,
                         source_embedder,
                         lstm_encoder,
                         flags.max_decode_length,
                         target_embedding_dim=flags.decoder_hidden_dim,
                         target_namespace=target_namespace,
                         attention=attention,
                         beam_size=flags.beam_size,
                         use_bleu=False,
                         seq_metrics=metric)
    return model
Esempio n. 3
0
def build_parsing_recombination_seq2seq_copy_model(
        flags,
        data_reader,
        vocab: Vocabulary,
        source_namespace: str = 'source_tokens',
        target_namespace: str = 'target_tokens') -> Model:
    source_embedding = Embedding(
        vocab.get_vocab_size(namespace=source_namespace),
        embedding_dim=flags.source_embedding_dim)
    lstm = PytorchSeq2SeqWrapper(
        torch.nn.LSTM(flags.source_embedding_dim,
                      flags.encoder_hidden_dim,
                      batch_first=True,
                      bidirectional=flags.encoder_bidirectional))
    attention = BilinearAttention(flags.attention_hidden_dim,
                                  flags.attention_hidden_dim,
                                  normalize=False)
    source_embedder = BasicTextFieldEmbedder({'tokens': source_embedding})
    initializer = InitializerApplicator.from_params([
        (".*bias", Params({
            "type": "constant",
            "val": 0
        })), ('.*', Params({
            "type": "uniform",
            "a": -0.1,
            "b": 0.1
        }))
    ])
    metric = SequenceAccuracy()
    model = RecombinationSeq2SeqWithCopy(
        vocab,
        source_embedder,
        lstm,
        flags.max_decode_length,
        seq_metrics=metric,
        source_namespace=source_namespace,
        target_namespace=target_namespace,
        target_embedding_dim=flags.target_embedding_dim,
        attention=attention,
        beam_size=flags.beam_size,
        use_bleu=False,
        encoder_input_dropout=flags.encoder_input_dropout,
        encoder_output_dropout=flags.encoder_output_dropout,
        dropout=flags.dropout,
        feed_output_attention_to_decoder=True,
        keep_decoder_output_dim_same_as_encoder=True,
        initializer=initializer)
    return model
Esempio n. 4
0
def build_grammar_copy_model(
    flags,
    data_reader: DatasetReader,
    vocab: Vocabulary,
    grammar: Grammar,
    source_namespace: str = 'source_tokens',
) -> Model:
    source_embedding = Embedding(
        vocab.get_vocab_size(namespace=source_namespace),
        embedding_dim=flags.source_embedding_dim)
    source_embedder = BasicTextFieldEmbedder({'tokens': source_embedding})
    lstm_encoder = PytorchSeq2SeqWrapper(
        torch.nn.LSTM(flags.source_embedding_dim,
                      flags.encoder_hidden_dim,
                      batch_first=True,
                      bidirectional=flags.encoder_bidirectional))
    decoder = LSTMGrammarCopyDecoder(
        grammar,
        AST,
        lstm_hidden_dim=flags.decoder_hidden_dim,
        num_lstm_layers=flags.decoder_num_layers,
        rule_pad_index=data_reader.rule_pad_index,
        rule_embedding_dim=flags.rule_embedding_dim,
        nonterminal_pad_index=data_reader.nonterminal_pad_index,
        nonterminal_end_index=data_reader.nonterminal_end_index,
        nonterminal_embedding_dim=flags.nonterminal_embedding_dim,
        source_encoding_dim=flags.encoder_hidden_dim * 2,
        dropout=flags.dropout,
        max_target_length=flags.max_decode_length)
    metric = SequenceAccuracy()
    model = GrammarModel(vocab,
                         source_embedder,
                         lstm_encoder,
                         decoder,
                         metric,
                         flags,
                         regularizer=None)
    return model