def __init__(self, opts, word2syllable, pretrained_emb_weights): super(TransformerCLM, self).__init__() self.opts = opts self.PAD_token = opts.PAD_token self.SOS_token = opts.SOS_token self.EOS_token = opts.EOS_token self.UNK_token = opts.UNK_token self.SEP_token = opts.SEP_token self.KEYWORD_token = opts.KEYWORD_token self.keyword_approaches = opts.keyword_approaches self.embedding = _build_embeddings(opts, word2syllable, pretrained_emb_weights, use_pos_emb=True) self.decoder = _build_stacked_transformer_layers(opts, opts.num_layers, opts.num_self_attn_per_layer) self.hidden2emb, self.out_proj = _build_output_layers( opts, self.embedding.get_word_emb_out_proj_weight(), self.embedding.get_word_emb_weight()) if opts.use_logits_mask: self.logits_mask_layer = LogitsMaskLayer( self.embedding.get_word2syllable_buffer(), opts.SEP_token, opts.UNK_token, opts.KEYWORD_token) self._reset_parameters()
def __init__(self, opts, word2syllable, pretrained_emb_weights): super(ModelBase, self).__init__() self.opts = opts self.PAD_token = opts.PAD_token self.SOS_token = opts.SOS_token self.EOS_token = opts.EOS_token self.UNK_token = opts.UNK_token self.SEP_token = opts.SEP_token self.embedding = _build_embeddings(opts, word2syllable, pretrained_emb_weights) self.latent_module = VADLatentModule(opts.dec_hidden_size, opts.latent_dim, opts.latent_use_tanh) if opts.use_bow_loss: self.bow_proj_layer = nn.Linear( opts.latent_dim + opts.dec_hidden_size, opts.vocab_size) self.emb2states_fwd = nn.Linear( opts.emb_out_dim, opts.dec_num_layers * opts.dec_hidden_size) self.emb2states_bwd = nn.Linear( opts.emb_out_dim, opts.dec_num_layers * opts.dec_hidden_size) self.fwd_decoder = StackedGRUCell(opts.emb_out_dim + opts.latent_dim, opts.dec_hidden_size, opts.dec_num_layers, opts.dropout, opts.use_layer_norm, opts.layer_norm_trainable) self.bwd_decoder = StackedGRUCell(opts.emb_out_dim, opts.dec_hidden_size, opts.dec_num_layers, opts.dropout, opts.use_layer_norm, opts.layer_norm_trainable) out_dim = opts.dec_hidden_size if opts.latent_out_attach: out_dim += opts.latent_dim self.fwd_out_proj = nn.Linear(out_dim, opts.vocab_size, bias=False) if opts.need_bwd_out_proj_layer: self.bwd_out_proj = nn.Linear(opts.dec_hidden_size, opts.vocab_size, bias=False) if opts.fwd_use_logits_mask: self.fwd_logits_mask_layer = LogitsMaskLayer( self.embedding.get_word2syllable_buffer(), opts.SEP_token, opts.UNK_token) if opts.bwd_use_logits_mask: self.bwd_logits_mask_layer = LogitsMaskLayer( self.embedding.get_word2syllable_buffer(), opts.SEP_token, opts.UNK_token) self._reset_parameters()
def __init__(self, opts, word2syllable, pretrained_emb_weights): super(SequentialTransformerCVAE, self).__init__() self.opts = opts self.PAD_token = opts.PAD_token self.SOS_token = opts.SOS_token self.EOS_token = opts.EOS_token self.UNK_token = opts.UNK_token self.SEP_token = opts.SEP_token self.KEYWORD_token = opts.KEYWORD_token self.keyword_approaches = opts.keyword_approaches self.embedding = _build_embeddings(opts, word2syllable, pretrained_emb_weights, use_pos_emb=True) self.prior_encoder = _build_stacked_transformer_layers( opts, opts.num_layers_before_latent, opts.num_self_attn_per_layer_before_latent) self.recognition_encoder = _build_stacked_transformer_layers( opts, opts.num_layers_before_latent, opts.num_self_attn_per_layer_before_latent, prebuilt_layers=self.prior_encoder) self.latent_module = VTLatentModule(opts.d_model, opts.latent_dim, opts.latent_use_tanh) self.drop = nn.Dropout(opts.dropout) self.norm = nn.LayerNorm(opts.d_model) if opts.use_bow_loss: self.bow_proj_layer = nn.Linear(opts.latent_dim + opts.d_model, opts.vocab_size) self.decoder = _build_stacked_transformer_layers( opts, opts.num_layers - opts.num_layers_before_latent, opts.num_self_attn_per_layer_after_latent) self.hidden2emb, self.out_proj = _build_output_layers( opts, self.embedding.get_word_emb_out_proj_weight(), self.embedding.get_word_emb_weight()) if opts.use_logits_mask: self.logits_mask_layer = LogitsMaskLayer( self.embedding.get_word2syllable_buffer(), opts.SEP_token, opts.UNK_token, opts.KEYWORD_token) self._reset_parameters()
def __init__(self, opts, word2syllable, pretrained_emb_weights): super(TransformerCVAE, self).__init__() self.opts = opts self.PAD_token = opts.PAD_token self.SOS_token = opts.SOS_token self.EOS_token = opts.EOS_token self.UNK_token = opts.UNK_token self.SEP_token = opts.SEP_token self.CLS_token = opts.CLS_token self.KEYWORD_token = opts.KEYWORD_token self.keyword_approaches = opts.keyword_approaches self.embedding = _build_embeddings(opts, word2syllable, pretrained_emb_weights, use_pos_emb=True) self.encoder = _build_stacked_transformer_layers( opts, opts.num_layers_before_latent, opts.num_self_attn_per_layer_before_latent) self.latent_module = TLatentModule( opts.d_model, opts.d_model, opts.latent_dim, opts.latent_use_tanh, "klatent" in opts.keyword_approaches) if opts.use_bow_loss: bow_inp_dim = opts.latent_dim if "klatent" in opts.keyword_approaches: bow_inp_dim += opts.d_model self.bow_proj_layer = nn.Linear(bow_inp_dim, opts.vocab_size) self.decoder = _build_stacked_transformer_layers( opts, opts.num_layers - opts.num_layers_before_latent, opts.num_self_attn_per_layer_after_latent) self.hidden2emb, self.out_proj = _build_output_layers( opts, self.embedding.get_word_emb_out_proj_weight(), self.embedding.get_word_emb_weight()) if opts.use_logits_mask: self.logits_mask_layer = LogitsMaskLayer( self.embedding.get_word2syllable_buffer(), opts.SEP_token, opts.UNK_token, opts.KEYWORD_token) self._reset_parameters()
class TransformerCLM(ModelBase): def __init__(self, opts, word2syllable, pretrained_emb_weights): super(TransformerCLM, self).__init__() self.opts = opts self.PAD_token = opts.PAD_token self.SOS_token = opts.SOS_token self.EOS_token = opts.EOS_token self.UNK_token = opts.UNK_token self.SEP_token = opts.SEP_token self.KEYWORD_token = opts.KEYWORD_token self.keyword_approaches = opts.keyword_approaches self.embedding = _build_embeddings(opts, word2syllable, pretrained_emb_weights, use_pos_emb=True) self.decoder = _build_stacked_transformer_layers(opts, opts.num_layers, opts.num_self_attn_per_layer) self.hidden2emb, self.out_proj = _build_output_layers( opts, self.embedding.get_word_emb_out_proj_weight(), self.embedding.get_word_emb_weight()) if opts.use_logits_mask: self.logits_mask_layer = LogitsMaskLayer( self.embedding.get_word2syllable_buffer(), opts.SEP_token, opts.UNK_token, opts.KEYWORD_token) self._reset_parameters() def _reset_parameters(self): if not self.opts.weight_typing: xavier_uniform_fan_in_(self.out_proj.weight) def _forward_all(self, input, keyword_ids, segment_ids=None, remain_syllables=None, use_cache=False): padding_mask = self._get_padding_mask(input, revise_for_khead="khead" in self.opts.keyword_approaches, revise_for_cls=False) attn_masks = self._get_attn_masks(segment_ids, input.size(0), input.device, triangle=True, revise_for_khead="khead" in self.opts.keyword_approaches, revise_for_cls=False) embedded = self.embedding(input, (0, 0), segment_ids=segment_ids) if "khead" in self.keyword_approaches: keyword_embs = self.embedding.forward_word_emb(keyword_ids) embedded = torch.cat([keyword_embs.unsqueeze(0), embedded], dim=0) if self.opts.hierarchical_model: dec_out = self._forward_layers_hierarchical("decoder", embedded, attn_masks, padding_mask) else: dec_out = self._forward_layers("decoder", embedded, attn_masks, padding_mask) if "khead" in self.keyword_approaches: dec_out = dec_out[1:] logits = self.out_proj(self.hidden2emb(dec_out)) if self.opts.use_logits_mask: logits = self.logits_mask_layer( logits, use_cache=use_cache, remain_syllables=remain_syllables, decoder_input=input, solve_ktoken="ktoken" in self.keyword_approaches, keyword_ids=keyword_ids, sample_n_to_check=1) return logits # src(without sos): [seq_len, batch_size] # keyword_ids: [batch_size] def forward(self, src, keyword_ids, segment_ids=None, remain_syllables=None, mode="train"): assert mode in ("train", "valid", "test") if mode != "train": assert not self.training else: assert self.training sos = torch.full((src.size(1),), fill_value=self.SOS_token, dtype=torch.long, device=src.device) input = torch.cat([sos.unsqueeze(0), src[:-1]], dim=0) if self.opts.need_segment_ids: if segment_ids is not None: segment_ids = torch.cat([torch.zeros_like(segment_ids[:1]), segment_ids[:-1]], dim=0) else: segment_ids = self.embedding.get_segment_ids(input) logits = self._forward_all(input, keyword_ids, segment_ids, remain_syllables) return (logits,) # keyword_ids: [batch_size] def generate(self, keyword_ids, approach, gen_options): assert not self.training assert keyword_ids is not None assert approach in ("beam", "greedy") return getattr(self, "_gen_{}".format(approach))(keyword_ids, **gen_options) # input: [seq_len, batch_size], the first token of each sequence should be <SOS> # keyword_ids: [batch_size] def _gen_forward_step(self, input, keyword_ids, use_cache=False): segment_ids = None if self.opts.need_segment_ids: segment_ids = self.embedding.get_segment_ids(input, use_cache=use_cache, restrict=False) return self._forward_all(input, keyword_ids, segment_ids, use_cache=use_cache) def _gen_greedy(self, keyword_ids, **kwargs): batch_size = keyword_ids.size(0) device = keyword_ids.device max_seq_len = self.opts.gen_max_seq_len input = torch.full((1, batch_size), self.SOS_token, dtype=torch.long, device=device) lens = torch.full((batch_size,), max_seq_len, dtype=torch.long, device=device) output_steps = [] for step in range(max_seq_len): logits = self._gen_forward_step(input, keyword_ids, use_cache=True) out_step = logits[-1].argmax(dim=-1, keepdim=False) output_steps.append(out_step.clone()) lens[(out_step == self.EOS_token) & (lens == max_seq_len)] = logits.size(0) if step == max_seq_len - 1 or (lens < max_seq_len).all(): break if "ktoken" in self.keyword_approaches: mask = out_step == self.KEYWORD_token out_step[mask] = keyword_ids[mask] input = torch.cat([input, out_step.unsqueeze(0)], dim=0) output = torch.stack(output_steps, dim=0) if self.opts.need_segment_ids: self.embedding.clear_segment_emb_cache() if self.opts.need_remain_syllables: self.logits_mask_layer.clear_cache() return output def _gen_beam(self, keyword_ids, **kwargs): device = keyword_ids.device batch_size = keyword_ids.size(0) max_seq_len = self.opts.gen_max_seq_len beam_width = kwargs["beam_width"] length_norm = kwargs["length_norm"] n_best = kwargs["n_best"] input = torch.full((1, batch_size), fill_value=self.SOS_token, dtype=torch.long, device=device) output_step = torch.zeros(batch_size * beam_width, dtype=torch.long, device=device) back_pointers = torch.zeros(batch_size * beam_width, dtype=torch.long, device=device) batch_beams = [Beam(beam_width, length_norm, self.EOS_token, n_best) for _ in range(batch_size)] # first step logits_step = self._gen_forward_step(input, keyword_ids, use_cache=False)[-1] step_batch_beams(batch_beams, logits_step, output_step, func="init_beams") if keyword_ids is not None: keyword_ids = keyword_ids.repeat_interleave(beam_width, dim=0) if "ktoken" in self.keyword_approaches: mask = output_step == self.KEYWORD_token output_step[mask] = keyword_ids[mask] # remain steps input = input.repeat_interleave(beam_width, dim=1) input = torch.cat([input, output_step.unsqueeze(0)], dim=0) for _ in range(1, max_seq_len): logits = self._gen_forward_step(input, keyword_ids, use_cache=False) logits_step = logits[-1].view(batch_size, beam_width, -1) step_batch_beams(batch_beams, logits_step, output_step, back_pointers, func="update_beams") if all(b.done for b in batch_beams): break if "ktoken" in self.keyword_approaches: mask = output_step == self.KEYWORD_token output_step[mask] = keyword_ids[mask] input = input.index_select(dim=1, index=back_pointers) input = torch.cat([input, output_step.unsqueeze(0)], dim=0) output = list(chain(*(beam.get_best_results()[0] for beam in batch_beams))) output = bidirectional_padding(output, self.PAD_token, 0, device=device)[0] return output
class TransformerCVAE(ModelBase): def __init__(self, opts, word2syllable, pretrained_emb_weights): super(TransformerCVAE, self).__init__() self.opts = opts self.PAD_token = opts.PAD_token self.SOS_token = opts.SOS_token self.EOS_token = opts.EOS_token self.UNK_token = opts.UNK_token self.SEP_token = opts.SEP_token self.CLS_token = opts.CLS_token self.KEYWORD_token = opts.KEYWORD_token self.keyword_approaches = opts.keyword_approaches self.embedding = _build_embeddings(opts, word2syllable, pretrained_emb_weights, use_pos_emb=True) self.encoder = _build_stacked_transformer_layers( opts, opts.num_layers_before_latent, opts.num_self_attn_per_layer_before_latent) self.latent_module = TLatentModule( opts.d_model, opts.d_model, opts.latent_dim, opts.latent_use_tanh, "klatent" in opts.keyword_approaches) if opts.use_bow_loss: bow_inp_dim = opts.latent_dim if "klatent" in opts.keyword_approaches: bow_inp_dim += opts.d_model self.bow_proj_layer = nn.Linear(bow_inp_dim, opts.vocab_size) self.decoder = _build_stacked_transformer_layers( opts, opts.num_layers - opts.num_layers_before_latent, opts.num_self_attn_per_layer_after_latent) self.hidden2emb, self.out_proj = _build_output_layers( opts, self.embedding.get_word_emb_out_proj_weight(), self.embedding.get_word_emb_weight()) if opts.use_logits_mask: self.logits_mask_layer = LogitsMaskLayer( self.embedding.get_word2syllable_buffer(), opts.SEP_token, opts.UNK_token, opts.KEYWORD_token) self._reset_parameters() def _reset_parameters(self): if hasattr(self, "bow_proj_layer"): xavier_uniform_fan_in_(self.bow_proj_layer.weight) nn.init.zeros_(self.bow_proj_layer.bias) if not self.opts.weight_typing: xavier_uniform_fan_in_(self.out_proj.weight) # src(without sos): [seq_len, batch_size] # keyword_ids: [batch_size] def forward(self, src, keyword_ids, segment_ids=None, remain_syllables=None, mode="train"): assert mode in ("train", "valid", "test") if mode != "train": assert not self.training else: assert self.training if self.opts.need_segment_ids and segment_ids is None: segment_ids = self.embedding.get_segment_ids(src) padding_mask = self._get_padding_mask(src, revise_for_khead=False, revise_for_cls=True) attn_masks = self._get_attn_masks(segment_ids, src.size(0), src.device, triangle=False, revise_for_khead=False, revise_for_cls=True) embedded = self.embedding(src, (0, 0), segment_ids=segment_ids) cls = torch.full((src.size(1),), fill_value=self.CLS_token, dtype=torch.long, device=src.device) cls_embs = self.embedding.forward_word_emb(cls) embedded = torch.cat([cls_embs.unsqueeze(0), embedded], dim=0) if self.opts.hierarchical_before_latent: enc_out = self._forward_layers_hierarchical("encoder", embedded, attn_masks, padding_mask) elif self.opts.hierarchical_model: enc_out = self._forward_layers("encoder", embedded, attn_masks[-1], padding_mask) else: enc_out = self._forward_layers("encoder", embedded, attn_masks, padding_mask) enc_hidden = enc_out[0] sample_n = self.opts.train_sample_n if mode == "train" else self.opts.test_sample_n keyword_embs = None if keyword_ids is not None: keyword_embs = self.embedding.forward_word_emb(keyword_ids) mu_p, log_var_p, mu_r, log_var_r, latent_vector, latent_out = self.latent_module.forward_train_path( enc_hidden, keyword_embs, sample_n=sample_n) # mu/log_var: [batch_size, latent_dim]; latent_out/latent_vec: [sample_n, batch_size, embedding_dim/latent_dim] bow_logits = None if self.opts.use_bow_loss: if keyword_embs is not None: keyword_embs_expanded = keyword_embs.unsqueeze(0).expand(sample_n, -1, -1) bow_inp = torch.cat([latent_vector, keyword_embs_expanded], dim=-1) else: bow_inp = latent_vector bow_inp = bow_inp.view(sample_n * bow_inp.size(1), -1) bow_logits = self.bow_proj_layer(bow_inp) bow_logits = bow_logits.unsqueeze(0).expand(src.size(0), -1, -1) sos = torch.full((src.size(1),), fill_value=self.SOS_token, dtype=torch.long, device=src.device) dec_input = torch.cat([sos.unsqueeze(0), src[:-1]], dim=0) if self.opts.need_segment_ids: segment_ids = torch.cat([torch.zeros_like(segment_ids[:1]), segment_ids[:-1]], dim=0) padding_mask = self._get_padding_mask(dec_input, revise_for_khead="khead" in self.opts.keyword_approaches, revise_for_cls=False) attn_masks = self._get_attn_masks(segment_ids, dec_input.size(0), dec_input.device, triangle=True, revise_for_khead="khead" in self.opts.keyword_approaches, revise_for_cls=False) embedded = self.embedding(dec_input, (0, 0), segment_ids=segment_ids) if sample_n > 1: padding_mask = self._expand_padding_mask(padding_mask, sample_n) attn_masks = self._expand_attn_masks(attn_masks, sample_n) embedded = embedded.repeat(1, sample_n, 1) if "khead" in self.keyword_approaches: keyword_embs = keyword_embs.repeat(sample_n, 1) embedded[0] = embedded[0] + latent_out.view(sample_n * latent_out.size(1), latent_out.size(-1)) if "khead" in self.keyword_approaches: embedded = torch.cat([keyword_embs.unsqueeze(0), embedded], dim=0) if self.opts.hierarchical_after_latent: dec_out = self._forward_layers_hierarchical("decoder", embedded, attn_masks, padding_mask) elif self.opts.hierarchical_model: dec_out = self._forward_layers("decoder", embedded, attn_masks[-1], padding_mask) else: dec_out = self._forward_layers("decoder", embedded, attn_masks, padding_mask) if "khead" in self.keyword_approaches: dec_out = dec_out[1:] logits = self.out_proj(self.hidden2emb(dec_out)) if self.opts.use_logits_mask: logits = self.logits_mask_layer( logits, remain_syllables=remain_syllables, decoder_input=dec_input, solve_ktoken="ktoken" in self.keyword_approaches, keyword_ids=keyword_ids, sample_n_to_check=sample_n) return logits, mu_p, log_var_p, mu_r, log_var_r, bow_logits # normal_vector ~ N(0,1): [batch_size, latent_dim] # keyword_ids: [batch_size] def generate(self, keyword_ids, normal_vector, approach, gen_options): assert not self.training assert approach in ("beam", "greedy") return getattr(self, "_gen_{}".format(approach))(keyword_ids, normal_vector, **gen_options) # input: [seq_len, batch_size], the first token of each sequence should be <SOS> # keyword_ids: [batch_size] # latent_out: [batch_size, embedding_dim] def _gen_forward_step(self, input, keyword_ids, keyword_embs, latent_out, use_cache=False): segment_ids = None if self.opts.need_segment_ids: segment_ids = self.embedding.get_segment_ids(input, use_cache=use_cache, restrict=False) padding_mask = self._get_padding_mask(input, revise_for_khead="khead" in self.opts.keyword_approaches, revise_for_cls=False) attn_masks = self._get_attn_masks(segment_ids, input.size(0), input.device, triangle=True, revise_for_khead="khead" in self.opts.keyword_approaches, revise_for_cls=False) embedded = self.embedding(input, (0, 0), segment_ids=segment_ids) embedded[0] = embedded[0] + latent_out if "khead" in self.keyword_approaches: embedded = torch.cat([keyword_embs.unsqueeze(0), embedded], dim=0) if self.opts.hierarchical_after_latent: dec_out = self._forward_layers_hierarchical("decoder", embedded, attn_masks, padding_mask) elif self.opts.hierarchical_model: dec_out = self._forward_layers("decoder", embedded, attn_masks[-1], padding_mask) else: dec_out = self._forward_layers("decoder", embedded, attn_masks, padding_mask) if "khead" in self.keyword_approaches: dec_out = dec_out[1:] logits = self.out_proj(self.hidden2emb(dec_out)) if self.opts.use_logits_mask: logits = self.logits_mask_layer( logits, use_cache=use_cache, decoder_input=input, solve_ktoken="ktoken" in self.keyword_approaches, keyword_ids=keyword_ids, sample_n_to_check=1) return logits def _gen_greedy(self, keyword_ids, normal_vector, **kwargs): batch_size = normal_vector.size(0) dtype = normal_vector.dtype device = normal_vector.device max_seq_len = self.opts.gen_max_seq_len keyword_embs = None if keyword_ids is not None: keyword_embs = self.embedding.forward_word_emb(keyword_ids) latent_out = self.latent_module.forward_gen_path(keyword_embs, normal_vector, head_dims=[], batch_size=batch_size, dtype=dtype, device=device)[1].squeeze(0) input = torch.full((1, batch_size), self.SOS_token, dtype=torch.long, device=device) lens = torch.full((batch_size,), max_seq_len, dtype=torch.long, device=device) output_steps = [] for step in range(max_seq_len): logits = self._gen_forward_step(input, keyword_ids, keyword_embs, latent_out, use_cache=True) out_step = logits[-1].argmax(dim=-1, keepdim=False) output_steps.append(out_step.clone()) lens[(out_step == self.EOS_token) & (lens == max_seq_len)] = logits.size(0) if step == max_seq_len - 1 or (lens < max_seq_len).all(): break if "ktoken" in self.keyword_approaches: mask = out_step == self.KEYWORD_token out_step[mask] = keyword_ids[mask] input = torch.cat([input, out_step.unsqueeze(0)], dim=0) output = torch.stack(output_steps, dim=0) if self.opts.need_segment_ids: self.embedding.clear_segment_emb_cache() if self.opts.need_remain_syllables: self.logits_mask_layer.clear_cache() return output def _gen_beam(self, keyword_ids, normal_vector, **kwargs): dtype = normal_vector.dtype device = normal_vector.device batch_size, latent_dim = normal_vector.size() max_seq_len = self.opts.gen_max_seq_len beam_width = kwargs["beam_width"] length_norm = kwargs["length_norm"] n_best = kwargs["n_best"] keyword_embs = None if keyword_ids is not None: keyword_embs = self.embedding.forward_word_emb(keyword_ids) latent_out = self.latent_module.forward_gen_path(keyword_embs, normal_vector, head_dims=[], batch_size=batch_size, dtype=dtype, device=device)[1].squeeze(0) input = torch.full((1, batch_size), fill_value=self.SOS_token, dtype=torch.long, device=device) output_step = torch.zeros(batch_size * beam_width, dtype=torch.long, device=device) back_pointers = torch.zeros(batch_size * beam_width, dtype=torch.long, device=device) batch_beams = [Beam(beam_width, length_norm, self.EOS_token, n_best) for _ in range(batch_size)] # first step logits_step = self._gen_forward_step(input, keyword_ids, keyword_embs, latent_out, use_cache=False)[-1] step_batch_beams(batch_beams, logits_step, output_step, func="init_beams") if keyword_ids is not None: keyword_ids = keyword_ids.repeat_interleave(beam_width, dim=0) if "ktoken" in self.keyword_approaches: mask = output_step == self.KEYWORD_token output_step[mask] = keyword_ids[mask] # remain steps input = input.repeat_interleave(beam_width, dim=1) input = torch.cat([input, output_step.unsqueeze(0)], dim=0) latent_out = latent_out.repeat_interleave(beam_width, dim=0) if keyword_embs is not None: keyword_embs = keyword_embs.repeat_interleave(beam_width, dim=0) for _ in range(1, max_seq_len): logits = self._gen_forward_step(input, keyword_ids, keyword_embs, latent_out, use_cache=False) logits_step = logits[-1].view(batch_size, beam_width, -1) step_batch_beams(batch_beams, logits_step, output_step, back_pointers, func="update_beams") if all(b.done for b in batch_beams): break if "ktoken" in self.keyword_approaches: mask = output_step == self.KEYWORD_token output_step[mask] = keyword_ids[mask] input = input.index_select(dim=1, index=back_pointers) input = torch.cat([input, output_step.unsqueeze(0)], dim=0) output = list(chain(*(beam.get_best_results()[0] for beam in batch_beams))) output = bidirectional_padding(output, self.PAD_token, 0, device=device)[0] return output
class CVAD(ModelBase): def __init__(self, opts, word2syllable, pretrained_emb_weights): super(ModelBase, self).__init__() self.opts = opts self.PAD_token = opts.PAD_token self.SOS_token = opts.SOS_token self.EOS_token = opts.EOS_token self.UNK_token = opts.UNK_token self.SEP_token = opts.SEP_token self.embedding = _build_embeddings(opts, word2syllable, pretrained_emb_weights) self.latent_module = VADLatentModule(opts.dec_hidden_size, opts.latent_dim, opts.latent_use_tanh) if opts.use_bow_loss: self.bow_proj_layer = nn.Linear( opts.latent_dim + opts.dec_hidden_size, opts.vocab_size) self.emb2states_fwd = nn.Linear( opts.emb_out_dim, opts.dec_num_layers * opts.dec_hidden_size) self.emb2states_bwd = nn.Linear( opts.emb_out_dim, opts.dec_num_layers * opts.dec_hidden_size) self.fwd_decoder = StackedGRUCell(opts.emb_out_dim + opts.latent_dim, opts.dec_hidden_size, opts.dec_num_layers, opts.dropout, opts.use_layer_norm, opts.layer_norm_trainable) self.bwd_decoder = StackedGRUCell(opts.emb_out_dim, opts.dec_hidden_size, opts.dec_num_layers, opts.dropout, opts.use_layer_norm, opts.layer_norm_trainable) out_dim = opts.dec_hidden_size if opts.latent_out_attach: out_dim += opts.latent_dim self.fwd_out_proj = nn.Linear(out_dim, opts.vocab_size, bias=False) if opts.need_bwd_out_proj_layer: self.bwd_out_proj = nn.Linear(opts.dec_hidden_size, opts.vocab_size, bias=False) if opts.fwd_use_logits_mask: self.fwd_logits_mask_layer = LogitsMaskLayer( self.embedding.get_word2syllable_buffer(), opts.SEP_token, opts.UNK_token) if opts.bwd_use_logits_mask: self.bwd_logits_mask_layer = LogitsMaskLayer( self.embedding.get_word2syllable_buffer(), opts.SEP_token, opts.UNK_token) self._reset_parameters() def _reset_parameters(self): if hasattr(self, "bow_proj_layer"): xavier_uniform_fan_in_(self.bow_proj_layer.weight) nn.init.zeros_(self.bow_proj_layer.bias) xavier_uniform_fan_in_(self.emb2states_fwd.weight) xavier_uniform_fan_in_(self.emb2states_bwd.weight) nn.init.zeros_(self.emb2states_fwd.bias) nn.init.zeros_(self.emb2states_bwd.bias) if hasattr(self, "fwd_out_proj"): xavier_uniform_fan_in_(self.fwd_out_proj.weight) if hasattr(self, "bwd_out_proj"): xavier_uniform_fan_in_(self.bwd_out_proj.weight) def _init_states(self, keyword_embs, direction): assert direction in ("fwd", "bwd") batch_size = keyword_embs.size(0) if direction == "fwd": hidden = self.emb2states_fwd(keyword_embs) else: if self.opts.detach_bwd_decoder_from_embedding: keyword_embs = keyword_embs.detach() hidden = self.emb2states_bwd(keyword_embs) hidden = hidden.view(batch_size, self.opts.dec_num_layers, self.opts.dec_hidden_size) return hidden.transpose(0, 1).contiguous() def _forward_bwd_decoder(self, bwd_input, bwd_segment_ids, bwd_remain_syllables, initial_hidden): seq_len, batch_size = bwd_input.size() device = bwd_input.device embedded = self.embedding(bwd_input, 0, segment_ids=bwd_segment_ids) if self.opts.detach_bwd_decoder_from_embedding: embedded = embedded.detach() bwd_last_layer_states = [] prev_hidden = initial_hidden for step in range(seq_len): cur_hidden = self.bwd_decoder(embedded[step], prev_hidden) pad_indexes = torch.arange( batch_size, device=device)[bwd_input[step] == self.opts.PAD_token] cur_hidden = cur_hidden.index_copy(1, pad_indexes, prev_hidden[:, pad_indexes]) bwd_last_layer_states.append(cur_hidden[-1]) prev_hidden = cur_hidden bwd_last_layer_hidden = torch.stack(bwd_last_layer_states, dim=0) logits = None if hasattr(self, "bwd_out_proj"): logits = self.bwd_out_proj(bwd_last_layer_hidden) if self.opts.bwd_use_logits_mask: logits = self.bwd_logits_mask_layer( logits, remain_syllables=bwd_remain_syllables, decoder_input=bwd_input, sample_n_to_check=1) return logits, bwd_last_layer_hidden def _forward_fwd_decoder(self, fwd_input, fwd_segment_ids, fwd_remain_syllables, initial_hidden, bwd_hidden, sample_n): fwd_last_layer_states = [] mu_p_list = [] log_var_p_list = [] mu_r_list = [] log_var_r_list = [] latent_vector_list = [] embedded = self.embedding(fwd_input, 0, segment_ids=fwd_segment_ids) prev_hidden = initial_hidden if sample_n > 1: embedded = embedded.repeat(1, sample_n, 1) prev_hidden = prev_hidden.repeat(1, sample_n, 1) bwd_hidden = bwd_hidden.repeat(1, sample_n, 1) for step in range(fwd_input.size(0)): mu_p, log_var_p, mu_r, log_var_r, z = self.latent_module.forward_train_path( prev_hidden[-1], bwd_hidden[-(step + 1)]) cur_hidden = self.fwd_decoder( torch.cat([embedded[step], z], dim=-1), prev_hidden) fwd_last_layer_states.append(cur_hidden[-1]) prev_hidden = cur_hidden mu_p_list.append(mu_p) log_var_p_list.append(log_var_p) mu_r_list.append(mu_r) log_var_r_list.append(log_var_r) latent_vector_list.append(z) fwd_last_layer_hidden = torch.stack(fwd_last_layer_states, dim=0) latent_vector = torch.stack(latent_vector_list, dim=0) out_proj_inp = fwd_last_layer_hidden if self.opts.latent_out_attach: out_proj_inp = torch.cat([fwd_last_layer_hidden, latent_vector], dim=-1) logits = self.fwd_out_proj(out_proj_inp) if self.opts.fwd_use_logits_mask: logits = self.fwd_logits_mask_layer( logits, remain_syllables=fwd_remain_syllables, decoder_input=fwd_input, sample_n_to_check=sample_n) mu_p = torch.stack(mu_p_list, dim=0) log_var_p = torch.stack(log_var_p_list, dim=0) mu_r = torch.stack(mu_r_list, dim=0) log_var_r = torch.stack(log_var_r_list, dim=0) return logits, fwd_last_layer_hidden, latent_vector, mu_p, log_var_p, mu_r, log_var_r def forward(self, inputs, keyword_ids, segment_ids=None, remain_syllables=None, mode="train"): assert mode in ("train", "valid", "test") if mode != "train": assert not self.training else: assert self.training fwd_tgt, bwd_inp = inputs sos = torch.full((fwd_tgt.size(1), ), fill_value=self.SOS_token, dtype=torch.long, device=fwd_tgt.device) fwd_inp = torch.cat([sos.unsqueeze(0), fwd_tgt[:-1]], dim=0) bwd_tgt = fwd_inp.flip(0) if self.opts.need_segment_ids: if segment_ids is None: fwd_seg_ids = self.embedding.get_segment_ids(fwd_inp) bwd_seg_ids = self.embedding.get_segment_ids(bwd_inp) else: fwd_seg_ids, bwd_seg_ids = segment_ids fwd_seg_ids = torch.cat( [torch.zeros_like(fwd_seg_ids[:1]), fwd_seg_ids[:-1]], dim=0) else: fwd_seg_ids = bwd_seg_ids = None fwd_rem_syls = bwd_rem_syls = None if self.opts.fwd_need_remain_syllables or self.opts.bwd_need_remain_syllables: if remain_syllables is None: if self.opts.fwd_need_remain_syllables: fwd_rem_syls = self.logits_mask_layer.get_remain_syllables( fwd_inp) if self.opts.bwd_need_remain_syllables: bwd_rem_syls = self.logits_mask_layer.get_remain_syllables( bwd_inp) else: fwd_rem_syls, bwd_rem_syls = remain_syllables keyword_embs = self.embedding.forward_word_emb(keyword_ids) fwd_initial_states = self._init_states(keyword_embs, "fwd") bwd_initial_states = self._init_states(keyword_embs, "bwd") bwd_logits, bwd_hidden = self._forward_bwd_decoder( bwd_inp, bwd_seg_ids, bwd_rem_syls, bwd_initial_states) sample_n = self.opts.train_sample_n if mode == "train" else self.opts.test_sample_n fwd_logits, fwd_last_layer_hidden, latent_vector, mu_p, log_var_p, mu_r, log_var_r = self._forward_fwd_decoder( fwd_inp, fwd_seg_ids, fwd_rem_syls, fwd_initial_states, bwd_hidden, sample_n) bow_logits = None if self.opts.use_bow_loss: bow_inp = torch.cat([latent_vector, fwd_last_layer_hidden], dim=-1) bow_logits = self.bow_proj_layer(bow_inp) expand_dim = bow_logits.size( 0) if self.opts.bow_window is None else self.opts.bow_window bow_logits = bow_logits.unsqueeze(0).expand(expand_dim, -1, -1, -1) fwd_tgt = self.expand_tgt(fwd_tgt, sample_n) return (fwd_logits, bwd_logits), ( fwd_tgt, bwd_tgt), bow_logits, mu_p, log_var_p, mu_r, log_var_r # normal_vector ~ N(0,1): [seq_len, batch_size, latent_dim] # keyword_ids: [batch_size] def generate(self, keyword_ids, normal_vector, approach, gen_options): assert not self.training assert approach in ("beam", "greedy") return getattr(self, "_gen_{}".format(approach))(keyword_ids, normal_vector, **gen_options) # input: [seq_len, batch_size], the first token of each sequence should be <SOS> # hidden_step: [num_layers, batch_size, hidden_size] # normal_vector_step ~ N(0,1): [batch_size, latent_dim] def _gen_forward_step(self, input, hidden_step, normal_vector_step, use_cache=False): segment_ids = None if self.opts.need_segment_ids: segment_ids = self.embedding.get_segment_ids(input, use_cache=use_cache, restrict=False) embedded = self.embedding(input, 0, segment_ids=segment_ids, segment_emb_restrict=False) z = self.latent_module.forward_gen_path(hidden_step[-1], normal_vector_step) hidden_step = self.fwd_decoder(torch.cat([embedded[-1], z], dim=-1), hidden_step) out_proj_inp = hidden_step[-1] if self.opts.latent_out_attach: out_proj_inp = torch.cat([hidden_step[-1], z], dim=-1) logits_step = self.fwd_out_proj(out_proj_inp) if self.opts.fwd_use_logits_mask: logits_step = self.fwd_logits_mask_layer(logits_step, use_cache=use_cache, decoder_input=input, sample_n_to_check=1, only_last_step=True) return logits_step, hidden_step # keyword_ids: [batch_size] # normal_vector: [seq_len, batch_size, latent_dim] def _gen_greedy(self, keyword_ids, normal_vector, **kwargs): batch_size = normal_vector.size(1) device = normal_vector.device max_seq_len = self.opts.gen_max_seq_len keyword_embs = self.embedding.forward_word_emb(keyword_ids) hidden = self._init_states(keyword_embs, "fwd") input = torch.full((1, batch_size), self.SOS_token, dtype=torch.long, device=device) lens = torch.full((batch_size, ), max_seq_len, dtype=torch.long, device=device) output_steps = [] for step in range(max_seq_len): logits_step, hidden = self._gen_forward_step(input, hidden, normal_vector[step], use_cache=True) out_step = logits_step.argmax(dim=-1, keepdim=False) output_steps.append(out_step) lens[(out_step == self.EOS_token) & (lens == max_seq_len)] = step + 1 if step == max_seq_len - 1 or (lens < max_seq_len).all(): break input = torch.cat([input, out_step.unsqueeze(0)], dim=0) output = torch.stack(output_steps, dim=0) if self.opts.need_segment_ids: self.embedding.clear_segment_emb_cache() if self.opts.fwd_need_remain_syllables: self.fwd_logits_mask_layer.clear_cache() return output def _gen_beam(self, keyword_ids, normal_vector, **kwargs): device = normal_vector.device _, batch_size, latent_dim = normal_vector.size() max_seq_len = self.opts.gen_max_seq_len beam_width = kwargs["beam_width"] length_norm = kwargs["length_norm"] n_best = kwargs["n_best"] input = torch.full((1, batch_size), fill_value=self.SOS_token, dtype=torch.long, device=device) keyword_embs = self.embedding.forward_word_emb(keyword_ids) hidden = self._init_states(keyword_embs, "fwd") output_step = torch.zeros(batch_size * beam_width, dtype=torch.long, device=device) back_pointers = torch.zeros(batch_size * beam_width, dtype=torch.long, device=device) batch_beams = [ Beam(beam_width, length_norm, self.EOS_token, n_best) for _ in range(batch_size) ] # first step logits_step, hidden = self._gen_forward_step(input, hidden, normal_vector[0], use_cache=False) step_batch_beams(batch_beams, logits_step, output_step, func="init_beams") # remain steps input = input.repeat_interleave(beam_width, dim=1) normal_vector = normal_vector.repeat_interleave(beam_width, dim=1) input = torch.cat([input, output_step.unsqueeze(0)], dim=0) hidden = hidden.repeat_interleave(beam_width, dim=1) for step in range(1, max_seq_len): logits_step, hidden = self._gen_forward_step(input, hidden, normal_vector[step], use_cache=False) logits_step = logits_step.view(batch_size, beam_width, -1) step_batch_beams(batch_beams, logits_step, output_step, back_pointers, func="update_beams") if all(b.done for b in batch_beams): break input = input.index_select(dim=1, index=back_pointers) input = torch.cat([input, output_step.unsqueeze(0)], dim=0) hidden = hidden.index_select(dim=1, index=back_pointers) output = list( chain(*(beam.get_best_results()[0] for beam in batch_beams))) output = bidirectional_padding(output, self.PAD_token, 0, device=device)[0] return output