def __init__(self, hidden_size, vocab_size, num_layers=1, init_scale=0.1, dropout=None): super(PtbModel, self).__init__() self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_layers = num_layers self.embedding = nn.Embedding( num_embeddings=vocab_size, embedding_dim=hidden_size, sparse=False, weight_attr=paddle.ParamAttr( name="embedding_para", initializer=paddle.nn.initializer.Uniform(low=-init_scale, high=init_scale))) self.rnn = nn.GRU(input_size=hidden_size, hidden_size=hidden_size, num_layers=num_layers) self.classifier = nn.Linear(in_features=hidden_size, out_features=vocab_size) self.dropout_layer = nn.Dropout(p=dropout, mode='upscale_in_train') \ if dropout is not None and dropout > 0.0 else None
def __init__(self, word_emb_dim, hidden_size, vocab_size, num_labels, emb_lr=2.0, crf_lr=0.2, with_start_stop_tag=True): super(BiGruCrf, self).__init__() self.word_emb_dim = word_emb_dim self.vocab_size = vocab_size self.num_labels = num_labels self.hidden_size = hidden_size self.emb_lr = emb_lr self.crf_lr = crf_lr self.init_bound = 0.1 self.word_embedding = nn.Embedding( num_embeddings=self.vocab_size, embedding_dim=self.word_emb_dim, weight_attr=paddle.ParamAttr(learning_rate=self.emb_lr, initializer=nn.initializer.Uniform( low=-self.init_bound, high=self.init_bound))) self.gru = nn.GRU( input_size=self.word_emb_dim, hidden_size=self.hidden_size, num_layers=2, direction='bidirectional', weight_ih_attr=paddle.ParamAttr( initializer=nn.initializer.Uniform(low=-self.init_bound, high=self.init_bound), regularizer=paddle.regularizer.L2Decay(coeff=1e-4)), weight_hh_attr=paddle.ParamAttr( initializer=nn.initializer.Uniform(low=-self.init_bound, high=self.init_bound), regularizer=paddle.regularizer.L2Decay(coeff=1e-4))) self.fc = nn.Linear( in_features=self.hidden_size * 2, out_features=self.num_labels + 2 \ if with_start_stop_tag else self.num_labels, weight_attr=paddle.ParamAttr( initializer=nn.initializer.Uniform( low=-self.init_bound, high=self.init_bound), regularizer=paddle.regularizer.L2Decay(coeff=1e-4))) self.crf = LinearChainCrf(self.num_labels, self.crf_lr, with_start_stop_tag) self.crf_loss = LinearChainCrfLoss(self.crf) self.viterbi_decoder = ViterbiDecoder(self.crf.transitions, with_start_stop_tag)
def __init__(self, vocab, model_config): super(VAE, self).__init__() self.config = model_config self.vocabulary = vocab # Special symbols for ss in ('bos', 'eos', 'unk', 'pad'): setattr(self, ss, getattr(vocab, ss)) # Word embeddings layer n_vocab, d_emb = len(vocab), vocab.vectors.shape[1] self.x_emb = nn.Embedding(n_vocab, d_emb, self.pad) self.x_emb.weight.set_value(paddle.to_tensor(vocab.vectors)) if self.config['freeze_embeddings']: self.x_emb.weight.stop_gradient=True # encoder self.encoder_rnn = nn.GRU( d_emb, self.config['q_d_h'], num_layers=self.config['q_n_layers'], dropout=self.config['q_dropout'] if self.config['q_n_layers'] > 1 else 0, direction= 'bidirectional' if self.config['q_bidir'] else 'forward' ) q_d_last = self.config['q_d_h'] * (2 if self.config['q_bidir'] else 1) self.q_mu = nn.Linear(q_d_last, self.config['d_z']) self.q_logvar = nn.Linear(q_d_last, self.config['d_z']) # decoder self.decoder_rnn = nn.GRU( d_emb + self.config['d_z'], self.config['d_d_h'], num_layers=self.config['d_n_layers'], dropout=self.config['d_dropout'] if self.config['d_n_layers'] > 1 else 0 ) self.decoder_lat = nn.Linear(self.config['d_z'], self.config['d_d_h']) self.decoder_fc = nn.Linear(self.config['d_d_h'], n_vocab)
def __init__(self,embedding_name): super(Embedding, self).__init__() self.embedding =TokenEmbedding(embedding_name) self.embedding_dim = self.embedding.embedding_dim weight_attr = paddle.framework.ParamAttr( name="linear_weight", initializer=paddle.nn.initializer.XavierNormal()) bias_attr = paddle.framework.ParamAttr( name="linear_bias", initializer=paddle.nn.initializer.XavierNormal()) self.mlp = paddle.nn.Linear(self.embedding_dim*2, self.embedding_dim, weight_attr=weight_attr, bias_attr=bias_attr) self.gru = nn.GRU(input_size=self.embedding_dim,hidden_size=self.embedding_dim//2,num_layers=1, direction="bidirectional",)
def __init__(self, max_len, latent_dim, rnn_type): super(StateDecoder, self).__init__() self.latent_dim = latent_dim self.max_len = max_len self.z_to_latent = nn.Linear(self.latent_dim, self.latent_dim) if rnn_type == 'gru': self.gru = nn.GRU(self.latent_dim, 501, 3) else: raise NotImplementedError self.decoded_logits = nn.Linear(501, DECISION_DIM) weights_init(self)
def __init__(self, emb_size, hidden_size, word_num, label_num, use_w2v_emb=False): super(BiGRUWithCRF, self).__init__() if use_w2v_emb: self.word_emb = TokenEmbedding( extended_vocab_path='./conf/word.dic', unknown_token='OOV') else: self.word_emb = nn.Embedding(word_num, emb_size) self.gru = nn.GRU(emb_size, hidden_size, num_layers=2, direction='bidirectional') self.fc = nn.Linear(hidden_size * 2, label_num + 2) # BOS EOS self.crf = LinearChainCrf(label_num) self.decoder = ViterbiDecoder(self.crf.transitions)
def __init__(self, x_size, y_size,hidden_size, hop=1, dropout_rate=0, normalize=True): super(MemoryAnsPointer, self).__init__() self.normalize = normalize self.hidden_size = hidden_size self.hop = hop self.x_size = x_size self.y_size = y_size self.dropout_rate = dropout_rate self.ques_encoder = nn.GRU(input_size=y_size,hidden_size=x_size,num_layers=1, direction='forward',time_major=False) self.FFNs_start = nn.LayerList() self.SFUs_start = nn.LayerList() self.FFNs_end = nn.LayerList() self.SFUs_end = nn.LayerList() for i in range(self.hop): self.FFNs_start.append(FeedForwardNetwork(3*x_size, hidden_size, 1, dropout_rate)) self.SFUs_start.append(SFU(x_size, x_size)) self.FFNs_end.append(FeedForwardNetwork(3*x_size, hidden_size, 1, dropout_rate)) self.SFUs_end.append(SFU(x_size, x_size))
def __init__(self, input_size, hidden_size, num_layers=1, direction="forward", dropout=0.0, pooling_type=None, **kwargs): super().__init__() self._input_size = input_size self._hidden_size = hidden_size self._direction = direction self._pooling_type = pooling_type self.gru_layer = nn.GRU(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, direction=direction, dropout=dropout, **kwargs)
def __init__(self, skep, num_classes): super().__init__() self.num_classes = num_classes self.skep = skep # allow skep to be config gru_hidden_size = 128 self.gru = nn.GRU(self.skep.config["hidden_size"], gru_hidden_size, num_layers=2, direction='bidirect') self.fc = nn.Linear( gru_hidden_size * 2, self.num_classes, weight_attr=paddle.ParamAttr( initializer=nn.initializer.Uniform(low=-0.1, high=0.1), regularizer=paddle.regularizer.L2Decay(coeff=1e-4))) self.crf = LinearChainCrf(self.num_classes, crf_lr=0.2, with_start_stop_tag=False) self.crf_loss = LinearChainCrfLoss(self.crf) self.viterbi_decoder = ViterbiDecoder(self.crf.transitions, False)
def __init__(self, emb_size, hidden_size, word_num, label_num, use_w2v_emb=False): super(BiGRUWithCRF, self).__init__() if use_w2v_emb: self.word_emb = TokenEmbedding( extended_vocab_path='./data/word.dic', unknown_token='OOV') else: self.word_emb = nn.Embedding(word_num, emb_size) self.gru = nn.GRU(emb_size, hidden_size, num_layers=2, direction='bidirect') # We need `label_num + 2` for appending BOS and EOS tag self.fc = nn.Linear(hidden_size * 2, label_num + 2) self.crf = LinearChainCrf(label_num) self.crf_loss = LinearChainCrfLoss(self.crf) self.viterbi_decoder = ViterbiDecoder(self.crf.transitions)
def __init__(self, in_channels=3, out_classes=5, hid=64, num=64): super(FallNet, self).__init__() self.cnn0 = Block(in_channels, hid, 7, 0) self.cnn1 = Block(hid, hid, 5, 0) self.cnn2 = Block(hid, hid, 3, 0) self.cnn3 = Block(hid, hid, 1, 0) self.avg = nn.AdaptiveAvgPool1D(output_size=num) # self.rnn0 = nn.LSTM(input_size=145, hidden_size=num, dropout=.2, num_layers=3) self.rnn0 = nn.GRU(input_size=145, hidden_size=num, num_layers=1, dropout=0.2) self.rnn1 = Block(hid, hid, 1, 0) self.rnn2 = Block(hid, 4, 3, 0) self.cls = nn.Sequential( nn.Linear(in_features=1016, out_features=128), nn.Dropout(p=.2), nn.Linear(in_features=128, out_features=out_classes), nn.Softmax(axis=1))
def __init__(self, enc_bi_rnn=False, enc_drop_rnn=0.1, enc_gru=False, d_model=512, d_enc=512, mask=True, **kwargs): super().__init__() assert isinstance(enc_bi_rnn, bool) assert isinstance(enc_drop_rnn, (int, float)) assert 0 <= enc_drop_rnn < 1.0 assert isinstance(enc_gru, bool) assert isinstance(d_model, int) assert isinstance(d_enc, int) assert isinstance(mask, bool) self.enc_bi_rnn = enc_bi_rnn self.enc_drop_rnn = enc_drop_rnn self.mask = mask # LSTM Encoder if enc_bi_rnn: direction = 'bidirectional' else: direction = 'forward' kwargs = dict(input_size=d_model, hidden_size=d_enc, num_layers=2, time_major=False, dropout=enc_drop_rnn, direction=direction) if enc_gru: self.rnn_encoder = nn.GRU(**kwargs) else: self.rnn_encoder = nn.LSTM(**kwargs) # global feature transformation encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1) self.linear = nn.Linear(encoder_rnn_out_size, encoder_rnn_out_size)
def func_test_layer_str(self): module = nn.ELU(0.2) self.assertEqual(str(module), 'ELU(alpha=0.2)') module = nn.CELU(0.2) self.assertEqual(str(module), 'CELU(alpha=0.2)') module = nn.GELU(True) self.assertEqual(str(module), 'GELU(approximate=True)') module = nn.Hardshrink() self.assertEqual(str(module), 'Hardshrink(threshold=0.5)') module = nn.Hardswish(name="Hardswish") self.assertEqual(str(module), 'Hardswish(name=Hardswish)') module = nn.Tanh(name="Tanh") self.assertEqual(str(module), 'Tanh(name=Tanh)') module = nn.Hardtanh(name="Hardtanh") self.assertEqual(str(module), 'Hardtanh(min=-1.0, max=1.0, name=Hardtanh)') module = nn.PReLU(1, 0.25, name="PReLU", data_format="NCHW") self.assertEqual( str(module), 'PReLU(num_parameters=1, data_format=NCHW, init=0.25, dtype=float32, name=PReLU)' ) module = nn.ReLU() self.assertEqual(str(module), 'ReLU()') module = nn.ReLU6() self.assertEqual(str(module), 'ReLU6()') module = nn.SELU() self.assertEqual( str(module), 'SELU(scale=1.0507009873554805, alpha=1.6732632423543772)') module = nn.LeakyReLU() self.assertEqual(str(module), 'LeakyReLU(negative_slope=0.01)') module = nn.Sigmoid() self.assertEqual(str(module), 'Sigmoid()') module = nn.Hardsigmoid() self.assertEqual(str(module), 'Hardsigmoid()') module = nn.Softplus() self.assertEqual(str(module), 'Softplus(beta=1, threshold=20)') module = nn.Softshrink() self.assertEqual(str(module), 'Softshrink(threshold=0.5)') module = nn.Softsign() self.assertEqual(str(module), 'Softsign()') module = nn.Swish() self.assertEqual(str(module), 'Swish()') module = nn.Tanhshrink() self.assertEqual(str(module), 'Tanhshrink()') module = nn.ThresholdedReLU() self.assertEqual(str(module), 'ThresholdedReLU(threshold=1.0)') module = nn.LogSigmoid() self.assertEqual(str(module), 'LogSigmoid()') module = nn.Softmax() self.assertEqual(str(module), 'Softmax(axis=-1)') module = nn.LogSoftmax() self.assertEqual(str(module), 'LogSoftmax(axis=-1)') module = nn.Maxout(groups=2) self.assertEqual(str(module), 'Maxout(groups=2, axis=1)') module = nn.Linear(2, 4, name='linear') self.assertEqual( str(module), 'Linear(in_features=2, out_features=4, dtype=float32, name=linear)' ) module = nn.Upsample(size=[12, 12]) self.assertEqual( str(module), 'Upsample(size=[12, 12], mode=nearest, align_corners=False, align_mode=0, data_format=NCHW)' ) module = nn.UpsamplingNearest2D(size=[12, 12]) self.assertEqual( str(module), 'UpsamplingNearest2D(size=[12, 12], data_format=NCHW)') module = nn.UpsamplingBilinear2D(size=[12, 12]) self.assertEqual( str(module), 'UpsamplingBilinear2D(size=[12, 12], data_format=NCHW)') module = nn.Bilinear(in1_features=5, in2_features=4, out_features=1000) self.assertEqual( str(module), 'Bilinear(in1_features=5, in2_features=4, out_features=1000, dtype=float32)' ) module = nn.Dropout(p=0.5) self.assertEqual(str(module), 'Dropout(p=0.5, axis=None, mode=upscale_in_train)') module = nn.Dropout2D(p=0.5) self.assertEqual(str(module), 'Dropout2D(p=0.5, data_format=NCHW)') module = nn.Dropout3D(p=0.5) self.assertEqual(str(module), 'Dropout3D(p=0.5, data_format=NCDHW)') module = nn.AlphaDropout(p=0.5) self.assertEqual(str(module), 'AlphaDropout(p=0.5)') module = nn.Pad1D(padding=[1, 2], mode='constant') self.assertEqual( str(module), 'Pad1D(padding=[1, 2], mode=constant, value=0.0, data_format=NCL)') module = nn.Pad2D(padding=[1, 0, 1, 2], mode='constant') self.assertEqual( str(module), 'Pad2D(padding=[1, 0, 1, 2], mode=constant, value=0.0, data_format=NCHW)' ) module = nn.ZeroPad2D(padding=[1, 0, 1, 2]) self.assertEqual(str(module), 'ZeroPad2D(padding=[1, 0, 1, 2], data_format=NCHW)') module = nn.Pad3D(padding=[1, 0, 1, 2, 0, 0], mode='constant') self.assertEqual( str(module), 'Pad3D(padding=[1, 0, 1, 2, 0, 0], mode=constant, value=0.0, data_format=NCDHW)' ) module = nn.CosineSimilarity(axis=0) self.assertEqual(str(module), 'CosineSimilarity(axis=0, eps=1e-08)') module = nn.Embedding(10, 3, sparse=True) self.assertEqual(str(module), 'Embedding(10, 3, sparse=True)') module = nn.Conv1D(3, 2, 3) self.assertEqual(str(module), 'Conv1D(3, 2, kernel_size=[3], data_format=NCL)') module = nn.Conv1DTranspose(2, 1, 2) self.assertEqual( str(module), 'Conv1DTranspose(2, 1, kernel_size=[2], data_format=NCL)') module = nn.Conv2D(4, 6, (3, 3)) self.assertEqual(str(module), 'Conv2D(4, 6, kernel_size=[3, 3], data_format=NCHW)') module = nn.Conv2DTranspose(4, 6, (3, 3)) self.assertEqual( str(module), 'Conv2DTranspose(4, 6, kernel_size=[3, 3], data_format=NCHW)') module = nn.Conv3D(4, 6, (3, 3, 3)) self.assertEqual( str(module), 'Conv3D(4, 6, kernel_size=[3, 3, 3], data_format=NCDHW)') module = nn.Conv3DTranspose(4, 6, (3, 3, 3)) self.assertEqual( str(module), 'Conv3DTranspose(4, 6, kernel_size=[3, 3, 3], data_format=NCDHW)') module = nn.PairwiseDistance() self.assertEqual(str(module), 'PairwiseDistance(p=2.0)') module = nn.InstanceNorm1D(2) self.assertEqual(str(module), 'InstanceNorm1D(num_features=2, epsilon=1e-05)') module = nn.InstanceNorm2D(2) self.assertEqual(str(module), 'InstanceNorm2D(num_features=2, epsilon=1e-05)') module = nn.InstanceNorm3D(2) self.assertEqual(str(module), 'InstanceNorm3D(num_features=2, epsilon=1e-05)') module = nn.GroupNorm(num_channels=6, num_groups=6) self.assertEqual( str(module), 'GroupNorm(num_groups=6, num_channels=6, epsilon=1e-05)') module = nn.LayerNorm([2, 2, 3]) self.assertEqual( str(module), 'LayerNorm(normalized_shape=[2, 2, 3], epsilon=1e-05)') module = nn.BatchNorm1D(1) self.assertEqual( str(module), 'BatchNorm1D(num_features=1, momentum=0.9, epsilon=1e-05, data_format=NCL)' ) module = nn.BatchNorm2D(1) self.assertEqual( str(module), 'BatchNorm2D(num_features=1, momentum=0.9, epsilon=1e-05)') module = nn.BatchNorm3D(1) self.assertEqual( str(module), 'BatchNorm3D(num_features=1, momentum=0.9, epsilon=1e-05, data_format=NCDHW)' ) module = nn.SyncBatchNorm(2) self.assertEqual( str(module), 'SyncBatchNorm(num_features=2, momentum=0.9, epsilon=1e-05)') module = nn.LocalResponseNorm(size=5) self.assertEqual( str(module), 'LocalResponseNorm(size=5, alpha=0.0001, beta=0.75, k=1.0)') module = nn.AvgPool1D(kernel_size=2, stride=2, padding=0) self.assertEqual(str(module), 'AvgPool1D(kernel_size=2, stride=2, padding=0)') module = nn.AvgPool2D(kernel_size=2, stride=2, padding=0) self.assertEqual(str(module), 'AvgPool2D(kernel_size=2, stride=2, padding=0)') module = nn.AvgPool3D(kernel_size=2, stride=2, padding=0) self.assertEqual(str(module), 'AvgPool3D(kernel_size=2, stride=2, padding=0)') module = nn.MaxPool1D(kernel_size=2, stride=2, padding=0) self.assertEqual(str(module), 'MaxPool1D(kernel_size=2, stride=2, padding=0)') module = nn.MaxPool2D(kernel_size=2, stride=2, padding=0) self.assertEqual(str(module), 'MaxPool2D(kernel_size=2, stride=2, padding=0)') module = nn.MaxPool3D(kernel_size=2, stride=2, padding=0) self.assertEqual(str(module), 'MaxPool3D(kernel_size=2, stride=2, padding=0)') module = nn.AdaptiveAvgPool1D(output_size=16) self.assertEqual(str(module), 'AdaptiveAvgPool1D(output_size=16)') module = nn.AdaptiveAvgPool2D(output_size=3) self.assertEqual(str(module), 'AdaptiveAvgPool2D(output_size=3)') module = nn.AdaptiveAvgPool3D(output_size=3) self.assertEqual(str(module), 'AdaptiveAvgPool3D(output_size=3)') module = nn.AdaptiveMaxPool1D(output_size=16, return_mask=True) self.assertEqual( str(module), 'AdaptiveMaxPool1D(output_size=16, return_mask=True)') module = nn.AdaptiveMaxPool2D(output_size=3, return_mask=True) self.assertEqual(str(module), 'AdaptiveMaxPool2D(output_size=3, return_mask=True)') module = nn.AdaptiveMaxPool3D(output_size=3, return_mask=True) self.assertEqual(str(module), 'AdaptiveMaxPool3D(output_size=3, return_mask=True)') module = nn.SimpleRNNCell(16, 32) self.assertEqual(str(module), 'SimpleRNNCell(16, 32)') module = nn.LSTMCell(16, 32) self.assertEqual(str(module), 'LSTMCell(16, 32)') module = nn.GRUCell(16, 32) self.assertEqual(str(module), 'GRUCell(16, 32)') module = nn.PixelShuffle(3) self.assertEqual(str(module), 'PixelShuffle(upscale_factor=3)') module = nn.SimpleRNN(16, 32, 2) self.assertEqual( str(module), 'SimpleRNN(16, 32, num_layers=2\n (0): RNN(\n (cell): SimpleRNNCell(16, 32)\n )\n (1): RNN(\n (cell): SimpleRNNCell(32, 32)\n )\n)' ) module = nn.LSTM(16, 32, 2) self.assertEqual( str(module), 'LSTM(16, 32, num_layers=2\n (0): RNN(\n (cell): LSTMCell(16, 32)\n )\n (1): RNN(\n (cell): LSTMCell(32, 32)\n )\n)' ) module = nn.GRU(16, 32, 2) self.assertEqual( str(module), 'GRU(16, 32, num_layers=2\n (0): RNN(\n (cell): GRUCell(16, 32)\n )\n (1): RNN(\n (cell): GRUCell(32, 32)\n )\n)' ) module1 = nn.Sequential( ('conv1', nn.Conv2D(1, 20, 5)), ('relu1', nn.ReLU()), ('conv2', nn.Conv2D(20, 64, 5)), ('relu2', nn.ReLU())) self.assertEqual( str(module1), 'Sequential(\n '\ '(conv1): Conv2D(1, 20, kernel_size=[5, 5], data_format=NCHW)\n '\ '(relu1): ReLU()\n '\ '(conv2): Conv2D(20, 64, kernel_size=[5, 5], data_format=NCHW)\n '\ '(relu2): ReLU()\n)' ) module2 = nn.Sequential( nn.Conv3DTranspose(4, 6, (3, 3, 3)), nn.AvgPool3D(kernel_size=2, stride=2, padding=0), nn.Tanh(name="Tanh"), module1, nn.Conv3D(4, 6, (3, 3, 3)), nn.MaxPool3D(kernel_size=2, stride=2, padding=0), nn.GELU(True)) self.assertEqual( str(module2), 'Sequential(\n '\ '(0): Conv3DTranspose(4, 6, kernel_size=[3, 3, 3], data_format=NCDHW)\n '\ '(1): AvgPool3D(kernel_size=2, stride=2, padding=0)\n '\ '(2): Tanh(name=Tanh)\n '\ '(3): Sequential(\n (conv1): Conv2D(1, 20, kernel_size=[5, 5], data_format=NCHW)\n (relu1): ReLU()\n'\ ' (conv2): Conv2D(20, 64, kernel_size=[5, 5], data_format=NCHW)\n (relu2): ReLU()\n )\n '\ '(4): Conv3D(4, 6, kernel_size=[3, 3, 3], data_format=NCDHW)\n '\ '(5): MaxPool3D(kernel_size=2, stride=2, padding=0)\n '\ '(6): GELU(approximate=True)\n)' )
def __init__( self, out_channels, # 90 + unknown + start + padding enc_bi_rnn=False, dec_bi_rnn=False, dec_drop_rnn=0.0, dec_gru=False, d_model=512, d_enc=512, d_k=64, pred_dropout=0.1, max_text_length=30, mask=True, pred_concat=True, **kwargs): super().__init__() self.num_classes = out_channels self.enc_bi_rnn = enc_bi_rnn self.d_k = d_k self.start_idx = out_channels - 2 self.padding_idx = out_channels - 1 self.max_seq_len = max_text_length self.mask = mask self.pred_concat = pred_concat encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1) decoder_rnn_out_size = encoder_rnn_out_size * (int(dec_bi_rnn) + 1) # 2D attention layer self.conv1x1_1 = nn.Linear(decoder_rnn_out_size, d_k) self.conv3x3_1 = nn.Conv2D(d_model, d_k, kernel_size=3, stride=1, padding=1) self.conv1x1_2 = nn.Linear(d_k, 1) # Decoder RNN layer if dec_bi_rnn: direction = 'bidirectional' else: direction = 'forward' kwargs = dict(input_size=encoder_rnn_out_size, hidden_size=encoder_rnn_out_size, num_layers=2, time_major=False, dropout=dec_drop_rnn, direction=direction) if dec_gru: self.rnn_decoder = nn.GRU(**kwargs) else: self.rnn_decoder = nn.LSTM(**kwargs) # Decoder input embedding self.embedding = nn.Embedding(self.num_classes, encoder_rnn_out_size, padding_idx=self.padding_idx) # Prediction layer self.pred_dropout = nn.Dropout(pred_dropout) pred_num_classes = self.num_classes - 1 if pred_concat: fc_in_channel = decoder_rnn_out_size + d_model + d_enc else: fc_in_channel = d_model self.prediction = nn.Linear(fc_in_channel, pred_num_classes)