def __init__(self, vocab_size, embedding_dim, sos_id, eos_id, hidden_size, num_layers, bidirectional_encoder=True): super(Decoder, self).__init__() # Hyper parameters # embedding + output self.vocab_size = vocab_size self.embedding_dim = embedding_dim self.sos_id = sos_id # Start of Sentence self.eos_id = eos_id # End of Sentence # rnn self.hidden_size = hidden_size self.num_layers = num_layers self.bidirectional_encoder = bidirectional_encoder # useless now self.encoder_hidden_size = hidden_size # must be equal now # Components self.embedding = nn.Embedding(self.vocab_size, self.embedding_dim) self.rnn = nn.ModuleList() self.rnn += [ nn.LSTMCell(self.embedding_dim + self.encoder_hidden_size, self.hidden_size) ] for l in range(1, self.num_layers): self.rnn += [nn.LSTMCell(self.hidden_size, self.hidden_size)] self.attention = DotProductAttention() self.mlp = nn.Sequential( nn.Linear(self.encoder_hidden_size + self.hidden_size, self.hidden_size), nn.Tanh(), nn.Linear(self.hidden_size, self.vocab_size))
def __init__(self, vocabs, external, settings): super().__init__() only_lonely = True self.settings = settings self.n_labels_other = 0 if vocabs[settings.ot] is None else len( vocabs[settings.ot]) self.n_labels = len(vocabs[settings.pt]) self.other_scorer = None self.bridge = None self.combine = None self.helpers = settings.helpers self.base = BaseLSTM(vocabs, external, settings) if settings.use_elmo: self.scalelmo = nn.Linear(settings.vec_dim, 100) else: self.scalelmo = None if settings.ot: self.other_scorer = Scorer(self.n_labels_other, settings, False, True) if settings.ot or settings.helpers: only_lonely = False if settings.bridge == "dpa": self.bridge = DotProductAttention(settings.dim_mlp) elif settings.bridge == "dpa+": self.combine = nn.Linear(settings.hidden_lstm * 4, settings.hidden_lstm * 2) def bridge(x, y): a = DotProductAttention(settings.dim_mlp)(x, y) b = DotProductAttention(settings.dim_mlp)(x.transpose( -2, -1), y) c = self.combine(torch.cat((a, b), -1)) return c self.bridge = bridge elif settings.bridge == "gcn": self.bridge = GCN(settings.hidden_lstm * 2, int(settings.hidden_lstm / 2), settings.hidden_lstm * 2, settings.gcn_layers, settings, self.n_labels_other) elif settings.bridge == "simple": self.bridge = lambda x, y: x.transpose(-2, -1).float() @ y self.scorer = Scorer(self.n_labels, settings, settings.unfactorized, only_lonely) print(self.n_labels)
def bridge(x, y): a = DotProductAttention(settings.dim_mlp)(x, y) b = DotProductAttention(settings.dim_mlp)(x.transpose( -2, -1), y) c = self.combine(torch.cat((a, b), -1)) return c