def __init__(self, enc_embedding_weight, dec_embedding_weight, start_idx): super(MainModel, self).__init__() self.lr_rate = 1e-3 self.max_length = 100 self.__start_idx_int = start_idx self.encoder = Encoder( embedding=create_my_embedding(enc_embedding_weight), lstm_num_layer=1, lstm_size=1024) _enc_output_size = 2 * self.encoder.lstm_size if self.encoder.is_bidirectional else self.encoder.lstm_size self.flatten_hidden_lstm = FlattenHiddenLSTM( lstm_num_layer=self.encoder.lstm_num_layer, is_bidirectional=self.encoder.is_bidirectional) self.core_decoder = AttnRawDecoderWithSrc( embedding=create_my_embedding(dec_embedding_weight), enc_output_size=_enc_output_size, use_pred_prob=0.1, lstm_size=self.encoder.lstm_size, lstm_num_layer=1, enc_embedding_size=self.encoder.embedding_size) self.greedy_infer = DecoderGreedyInfer(core_decoder=self.core_decoder, start_idx=start_idx) self.xent = None self.optimizer = None self.register_buffer('start_idx', torch.Tensor([[start_idx]]).long())
def __init__(self, src_vocab_size, tgt_vocab_size, start_idx, padding_idx, max_length): super(Seq2SeqChunk, self).__init__() pytorch_utils.register_buffer(self, 'lr_rate', 1e-3) pytorch_utils.register_buffer(self, 'max_length', max_length) pytorch_utils.register_buffer(self, 'chunk_size', 10) self.__start_idx_int = start_idx self.__padding_idx_int = padding_idx self.encoder = Encoder(vocab_size=src_vocab_size, is_bidirectional=False) _enc_output_size = 2 * self.encoder.lstm_size.item( ) if self.encoder.is_bidirectional.item( ) else self.encoder.lstm_size.item() self.flatten_hidden_lstm = FlattenHiddenLSTM( lstm_num_layer=3, is_bidirectional=bool(self.encoder.is_bidirectional.item())) self.core_decoder = AttnRawDecoderWithSrc( vocab_size=tgt_vocab_size, enc_output_size=_enc_output_size, enc_embedding_size=self.encoder.embedding_size.item()) self.greedy_infer = DecoderGreedyWithSrcInfer( core_decoder=self.core_decoder) self.xent = None self.optimizer = None self.register_buffer('start_idx', torch.Tensor([start_idx]).long()) self.register_buffer('padding_idx', torch.Tensor([[padding_idx]]).long())
def test_encoder(self): docs = torch.Tensor([[1, 2, 3, 4], [1, 2, 2, 4]]).long() batch_size = docs.size(0) encoder = Encoder(embedding=create_my_embedding(np.random.rand(10, 5))) h_n, c_n, _ = encoder(docs) self.assertEqual(h_n.shape, (6, batch_size, 512)) self.assertEqual(c_n.shape, (6, batch_size, 512))
def test_encoder_step_by_step(self): batch_size = 2 seq_len = 2 vocab_size = 100 docs = torch.randint(vocab_size, size=(batch_size, seq_len)) encoder = Encoder(embedding=create_my_embedding( np.random.rand(vocab_size, 5)), is_bidirectional=False) encoder.eval() with torch.no_grad(): h_n_1, c_n_1, _ = encoder(docs) h_n_2, c_n_2, _ = encoder(docs[:, 0:1]) for step in range(1, seq_len): h_n_2, c_n_2, _ = encoder(docs[:, step:step + 1], (h_n_2, c_n_2)) self.assertEqual(torch.norm(h_n_1 - h_n_2), 0) self.assertEqual(torch.norm(c_n_1 - c_n_2), 0)
def __init__(self, src_vocab_size, tgt_vocab_size, start_idx, end_idx): super(Seq2SeqHugeFeedingAttn, self).__init__() self.lr_rate = 1e-3 self.max_length = 100 self.__start_idx_int = start_idx self.encoder = Encoder(vocab_size=src_vocab_size, lstm_size=1024, lstm_num_layer=4) _enc_output_size = 2*self.encoder.lstm_size if self.encoder.is_bidirectional else self.encoder.lstm_size self.flatten_hidden_lstm = FlattenHiddenLSTM(lstm_num_layer=self.encoder.lstm_num_layer, is_bidirectional=self.encoder.is_bidirectional) self.core_decoder = AttnRawDecoder(vocab_size=tgt_vocab_size, enc_output_size=_enc_output_size, lstm_size=1024, lstm_num_layer=4) self.greedy_infer = DecoderGreedyInfer(core_decoder=self.core_decoder, max_length=self.max_length, start_idx=start_idx) self.xent = None self.optimizer = None self.register_buffer('start_idx', torch.Tensor([[start_idx]]).long()) self.register_buffer('end_idx', torch.Tensor([[end_idx]]).long())
class Seq2SeqChunk(nn.Module): def __init__(self, src_vocab_size, tgt_vocab_size, start_idx, padding_idx, max_length): super(Seq2SeqChunk, self).__init__() pytorch_utils.register_buffer(self, 'lr_rate', 1e-3) pytorch_utils.register_buffer(self, 'max_length', max_length) pytorch_utils.register_buffer(self, 'chunk_size', 10) self.__start_idx_int = start_idx self.__padding_idx_int = padding_idx self.encoder = Encoder(vocab_size=src_vocab_size, is_bidirectional=False) _enc_output_size = 2 * self.encoder.lstm_size.item( ) if self.encoder.is_bidirectional.item( ) else self.encoder.lstm_size.item() self.flatten_hidden_lstm = FlattenHiddenLSTM( lstm_num_layer=3, is_bidirectional=bool(self.encoder.is_bidirectional.item())) self.core_decoder = AttnRawDecoderWithSrc( vocab_size=tgt_vocab_size, enc_output_size=_enc_output_size, enc_embedding_size=self.encoder.embedding_size.item()) self.greedy_infer = DecoderGreedyWithSrcInfer( core_decoder=self.core_decoder) self.xent = None self.optimizer = None self.register_buffer('start_idx', torch.Tensor([start_idx]).long()) self.register_buffer('padding_idx', torch.Tensor([[padding_idx]]).long()) def chunk_forward(self, word_input, h_c, starts_idx, *args): """ Encoding procedure is the same, but only decoding the first half of the sequence :param word_input: shape == (batch_size, max_len) :param h_c: tuple of (h, c). Set it None to indicate the start of the sequence :param starts_idx: Tensor shape == (batch) :param args: :return: Tensor shape == (batch, seq_len) """ if h_c is not None: h_n, c_n, outputs = self.encoder(word_input) else: h_n, c_n, outputs = self.encoder(word_input, h_c) h_n, c_n = self.flatten_hidden_lstm(h_n, c_n) seq_len = word_input.size(1) assert seq_len % 2 == 0 word_input = word_input[:, :int(seq_len / 2)] enc_inputs = self.encoder.embedding(word_input) enc_inputs = enc_inputs.permute(1, 0, 2) output = self.greedy_infer(h_n, c_n, outputs, enc_inputs, starts_idx) return output, (h_n, c_n) def forward(self, word_input, *args): """ :param word_input: shape == (batch_size, seq_len) :param args: :return: Tensor shape == (batch, seq_len) """ __batch_size = word_input.size(0) input_chunks = self.__chunking_sequence(word_input) h_c = None output = [] previous_starts_idx = self.start_idx.repeat(__batch_size) for i_chunk in input_chunks: output_chunk, h_c = self.chunk_forward(i_chunk, h_c, previous_starts_idx) output.append(output_chunk) previous_starts_idx = output_chunk[:, -1] output = torch.cat(output, dim=1) seq_len = word_input.size(1) output = output[:, :seq_len] return output def train(self, mode=True): if self.xent is None: self.xent = nn.CrossEntropyLoss(reduction='none') if self.optimizer is None: self.optimizer = optim.Adam(self.parameters(), lr=self.lr_rate.item()) super().train(mode) def get_loss_chunk(self, word_input, target, length, previous_starts_idx): """ :param word_input: shape == (batch, seq_len) :param target: shape == (batch, seq_len/2) :param length: shape == (batch) :param previous_starts_idx: shape == (batch) :return: Tensor shape == (batch, seq_len/2) """ assert target.size(1) * 2 == word_input.size(1) __half_seq_len = target.size(1) __batch_size = word_input.size(0) enc_h_n, enc_c_n, enc_outputs = self.encoder(word_input) enc_h_n, enc_c_n = self.flatten_hidden_lstm(enc_h_n, enc_c_n) # shape == (batch_size, seq_len/2) dec_inputs = torch.cat( (previous_starts_idx.view(__batch_size, 1), target[:, :-1]), dim=1) # shape == (seq_len/2, batch_size) dec_inputs = dec_inputs.permute(1, 0) enc_inputs = self.encoder.embedding(word_input[:, :__half_seq_len]) # shape == (seq_len/2, batch, _) enc_inputs = enc_inputs.permute(1, 0, 2) # shape == (seq_len/2, batch_size, tgt_vocab_size) predict, _ = self.core_decoder(dec_inputs, (enc_h_n, enc_c_n), enc_outputs, enc_inputs, step=None) # shape == (batch_size, tgt_vocab_size, max_len+1) predict = predict.permute(1, 2, 0) dec_target = target loss = self.xent(predict, dec_target) __chunk_size = self.chunk_size.item() assert __chunk_size == word_input.size(1) loss_mask = pytorch_utils.length_to_mask(length, max_len=__half_seq_len, dtype=torch.float) loss = torch.mul(loss, loss_mask) return loss def get_loss(self, word_input, target, length): """ :param word_input: shape == (batch, seq_len) :param target: shape == (batch, seq_len) :param length: shape == (batch) :return: Tensor shape == (batch, seq_len) """ __max_length = self.max_length.item() __half_chunk_size = int(self.chunk_size.item() / 2) __batch_size = word_input.size(0) input_chunks = self.__chunking_sequence(word_input) target_chunks = self.__chunking_sequence(target) # shape == (batch, __max_length) mask = pytorch_utils.length_to_mask(length, max_len=__max_length) length_chunks = [ torch.sum(mask[:, i:i + __half_chunk_size], dim=1) for i in range(0, __max_length, __half_chunk_size) ] loss = [] previous_starts_idx = self.start_idx.repeat(__batch_size) for idx, (i_chunk, t_chunk) in enumerate(zip(input_chunks, target_chunks)): t_chunk = t_chunk[:, :__half_chunk_size] length_chunk = length_chunks[idx] loss.append( self.get_loss_chunk(i_chunk, t_chunk, length_chunk, previous_starts_idx)) previous_starts_idx = t_chunk[:, -1] loss = torch.cat(loss, dim=1) loss = torch.div(loss.sum(dim=1), length.float()) loss = loss.mean(dim=0) return loss def __chunking_sequence(self, word_input): """ :param word_input: :return: List chunks """ assert self.chunk_size % 2 == 0 seq_len = word_input.size(1) batch_size = word_input.size(0) must_have = int( int(seq_len / (self.chunk_size / 2)) * (self.chunk_size / 2) + self.chunk_size) no_padding = must_have - seq_len padding = self.padding_idx.repeat(batch_size, no_padding) word_input = torch.cat((word_input, padding), dim=1) input_chunks = [ word_input[:, i:i + self.chunk_size] for i in range(0, seq_len, int(self.chunk_size / 2)) ] return input_chunks def train_batch(self, word_input, target, length): """ :param word_input: shape == (batch_size, max_len) :param target: shape == (batch_size, max_len) :return: """ self.train() self.optimizer.zero_grad() loss = self.get_loss(word_input, target, length) loss.backward() self.optimizer.step() return loss.item()
class MainModel(nn.Module): def __init__(self, enc_embedding_weight, dec_embedding_weight, start_idx): super(MainModel, self).__init__() self.lr_rate = 1e-3 self.max_length = 100 self.__start_idx_int = start_idx self.encoder = Encoder( embedding=create_my_embedding(enc_embedding_weight), lstm_num_layer=1, lstm_size=1024) _enc_output_size = 2 * self.encoder.lstm_size if self.encoder.is_bidirectional else self.encoder.lstm_size self.flatten_hidden_lstm = FlattenHiddenLSTM( lstm_num_layer=self.encoder.lstm_num_layer, is_bidirectional=self.encoder.is_bidirectional) self.core_decoder = AttnRawDecoderWithSrc( embedding=create_my_embedding(dec_embedding_weight), enc_output_size=_enc_output_size, use_pred_prob=0.1, lstm_size=self.encoder.lstm_size, lstm_num_layer=1, enc_embedding_size=self.encoder.embedding_size) self.greedy_infer = DecoderGreedyInfer(core_decoder=self.core_decoder, start_idx=start_idx) self.xent = None self.optimizer = None self.register_buffer('start_idx', torch.Tensor([[start_idx]]).long()) # self.register_buffer('end_idx', torch.Tensor([[end_idx]]).long()) def forward(self, word_input, *args): """ :param word_input: shape == (batch_size, max_len) :param start_idx: int scala :param end_idx: int scala :param args: :return: """ h_n, c_n, outputs = self.encoder(word_input) h_n, c_n = self.flatten_hidden_lstm(h_n, c_n) h_c = (h_n, c_n) enc_inputs = self.encoder.embedding(word_input) enc_inputs = enc_inputs.permute(1, 0, 2) output = self.greedy_infer(h_c, outputs, enc_inputs) return output def train(self, mode=True): if self.xent is None: self.xent = nn.CrossEntropyLoss(reduction='none') if self.optimizer is None: self.optimizer = optim.Adam(self.parameters(), lr=self.lr_rate) super().train(mode) def get_loss(self, word_input, target, length): """ :param word_input: shape == (batch_size, max_len) :param target: shape == (batch_size, max_len) :param length: shape == (batch_size) :return: """ enc_h_n, enc_c_n, enc_outputs = self.encoder(word_input) enc_h_n, enc_c_n = self.flatten_hidden_lstm(enc_h_n, enc_c_n) batch_size = enc_h_n.size(1) init_words = self.start_idx.repeat(batch_size, 1) # shape == (batch_size, max_len) dec_input = torch.cat((init_words, target), dim=1) # shape == (max_len, batch_size) dec_input = dec_input.permute(1, 0)[:-1] # shape == (batch_size, max_len, _) enc_inputs = self.encoder.embedding(word_input) # shape == (max_len, batch_size, _) enc_inputs = enc_inputs.permute(1, 0, 2) # shape == (max_len+1, batch_size, tgt_vocab_size) predict, _, _ = self.core_decoder(dec_input, (enc_h_n, enc_c_n), enc_outputs, None, enc_inputs) # shape == (batch_size, tgt_vocab_size, max_len+1) predict = predict.permute(1, 2, 0) # end_words = self.end_idx.repeat(batch_size, 1) # dec_target = torch.cat((target, end_words), dim=1) dec_target = target loss = self.xent(predict, dec_target) loss_mask = pytorch_utils.length_to_mask(length, max_len=self.max_length, dtype=torch.float) loss = torch.mul(loss, loss_mask) loss = torch.div(loss.sum(dim=1), length.float()) loss = loss.mean(dim=0) return loss def train_batch(self, word_input, target, length): """ :param word_input: shape == (batch_size, max_len) :param target: shape == (batch_size, max_len) :return: """ self.train() self.optimizer.zero_grad() loss = self.get_loss(word_input, target, length) loss.backward() self.optimizer.step() return loss.item()