Ejemplo n.º 1
0
    def __init__(
        self,
        dictionary,
        num_chars,
        char_embed_dim,
        token_embed_dim,
        normalize_embed,
        char_rnn_units,
        char_rnn_layers,
        hidden_dim,
        num_layers,
        dropout_in,
        dropout_out,
        residual_level,
        bidirectional,
    ):
        super().__init__(dictionary)
        self.dropout_in = dropout_in

        self.embed_chars = char_encoder.CharRNNModel(
            dictionary=dictionary,
            num_chars=num_chars,
            char_embed_dim=char_embed_dim,
            char_rnn_units=char_rnn_units,
            char_rnn_layers=char_rnn_layers,
        )

        self.embed_tokens = None
        if token_embed_dim > 0:
            self.embed_tokens = rnn.Embedding(
                num_embeddings=len(dictionary),
                embedding_dim=token_embed_dim,
                padding_idx=dictionary.pad(),
                freeze_embed=False,
                normalize_embed=normalize_embed,
            )

        self.word_dim = char_rnn_units + token_embed_dim

        self.bilstm = rnn.BiLSTM(
            num_layers=num_layers,
            bidirectional=bidirectional,
            embed_dim=self.word_dim,
            hidden_dim=hidden_dim,
            dropout=dropout_out,
            residual_level=residual_level,
        )

        # disables sorting and word-length thresholding if True
        # (enables ONNX tracing of length-sorted input with batch_size = 1)
        self.onnx_export_model = False
Ejemplo n.º 2
0
    def __init__(
        self,
        dictionary,
        num_chars=50,
        unk_only_char_encoding=False,
        embed_dim=32,
        token_embed_dim=256,
        freeze_embed=False,
        normalize_embed=False,
        char_cnn_params="[(128, 3), (128, 5)]",
        char_cnn_nonlinear_fn="tanh",
        char_cnn_pool_type="max",
        char_cnn_num_highway_layers=0,
        char_cnn_output_dim=-1,
        hidden_dim=512,
        num_layers=1,
        dropout_in=0.1,
        dropout_out=0.1,
        residual_level=None,
        bidirectional=False,
        word_dropout_params=None,
        use_pretrained_weights=False,
        finetune_pretrained_weights=False,
        weights_file=None,
    ):
        super().__init__(dictionary)
        self.dropout_in = dropout_in

        convolutions_params = literal_eval(char_cnn_params)
        self.char_cnn_encoder = char_encoder.CharCNNModel(
            dictionary,
            num_chars,
            embed_dim,
            convolutions_params,
            char_cnn_nonlinear_fn,
            char_cnn_pool_type,
            char_cnn_num_highway_layers,
            char_cnn_output_dim,
            use_pretrained_weights,
            finetune_pretrained_weights,
            weights_file,
        )

        self.embed_tokens = None
        num_tokens = len(dictionary)
        self.padding_idx = dictionary.pad()
        self.unk_idx = dictionary.unk()
        if token_embed_dim > 0:
            self.embed_tokens = rnn.Embedding(
                num_embeddings=num_tokens,
                embedding_dim=token_embed_dim,
                padding_idx=self.padding_idx,
                freeze_embed=freeze_embed,
                normalize_embed=normalize_embed,
            )
        self.word_dim = (
            char_cnn_output_dim
            if char_cnn_output_dim != -1
            else sum(out_dim for (out_dim, _) in convolutions_params)
        )
        self.token_embed_dim = token_embed_dim

        self.unk_only_char_encoding = unk_only_char_encoding
        if self.unk_only_char_encoding:
            assert char_cnn_output_dim == token_embed_dim, (
                "char_cnn_output_dim (%d) must equal to token_embed_dim (%d)"
                % (char_cnn_output_dim, token_embed_dim)
            )
            self.word_dim = token_embed_dim
        else:
            self.word_dim = self.word_dim + token_embed_dim

        self.bilstm = rnn.BiLSTM(
            num_layers=num_layers,
            bidirectional=bidirectional,
            embed_dim=self.word_dim,
            hidden_dim=hidden_dim,
            dropout=dropout_out,
            residual_level=residual_level,
        )

        # Variable tracker
        self.tracker = VariableTracker()
        # Initialize adversarial mode
        self.set_gradient_tracking_mode(False)
        self.set_embed_noising_mode(False)