def __init__(self, args, vocab, enc_dim, dec_hidden): super().__init__() self.bridger = MLPBridger( rnn_type=args.rnn_type, mapper_type=args.mapper_type, encoder_dim=enc_dim, encoder_layer=args.enc_num_layers, decoder_dim=dec_hidden, decoder_layer=args.dec_num_layers, ) if "stack_mlp" not in args: self.scorer = nn.Sequential( nn.Dropout(args.dec_rd), nn.Linear(in_features=dec_hidden * args.dec_num_layers, out_features=len(vocab.src), bias=True), nn.LogSoftmax(dim=-1)) else: self.scorer = nn.Sequential( nn.Dropout(args.dec_rd), nn.Linear(in_features=dec_hidden * args.dec_num_layers, out_features=dec_hidden, bias=True), nn.ReLU(), nn.Linear(in_features=dec_hidden, out_features=len(vocab.src), bias=True), nn.LogSoftmax(dim=-1)) self.semantic_nll = nn.NLLLoss(ignore_index=vocab.src.pad_id)
def __init__(self, args, vocab, src_embed=None, tgt_embed=None): super(BaseSeq2seq, self).__init__() self.vocab = vocab self.src_vocab = vocab.src self.tgt_vocab = vocab.tgt self.args = args self.encoder = RNNEncoder(vocab_size=len(self.src_vocab), max_len=args.src_max_time_step, input_size=args.enc_embed_dim, hidden_size=args.enc_hidden_dim, embed_droprate=args.enc_ed, rnn_droprate=args.enc_rd, n_layers=args.enc_num_layers, bidirectional=args.bidirectional, rnn_cell=args.rnn_type, variable_lengths=True, embedding=src_embed) self.enc_factor = 2 if args.bidirectional else 1 self.enc_dim = args.enc_hidden_dim * self.enc_factor if args.mapper_type == "link": self.dec_hidden = self.enc_dim elif args.use_attention: self.dec_hidden = self.enc_dim else: self.dec_hidden = args.dec_hidden_dim self.bridger = MLPBridger( rnn_type=args.rnn_type, mapper_type=args.mapper_type, encoder_dim=self.enc_dim, encoder_layer=args.enc_num_layers, decoder_dim=self.dec_hidden, decoder_layer=args.dec_num_layers, ) self.decoder = RNNDecoder( vocab=len(self.tgt_vocab), max_len=args.tgt_max_time_step, input_size=args.dec_embed_dim, hidden_size=self.dec_hidden, embed_droprate=args.dec_ed, rnn_droprate=args.dec_rd, n_layers=args.dec_num_layers, rnn_cell=args.rnn_type, use_attention=args.use_attention, embedding=tgt_embed, eos_id=self.tgt_vocab.eos_id, sos_id=self.tgt_vocab.sos_id, ) self.beam_decoder = TopKDecoder(decoder_rnn=self.decoder, k=args.sample_size) print("enc layer: {}, dec layer: {}, type: {}, with attention: {}". format(args.enc_num_layers, args.dec_num_layers, args.rnn_type, args.use_attention))
class BridgeMLP(nn.Module): def __init__(self, args, vocab, enc_dim, dec_hidden): super().__init__() self.bridger = MLPBridger( rnn_type=args.rnn_type, mapper_type=args.mapper_type, encoder_dim=enc_dim, encoder_layer=args.enc_num_layers, decoder_dim=dec_hidden, decoder_layer=args.dec_num_layers, ) self.scorer = nn.Sequential( nn.Dropout(args.dec_rd), nn.Linear( in_features=dec_hidden * args.dec_num_layers, out_features=len(vocab.src), bias=True ), nn.LogSoftmax(dim=-1) ) self.semantic_nll = nn.NLLLoss(ignore_index=vocab.src.pad_id, reduction="none") def forward(self, hidden, tgt_var): batch_size = tgt_var.size(0) semantic_decode_init = self.bridger.forward(input_tensor=hidden) xx_hidden = semantic_decode_init.permute(1, 0, 2).contiguous().view(batch_size, -1) score = self.scorer.forward(input=xx_hidden) sem_loss = bag_of_word_loss(score, tgt_var, self.semantic_nll) sem_loss = sem_loss.contiguous().view(batch_size, -1).sum(-1) return sem_loss.sum()
def __init__(self, args, vocab, enc_hidden_dim, dec_hidden_dim, embed, mode='src'): super().__init__() self.bridger = MLPBridger( rnn_type=args.rnn_type, mapper_type=args.mapper_type, encoder_dim=enc_hidden_dim, encoder_layer=args.enc_num_layers, decoder_dim=dec_hidden_dim, decoder_layer=args.dec_num_layers, ) if mode == 'src': self.decoder = RNNDecoder( vocab=len(vocab.src), max_len=args.src_max_time_step, input_size=args.dec_embed_dim, hidden_size=dec_hidden_dim, embed_droprate=args.dec_ed, rnn_droprate=args.dec_rd, n_layers=args.dec_num_layers, rnn_cell=args.rnn_type, use_attention=args.use_attention, embedding=embed, eos_id=vocab.src.eos_id, sos_id=vocab.src.sos_id, ) else: self.decoder = RNNDecoder( vocab=len(vocab.tgt), max_len=args.tgt_max_time_step, input_size=args.dec_embed_dim, hidden_size=dec_hidden_dim, embed_droprate=args.dec_ed, rnn_droprate=args.dec_rd, n_layers=args.dec_num_layers, rnn_cell=args.rnn_type, use_attention=args.use_attention, embedding=embed, eos_id=vocab.tgt.eos_id, sos_id=vocab.tgt.sos_id, )
class MLPPredictor(nn.Module): def __init__(self, args, vocab, enc_dim, dec_hidden): super().__init__() self.bridger = MLPBridger( rnn_type=args.rnn_type, mapper_type=args.mapper_type, encoder_dim=enc_dim, encoder_layer=args.enc_num_layers, decoder_dim=dec_hidden, decoder_layer=args.dec_num_layers, ) self.scorer = nn.Sequential( nn.Dropout(args.dec_rd), nn.Linear( in_features=dec_hidden * args.dec_num_layers, out_features=len(vocab.src), bias=True ), nn.LogSoftmax(dim=-1) ) self.pad_id = vocab.src.pad_id self.semantic_nll = nn.NLLLoss(ignore_index=vocab.src.pad_id, reduction='none') def detach(self): for p in self.parameters(): p.requires_grad = False def reload(self): for p in self.parameters(): p.requires_grad = True def forward(self, hidden, tgt_var, div_by_word=False): batch_size = tgt_var.size(0) semantic_decode_init = self.bridger.forward(input_tensor=hidden) xx_hidden = semantic_decode_init.permute(1, 0, 2).contiguous().view(batch_size, -1) score = self.scorer.forward(input=xx_hidden) sem_loss = bag_of_word_loss(score, tgt_var, self.semantic_nll) sem_loss = sem_loss.contiguous().view(batch_size, -1) if div_by_word: tgt_len = tgt_var.ne(self.pad_id).sum(-1).float() sem_loss = sem_loss.div(1e-9 + tgt_len) return sem_loss.sum()
def __init__(self, args, vocab, word_embed=None): super(VanillaVAE, self).__init__(args, vocab, name='MySentVAE') print("This is {} with parameter\n{}".format(self.name, self.base_information())) if word_embed is None: src_embed = nn.Embedding(len(vocab.src), args.embed_size) else: src_embed = word_embed if args.share_embed: tgt_embed = src_embed args.dec_embed_dim = args.enc_embed_dim else: tgt_embed = None self.latent_size = args.latent_size self.unk_rate = args.unk_rate self.step_unk_rate = 0.0 self.hidden_size = args.enc_hidden_dim self.hidden_factor = (2 if args.bidirectional else 1) * args.enc_num_layers args.use_attention = False # layer size setting self.enc_dim = args.enc_hidden_dim * (2 if args.bidirectional else 1) # single layer unit size if args.mapper_type == "link": self.dec_hidden = self.enc_dim else: self.dec_hidden = args.dec_hidden_dim self.encoder = RNNEncoder( vocab_size=len(vocab.src), max_len=args.src_max_time_step, input_size=args.enc_embed_dim, hidden_size=args.enc_hidden_dim, embed_droprate=args.enc_ed, rnn_droprate=args.enc_rd, n_layers=args.enc_num_layers, bidirectional=args.bidirectional, rnn_cell=args.rnn_type, variable_lengths=True, embedding=src_embed ) self.bridger = MLPBridger( rnn_type=args.rnn_type, mapper_type=args.mapper_type, encoder_dim=self.enc_dim, encoder_layer=args.enc_num_layers, decoder_dim=self.dec_hidden, decoder_layer=args.dec_num_layers, ) self.decoder = RNNDecoder( vocab=len(vocab.src), max_len=args.src_max_time_step, input_size=args.dec_embed_dim, hidden_size=self.dec_hidden, embed_droprate=args.dec_ed, rnn_droprate=args.dec_rd, n_layers=args.dec_num_layers, rnn_cell=args.rnn_type, use_attention=False, embedding=tgt_embed, eos_id=vocab.src.eos_id, sos_id=vocab.src.sos_id, ) self.hidden2mean = nn.Linear(args.hidden_size * self.hidden_factor, args.latent_size) self.hidden2logv = nn.Linear(args.hidden_size * self.hidden_factor, args.latent_size) self.latent2hidden = nn.Linear(args.latent_size, args.hidden_size * self.hidden_factor)
class VanillaVAE(BaseVAE): def __init__(self, args, vocab, word_embed=None): super(VanillaVAE, self).__init__(args, vocab, name='MySentVAE') print("This is {} with parameter\n{}".format(self.name, self.base_information())) if word_embed is None: src_embed = nn.Embedding(len(vocab.src), args.embed_size) else: src_embed = word_embed if args.share_embed: tgt_embed = src_embed args.dec_embed_dim = args.enc_embed_dim else: tgt_embed = None self.latent_size = args.latent_size self.unk_rate = args.unk_rate self.step_unk_rate = 0.0 self.hidden_size = args.enc_hidden_dim self.hidden_factor = (2 if args.bidirectional else 1) * args.enc_num_layers args.use_attention = False # layer size setting self.enc_dim = args.enc_hidden_dim * (2 if args.bidirectional else 1) # single layer unit size if args.mapper_type == "link": self.dec_hidden = self.enc_dim else: self.dec_hidden = args.dec_hidden_dim self.encoder = RNNEncoder( vocab_size=len(vocab.src), max_len=args.src_max_time_step, input_size=args.enc_embed_dim, hidden_size=args.enc_hidden_dim, embed_droprate=args.enc_ed, rnn_droprate=args.enc_rd, n_layers=args.enc_num_layers, bidirectional=args.bidirectional, rnn_cell=args.rnn_type, variable_lengths=True, embedding=src_embed ) self.bridger = MLPBridger( rnn_type=args.rnn_type, mapper_type=args.mapper_type, encoder_dim=self.enc_dim, encoder_layer=args.enc_num_layers, decoder_dim=self.dec_hidden, decoder_layer=args.dec_num_layers, ) self.decoder = RNNDecoder( vocab=len(vocab.src), max_len=args.src_max_time_step, input_size=args.dec_embed_dim, hidden_size=self.dec_hidden, embed_droprate=args.dec_ed, rnn_droprate=args.dec_rd, n_layers=args.dec_num_layers, rnn_cell=args.rnn_type, use_attention=False, embedding=tgt_embed, eos_id=vocab.src.eos_id, sos_id=vocab.src.sos_id, ) self.hidden2mean = nn.Linear(args.hidden_size * self.hidden_factor, args.latent_size) self.hidden2logv = nn.Linear(args.hidden_size * self.hidden_factor, args.latent_size) self.latent2hidden = nn.Linear(args.latent_size, args.hidden_size * self.hidden_factor) def encode(self, input_var, length): if self.training and self.args.src_wd: input_var = unk_replace(input_var, self.step_unk_rate, self.vocab.src) encoder_output, encoder_hidden = self.encoder.forward(input_var, length) return encoder_output, encoder_hidden def decode(self, inputs, encoder_outputs, encoder_hidden): return self.decoder.forward( inputs=inputs, encoder_outputs=encoder_outputs, encoder_hidden=encoder_hidden ) def forward(self, examples): if not isinstance(examples, list): examples = [examples] batch_size = len(examples) sent_words = [e.src for e in examples] ret = self.encode_to_hidden(examples) ret = self.hidden_to_latent(ret=ret, is_sampling=self.training) ret = self.latent_for_init(ret=ret) decode_init = ret['decode_init'] tgt_var = to_input_variable(sent_words, self.vocab.src, training=False, cuda=self.args.cuda, append_boundary_sym=True, batch_first=True) decode_init = self.bridger.forward(decode_init) if self.training and self.args.tgt_wd > 0.: input_var = unk_replace(tgt_var, self.step_unk_rate, self.vocab.src) tgt_token_scores = self.decoder.generate( con_inputs=input_var, encoder_hidden=decode_init, encoder_outputs=None, teacher_forcing_ratio=1.0, ) reconstruct_loss = -torch.sum(self.decoder.score_decoding_results(tgt_token_scores, tgt_var)) else: reconstruct_loss = -torch.sum(self.decoder.score( inputs=tgt_var, encoder_outputs=None, encoder_hidden=decode_init, )) return { "mean": ret['mean'], "logv": ret['logv'], "z": ret['latent'], 'nll_loss': reconstruct_loss, 'batch_size': batch_size } def get_loss(self, examples, step): self.step_unk_rate = wd_anneal_function(unk_max=self.unk_rate, anneal_function=self.args.unk_schedule, step=step, x0=self.args.x0, k=self.args.k) explore = self.forward(examples) batch_size = explore['batch_size'] kl_loss, kl_weight = self.compute_kl_loss(explore['mean'], explore['logv'], step) kl_weight *= self.args.kl_factor nll_loss = explore['nll_loss'] / batch_size kl_loss = kl_loss / batch_size kl_item = kl_loss * kl_weight return { 'KL Loss': kl_loss, 'NLL Loss': nll_loss, 'KL Weight': kl_weight, 'Model Score': kl_loss + nll_loss, 'ELBO': kl_item + nll_loss, 'Loss': kl_item + nll_loss, 'KL Item': kl_item, } def sample_latent(self, batch_size): z = to_var(torch.randn([batch_size, self.latent_size])) return { "latent": z } def latent_for_init(self, ret): z = ret['latent'] batch_size = z.size(0) hidden = self.latent2hidden(z) if self.hidden_factor > 1: hidden = hidden.view(batch_size, self.hidden_factor, self.hidden_size) hidden = hidden.permute(1, 0, 2) else: hidden = hidden.unsqueeze(0) ret['decode_init'] = hidden return ret def batch_beam_decode(self, examples): raise NotImplementedError def hidden_to_latent(self, ret, is_sampling=True): hidden = ret['hidden'] batch_size = hidden.size(1) hidden = hidden.permute(1, 0, 2).contiguous() if self.hidden_factor > 1: hidden = hidden.view(batch_size, self.hidden_size * self.hidden_factor) else: hidden = hidden.squeeze() mean = self.hidden2mean(hidden) logv = self.hidden2logv(hidden) if is_sampling: std = torch.exp(0.5 * logv) z = to_var(torch.randn([batch_size, self.latent_size])) z = z * std + mean else: z = mean ret["latent"] = z ret["mean"] = mean ret['logv'] = logv return ret def conditional_generating(self, condition='sem', examples=None): if not isinstance(examples, list): examples = [examples] if condition.startswith("sem"): ret = self.encode_to_hidden(examples) ret = self.hidden_to_latent(ret=ret, is_sampling=True) ret = self.latent_for_init(ret=ret) return { 'res': self.decode_to_sentence(ret=ret) } if condition is None: return { "res": self.unsupervised_generating(sample_num=len(examples)) } def eval_adv(self, sem_in, syn_ref): sem_ret = self.encode_to_hidden(sem_in) sem_ret = self.hidden_to_latent(sem_ret, is_sampling=self.training) syn_ret = self.encode_to_hidden(syn_ref, need_sort=True) syn_ret = self.hidden_to_latent(syn_ret, is_sampling=self.training) sem_ret = self.latent_for_init(ret=sem_ret) syn_ret = self.latent_for_init(ret=syn_ret) ret = dict() ret["latent"] = (sem_ret['latent'] + syn_ret['latent']) * 0.5 ret = self.latent_for_init(ret) ret['res'] = self.decode_to_sentence(ret=ret) return ret
class BridgeRNN(nn.Module): def __init__(self, args, vocab, enc_hidden_dim, dec_hidden_dim, embed, mode='src'): super().__init__() self.bridger = MLPBridger( rnn_type=args.rnn_type, mapper_type=args.mapper_type, encoder_dim=enc_hidden_dim, encoder_layer=args.enc_num_layers, decoder_dim=dec_hidden_dim, decoder_layer=args.dec_num_layers, ) if mode == 'src': self.decoder = RNNDecoder( vocab=len(vocab.src), max_len=args.src_max_time_step, input_size=args.dec_embed_dim, hidden_size=dec_hidden_dim, embed_droprate=args.dec_ed, rnn_droprate=args.dec_rd, n_layers=args.dec_num_layers, rnn_cell=args.rnn_type, use_attention=args.use_attention, embedding=embed, eos_id=vocab.src.eos_id, sos_id=vocab.src.sos_id, ) else: self.decoder = RNNDecoder( vocab=len(vocab.tgt), max_len=args.tgt_max_time_step, input_size=args.dec_embed_dim, hidden_size=dec_hidden_dim, embed_droprate=args.dec_ed, rnn_droprate=args.dec_rd, n_layers=args.dec_num_layers, rnn_cell=args.rnn_type, use_attention=args.use_attention, embedding=embed, eos_id=vocab.tgt.eos_id, sos_id=vocab.tgt.sos_id, ) def forward(self, hidden, tgt_var): decode_init = self.bridger.forward(input_tensor=hidden) _loss = -torch.sum( self.decoder.score( inputs=tgt_var, encoder_outputs=None, encoder_hidden=decode_init, )) return _loss def predict(self, hidden): decode_init = self.bridger.forward(input_tensor=hidden) decoder_outputs, decoder_hidden, ret_dict, enc_states = self.decoder.forward( inputs=None, encoder_outputs=None, encoder_hidden=decode_init, ) result = torch.stack(ret_dict['sequence']).squeeze() if result.dim() < 2: result = result.unsqueeze(1) return result
class BaseSeq2seq(nn.Module, BaseGenerator): def __init__(self, args, vocab, src_embed=None, tgt_embed=None): super(BaseSeq2seq, self).__init__() self.vocab = vocab self.src_vocab = vocab.src self.tgt_vocab = vocab.tgt self.args = args self.encoder = RNNEncoder(vocab_size=len(self.src_vocab), max_len=args.src_max_time_step, input_size=args.enc_embed_dim, hidden_size=args.enc_hidden_dim, embed_droprate=args.enc_ed, rnn_droprate=args.enc_rd, n_layers=args.enc_num_layers, bidirectional=args.bidirectional, rnn_cell=args.rnn_type, variable_lengths=True, embedding=src_embed) self.enc_factor = 2 if args.bidirectional else 1 self.enc_dim = args.enc_hidden_dim * self.enc_factor if args.mapper_type == "link": self.dec_hidden = self.enc_dim elif args.use_attention: self.dec_hidden = self.enc_dim else: self.dec_hidden = args.dec_hidden_dim self.bridger = MLPBridger( rnn_type=args.rnn_type, mapper_type=args.mapper_type, encoder_dim=self.enc_dim, encoder_layer=args.enc_num_layers, decoder_dim=self.dec_hidden, decoder_layer=args.dec_num_layers, ) self.decoder = RNNDecoder( vocab=len(self.tgt_vocab), max_len=args.tgt_max_time_step, input_size=args.dec_embed_dim, hidden_size=self.dec_hidden, embed_droprate=args.dec_ed, rnn_droprate=args.dec_rd, n_layers=args.dec_num_layers, rnn_cell=args.rnn_type, use_attention=args.use_attention, embedding=tgt_embed, eos_id=self.tgt_vocab.eos_id, sos_id=self.tgt_vocab.sos_id, ) self.beam_decoder = TopKDecoder(decoder_rnn=self.decoder, k=args.sample_size) print("enc layer: {}, dec layer: {}, type: {}, with attention: {}". format(args.enc_num_layers, args.dec_num_layers, args.rnn_type, args.use_attention)) def get_loss(self, **kwargs): return {"Loss": -self.score(**kwargs)} def forward(self, seqs_x, x_length, to_word=False): pass def init(self): self.encoder.rnn.flatten_parameters() self.decoder.rnn.flatten_parameters() def encode(self, src_var, src_length): encoder_outputs, encoder_hidden = self.encoder.forward( input_var=src_var, input_lengths=src_length) return encoder_outputs, encoder_hidden def bridge(self, encoder_hidden): # batch_size = encoder_hidden.size(1) # convert = encoder_hidden.permute(1, 0, 2).contiguous().view(batch_size, -1) return self.bridger.forward(encoder_hidden) def get_hidden(self, examples): args = self.args if not isinstance(examples, list): examples = [examples] input_dict = to_input_dict( examples=examples, vocab=self.vocab, max_tgt_len=-1, cuda=args.cuda, training=self.training, src_append=False, tgt_append=True, use_tgt=True, use_tag=False, use_dst=False, ) src_var = input_dict['src'] tgt_var = input_dict['tgt'] src_length = input_dict['src_len'] encoder_outputs, encoder_hidden = self.encode(src_var=src_var, src_length=src_length) encoder_hidden = self.bridge(encoder_hidden) return encoder_hidden def score(self, examples, return_enc_state=False, **kwargs): args = self.args if not isinstance(examples, list): examples = [examples] input_dict = to_input_dict( examples=examples, vocab=self.vocab, max_tgt_len=-1, cuda=args.cuda, training=self.training, src_append=False, tgt_append=True, use_tgt=True, use_tag=False, use_dst=False, ) src_var = input_dict['src'] tgt_var = input_dict['tgt'] src_length = input_dict['src_len'] encoder_outputs, encoder_hidden = self.encode(src_var=src_var, src_length=src_length) encoder_hidden = self.bridge(encoder_hidden) scores = self.decoder.score(inputs=tgt_var, encoder_hidden=encoder_hidden, encoder_outputs=encoder_outputs) if return_enc_state: return scores, encoder_hidden else: return scores def predict(self, examples, to_word=True): args = self.args if not isinstance(examples, list): examples = [examples] input_dict = to_input_dict( examples=examples, vocab=self.vocab, max_tgt_len=-1, cuda=args.cuda, training=self.training, src_append=False, tgt_append=True, use_tgt=False, use_tag=False, use_dst=False, ) src_var = input_dict['src'] src_length = input_dict['src_len'] encoder_outputs, encoder_hidden = self.encode(src_var=src_var, src_length=src_length) encoder_hidden = self.bridge(encoder_hidden) decoder_output, decoder_hidden, ret_dict, _ = self.decoder.forward( encoder_hidden=encoder_hidden, encoder_outputs=encoder_outputs, teacher_forcing_ratio=0.0) result = torch.stack(ret_dict['sequence']).squeeze() final_result = [] if len(result.size()) < 2: result = result.view(-1, 1) example_nums = result.size(-1) if to_word: for i in range(example_nums): hyp = result[:, i].data.tolist() res = id2word(hyp, self.vocab.tgt) seems = [[res], [len(res)]] final_result.append(seems) return final_result def beam_search(self, src_sent, beam_size=5, dmts=None): if dmts is None: dmts = self.args.decode_max_time_step src_var = to_input_variable(src_sent, self.src_vocab, cuda=self.args.cuda, training=False, append_boundary_sym=False, batch_first=True) src_length = [len(src_sent)] encoder_outputs, encoder_hidden = self.encode(src_var=src_var, src_length=src_length) encoder_hidden = self.bridger.forward(input_tensor=encoder_hidden) meta_data = self.beam_decoder.beam_search( encoder_hidden=encoder_hidden, encoder_outputs=encoder_outputs, beam_size=beam_size, decode_max_time_step=dmts) topk_sequence = meta_data['sequence'] topk_score = meta_data['score'].squeeze() completed_hypotheses = torch.cat(topk_sequence, dim=-1) number_return = completed_hypotheses.size(0) final_result = [] final_scores = [] for i in range(number_return): hyp = completed_hypotheses[i, :].data.tolist() res = id2word(hyp, self.tgt_vocab) final_result.append(res) final_scores.append(topk_score[i].item()) return final_result, final_scores def load_state_dict(self, state_dict, strict=True): return super().load_state_dict(state_dict, strict) def save(self, path): dir_name = os.path.dirname(path) if not os.path.exists(dir_name): os.makedirs(dir_name) params = { 'args': self.args, 'vocab': self.vocab, 'state_dict': self.state_dict(), } torch.save(params, path) @classmethod def load(cls, load_path): params = torch.load(load_path, map_location=lambda storage, loc: storage) args = params['args'] vocab = params['vocab'] model = cls(args, vocab) model.load_state_dict(params['state_dict']) if args.cuda: model = model.cuda() return model
class RNNPredictor(nn.Module): def __init__(self, args, vocab, enc_hidden_dim, dec_hidden_dim, embed, mode='src'): super().__init__() self.bridger = MLPBridger( rnn_type=args.rnn_type, mapper_type=args.mapper_type, encoder_dim=enc_hidden_dim, encoder_layer=args.enc_num_layers, decoder_dim=dec_hidden_dim, decoder_layer=args.dec_num_layers, ) if mode == 'src': self.decoder = RNNDecoder( vocab=len(vocab.src), max_len=args.src_max_time_step, input_size=args.dec_embed_dim, hidden_size=dec_hidden_dim, embed_droprate=args.dec_ed, rnn_droprate=args.dec_rd, n_layers=args.dec_num_layers, rnn_cell=args.rnn_type, use_attention=args.use_attention, embedding=embed, eos_id=vocab.src.eos_id, sos_id=vocab.src.sos_id, ) self.pad_id = vocab.src.pad_id else: self.decoder = RNNDecoder( vocab=len(vocab.tgt), max_len=args.tgt_max_time_step, input_size=args.dec_embed_dim, hidden_size=dec_hidden_dim, embed_droprate=args.dec_ed, rnn_droprate=args.dec_rd, n_layers=args.dec_num_layers, rnn_cell=args.rnn_type, use_attention=args.use_attention, embedding=embed, eos_id=vocab.tgt.eos_id, sos_id=vocab.tgt.sos_id, ) self.pad_id = vocab.tgt.pad_id def forward(self, hidden, tgt_var, div_by_word=False): decode_init = self.bridger.forward(input_tensor=hidden) _loss = self.decoder.score( inputs=tgt_var, encoder_outputs=None, encoder_hidden=decode_init, ) if div_by_word: tgt_len = tgt_var.ne(self.pad_id).sum(-1).float() _loss = _loss.div(tgt_len + 1e-9) return -_loss.sum() def detach(self): for p in self.parameters(): p.requires_grad = False def reload(self): for p in self.parameters(): p.requires_grad = True def predict(self, hidden): decode_init = self.bridger.forward(input_tensor=hidden) decoder_outputs, decoder_hidden, ret_dict, enc_states = self.decoder.forward( inputs=None, encoder_outputs=None, encoder_hidden=decode_init, ) result = torch.stack(ret_dict['sequence']).squeeze() if result.dim() < 2: result = result.unsqueeze(1) return result
def __init__(self, args, vocab, src_embed=None, tgt_embed=None): super(SyntaxGuideVAE, self).__init__(args, vocab, name='SyntaxGuideVAE') self.latent_size = args.latent_size self.rnn_type = args.rnn_type self.bidirectional = args.bidirectional self.num_layers = args.num_layers self.unk_rate = args.unk_rate self.step_unk_rate = 0.0 self.encoder = RNNEncoder(vocab_size=len(vocab.src), max_len=args.src_max_time_step, input_size=args.enc_embed_dim, hidden_size=args.enc_hidden_dim, embed_droprate=args.enc_ed, rnn_droprate=args.enc_rd, n_layers=args.enc_num_layers, bidirectional=args.bidirectional, rnn_cell=args.rnn_type, variable_lengths=True, embedding=src_embed) self.syntax_encoder = RNNEncoder(vocab_size=len(vocab.tgt), max_len=args.tgt_max_time_step, input_size=args.enc_embed_dim, hidden_size=args.enc_hidden_dim, embed_droprate=args.enc_ed, rnn_droprate=args.enc_rd, n_layers=args.enc_num_layers, bidirectional=args.bidirectional, rnn_cell=args.rnn_type, variable_lengths=True, embedding=tgt_embed) self.hidden_size = args.enc_hidden_dim self.hidden_factor = (2 if args.bidirectional else 1) * args.enc_num_layers self.hidden2mean = nn.Linear(self.hidden_size * self.hidden_factor, self.latent_size) self.hidden2logv = nn.Linear(self.hidden_size * self.hidden_factor, self.latent_size) self.latent2hidden = nn.Linear(self.latent_size, self.hidden_size * self.hidden_factor) self.enc_dim = args.enc_hidden_dim * (2 if args.bidirectional else 1) if args.mapper_type == "link": self.dec_hidden = self.enc_dim elif args.use_attention: self.dec_hidden = self.enc_dim else: self.dec_hidden = args.dec_hidden_dim self.bridger = MLPBridger( rnn_type=args.rnn_type, mapper_type=args.mapper_type, encoder_dim=self.enc_dim, encoder_layer=args.enc_num_layers, decoder_dim=self.dec_hidden, decoder_layer=args.dec_num_layers, ) self.decoder = RNNDecoder( vocab=len(vocab.src), max_len=args.src_max_time_step, input_size=args.dec_embed_dim, hidden_size=self.dec_hidden, embed_droprate=args.dec_ed, rnn_droprate=args.dec_rd, n_layers=args.dec_num_layers, rnn_cell=args.rnn_type, use_attention=args.use_attention, embedding=src_embed, eos_id=vocab.src.eos_id, sos_id=vocab.src.sos_id, )
class SyntaxGuideVAE(BaseVAE): def encode(self, input_var, length): pass def conditional_generating(self, condition=None, **kwargs): pass def generating(self, sample_num, batch_size=50): return super().generating(sample_num, batch_size) def unsupervised_generating(self, sample_num, batch_size=50): return super().unsupervised_generating(sample_num, batch_size) def predict(self, examples, to_word=True): return super().predict(examples, to_word) def base_information(self): return super().base_information() def __init__(self, args, vocab, src_embed=None, tgt_embed=None): super(SyntaxGuideVAE, self).__init__(args, vocab, name='SyntaxGuideVAE') self.latent_size = args.latent_size self.rnn_type = args.rnn_type self.bidirectional = args.bidirectional self.num_layers = args.num_layers self.unk_rate = args.unk_rate self.step_unk_rate = 0.0 self.encoder = RNNEncoder(vocab_size=len(vocab.src), max_len=args.src_max_time_step, input_size=args.enc_embed_dim, hidden_size=args.enc_hidden_dim, embed_droprate=args.enc_ed, rnn_droprate=args.enc_rd, n_layers=args.enc_num_layers, bidirectional=args.bidirectional, rnn_cell=args.rnn_type, variable_lengths=True, embedding=src_embed) self.syntax_encoder = RNNEncoder(vocab_size=len(vocab.tgt), max_len=args.tgt_max_time_step, input_size=args.enc_embed_dim, hidden_size=args.enc_hidden_dim, embed_droprate=args.enc_ed, rnn_droprate=args.enc_rd, n_layers=args.enc_num_layers, bidirectional=args.bidirectional, rnn_cell=args.rnn_type, variable_lengths=True, embedding=tgt_embed) self.hidden_size = args.enc_hidden_dim self.hidden_factor = (2 if args.bidirectional else 1) * args.enc_num_layers self.hidden2mean = nn.Linear(self.hidden_size * self.hidden_factor, self.latent_size) self.hidden2logv = nn.Linear(self.hidden_size * self.hidden_factor, self.latent_size) self.latent2hidden = nn.Linear(self.latent_size, self.hidden_size * self.hidden_factor) self.enc_dim = args.enc_hidden_dim * (2 if args.bidirectional else 1) if args.mapper_type == "link": self.dec_hidden = self.enc_dim elif args.use_attention: self.dec_hidden = self.enc_dim else: self.dec_hidden = args.dec_hidden_dim self.bridger = MLPBridger( rnn_type=args.rnn_type, mapper_type=args.mapper_type, encoder_dim=self.enc_dim, encoder_layer=args.enc_num_layers, decoder_dim=self.dec_hidden, decoder_layer=args.dec_num_layers, ) self.decoder = RNNDecoder( vocab=len(vocab.src), max_len=args.src_max_time_step, input_size=args.dec_embed_dim, hidden_size=self.dec_hidden, embed_droprate=args.dec_ed, rnn_droprate=args.dec_rd, n_layers=args.dec_num_layers, rnn_cell=args.rnn_type, use_attention=args.use_attention, embedding=src_embed, eos_id=vocab.src.eos_id, sos_id=vocab.src.sos_id, ) def syntax_encode(self, syntax_var, length): syntax_outputs, syntax_hidden = self.syntax_encoder.forward( syntax_var, length) return syntax_outputs, syntax_hidden def sentence_encode(self, sent_words): batch_size = len(sent_words) sent_lengths = [len(sent_word) for sent_word in sent_words] sorted_example_ids = sorted(range(batch_size), key=lambda x: -sent_lengths[x]) example_old_pos_map = [-1] * batch_size for new_pos, old_pos in enumerate(sorted_example_ids): example_old_pos_map[old_pos] = new_pos sorted_sent_words = [sent_words[i] for i in sorted_example_ids] sorted_sent_var = to_input_variable(sorted_sent_words, self.vocab.src, cuda=self.args.cuda, batch_first=True) if self.training and self.args.src_wd: sorted_sent_var = unk_replace(sorted_sent_var, self.step_unk_rate, self.vocab.src) sorted_sent_lengths = [ len(sent_word) for sent_word in sorted_sent_words ] _, sent_hidden = self.encoder.forward(sorted_sent_var, sorted_sent_lengths) hidden = sent_hidden[:, example_old_pos_map, :] return hidden def forward(self, examples): if not isinstance(examples, list): examples = [examples] batch_size = len(examples) ret = self.encode_to_hidden(examples) ret = self.hidden_to_latent(ret=ret, is_sampling=self.training) ret = self.latent_for_init(ret=ret) decode_init = ret['decode_init'] tgt_var = ret['tgt_var'] syntax_output = ret['syn_output'] decode_init = self.bridger.forward(decode_init) if self.training and self.args.tgt_wd: input_var = unk_replace(tgt_var, self.step_unk_rate, self.vocab.src) tgt_token_scores = self.decoder.generate( con_inputs=input_var, encoder_outputs=syntax_output, encoder_hidden=decode_init, teacher_forcing_ratio=1.0, ) reconstruct_loss = -self.decoder.score_decoding_results( tgt_token_scores, tgt_var) else: reconstruct_loss = -self.decoder.score( inputs=tgt_var, encoder_outputs=syntax_output, encoder_hidden=decode_init, ) return { "mean": ret['mean'], "logv": ret['logv'], "z": ret['latent'], 'nll_loss': reconstruct_loss, 'batch_size': batch_size } def get_loss(self, examples, step): self.step_unk_rate = wd_anneal_function( unk_max=self.unk_rate, anneal_function=self.args.unk_schedule, step=step, x0=self.args.x0, k=self.args.k) explore = self.forward(examples) kl_loss, kl_weight = self.compute_kl_loss(explore['mean'], explore['logv'], step) kl_weight *= self.args.kl_factor nll_loss = torch.sum(explore['nll_loss']) / explore['batch_size'] kl_loss = kl_loss / explore['batch_size'] kl_item = kl_loss * kl_weight loss = kl_item + nll_loss return { 'KL Loss': kl_loss, 'NLL Loss': nll_loss, 'Model Score': nll_loss + kl_loss, 'Loss': loss, 'ELBO': loss, 'KL Weight': kl_weight, 'WD Drop': self.step_unk_rate, 'KL Item': kl_item, } def batch_beam_decode(self, **kwargs): pass def latent_for_init(self, ret): z = ret['latent'] batch_size = z.size(0) hidden = self.latent2hidden(z) if self.hidden_factor > 1: hidden = hidden.view(batch_size, self.hidden_factor, self.args.enc_hidden_dim) hidden = hidden.permute(1, 0, 2) else: hidden = hidden.unsqueeze(0) ret['decode_init'] = (hidden + ret['syn_hidden']) / 2 return ret def sample_latent(self, batch_size): z = to_var(torch.randn([batch_size, self.latent_size])) return {"latent": z} def encode_to_hidden(self, examples): if not isinstance(examples, list): examples = [examples] batch_size = len(examples) sorted_example_ids = sorted(range(batch_size), key=lambda x: -len(examples[x].tgt)) example_old_pos_map = [-1] * batch_size sorted_examples = [examples[i] for i in sorted_example_ids] syntax_word = [e.tgt for e in sorted_examples] syntax_var = to_input_variable(syntax_word, self.vocab.tgt, training=False, cuda=self.args.cuda, batch_first=True) length = [len(e.tgt) for e in sorted_examples] syntax_output, syntax_hidden = self.syntax_encode(syntax_var, length) sent_words = [e.src for e in sorted_examples] sentence_hidden = self.sentence_encode(sent_words) tgt_var = to_input_variable(sent_words, self.vocab.src, training=False, cuda=self.args.cuda, append_boundary_sym=True, batch_first=True) for new_pos, old_pos in enumerate(sorted_example_ids): example_old_pos_map[old_pos] = new_pos return { 'hidden': sentence_hidden, "syn_output": syntax_output, "syn_hidden": syntax_hidden, 'tgt_var': tgt_var, 'old_pos': example_old_pos_map } def hidden_to_latent(self, ret, is_sampling): hidden = ret['hidden'] batch_size = hidden.size(1) hidden = hidden.permute(1, 0, 2).contiguous() if self.hidden_factor > 1: hidden = hidden.view(batch_size, self.args.enc_hidden_dim * self.hidden_factor) else: hidden = hidden.squeeze() mean = self.hidden2mean(hidden) logv = self.hidden2logv(hidden) if is_sampling: std = torch.exp(0.5 * logv) z = to_var(torch.randn([batch_size, self.latent_size])) z = z * std + mean else: z = mean ret["latent"] = z ret["mean"] = mean ret['logv'] = logv return ret def decode_to_sentence(self, ret): sentence_decode_init = ret['decode_init'] sentence_decode_init = self.bridger.forward( input_tensor=sentence_decode_init) decoder_outputs, decoder_hidden, ret_dict, enc_states = self.decoder.forward( inputs=None, encoder_outputs=None, encoder_hidden=sentence_decode_init, ) result = torch.stack(ret_dict['sequence']).squeeze() temp_result = [] if result.dim() < 2: result = result.unsqueeze(1) example_nums = result.size(1) for i in range(example_nums): hyp = result[:, i].data.tolist() res = id2word(hyp, self.vocab.src) seems = [[res], [len(res)]] temp_result.append(seems) final_result = [temp_result[i] for i in ret['old_pos']] return final_result