def build_encoder(opt, embeddings): """ Various encoder dispatcher function. Args: opt: the option in current environment. embeddings (Embeddings): vocab embeddings for this encoder. """ opt.input_size = 1 if opt.encoder_type == "transformer": encoder = TransformerEncoder(opt.enc_layers, opt.enc_rnn_size, opt.heads, opt.transformer_ff, opt.dropout, opt.input_size, embeddings) elif opt.encoder_type == "ctransformer": encoder = CTransformerEncoder([2, 2, 2, 2], opt.enc_layers, opt.enc_rnn_size, opt.heads, opt.transformer_ff, opt.dropout, embeddings) elif opt.encoder_type == "cnn": encoder = CNNEncoder(opt.enc_layers, opt.enc_rnn_size, opt.cnn_kernel_width, opt.dropout, opt.input_size, embeddings) elif opt.encoder_type == "mean": encoder = MeanEncoder(opt.enc_layers, opt.input_size, embeddings) elif opt.encoder_type == "nano": encoder = NanoEncoder(opt.rnn_type, opt.enc_layers, opt.dec_layers, opt.enc_rnn_size, opt.dec_rnn_size, opt.audio_enc_pooling, opt.dropout, opt.sample_rate, opt.window_size, opt.input_size) elif opt.encoder_type == "crnn": encoder = CRNNEncoder([2, 2, 2, 2], opt.rnn_type, opt.enc_layers, opt.dec_layers, opt.enc_rnn_size, opt.dec_rnn_size, opt.audio_enc_pooling, opt.dropout, opt.sample_rate, opt.window_size) elif opt.encoder_type == "resnet": if opt.decoder_type == 'cnn': encoder = ResNetEncoder(opt.enc_layers, opt.enc_rnn_size, [2, 2, 2, 2], opt.input_size, embeddings) else: encoder = ResNetForRNNEncoder(opt.dec_layers, opt.enc_rnn_size, opt.dec_rnn_size, [2, 2, 2, 2], opt.input_size, embeddings, opt.rnn_type) else: encoder = RNNEncoder(opt.rnn_type, opt.brnn, opt.enc_layers, opt.enc_rnn_size, opt.dropout, opt.input_size, embeddings, opt.bridge) return encoder
def __init__(self, args, vocab, src_embed=None, tgt_embed=None): super(SyntaxGuideVAE, self).__init__(args, vocab, name='SyntaxGuideVAE') self.latent_size = args.latent_size self.rnn_type = args.rnn_type self.bidirectional = args.bidirectional self.num_layers = args.num_layers self.unk_rate = args.unk_rate self.step_unk_rate = 0.0 self.encoder = RNNEncoder(vocab_size=len(vocab.src), max_len=args.src_max_time_step, input_size=args.enc_embed_dim, hidden_size=args.enc_hidden_dim, embed_droprate=args.enc_ed, rnn_droprate=args.enc_rd, n_layers=args.enc_num_layers, bidirectional=args.bidirectional, rnn_cell=args.rnn_type, variable_lengths=True, embedding=src_embed) self.syntax_encoder = RNNEncoder(vocab_size=len(vocab.tgt), max_len=args.tgt_max_time_step, input_size=args.enc_embed_dim, hidden_size=args.enc_hidden_dim, embed_droprate=args.enc_ed, rnn_droprate=args.enc_rd, n_layers=args.enc_num_layers, bidirectional=args.bidirectional, rnn_cell=args.rnn_type, variable_lengths=True, embedding=tgt_embed) self.hidden_size = args.enc_hidden_dim self.hidden_factor = (2 if args.bidirectional else 1) * args.enc_num_layers self.hidden2mean = nn.Linear(self.hidden_size * self.hidden_factor, self.latent_size) self.hidden2logv = nn.Linear(self.hidden_size * self.hidden_factor, self.latent_size) self.latent2hidden = nn.Linear(self.latent_size, self.hidden_size * self.hidden_factor) self.enc_dim = args.enc_hidden_dim * (2 if args.bidirectional else 1) if args.mapper_type == "link": self.dec_hidden = self.enc_dim elif args.use_attention: self.dec_hidden = self.enc_dim else: self.dec_hidden = args.dec_hidden_dim self.bridger = Bridge( rnn_type=args.rnn_type, mapper_type=args.mapper_type, encoder_dim=self.enc_dim, encoder_layer=args.enc_num_layers, decoder_dim=self.dec_hidden, decoder_layer=args.dec_num_layers, ) self.decoder = RNNDecoder( vocab=len(vocab.src), max_len=args.src_max_time_step, input_size=args.dec_embed_dim, hidden_size=self.dec_hidden, embed_droprate=args.dec_ed, rnn_droprate=args.dec_rd, n_layers=args.dec_num_layers, rnn_cell=args.rnn_type, use_attention=args.use_attention, embedding=src_embed, eos_id=vocab.src.eos_id, sos_id=vocab.src.sos_id, )
class SyntaxGuideVAE(BaseVAE): def encode(self, input_var, length): pass def conditional_generating(self, condition=None, **kwargs): pass def generating(self, sample_num, batch_size=50): return super().generating(sample_num, batch_size) def unsupervised_generating(self, sample_num, batch_size=50): return super().unsupervised_generating(sample_num, batch_size) def predict(self, examples, to_word=True): return super().predict(examples, to_word) def base_information(self): return super().base_information() def __init__(self, args, vocab, src_embed=None, tgt_embed=None): super(SyntaxGuideVAE, self).__init__(args, vocab, name='SyntaxGuideVAE') self.latent_size = args.latent_size self.rnn_type = args.rnn_type self.bidirectional = args.bidirectional self.num_layers = args.num_layers self.unk_rate = args.unk_rate self.step_unk_rate = 0.0 self.encoder = RNNEncoder(vocab_size=len(vocab.src), max_len=args.src_max_time_step, input_size=args.enc_embed_dim, hidden_size=args.enc_hidden_dim, embed_droprate=args.enc_ed, rnn_droprate=args.enc_rd, n_layers=args.enc_num_layers, bidirectional=args.bidirectional, rnn_cell=args.rnn_type, variable_lengths=True, embedding=src_embed) self.syntax_encoder = RNNEncoder(vocab_size=len(vocab.tgt), max_len=args.tgt_max_time_step, input_size=args.enc_embed_dim, hidden_size=args.enc_hidden_dim, embed_droprate=args.enc_ed, rnn_droprate=args.enc_rd, n_layers=args.enc_num_layers, bidirectional=args.bidirectional, rnn_cell=args.rnn_type, variable_lengths=True, embedding=tgt_embed) self.hidden_size = args.enc_hidden_dim self.hidden_factor = (2 if args.bidirectional else 1) * args.enc_num_layers self.hidden2mean = nn.Linear(self.hidden_size * self.hidden_factor, self.latent_size) self.hidden2logv = nn.Linear(self.hidden_size * self.hidden_factor, self.latent_size) self.latent2hidden = nn.Linear(self.latent_size, self.hidden_size * self.hidden_factor) self.enc_dim = args.enc_hidden_dim * (2 if args.bidirectional else 1) if args.mapper_type == "link": self.dec_hidden = self.enc_dim elif args.use_attention: self.dec_hidden = self.enc_dim else: self.dec_hidden = args.dec_hidden_dim self.bridger = Bridge( rnn_type=args.rnn_type, mapper_type=args.mapper_type, encoder_dim=self.enc_dim, encoder_layer=args.enc_num_layers, decoder_dim=self.dec_hidden, decoder_layer=args.dec_num_layers, ) self.decoder = RNNDecoder( vocab=len(vocab.src), max_len=args.src_max_time_step, input_size=args.dec_embed_dim, hidden_size=self.dec_hidden, embed_droprate=args.dec_ed, rnn_droprate=args.dec_rd, n_layers=args.dec_num_layers, rnn_cell=args.rnn_type, use_attention=args.use_attention, embedding=src_embed, eos_id=vocab.src.eos_id, sos_id=vocab.src.sos_id, ) def syntax_encode(self, syntax_var, length): syntax_outputs, syntax_hidden = self.syntax_encoder.forward( syntax_var, length) return syntax_outputs, syntax_hidden def sentence_encode(self, sent_words): batch_size = len(sent_words) sent_lengths = [len(sent_word) for sent_word in sent_words] sorted_example_ids = sorted(range(batch_size), key=lambda x: -sent_lengths[x]) example_old_pos_map = [-1] * batch_size for new_pos, old_pos in enumerate(sorted_example_ids): example_old_pos_map[old_pos] = new_pos sorted_sent_words = [sent_words[i] for i in sorted_example_ids] sorted_sent_var = to_input_variable(sorted_sent_words, self.vocab.src, cuda=self.args.cuda, batch_first=True) if self.training and self.args.src_wd: sorted_sent_var = unk_replace(sorted_sent_var, self.step_unk_rate, self.vocab.src) sorted_sent_lengths = [ len(sent_word) for sent_word in sorted_sent_words ] _, sent_hidden = self.encoder.forward(sorted_sent_var, sorted_sent_lengths) hidden = sent_hidden[:, example_old_pos_map, :] return hidden def forward(self, examples): if not isinstance(examples, list): examples = [examples] batch_size = len(examples) ret = self.encode_to_hidden(examples) ret = self.hidden_to_latent(ret=ret, is_sampling=self.training) ret = self.latent_for_init(ret=ret) decode_init = ret['decode_init'] tgt_var = ret['tgt_var'] syntax_output = ret['syn_output'] decode_init = self.bridger.forward(decode_init) if self.training and self.args.tgt_wd: input_var = unk_replace(tgt_var, self.step_unk_rate, self.vocab.src) tgt_token_scores = self.decoder.generate( con_inputs=input_var, encoder_outputs=syntax_output, encoder_hidden=decode_init, teacher_forcing_ratio=1.0, ) reconstruct_loss = -self.decoder.score_decoding_results( tgt_token_scores, tgt_var) else: reconstruct_loss = -self.decoder.score( inputs=tgt_var, encoder_outputs=syntax_output, encoder_hidden=decode_init, ) return { "mean": ret['mean'], "logv": ret['logv'], "z": ret['latent'], 'nll_loss': reconstruct_loss, 'batch_size': batch_size } def get_loss(self, examples, step): self.step_unk_rate = wd_anneal_function( unk_max=self.unk_rate, anneal_function=self.args.unk_schedule, step=step, x0=self.args.x0, k=self.args.k) explore = self.forward(examples) kl_loss, kl_weight = self.compute_kl_loss(explore['mean'], explore['logv'], step) kl_weight *= self.args.kl_factor nll_loss = torch.sum(explore['nll_loss']) / explore['batch_size'] kl_loss = kl_loss / explore['batch_size'] kl_item = kl_loss * kl_weight loss = kl_item + nll_loss return { 'KL Loss': kl_loss, 'NLL Loss': nll_loss, 'Model Score': nll_loss + kl_loss, 'Loss': loss, 'ELBO': loss, 'KL Weight': kl_weight, 'WD Drop': self.step_unk_rate, 'KL Item': kl_item, } def batch_beam_decode(self, **kwargs): pass def latent_for_init(self, ret): z = ret['latent'] batch_size = z.size(0) hidden = self.latent2hidden(z) if self.hidden_factor > 1: hidden = hidden.view(batch_size, self.hidden_factor, self.args.enc_hidden_dim) hidden = hidden.permute(1, 0, 2) else: hidden = hidden.unsqueeze(0) ret['decode_init'] = (hidden + ret['syn_hidden']) / 2 return ret def sample_latent(self, batch_size): z = to_var(torch.randn([batch_size, self.latent_size])) return {"latent": z} def encode_to_hidden(self, examples): if not isinstance(examples, list): examples = [examples] batch_size = len(examples) sorted_example_ids = sorted(range(batch_size), key=lambda x: -len(examples[x].tgt)) example_old_pos_map = [-1] * batch_size sorted_examples = [examples[i] for i in sorted_example_ids] syntax_word = [e.tgt for e in sorted_examples] syntax_var = to_input_variable(syntax_word, self.vocab.tgt, training=False, cuda=self.args.cuda, batch_first=True) length = [len(e.tgt) for e in sorted_examples] syntax_output, syntax_hidden = self.syntax_encode(syntax_var, length) sent_words = [e.src for e in sorted_examples] sentence_hidden = self.sentence_encode(sent_words) tgt_var = to_input_variable(sent_words, self.vocab.src, training=False, cuda=self.args.cuda, append_boundary_sym=True, batch_first=True) for new_pos, old_pos in enumerate(sorted_example_ids): example_old_pos_map[old_pos] = new_pos return { 'hidden': sentence_hidden, "syn_output": syntax_output, "syn_hidden": syntax_hidden, 'tgt_var': tgt_var, 'old_pos': example_old_pos_map } def hidden_to_latent(self, ret, is_sampling): hidden = ret['hidden'] batch_size = hidden.size(1) hidden = hidden.permute(1, 0, 2).contiguous() if self.hidden_factor > 1: hidden = hidden.view(batch_size, self.args.enc_hidden_dim * self.hidden_factor) else: hidden = hidden.squeeze() mean = self.hidden2mean(hidden) logv = self.hidden2logv(hidden) if is_sampling: std = torch.exp(0.5 * logv) z = to_var(torch.randn([batch_size, self.latent_size])) z = z * std + mean else: z = mean ret["latent"] = z ret["mean"] = mean ret['logv'] = logv return ret def decode_to_sentence(self, ret): sentence_decode_init = ret['decode_init'] sentence_decode_init = self.bridger.forward( input_tensor=sentence_decode_init) decoder_outputs, decoder_hidden, ret_dict, enc_states = self.decoder.forward( inputs=None, encoder_outputs=None, encoder_hidden=sentence_decode_init, ) result = torch.stack(ret_dict['sequence']).squeeze() temp_result = [] if result.dim() < 2: result = result.unsqueeze(1) example_nums = result.size(1) for i in range(example_nums): hyp = result[:, i].data.tolist() res = id2word(hyp, self.vocab.src) seems = [[res], [len(res)]] temp_result.append(seems) final_result = [temp_result[i] for i in ret['old_pos']] return final_result
def __init__(self, args, vocab, src_embed=None, tgt_embed=None): super(DisentangleVAE, self).__init__(args, vocab, name="Disentangle VAE with deep encoder") print("This is {} with parameter\n{}".format(self.name, self.base_information())) if src_embed is None: self.src_embed = nn.Embedding(len(vocab.src), args.embed_size) else: self.src_embed = src_embed if tgt_embed is None: self.tgt_embed = nn.Embedding(len(vocab.tgt), args.embed_size) else: self.tgt_embed = tgt_embed self.pad_idx = vocab.src.sos_id self.latent_size = int(args.latent_size) self.rnn_type = args.rnn_type self.unk_rate = args.unk_rate self.step_unk_rate = 0.0 self.direction_num = 2 if args.bidirectional else 1 self.enc_hidden_dim = args.enc_hidden_dim self.enc_layer_dim = args.enc_hidden_dim * self.direction_num self.enc_hidden_factor = self.direction_num * args.enc_num_layers self.dec_hidden_factor = args.dec_num_layers args.use_attention = False if args.mapper_type == "link": self.dec_layer_dim = self.enc_layer_dim elif args.use_attention: self.dec_layer_dim = self.enc_layer_dim else: self.dec_layer_dim = args.dec_hidden_dim syn_var_dim = int(self.enc_hidden_dim * self.enc_hidden_factor / 2) sem_var_dim = int(self.enc_hidden_dim * self.enc_hidden_factor / 2) task_enc_dim = int(self.enc_layer_dim / 2) task_dec_dim = int(self.dec_layer_dim / 2) self.encoder = RNNEncoder(vocab_size=len(vocab.src), max_len=args.src_max_time_step, input_size=args.enc_embed_dim, hidden_size=self.enc_hidden_dim, embed_droprate=args.enc_ed, rnn_droprate=args.enc_rd, n_layers=args.enc_num_layers, bidirectional=args.bidirectional, rnn_cell=args.rnn_type, variable_lengths=True, embedding=self.src_embed) # output: [layer*direction ,batch_size, enc_hidden_dim] pack_decoder = BridgeRNN( args, vocab, enc_hidden_dim=self.enc_layer_dim, dec_hidden_dim=self.dec_layer_dim, embed=self.src_embed if args.share_embed else None, mode='src') self.bridger = pack_decoder.bridger self.decoder = pack_decoder.decoder if "report" in self.args: syn_common = nn.Sequential( nn.Linear(syn_var_dim, self.latent_size * 2, True), nn.ReLU()) self.syn_mean = nn.Sequential( syn_common, nn.Linear(self.latent_size * 2, self.latent_size)) self.syn_logv = nn.Sequential( syn_common, nn.Linear(self.latent_size * 2, self.latent_size)) sem_common = nn.Sequential( nn.Linear(sem_var_dim, self.latent_size * 2, True), nn.ReLU()) self.sem_mean = nn.Sequential( sem_common, nn.Linear(self.latent_size * 2, self.latent_size)) self.sem_logv = nn.Sequential( sem_common, nn.Linear(self.latent_size * 2, self.latent_size)) else: self.syn_mean = nn.Linear(syn_var_dim, self.latent_size) self.syn_logv = nn.Linear(syn_var_dim, self.latent_size) self.sem_mean = nn.Linear(sem_var_dim, self.latent_size) self.sem_logv = nn.Linear(sem_var_dim, self.latent_size) self.syn_to_h = nn.Linear(self.latent_size, syn_var_dim) self.sem_to_h = nn.Linear(self.latent_size, sem_var_dim) self.sup_syn = BridgeRNN(args, vocab, enc_hidden_dim=task_enc_dim, dec_hidden_dim=task_dec_dim, embed=tgt_embed, mode='tgt') self.sup_sem = BridgeMLP( args=args, vocab=vocab, enc_dim=task_enc_dim, dec_hidden=task_dec_dim, ) self.syn_adv = BridgeRNN( args, vocab, enc_hidden_dim=task_enc_dim, dec_hidden_dim=task_dec_dim, embed=self.tgt_embed if args.share_embed else None, mode='tgt') self.syn_infer = BridgeRNN( args, vocab, enc_hidden_dim=task_enc_dim, dec_hidden_dim=task_dec_dim, embed=self.src_embed if args.share_embed else None, mode='src') self.sem_adv = BridgeMLP( args=args, vocab=vocab, enc_dim=task_enc_dim, dec_hidden=task_dec_dim, ) self.sem_infer = BridgeRNN( args, vocab, enc_hidden_dim=task_enc_dim, dec_hidden_dim=task_dec_dim, embed=self.src_embed if args.share_embed else None, mode='src')
class DisentangleVAE(BaseVAE): """ Encoder the sentence, predict the parser, """ def decode(self, inputs, encoder_outputs, encoder_hidden): return self.decoder.forward(inputs=inputs, encoder_outputs=encoder_outputs, encoder_hidden=encoder_hidden) def __init__(self, args, vocab, src_embed=None, tgt_embed=None): super(DisentangleVAE, self).__init__(args, vocab, name="Disentangle VAE with deep encoder") print("This is {} with parameter\n{}".format(self.name, self.base_information())) if src_embed is None: self.src_embed = nn.Embedding(len(vocab.src), args.embed_size) else: self.src_embed = src_embed if tgt_embed is None: self.tgt_embed = nn.Embedding(len(vocab.tgt), args.embed_size) else: self.tgt_embed = tgt_embed self.pad_idx = vocab.src.sos_id self.latent_size = int(args.latent_size) self.rnn_type = args.rnn_type self.unk_rate = args.unk_rate self.step_unk_rate = 0.0 self.direction_num = 2 if args.bidirectional else 1 self.enc_hidden_dim = args.enc_hidden_dim self.enc_layer_dim = args.enc_hidden_dim * self.direction_num self.enc_hidden_factor = self.direction_num * args.enc_num_layers self.dec_hidden_factor = args.dec_num_layers args.use_attention = False if args.mapper_type == "link": self.dec_layer_dim = self.enc_layer_dim elif args.use_attention: self.dec_layer_dim = self.enc_layer_dim else: self.dec_layer_dim = args.dec_hidden_dim syn_var_dim = int(self.enc_hidden_dim * self.enc_hidden_factor / 2) sem_var_dim = int(self.enc_hidden_dim * self.enc_hidden_factor / 2) task_enc_dim = int(self.enc_layer_dim / 2) task_dec_dim = int(self.dec_layer_dim / 2) self.encoder = RNNEncoder(vocab_size=len(vocab.src), max_len=args.src_max_time_step, input_size=args.enc_embed_dim, hidden_size=self.enc_hidden_dim, embed_droprate=args.enc_ed, rnn_droprate=args.enc_rd, n_layers=args.enc_num_layers, bidirectional=args.bidirectional, rnn_cell=args.rnn_type, variable_lengths=True, embedding=self.src_embed) # output: [layer*direction ,batch_size, enc_hidden_dim] pack_decoder = BridgeRNN( args, vocab, enc_hidden_dim=self.enc_layer_dim, dec_hidden_dim=self.dec_layer_dim, embed=self.src_embed if args.share_embed else None, mode='src') self.bridger = pack_decoder.bridger self.decoder = pack_decoder.decoder if "report" in self.args: syn_common = nn.Sequential( nn.Linear(syn_var_dim, self.latent_size * 2, True), nn.ReLU()) self.syn_mean = nn.Sequential( syn_common, nn.Linear(self.latent_size * 2, self.latent_size)) self.syn_logv = nn.Sequential( syn_common, nn.Linear(self.latent_size * 2, self.latent_size)) sem_common = nn.Sequential( nn.Linear(sem_var_dim, self.latent_size * 2, True), nn.ReLU()) self.sem_mean = nn.Sequential( sem_common, nn.Linear(self.latent_size * 2, self.latent_size)) self.sem_logv = nn.Sequential( sem_common, nn.Linear(self.latent_size * 2, self.latent_size)) else: self.syn_mean = nn.Linear(syn_var_dim, self.latent_size) self.syn_logv = nn.Linear(syn_var_dim, self.latent_size) self.sem_mean = nn.Linear(sem_var_dim, self.latent_size) self.sem_logv = nn.Linear(sem_var_dim, self.latent_size) self.syn_to_h = nn.Linear(self.latent_size, syn_var_dim) self.sem_to_h = nn.Linear(self.latent_size, sem_var_dim) self.sup_syn = BridgeRNN(args, vocab, enc_hidden_dim=task_enc_dim, dec_hidden_dim=task_dec_dim, embed=tgt_embed, mode='tgt') self.sup_sem = BridgeMLP( args=args, vocab=vocab, enc_dim=task_enc_dim, dec_hidden=task_dec_dim, ) self.syn_adv = BridgeRNN( args, vocab, enc_hidden_dim=task_enc_dim, dec_hidden_dim=task_dec_dim, embed=self.tgt_embed if args.share_embed else None, mode='tgt') self.syn_infer = BridgeRNN( args, vocab, enc_hidden_dim=task_enc_dim, dec_hidden_dim=task_dec_dim, embed=self.src_embed if args.share_embed else None, mode='src') self.sem_adv = BridgeMLP( args=args, vocab=vocab, enc_dim=task_enc_dim, dec_hidden=task_dec_dim, ) self.sem_infer = BridgeRNN( args, vocab, enc_hidden_dim=task_enc_dim, dec_hidden_dim=task_dec_dim, embed=self.src_embed if args.share_embed else None, mode='src') def base_information(self): origin = super().base_information() return origin \ + "mul_syn:{}\n" \ "mul_sen:{}\n" \ "adv_syn:{}\n" \ "adv_sem:{}\n" \ "inf_syn:{}\n" \ "inf_sem:{}\n" \ "kl_syn:{}\n" \ "kl_sem:{}\n".format(str(self.args.mul_syn), str(self.args.mul_sem), str(self.args.adv_syn), str(self.args.adv_sem), str(self.args.inf_syn * self.args.infer_weight), str(self.args.inf_sem * self.args.infer_weight), str(self.args.syn_weight), str(self.args.sem_weight) ) def get_gpu(self): model_list = [ self.encoder, self.bridger, self.decoder, self.syn_mean, self.syn_logv, self.syn_to_h, self.sem_mean, self.sem_logv, self.sem_to_h, self.sup_syn, self.sup_sem, self.syn_adv, self.syn_infer, self.sem_adv, self.sem_infer ] for model in model_list: device = torch.device( "cuda:0" if torch.cuda.is_available else "cpu") model = torch.nn.DataParallel(model) model.to(device) def encode(self, input_var, length): if self.training and self.args.src_wd > 0.: input_var = unk_replace(input_var, self.step_unk_rate, self.vocab.src) encoder_output, encoder_hidden = self.encoder.forward( input_var, length) return encoder_output, encoder_hidden def forward(self, examples, is_dis=False): if not isinstance(examples, list): examples = [examples] batch_size = len(examples) words = [e.src for e in examples] tgt_var = to_input_variable(words, self.vocab.src, training=False, cuda=self.args.cuda, append_boundary_sym=True, batch_first=True) syn_seqs = [e.tgt for e in examples] syn_var = to_input_variable(syn_seqs, self.vocab.tgt, training=False, cuda=self.args.cuda, append_boundary_sym=True, batch_first=True) ret = self.encode_to_hidden(examples) ret = self.hidden_to_latent(ret=ret, is_sampling=self.training) ret = self.latent_for_init(ret=ret) syn_hidden = ret['syn_hidden'] sem_hidden = ret['sem_hidden'] if is_dis: dis_syn_loss, dis_sem_loss = self.get_dis_loss( syntax_hidden=syn_hidden, semantic_hidden=sem_hidden, syn_tgt=syn_var, sem_tgt=tgt_var) ret['dis syn'] = dis_syn_loss ret['dis sem'] = dis_sem_loss return ret decode_init = ret['decode_init'] sentence_decode_init = self.bridger.forward(decode_init) if self.training and self.args.tgt_wd: input_var = unk_replace(tgt_var, self.step_unk_rate, self.vocab.src) tgt_log_score = self.decoder.generate( con_inputs=input_var, encoder_hidden=sentence_decode_init, encoder_outputs=None, teacher_forcing_ratio=1.0) reconstruct_loss = -torch.sum( self.decoder.score_decoding_results(tgt_log_score, tgt_var)) else: reconstruct_loss = -torch.sum( self.decoder.score(inputs=tgt_var, encoder_outputs=None, encoder_hidden=sentence_decode_init)) mul_syn_loss, mul_sem_loss = self.get_mul_loss( syntax_hidden=syn_hidden, semantic_hidden=sem_hidden, syn_tgt=syn_var, sem_tgt=tgt_var) adv_syn_loss, adv_sem_loss = self.get_adv_loss( syntax_hidden=syn_hidden, semantic_hidden=sem_hidden, syn_tgt=syn_var, sem_tgt=tgt_var) ret['adv'] = adv_syn_loss + adv_sem_loss ret['mul'] = mul_syn_loss + mul_sem_loss ret['nll_loss'] = reconstruct_loss ret['sem_loss'] = mul_sem_loss ret['syn_loss'] = mul_syn_loss ret['batch_size'] = batch_size return ret def get_loss(self, examples, step, is_dis=False): self.step_unk_rate = wd_anneal_function( unk_max=self.unk_rate, anneal_function=self.args.unk_schedule, step=step, x0=self.args.x0, k=self.args.k) explore = self.forward(examples, is_dis) if is_dis: return explore sem_kl, kl_weight = self.compute_kl_loss( mean=explore['sem_mean'], logv=explore['sem_logv'], step=step, ) syn_kl, _ = self.compute_kl_loss( mean=explore['syn_mean'], logv=explore['syn_logv'], step=step, ) batch_size = explore['batch_size'] kl_weight *= self.args.kl_factor kl_loss = (self.args.sem_weight * sem_kl + self.args.syn_weight * syn_kl) / (self.args.sem_weight + self.args.syn_weight) kl_loss /= batch_size mul_loss = explore['mul'] / batch_size adv_loss = explore['adv'] / batch_size nll_loss = explore['nll_loss'] / batch_size kl_item = kl_loss * kl_weight return { 'KL Loss': kl_loss, 'NLL Loss': nll_loss, 'MUL Loss': mul_loss, 'ADV Loss': adv_loss, 'KL Weight': kl_weight, 'KL Item': kl_item, 'Model Score': kl_loss + nll_loss, 'ELBO': kl_item + nll_loss, 'Loss': kl_item + nll_loss + mul_loss - adv_loss, 'SYN KL Loss': syn_kl / explore['batch_size'], 'SEM KL Loss': sem_kl / explore['batch_size'], } def get_adv_loss(self, syntax_hidden, semantic_hidden, syn_tgt, sem_tgt): if self.training: with torch.no_grad(): loss_dict = self._dis_loss(syntax_hidden, semantic_hidden, syn_tgt, sem_tgt) if self.args.infer_weight > 0.: adv_syn = self.args.adv_syn * loss_dict['adv_syn_sup'] + self.args.infer_weight * self.args.inf_sem * \ loss_dict['adv_sem_inf'] adv_sem = self.args.adv_sem * loss_dict['adv_sem_sup'] + self.args.infer_weight * self.args.inf_syn * \ loss_dict['adv_syn_inf'] else: adv_syn = self.args.adv_syn * loss_dict['adv_syn_sup'] adv_sem = self.args.adv_sem * loss_dict['adv_sem_sup'] return adv_syn, adv_sem else: loss_dict = self._dis_loss(syntax_hidden, semantic_hidden, syn_tgt, sem_tgt) if self.args.infer_weight > 0.: adv_syn = self.args.adv_syn * loss_dict['adv_syn_sup'] + self.args.infer_weight * self.args.inf_sem * \ loss_dict['adv_sem_inf'] adv_sem = self.args.adv_sem * loss_dict['adv_sem_sup'] + self.args.infer_weight * self.args.inf_syn * \ loss_dict['adv_syn_inf'] else: adv_syn = self.args.adv_syn * loss_dict['adv_syn_sup'] adv_sem = self.args.adv_sem * loss_dict['adv_sem_sup'] return adv_syn, adv_sem def get_dis_loss(self, syntax_hidden, semantic_hidden, syn_tgt, sem_tgt): syntax_hid = syntax_hidden.detach() semantic_hid = semantic_hidden.detach() loss_dict = self._dis_loss(syntax_hid, semantic_hid, syn_tgt, sem_tgt) if self.args.infer_weight > 0.: return loss_dict['adv_syn_sup'] + loss_dict[ 'adv_sem_inf'], loss_dict['adv_sem_sup'] + loss_dict[ 'adv_syn_inf'] else: return loss_dict['adv_syn_sup'], loss_dict['adv_sem_sup'] def _dis_loss(self, syntax_hidden, semantic_hidden, syn_tgt, sem_tgt): dis_syn_sup = self.syn_adv.forward(hidden=semantic_hidden, tgt_var=syn_tgt) dis_sem_sup = self.sem_adv.forward(hidden=syntax_hidden, tgt_var=sem_tgt) if self.args.infer_weight > 0.: dis_syn_inf = self.syn_infer.forward(hidden=syntax_hidden, tgt_var=sem_tgt) dis_sem_inf = self.sem_infer.forward(hidden=semantic_hidden, tgt_var=sem_tgt) return { 'adv_syn_sup': dis_syn_sup if self.args.adv_syn > 0. else 0., 'adv_sem_sup': dis_sem_sup if self.args.adv_sem > 0. else 0., 'adv_syn_inf': dis_syn_inf if self.args.inf_syn > 0. else 0., "adv_sem_inf": dis_sem_inf if self.args.inf_sem > 0. else 0. } else: return { 'adv_syn_sup': dis_syn_sup, 'adv_sem_sup': dis_sem_sup, } def get_mul_loss(self, syntax_hidden, semantic_hidden, syn_tgt, sem_tgt): syn_loss = self.sup_syn.forward(hidden=syntax_hidden, tgt_var=syn_tgt) sem_loss = self.sup_sem.forward(hidden=semantic_hidden, tgt_var=sem_tgt) return self.args.mul_syn * syn_loss, self.args.mul_sem * sem_loss def sample_latent(self, batch_size): syntax_latent = to_var(torch.randn([batch_size, self.latent_size])) semantic_latent = to_var(torch.randn([batch_size, self.latent_size])) return { "syn_z": syntax_latent, "sem_z": semantic_latent, } def hidden_to_latent(self, ret, is_sampling=True): hidden = ret['hidden'] def sampling(mean, logv): if is_sampling: std = torch.exp(0.5 * logv) z = to_var(torch.randn([batch_size, self.latent_size])) z = z * std + mean else: z = mean return z def split_hidden(encode_hidden): bs = encode_hidden.size(1) factor = encode_hidden.size(0) hid = encode_hidden.permute(1, 0, 2).contiguous().view( bs, factor, 2, -1) return hid[:, :, 0, :].contiguous().view( bs, -1), hid[:, :, 1, :].contiguous().view(bs, -1) batch_size = hidden.size(1) sem_hid, syn_hid = split_hidden(hidden) semantic_mean = self.sem_mean(sem_hid) semantic_logv = self.sem_logv(sem_hid) syntax_mean = self.syn_mean(syn_hid) syntax_logv = self.syn_logv(syn_hid) syntax_latent = sampling(syntax_mean, syntax_logv) semantic_latent = sampling(semantic_mean, semantic_logv) ret['syn_mean'] = syntax_mean ret['syn_logv'] = syntax_logv ret['sem_mean'] = semantic_mean ret['sem_logv'] = semantic_logv ret['syn_z'] = syntax_latent ret['sem_z'] = semantic_latent return ret def latent_for_init(self, ret): def reshape(xx_hidden): xx_hidden = xx_hidden.view(batch_size, self.enc_hidden_factor, self.enc_hidden_dim / 2) xx_hidden = xx_hidden.permute(1, 0, 2) return xx_hidden syntax_latent = ret['syn_z'] semantic_latent = ret['sem_z'] batch_size = semantic_latent.size(0) syntax_hidden = reshape(self.syn_to_h(syntax_latent)) semantic_hidden = reshape(self.sem_to_h(semantic_latent)) ret['syn_hidden'] = syntax_hidden ret['sem_hidden'] = semantic_hidden ret['decode_init'] = torch.cat([syntax_hidden, semantic_hidden], dim=-1) return ret def evaluate_(self, examples, beam_size=5): if not isinstance(examples, list): examples = [examples] ret = self.encode_to_hidden(examples) ret = self.hidden_to_latent(ret=ret, is_sampling=self.training) ret = self.latent_for_init(ret=ret) ret['res'] = self.decode_to_sentence(ret=ret) return ret def predict_syntax(self, hidden, predictor): result = predictor.predict(hidden) numbers = result.size(1) final_result = [] for i in range(numbers): hyp = result[:, i].data.tolist() res = id2word(hyp, self.vocab.tgt) seems = [[res], [len(res)]] final_result.append(seems) return final_result def extract_variable(self, examples): pass def eval_syntax(self, examples): ret = self.encode_to_hidden(examples, need_sort=True) ret = self.hidden_to_latent(ret, is_sampling=False) ret = self.latent_for_init(ret) return self.predict_syntax(hidden=ret['syn_hidden'], predictor=self.sup_syn) def eval_adv(self, sem_in, syn_ref): sem_ret = self.encode_to_hidden(sem_in) sem_ret = self.hidden_to_latent(sem_ret, is_sampling=self.training) syn_ret = self.encode_to_hidden(syn_ref, need_sort=True) syn_ret = self.hidden_to_latent(syn_ret, is_sampling=self.training) sem_ret = self.latent_for_init(ret=sem_ret) syn_ret = self.latent_for_init(ret=syn_ret) ret = dict(sem_z=sem_ret['sem_z'], syn_z=syn_ret['syn_z']) ret = self.latent_for_init(ret) ret['res'] = self.decode_to_sentence(ret=ret) ret['ori syn'] = self.predict_syntax(hidden=sem_ret['syn_hidden'], predictor=self.sup_syn) ret['ref syn'] = self.predict_syntax(hidden=syn_ret['syn_hidden'], predictor=self.sup_syn) return ret def conditional_generating(self, condition="sem", examples=None): ref_ret = self.encode_to_hidden(examples) ref_ret = self.hidden_to_latent(ref_ret, is_sampling=True) if condition.startswith("sem"): ref_ret['sem_z'] = ref_ret['sem_mean'] else: ref_ret['syn_z'] = ref_ret['syn_mean'] if condition == "sem-only": sam_ref = self.sample_latent(batch_size=ref_ret['batch_size']) ref_ret['syn_z'] = sam_ref['syn_z'] ret = self.latent_for_init(ret=ref_ret) ret['res'] = self.decode_to_sentence(ret=ret) return ret
def __init__(self, args, vocab, word_embed=None): super(VanillaVAE, self).__init__(args, vocab, name='MySentVAE') print("This is {} with parameter\n{}".format(self.name, self.base_information())) if word_embed is None: src_embed = nn.Embedding(len(vocab.src), args.embed_size) else: src_embed = word_embed if args.share_embed: tgt_embed = src_embed args.dec_embed_dim = args.enc_embed_dim else: tgt_embed = None self.latent_size = args.latent_size self.unk_rate = args.unk_rate self.step_unk_rate = 0.0 self.hidden_size = args.enc_hidden_dim self.hidden_factor = (2 if args.bidirectional else 1) * args.enc_num_layers args.use_attention = False # layer size setting self.enc_dim = args.enc_hidden_dim * (2 if args.bidirectional else 1 ) # single layer unit size if args.mapper_type == "link": self.dec_hidden = self.enc_dim else: self.dec_hidden = args.dec_hidden_dim self.encoder = RNNEncoder(vocab_size=len(vocab.src), max_len=args.src_max_time_step, input_size=args.enc_embed_dim, hidden_size=args.enc_hidden_dim, embed_droprate=args.enc_ed, rnn_droprate=args.enc_rd, n_layers=args.enc_num_layers, bidirectional=args.bidirectional, rnn_cell=args.rnn_type, variable_lengths=True, embedding=src_embed) self.bridger = Bridge( rnn_type=args.rnn_type, mapper_type=args.mapper_type, encoder_dim=self.enc_dim, encoder_layer=args.enc_num_layers, decoder_dim=self.dec_hidden, decoder_layer=args.dec_num_layers, ) self.decoder = RNNDecoder( vocab=len(vocab.src), max_len=args.src_max_time_step, input_size=args.dec_embed_dim, hidden_size=self.dec_hidden, embed_droprate=args.dec_ed, rnn_droprate=args.dec_rd, n_layers=args.dec_num_layers, rnn_cell=args.rnn_type, use_attention=False, embedding=tgt_embed, eos_id=vocab.src.eos_id, sos_id=vocab.src.sos_id, ) self.hidden2mean = nn.Linear(args.hidden_size * self.hidden_factor, args.latent_size) self.hidden2logv = nn.Linear(args.hidden_size * self.hidden_factor, args.latent_size) self.latent2hidden = nn.Linear(args.latent_size, args.hidden_size * self.hidden_factor)
class VanillaVAE(BaseVAE): def __init__(self, args, vocab, word_embed=None): super(VanillaVAE, self).__init__(args, vocab, name='MySentVAE') print("This is {} with parameter\n{}".format(self.name, self.base_information())) if word_embed is None: src_embed = nn.Embedding(len(vocab.src), args.embed_size) else: src_embed = word_embed if args.share_embed: tgt_embed = src_embed args.dec_embed_dim = args.enc_embed_dim else: tgt_embed = None self.latent_size = args.latent_size self.unk_rate = args.unk_rate self.step_unk_rate = 0.0 self.hidden_size = args.enc_hidden_dim self.hidden_factor = (2 if args.bidirectional else 1) * args.enc_num_layers args.use_attention = False # layer size setting self.enc_dim = args.enc_hidden_dim * (2 if args.bidirectional else 1 ) # single layer unit size if args.mapper_type == "link": self.dec_hidden = self.enc_dim else: self.dec_hidden = args.dec_hidden_dim self.encoder = RNNEncoder(vocab_size=len(vocab.src), max_len=args.src_max_time_step, input_size=args.enc_embed_dim, hidden_size=args.enc_hidden_dim, embed_droprate=args.enc_ed, rnn_droprate=args.enc_rd, n_layers=args.enc_num_layers, bidirectional=args.bidirectional, rnn_cell=args.rnn_type, variable_lengths=True, embedding=src_embed) self.bridger = Bridge( rnn_type=args.rnn_type, mapper_type=args.mapper_type, encoder_dim=self.enc_dim, encoder_layer=args.enc_num_layers, decoder_dim=self.dec_hidden, decoder_layer=args.dec_num_layers, ) self.decoder = RNNDecoder( vocab=len(vocab.src), max_len=args.src_max_time_step, input_size=args.dec_embed_dim, hidden_size=self.dec_hidden, embed_droprate=args.dec_ed, rnn_droprate=args.dec_rd, n_layers=args.dec_num_layers, rnn_cell=args.rnn_type, use_attention=False, embedding=tgt_embed, eos_id=vocab.src.eos_id, sos_id=vocab.src.sos_id, ) self.hidden2mean = nn.Linear(args.hidden_size * self.hidden_factor, args.latent_size) self.hidden2logv = nn.Linear(args.hidden_size * self.hidden_factor, args.latent_size) self.latent2hidden = nn.Linear(args.latent_size, args.hidden_size * self.hidden_factor) def encode(self, input_var, length): if self.training and self.args.src_wd: input_var = unk_replace(input_var, self.step_unk_rate, self.vocab.src) encoder_output, encoder_hidden = self.encoder.forward( input_var, length) return encoder_output, encoder_hidden def decode(self, inputs, encoder_outputs, encoder_hidden): return self.decoder.forward(inputs=inputs, encoder_outputs=encoder_outputs, encoder_hidden=encoder_hidden) def forward(self, examples): if not isinstance(examples, list): examples = [examples] batch_size = len(examples) sent_words = [e.src for e in examples] ret = self.encode_to_hidden(examples) ret = self.hidden_to_latent(ret=ret, is_sampling=self.training) ret = self.latent_for_init(ret=ret) decode_init = ret['decode_init'] tgt_var = to_input_variable(sent_words, self.vocab.src, training=False, cuda=self.args.cuda, append_boundary_sym=True, batch_first=True) decode_init = self.bridger.forward(decode_init) if self.training and self.args.tgt_wd > 0.: input_var = unk_replace(tgt_var, self.step_unk_rate, self.vocab.src) tgt_token_scores = self.decoder.generate( con_inputs=input_var, encoder_hidden=decode_init, encoder_outputs=None, teacher_forcing_ratio=1.0, ) reconstruct_loss = -torch.sum( self.decoder.score_decoding_results(tgt_token_scores, tgt_var)) else: reconstruct_loss = -torch.sum( self.decoder.score( inputs=tgt_var, encoder_outputs=None, encoder_hidden=decode_init, )) return { "mean": ret['mean'], "logv": ret['logv'], "z": ret['latent'], 'nll_loss': reconstruct_loss, 'batch_size': batch_size } def get_loss(self, examples, step): self.step_unk_rate = wd_anneal_function( unk_max=self.unk_rate, anneal_function=self.args.unk_schedule, step=step, x0=self.args.x0, k=self.args.k) explore = self.forward(examples) batch_size = explore['batch_size'] kl_loss, kl_weight = self.compute_kl_loss(explore['mean'], explore['logv'], step) kl_weight *= self.args.kl_factor nll_loss = explore['nll_loss'] / batch_size kl_loss = kl_loss / batch_size kl_item = kl_loss * kl_weight return { 'KL Loss': kl_loss, 'NLL Loss': nll_loss, 'KL Weight': kl_weight, 'Model Score': kl_loss + nll_loss, 'ELBO': kl_item + nll_loss, 'Loss': kl_item + nll_loss, 'KL Item': kl_item, } def sample_latent(self, batch_size): z = to_var(torch.randn([batch_size, self.latent_size])) return {"latent": z} def latent_for_init(self, ret): z = ret['latent'] batch_size = z.size(0) hidden = self.latent2hidden(z) if self.hidden_factor > 1: hidden = hidden.view(batch_size, self.hidden_factor, self.hidden_size) hidden = hidden.permute(1, 0, 2) else: hidden = hidden.unsqueeze(0) ret['decode_init'] = hidden return ret def batch_beam_decode(self, examples): raise NotImplementedError def hidden_to_latent(self, ret, is_sampling=True): hidden = ret['hidden'] batch_size = hidden.size(1) hidden = hidden.permute(1, 0, 2).contiguous() if self.hidden_factor > 1: hidden = hidden.view(batch_size, self.hidden_size * self.hidden_factor) else: hidden = hidden.squeeze() mean = self.hidden2mean(hidden) logv = self.hidden2logv(hidden) if is_sampling: std = torch.exp(0.5 * logv) z = to_var(torch.randn([batch_size, self.latent_size])) z = z * std + mean else: z = mean ret["latent"] = z ret["mean"] = mean ret['logv'] = logv return ret def conditional_generating(self, condition='sem', examples=None): if not isinstance(examples, list): examples = [examples] if condition.startswith("sem"): ret = self.encode_to_hidden(examples) ret = self.hidden_to_latent(ret=ret, is_sampling=True) ret = self.latent_for_init(ret=ret) return {'res': self.decode_to_sentence(ret=ret)} if condition is None: return { "res": self.unsupervised_generating(sample_num=len(examples)) } def eval_adv(self, sem_in, syn_ref): sem_ret = self.encode_to_hidden(sem_in) sem_ret = self.hidden_to_latent(sem_ret, is_sampling=self.training) syn_ret = self.encode_to_hidden(syn_ref, need_sort=True) syn_ret = self.hidden_to_latent(syn_ret, is_sampling=self.training) sem_ret = self.latent_for_init(ret=sem_ret) syn_ret = self.latent_for_init(ret=syn_ret) ret = dict() ret["latent"] = (sem_ret['latent'] + syn_ret['latent']) * 0.5 ret = self.latent_for_init(ret) ret['res'] = self.decode_to_sentence(ret=ret) return ret