Example #1
0
    def __init__(self, vocab, config, elmo_shape):
        super(BiLSTMModel, self).__init__()
        self.config = config
        self.PAD = vocab.PAD
        self.word_dims = config.word_dims
        self.elmo_layers = elmo_shape[0]
        self.elmo_dims = elmo_shape[1]

        weights = torch.randn(self.elmo_layers)
        self.weights = torch.nn.Parameter(weights, requires_grad=True)
        self.mlp_elmo = nn.Linear(self.elmo_dims, self.word_dims, bias=False)

        self.word_embed = nn.Embedding(vocab.vocab_size, config.word_dims, padding_idx=0)
        word_init = np.random.randn(vocab.vocab_size, config.word_dims).astype(np.float32)
        self.word_embed.weight.data.copy_(torch.from_numpy(word_init))

        self.predicate_embed = nn.Embedding(3, config.predict_dims, padding_idx=0)
        nn.init.normal_(self.predicate_embed.weight, 0.0, 1.0 / (config.predict_dims ** 0.5))

        self.lstm_input_dims = config.word_dims + config.predict_dims

        self.bilstm = MyLSTM(
            input_size=self.lstm_input_dims,
            hidden_size=config.lstm_hiddens,
            num_layers=config.lstm_layers,
            batch_first=True,
            bidirectional=True,
            dropout_in=config.dropout_lstm_input,
            dropout_out=config.dropout_lstm_hidden,
        )

        self.outlayer = nn.Linear(2 * config.lstm_hiddens, vocab.label_size, bias=False)
        nn.init.normal_(self.outlayer.weight, 0.0, 1.0 / ((2 * config.lstm_hiddens) ** 0.5))

        self.crf = CRF(vocab.label_size)
Example #2
0
    def __init__(self, vocab, config, elmo_shape):
        super(BiLSTMModel, self).__init__()
        self.config = config
        self.word_dims = config.word_dims
        self.elmo_layers = elmo_shape[0]
        self.elmo_dims = elmo_shape[1]

        weights = torch.randn(self.elmo_layers)
        self.weights = torch.nn.Parameter(weights, requires_grad=True)
        self.mlp_elmo = nn.Linear(self.elmo_dims, self.word_dims, bias=False)

        self.word_embed = nn.Embedding(vocab.vocab_size,
                                       config.word_dims,
                                       padding_idx=0)
        word_init = np.random.randn(vocab.vocab_size,
                                    config.word_dims).astype(np.float32)
        self.word_embed.weight.data.copy_(torch.from_numpy(word_init))

        self.lstm = MyLSTM(
            input_size=self.word_dims,
            hidden_size=config.lstm_hiddens,
            num_layers=config.lstm_layers,
            batch_first=True,
            bidirectional=True,
            dropout_in=config.dropout_lstm_input,
            dropout_out=config.dropout_lstm_hidden,
        )

        self.proj = nn.Linear(2 * config.lstm_hiddens,
                              vocab.tag_size,
                              bias=False)
Example #3
0
    def __init__(self, vocab, config, pretrained_embedding):
        super(BiLSTMModel, self).__init__()
        self.config = config
        vocab_size, word_dims = pretrained_embedding.shape
        if vocab.vocab_size != vocab_size:
            print("word vocab size does not match, check word embedding file")
        self.word_embed = CPUEmbedding(vocab.vocab_size,
                                       word_dims,
                                       padding_idx=vocab.PAD)
        self.word_embed.weight.data.copy_(
            torch.from_numpy(pretrained_embedding))
        self.word_embed.weight.requires_grad = False
        self.use_cosine = config.use_cosine

        self.lstm = MyLSTM(
            input_size=word_dims,
            hidden_size=config.lstm_hiddens,
            num_layers=config.lstm_layers,
            batch_first=True,
            bidirectional=True,
            dropout_in=config.dropout_lstm_input,
            dropout_out=config.dropout_lstm_hidden,
        )
        self.sent_dim = 2 * config.lstm_hiddens
        self.atten_guide = Parameter(torch.Tensor(self.sent_dim))
        self.atten_guide.data.normal_(0, 1)
        self.atten = LinearAttention(tensor_1_dim=self.sent_dim,
                                     tensor_2_dim=self.sent_dim)
        self.proj = NonLinear(self.sent_dim, vocab.tag_size)
Example #4
0
    def __init__(self, vocab, config, input_dims, bert_layers):
        super(SAModel, self).__init__()
        self.config = config
        self.input_dims = input_dims
        self.input_depth = bert_layers if config.bert_tune == 0 else 1
        self.hidden_dims = 2 * config.lstm_hiddens
        self.projections = nn.ModuleList([NonLinear(self.input_dims, self.hidden_dims, activation=GELU()) \
                                          for i in range(self.input_depth)])
        self.rescale = ScalarMix(mixture_size=self.input_depth)

        self.rel_embed = nn.Embedding(vocab.rel_size,
                                      self.word_dims,
                                      padding_idx=vocab.PAD)
        rel_init = np.random.randn(vocab.rel_size,
                                   config.word_dims).astype(np.float32)
        self.rel_embed.weight.data.copy_(torch.from_numpy(rel_init))

        self.dt_tree = DTTreeGRU(2 * self.word_dims, config.lstm_hiddens)
        self.td_tree = TDTreeGRU(2 * self.word_dims, config.lstm_hiddens)

        self.lstm = MyLSTM(
            input_size=2 * config.lstm_hiddens,
            hidden_size=config.lstm_hiddens,
            num_layers=config.lstm_layers,
            batch_first=True,
            bidirectional=True,
            dropout_in=config.dropout_lstm_input,
            dropout_out=config.dropout_lstm_hidden,
        )

        self.proj = nn.Linear(2 * config.lstm_hiddens,
                              vocab.tag_size,
                              bias=False)
Example #5
0
    def __init__(self, vocab, config, pretrained_embedding):
        super(BiLSTMModel, self).__init__()
        self.config = config
        extvocab_size, extword_dims = pretrained_embedding.shape
        self.word_dims = extword_dims
        if config.word_dims != extword_dims:
            print("word dim size does not match, check config file")
        self.word_embed = nn.Embedding(vocab.vocab_size, self.word_dims, padding_idx=vocab.PAD)
        self.rel_embed = nn.Embedding(vocab.rel_size, self.word_dims, padding_idx=vocab.PAD)
        if vocab.extvocab_size != extvocab_size:
            print("word vocab size does not match, check word embedding file")
        self.extword_embed = CPUEmbedding(vocab.extvocab_size, self.word_dims, padding_idx=vocab.PAD)

        word_init = np.zeros((vocab.vocab_size, self.word_dims), dtype=np.float32)
        self.word_embed.weight.data.copy_(torch.from_numpy(word_init))
        self.extword_embed.weight.data.copy_(torch.from_numpy(pretrained_embedding))
        self.extword_embed.weight.requires_grad = False

        self.dt_tree = DTTreeGRU(2*self.word_dims, config.lstm_hiddens)
        self.td_tree = TDTreeGRU(2*self.word_dims, config.lstm_hiddens)

        self.lstm = MyLSTM(
            input_size=2*config.lstm_hiddens,
            hidden_size=config.lstm_hiddens,
            num_layers=config.lstm_layers,
            batch_first=True,
            bidirectional=True,
            dropout_in=config.dropout_lstm_input,
            dropout_out=config.dropout_lstm_hidden,
        )

        self.proj = nn.Linear(2 * config.lstm_hiddens, vocab.tag_size, bias=True)
Example #6
0
    def __init__(self, vocab, config, input_dims, bert_layers):
        super(BiLSTMModel, self).__init__()
        self.config = config
        self.PAD = vocab.PAD
        self.input_dims = input_dims
        self.input_depth = bert_layers if config.bert_tune == 0 else 1
        self.hidden_dims = config.word_dims
        self.projections = nn.ModuleList([NonLinear(self.input_dims, self.hidden_dims, activation=GELU()) \
                                          for i in range(self.input_depth)])

        self.rescale = ScalarMix(mixture_size=self.input_depth)

        self.word_embed = nn.Embedding(vocab.vocab_size, config.word_dims, padding_idx=0)
        word_init = np.random.randn(vocab.vocab_size, config.word_dims).astype(np.float32)
        self.word_embed.weight.data.copy_(torch.from_numpy(word_init))

        self.predicate_embed = nn.Embedding(3, config.predict_dims, padding_idx=0)
        nn.init.normal_(self.predicate_embed.weight, 0.0, 1.0 / (config.predict_dims ** 0.5))

        self.lstm_input_dims = config.word_dims + config.predict_dims

        self.bilstm = MyLSTM(
            input_size=self.lstm_input_dims,
            hidden_size=config.lstm_hiddens,
            num_layers=config.lstm_layers,
            batch_first=True,
            bidirectional=True,
            dropout_in=config.dropout_lstm_input,
            dropout_out=config.dropout_lstm_hidden,
        )

        self.outlayer = nn.Linear(2 * config.lstm_hiddens, vocab.label_size, bias=False)
        nn.init.normal_(self.outlayer.weight, 0.0, 1.0 / ((2 * config.lstm_hiddens) ** 0.5))

        self.crf = CRF(vocab.label_size)
Example #7
0
    def __init__(self, vocab, config, parser_config, pretrained_embedding):
        super(BiLSTMModel, self).__init__()
        self.config = config
        extvocab_size, extword_dims = pretrained_embedding.shape
        self.word_dims = extword_dims
        if config.word_dims != extword_dims:
            print("word dim size does not match, check config file")
        self.word_embed = nn.Embedding(vocab.vocab_size,
                                       self.word_dims,
                                       padding_idx=vocab.PAD)
        if vocab.extvocab_size != extvocab_size:
            print("word vocab size does not match, check word embedding file")
        self.extword_embed = CPUEmbedding(vocab.extvocab_size,
                                          self.word_dims,
                                          padding_idx=vocab.PAD)

        word_init = np.zeros((vocab.vocab_size, self.word_dims),
                             dtype=np.float32)
        self.word_embed.weight.data.copy_(torch.from_numpy(word_init))

        self.extword_embed.weight.data.copy_(
            torch.from_numpy(pretrained_embedding))
        self.extword_embed.weight.requires_grad = False

        self.transformer_emb = nn.Linear(parser_config.word_dims,
                                         self.word_dims,
                                         bias=False)

        parser_dim = 2 * parser_config.lstm_hiddens
        transformer_lstm = []
        for layer in range(parser_config.lstm_layers):
            transformer_lstm.append(
                nn.Linear(parser_dim, self.word_dims, bias=False))
        self.transformer_lstm = nn.ModuleList(transformer_lstm)

        parser_mlp_dim = parser_config.mlp_arc_size + parser_config.mlp_rel_size
        self.transformer_dep = nn.Linear(parser_mlp_dim,
                                         self.word_dims,
                                         bias=False)
        self.transformer_head = nn.Linear(parser_mlp_dim,
                                          self.word_dims,
                                          bias=False)

        self.parser_lstm_layers = parser_config.lstm_layers
        self.synscale = ScalarMix(mixture_size=3 + parser_config.lstm_layers)

        self.lstm = MyLSTM(
            input_size=2 * self.word_dims,
            hidden_size=config.lstm_hiddens,
            num_layers=config.lstm_layers,
            batch_first=True,
            bidirectional=True,
            dropout_in=config.dropout_lstm_input,
            dropout_out=config.dropout_lstm_hidden,
        )

        self.proj = nn.Linear(2 * config.lstm_hiddens,
                              vocab.tag_size,
                              bias=False)
Example #8
0
    def __init__(self, vocab, config, parser_config, input_dims, bert_layers):
        super(BiLSTMModel, self).__init__()
        self.config = config
        self.PAD = vocab.PAD
        self.input_dims = input_dims
        self.input_depth = bert_layers if config.bert_tune == 0 else 1
        self.hidden_dims = 2 * config.lstm_hiddens
        self.projections = nn.ModuleList([NonLinear(self.input_dims, self.hidden_dims, activation=GELU()) \
                                          for i in range(self.input_depth)])

        self.rescale = ScalarMix(mixture_size=self.input_depth)

        parser_dim = 2 * parser_config.lstm_hiddens
        self.transformer_lstm = nn.ModuleList([
            NonLinear(parser_dim, self.hidden_dims, activation=GELU())
            for i in range(parser_config.lstm_layers)
        ])

        parser_mlp_dim = parser_config.mlp_arc_size + parser_config.mlp_rel_size
        self.transformer_dep = NonLinear(parser_mlp_dim,
                                         self.hidden_dims,
                                         activation=GELU())
        self.transformer_head = NonLinear(parser_mlp_dim,
                                          self.hidden_dims,
                                          activation=GELU())

        self.parser_lstm_layers = parser_config.lstm_layers
        self.synscale = ScalarMix(mixture_size=3 + parser_config.lstm_layers)

        self.predicate_embed = nn.Embedding(3,
                                            config.predict_dims,
                                            padding_idx=0)
        nn.init.normal_(self.predicate_embed.weight, 0.0,
                        1.0 / (config.predict_dims**0.5))

        self.lstm_input_dims = 2 * self.hidden_dims + config.predict_dims

        self.bilstm = MyLSTM(
            input_size=self.lstm_input_dims,
            hidden_size=config.lstm_hiddens,
            num_layers=config.lstm_layers,
            batch_first=True,
            bidirectional=True,
            dropout_in=config.dropout_lstm_input,
            dropout_out=config.dropout_lstm_hidden,
        )

        self.outlayer = nn.Linear(2 * config.lstm_hiddens,
                                  vocab.label_size,
                                  bias=False)
        nn.init.normal_(self.outlayer.weight, 0.0,
                        1.0 / ((2 * config.lstm_hiddens)**0.5))

        self.crf = CRF(vocab.label_size)
Example #9
0
    def __init__(self, vocab, config, pretrained_embedding):
        super(BiLSTMModel, self).__init__()
        self.config = config
        self.PAD = vocab.PAD
        extvocab_size, extword_dims = pretrained_embedding.shape
        self.word_dims = extword_dims
        if config.word_dims != extword_dims:
            print("word dim size does not match, check config file")
        self.word_embed = nn.Embedding(vocab.vocab_size,
                                       self.word_dims,
                                       padding_idx=vocab.PAD)
        if vocab.extvocab_size != extvocab_size:
            print("word vocab size does not match, check word embedding file")
        self.extword_embed = CPUEmbedding(vocab.extvocab_size,
                                          self.word_dims,
                                          padding_idx=vocab.PAD)

        word_init = np.zeros((vocab.vocab_size, self.word_dims),
                             dtype=np.float32)
        self.word_embed.weight.data.copy_(torch.from_numpy(word_init))
        self.extword_embed.weight.data.copy_(
            torch.from_numpy(pretrained_embedding))
        self.extword_embed.weight.requires_grad = False

        self.predicate_embed = nn.Embedding(3,
                                            config.predict_dims,
                                            padding_idx=0)
        nn.init.normal_(self.predicate_embed.weight, 0.0,
                        1.0 / (config.predict_dims**0.5))

        self.lstm_input_dims = config.word_dims + config.predict_dims

        self.bilstm = MyLSTM(
            input_size=self.lstm_input_dims,
            hidden_size=config.lstm_hiddens,
            num_layers=config.lstm_layers,
            batch_first=True,
            bidirectional=True,
            dropout_in=config.dropout_lstm_input,
            dropout_out=config.dropout_lstm_hidden,
        )

        self.outlayer = nn.Linear(2 * config.lstm_hiddens,
                                  vocab.label_size,
                                  bias=False)
        nn.init.normal_(self.outlayer.weight, 0.0,
                        1.0 / ((2 * config.lstm_hiddens)**0.5))

        self.crf = CRF(vocab.label_size)
Example #10
0
    def __init__(self, vocab, config, pretrained_embedding):
        super(ParserModel, self).__init__()
        self.config = config
        self.word_embed = nn.Embedding(vocab.vocab_size,
                                       config.word_dims,
                                       padding_idx=vocab.PAD)
        self.extword_embed = CPUEmbedding(vocab.extvocab_size,
                                          config.word_dims,
                                          padding_idx=vocab.PAD)

        word_init = np.zeros((vocab.vocab_size, config.word_dims),
                             dtype=np.float32)
        self.word_embed.weight.data.copy_(torch.from_numpy(word_init))

        self.extword_embed.weight.data.copy_(
            torch.from_numpy(pretrained_embedding))
        self.extword_embed.weight.requires_grad = False

        self.lstm = MyLSTM(
            input_size=config.word_dims,
            hidden_size=config.lstm_hiddens,
            num_layers=config.lstm_layers,
            batch_first=True,
            bidirectional=True,
            dropout_in=config.dropout_lstm_input,
            dropout_out=config.dropout_lstm_hidden,
        )

        self.mlp_arc_dep = NonLinear(input_size=2 * config.lstm_hiddens,
                                     hidden_size=config.mlp_arc_size +
                                     config.mlp_rel_size,
                                     activation=nn.Tanh())
        self.mlp_arc_head = NonLinear(input_size=2 * config.lstm_hiddens,
                                      hidden_size=config.mlp_arc_size +
                                      config.mlp_rel_size,
                                      activation=nn.Tanh())

        self.total_num = int((config.mlp_arc_size + config.mlp_rel_size) / 100)
        self.arc_num = int(config.mlp_arc_size / 100)
        self.rel_num = int(config.mlp_rel_size / 100)

        self.arc_biaffine = Biaffine(config.mlp_arc_size, config.mlp_arc_size, \
                                     1, bias=(True, False))
        self.rel_biaffine = Biaffine(config.mlp_rel_size, config.mlp_rel_size, \
                                     vocab.rel_size, bias=(True, True))
        self.arc_biaffine.linear.weight.requires_grad = False
        self.rel_biaffine.linear.weight.requires_grad = False
Example #11
0
    def __init__(self, vocab, config, parser_config, input_dims, bert_layers):
        super(SAModel, self).__init__()
        self.config = config
        self.input_dims = input_dims
        self.input_depth = bert_layers if config.bert_tune == 0 else 1
        self.hidden_dims = 2 * config.lstm_hiddens
        self.projections = nn.ModuleList([
            NonLinear(self.input_dims, self.hidden_dims, activation=GELU())
            for i in range(self.input_depth)
        ])

        self.rescale = ScalarMix(mixture_size=self.input_depth)

        self.transformer_emb = NonLinear(parser_config.word_dims,
                                         self.hidden_dims,
                                         activation=GELU())

        parser_dim = 2 * parser_config.lstm_hiddens
        self.transformer_lstm = nn.ModuleList([
            NonLinear(parser_dim, self.hidden_dims, activation=GELU())
            for i in range(parser_config.lstm_layers)
        ])

        parser_mlp_dim = parser_config.mlp_arc_size + parser_config.mlp_rel_size
        self.transformer_dep = NonLinear(parser_mlp_dim,
                                         self.hidden_dims,
                                         activation=GELU())
        self.transformer_head = NonLinear(parser_mlp_dim,
                                          self.hidden_dims,
                                          activation=GELU())

        self.parser_lstm_layers = parser_config.lstm_layers
        self.synscale = ScalarMix(mixture_size=3 + parser_config.lstm_layers)

        self.lstm = MyLSTM(
            input_size=2 * self.hidden_dims,
            hidden_size=config.lstm_hiddens,
            num_layers=config.lstm_layers,
            batch_first=True,
            bidirectional=True,
            dropout_in=config.dropout_lstm_input,
            dropout_out=config.dropout_lstm_hidden,
        )

        self.proj = nn.Linear(2 * config.lstm_hiddens,
                              vocab.tag_size,
                              bias=False)
Example #12
0
    def __init__(self, vocab, config, parser_config, elmo_shape):
        super(SAModel, self).__init__()
        self.config = config
        self.word_dims = config.word_dims
        self.elmo_layers = elmo_shape[0]
        self.elmo_dims = elmo_shape[1]

        weights = torch.randn(self.elmo_layers)
        self.weights = torch.nn.Parameter(weights, requires_grad=True)
        self.mlp_elmo = nn.Linear(self.elmo_dims, self.word_dims, bias=False)

        self.transformer_emb = nn.Linear(parser_config.word_dims,
                                         self.word_dims,
                                         bias=False)

        parser_dim = 2 * parser_config.lstm_hiddens
        transformer_lstm = []
        for layer in range(parser_config.lstm_layers):
            transformer_lstm.append(
                nn.Linear(parser_dim, self.word_dims, bias=False))
        self.transformer_lstm = nn.ModuleList(transformer_lstm)

        parser_mlp_dim = parser_config.mlp_arc_size + parser_config.mlp_rel_size
        self.transformer_dep = nn.Linear(parser_mlp_dim,
                                         self.word_dims,
                                         bias=False)
        self.transformer_head = nn.Linear(parser_mlp_dim,
                                          self.word_dims,
                                          bias=False)

        self.parser_lstm_layers = parser_config.lstm_layers
        self.synscale = ScalarMix(mixture_size=3 + parser_config.lstm_layers)

        self.lstm = MyLSTM(
            input_size=2 * self.word_dims,
            hidden_size=config.lstm_hiddens,
            num_layers=config.lstm_layers,
            batch_first=True,
            bidirectional=True,
            dropout_in=config.dropout_lstm_input,
            dropout_out=config.dropout_lstm_hidden,
        )

        self.proj = nn.Linear(2 * config.lstm_hiddens,
                              vocab.tag_size,
                              bias=False)
Example #13
0
    def __init__(self, vocab, config, parser_config, elmo_shape):
        super(BiLSTMModel, self).__init__()
        self.config = config
        self.PAD = vocab.PAD
        self.word_dims = config.word_dims
        self.elmo_layers = elmo_shape[0]
        self.elmo_dims = elmo_shape[1]

        weights = torch.randn(self.elmo_layers)
        self.weights = torch.nn.Parameter(weights, requires_grad=True)
        self.mlp_elmo = nn.Linear(self.elmo_dims, self.word_dims, bias=False)

        self.transformer_emb = nn.Linear(parser_config.word_dims, self.word_dims, bias=False)

        parser_dim = 2 * parser_config.lstm_hiddens
        transformer_lstm = []
        for layer in range(parser_config.lstm_layers):
            transformer_lstm.append(nn.Linear(parser_dim, self.word_dims, bias=False))
        self.transformer_lstm = nn.ModuleList(transformer_lstm)

        parser_mlp_dim = parser_config.mlp_arc_size + parser_config.mlp_rel_size
        self.transformer_dep = nn.Linear(parser_mlp_dim, self.word_dims, bias=False)
        self.transformer_head = nn.Linear(parser_mlp_dim, self.word_dims, bias=False)

        self.parser_lstm_layers = parser_config.lstm_layers
        self.synscale = ScalarMix(mixture_size=3+parser_config.lstm_layers)

        self.predicate_embed = nn.Embedding(3, config.predict_dims, padding_idx=0)
        nn.init.normal_(self.predicate_embed.weight, 0.0, 1.0 / (config.predict_dims ** 0.5))

        self.lstm_input_dims = 2 * config.word_dims + config.predict_dims

        self.bilstm = MyLSTM(
            input_size=self.lstm_input_dims,
            hidden_size=config.lstm_hiddens,
            num_layers=config.lstm_layers,
            batch_first=True,
            bidirectional=True,
            dropout_in=config.dropout_lstm_input,
            dropout_out=config.dropout_lstm_hidden,
        )

        self.outlayer = nn.Linear(2 * config.lstm_hiddens, vocab.label_size, bias=False)
        nn.init.normal_(self.outlayer.weight, 0.0, 1.0 / ((2 * config.lstm_hiddens) ** 0.5))

        self.crf = CRF(vocab.label_size)
Example #14
0
    def __init__(self, vocab, config, elmo_shape):
        super(BiLSTMModel, self).__init__()
        self.config = config
        self.word_dims = config.word_dims
        self.elmo_layers = elmo_shape[0]
        self.elmo_dims = elmo_shape[1]

        weights = torch.randn(self.elmo_layers)
        self.weights = torch.nn.Parameter(weights, requires_grad=True)
        self.mlp_elmo = nn.Linear(self.elmo_dims, self.word_dims, bias=False)

        self.lstm = MyLSTM(
            input_size=self.word_dims,
            hidden_size=config.lstm_hiddens,
            num_layers=config.lstm_layers,
            batch_first=True,
            bidirectional=True,
            dropout_in=config.dropout_lstm_input,
            dropout_out=config.dropout_lstm_hidden,
        )

        self.proj = nn.Linear(2 * config.lstm_hiddens,
                              vocab.tag_size,
                              bias=False)
Example #15
0
    def __init__(self, vocab, config, input_dims, bert_layers):
        super(SAModel, self).__init__()
        self.config = config
        self.input_dims = input_dims
        self.input_depth = bert_layers if config.bert_tune == 0 else 1
        self.hidden_dims = 2 * config.lstm_hiddens
        self.projections = nn.ModuleList([NonLinear(self.input_dims, self.hidden_dims, activation=GELU()) \
                                          for i in range(self.input_depth)])

        self.rescale = ScalarMix(mixture_size=self.input_depth)

        self.lstm = MyLSTM(
            input_size=self.hidden_dims,
            hidden_size=config.lstm_hiddens,
            num_layers=config.lstm_layers,
            batch_first=True,
            bidirectional=True,
            dropout_in=config.dropout_lstm_input,
            dropout_out=config.dropout_lstm_hidden,
        )

        self.proj = nn.Linear(2 * config.lstm_hiddens,
                              vocab.tag_size,
                              bias=False)