コード例 #1
0
    def __init__(
            self,
            # Vocabluary.
            vocab: Vocabulary,

            # Embeddings.
            source_field_embedder: TextFieldEmbedder,
            target_embedding_size: int,

            # Encoders and Decoders.
            encoder: Seq2SeqEncoder,
            decoder_type: str,
            output_projection_layer: FeedForward,
            source_namespace: str = "source",
            target_namespace: str = "target",

            # Hyperparamters and flags.
            decoder_attention_function: BilinearAttention = None,
            decoder_is_bidirectional: bool = False,
            decoder_num_layers: int = 1,
            apply_attention: Optional[bool] = False,
            max_decoding_steps: int = 100,
            scheduled_sampling_ratio: float = 0.4,

            # Logistical.
            initializer: InitializerApplicator = InitializerApplicator(),
            regularizer: Optional[RegularizerApplicator] = None) -> None:
        super().__init__(vocab, regularizer)
        if encoder.get_input_dim() != source_field_embedder.get_output_dim():
            raise ConfigurationError(
                "The input dimension of the encoder must match the embedding"
                "size of the source_field_embedder. Found {} and {}, respectively."
                .format(encoder.get_input_dim(),
                        source_field_embedder.get_output_dim()))
        if output_projection_layer.get_output_dim() != vocab.get_vocab_size(
                target_namespace):
            raise ConfigurationError(
                "The output dimension of the output_projection_layer must match the "
                "size of the French vocabulary. Found {} and {}, "
                "respectively.".format(
                    output_projection_layer.get_output_dim(),
                    vocab.get_vocab_size(target_namespace)))
        if decoder_type not in SequenceToSequence.DECODERS:
            raise ConfigurationError(
                "Unrecognized decoder option '{}'".format(decoder_type))

        # For dealing with input.
        self.source_vocab_size = vocab.get_vocab_size(source_namespace)
        self.target_vocab_size = vocab.get_vocab_size(target_namespace)
        self.source_field_embedder = source_field_embedder or TextFieldEmbedder(
        )
        self.encoder = encoder

        # For dealing with / producing output.
        self.target_vocab_size = vocab.get_vocab_size(target_namespace)
        self.target_embedder = Embedding(self.target_vocab_size,
                                         target_embedding_size)

        # Input size will either be the target embedding size or the target embedding size plus the
        # encoder hidden size to attend on the input.
        #
        # When making a custom attention function that uses neither of those input sizes, you will
        # have to define the decoder yourself.
        decoder_input_size = target_embedding_size
        if apply_attention:
            decoder_input_size += encoder.get_output_dim()

        # Hidden size of the encoder and decoder should match.
        decoder_hidden_size = encoder.get_output_dim()
        self.decoder = SequenceToSequence.DECODERS[decoder_type](
            decoder_input_size,
            decoder_hidden_size,
            num_layers=decoder_num_layers,
            batch_first=True,
            bias=True,
            bidirectional=decoder_is_bidirectional)
        self.output_projection_layer = output_projection_layer
        self.apply_attention = apply_attention
        self.decoder_attention_function = decoder_attention_function or BilinearAttention(
            matrix_dim=encoder.get_output_dim(),
            vector_dim=encoder.get_output_dim())

        # Hyperparameters.
        self._max_decoding_steps = max_decoding_steps
        self._scheduled_sampling_ratio = scheduled_sampling_ratio

        # Used for prepping the translation primer (initialization of the target word-level
        # encoder's hidden state).
        #
        # If the decoder is an LSTM, both hidden states and cell states must be initialized.
        # Also, hidden states that prime translation via this encoder must be duplicated
        # across by number of layers they has.
        self._decoder_is_lstm = isinstance(self.decoder, torch.nn.LSTM)
        self._decoder_num_layers = decoder_num_layers

        self._start_index = vocab.get_token_index(START_SYMBOL,
                                                  target_namespace)
        self._end_index = vocab.get_token_index(END_SYMBOL, target_namespace)
        self._source_namespace = source_namespace
        self._target_namespace = target_namespace
        self._batch_size = None

        initializer(self)
コード例 #2
0
#question_lstm_mask = None; passage_lstm_mask = None
"""
################### EMBEDDING LAYER  #########################################
"""
print("-------------- EMBEDDING LAYER ---------------")
if (use_ELMO):
    if (load_ELMO_experiments_flag):
        options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_options.json"
        weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5"

        print("Loading ELMO")
        text_field_embedder = Elmo(options_file, weight_file, 2, dropout=0)
        print("ELMO weights loaded")
else:
    text_field_embedder = TextFieldEmbedder()
    token_embedders = dict()
    text_field_embedder = Embedding(embedding_dim=100, trainable=False)

## Parameters needed for the next layer
embedder_out_dim = text_field_embedder.get_output_dim()

print("Embedder output dimensions: ", embedder_out_dim)
## Propagate the Batch though the Embedder
embeddings_batch_question = text_field_embedder(
    character_ids_question)["elmo_representations"][1]
embeddings_batch_passage = text_field_embedder(
    character_ids_passage)["elmo_representations"][1]

#print (embeddings_batch_question)
print("Question representations: ", embeddings_batch_question.shape)