Exemple #1
0
    def __init__(self, opts, word2syllable, pretrained_emb_weights):
        super(TransformerCLM, self).__init__()

        self.opts = opts
        self.PAD_token = opts.PAD_token
        self.SOS_token = opts.SOS_token
        self.EOS_token = opts.EOS_token
        self.UNK_token = opts.UNK_token
        self.SEP_token = opts.SEP_token
        self.KEYWORD_token = opts.KEYWORD_token
        self.keyword_approaches = opts.keyword_approaches

        self.embedding = _build_embeddings(opts, word2syllable, pretrained_emb_weights, use_pos_emb=True)

        self.decoder = _build_stacked_transformer_layers(opts, opts.num_layers, opts.num_self_attn_per_layer)

        self.hidden2emb, self.out_proj = _build_output_layers(
            opts, self.embedding.get_word_emb_out_proj_weight(), self.embedding.get_word_emb_weight())

        if opts.use_logits_mask:
            self.logits_mask_layer = LogitsMaskLayer(
                self.embedding.get_word2syllable_buffer(),
                opts.SEP_token, opts.UNK_token, opts.KEYWORD_token)

        self._reset_parameters()
Exemple #2
0
    def __init__(self, opts, word2syllable, pretrained_emb_weights):
        super(ModelBase, self).__init__()

        self.opts = opts
        self.PAD_token = opts.PAD_token
        self.SOS_token = opts.SOS_token
        self.EOS_token = opts.EOS_token
        self.UNK_token = opts.UNK_token
        self.SEP_token = opts.SEP_token

        self.embedding = _build_embeddings(opts, word2syllable,
                                           pretrained_emb_weights)

        self.latent_module = VADLatentModule(opts.dec_hidden_size,
                                             opts.latent_dim,
                                             opts.latent_use_tanh)
        if opts.use_bow_loss:
            self.bow_proj_layer = nn.Linear(
                opts.latent_dim + opts.dec_hidden_size, opts.vocab_size)

        self.emb2states_fwd = nn.Linear(
            opts.emb_out_dim, opts.dec_num_layers * opts.dec_hidden_size)
        self.emb2states_bwd = nn.Linear(
            opts.emb_out_dim, opts.dec_num_layers * opts.dec_hidden_size)

        self.fwd_decoder = StackedGRUCell(opts.emb_out_dim + opts.latent_dim,
                                          opts.dec_hidden_size,
                                          opts.dec_num_layers, opts.dropout,
                                          opts.use_layer_norm,
                                          opts.layer_norm_trainable)
        self.bwd_decoder = StackedGRUCell(opts.emb_out_dim,
                                          opts.dec_hidden_size,
                                          opts.dec_num_layers, opts.dropout,
                                          opts.use_layer_norm,
                                          opts.layer_norm_trainable)

        out_dim = opts.dec_hidden_size
        if opts.latent_out_attach:
            out_dim += opts.latent_dim
        self.fwd_out_proj = nn.Linear(out_dim, opts.vocab_size, bias=False)
        if opts.need_bwd_out_proj_layer:
            self.bwd_out_proj = nn.Linear(opts.dec_hidden_size,
                                          opts.vocab_size,
                                          bias=False)

        if opts.fwd_use_logits_mask:
            self.fwd_logits_mask_layer = LogitsMaskLayer(
                self.embedding.get_word2syllable_buffer(), opts.SEP_token,
                opts.UNK_token)
        if opts.bwd_use_logits_mask:
            self.bwd_logits_mask_layer = LogitsMaskLayer(
                self.embedding.get_word2syllable_buffer(), opts.SEP_token,
                opts.UNK_token)

        self._reset_parameters()
Exemple #3
0
    def __init__(self, opts, word2syllable, pretrained_emb_weights):
        super(SequentialTransformerCVAE, self).__init__()

        self.opts = opts
        self.PAD_token = opts.PAD_token
        self.SOS_token = opts.SOS_token
        self.EOS_token = opts.EOS_token
        self.UNK_token = opts.UNK_token
        self.SEP_token = opts.SEP_token
        self.KEYWORD_token = opts.KEYWORD_token
        self.keyword_approaches = opts.keyword_approaches

        self.embedding = _build_embeddings(opts, word2syllable, pretrained_emb_weights, use_pos_emb=True)

        self.prior_encoder = _build_stacked_transformer_layers(
            opts,
            opts.num_layers_before_latent,
            opts.num_self_attn_per_layer_before_latent)

        self.recognition_encoder = _build_stacked_transformer_layers(
            opts,
            opts.num_layers_before_latent,
            opts.num_self_attn_per_layer_before_latent,
            prebuilt_layers=self.prior_encoder)

        self.latent_module = VTLatentModule(opts.d_model, opts.latent_dim, opts.latent_use_tanh)
        self.drop = nn.Dropout(opts.dropout)
        self.norm = nn.LayerNorm(opts.d_model)
        if opts.use_bow_loss:
            self.bow_proj_layer = nn.Linear(opts.latent_dim + opts.d_model, opts.vocab_size)

        self.decoder = _build_stacked_transformer_layers(
            opts,
            opts.num_layers - opts.num_layers_before_latent,
            opts.num_self_attn_per_layer_after_latent)

        self.hidden2emb, self.out_proj = _build_output_layers(
            opts, self.embedding.get_word_emb_out_proj_weight(), self.embedding.get_word_emb_weight())

        if opts.use_logits_mask:
            self.logits_mask_layer = LogitsMaskLayer(
                self.embedding.get_word2syllable_buffer(),
                opts.SEP_token, opts.UNK_token, opts.KEYWORD_token)

        self._reset_parameters()
Exemple #4
0
    def __init__(self, opts, word2syllable, pretrained_emb_weights):
        super(TransformerCVAE, self).__init__()

        self.opts = opts
        self.PAD_token = opts.PAD_token
        self.SOS_token = opts.SOS_token
        self.EOS_token = opts.EOS_token
        self.UNK_token = opts.UNK_token
        self.SEP_token = opts.SEP_token
        self.CLS_token = opts.CLS_token
        self.KEYWORD_token = opts.KEYWORD_token
        self.keyword_approaches = opts.keyword_approaches

        self.embedding = _build_embeddings(opts, word2syllable, pretrained_emb_weights, use_pos_emb=True)

        self.encoder = _build_stacked_transformer_layers(
            opts,
            opts.num_layers_before_latent,
            opts.num_self_attn_per_layer_before_latent)

        self.latent_module = TLatentModule(
            opts.d_model, opts.d_model, opts.latent_dim, opts.latent_use_tanh, "klatent" in opts.keyword_approaches)

        if opts.use_bow_loss:
            bow_inp_dim = opts.latent_dim
            if "klatent" in opts.keyword_approaches:
                bow_inp_dim += opts.d_model
            self.bow_proj_layer = nn.Linear(bow_inp_dim, opts.vocab_size)

        self.decoder = _build_stacked_transformer_layers(
            opts,
            opts.num_layers - opts.num_layers_before_latent,
            opts.num_self_attn_per_layer_after_latent)

        self.hidden2emb, self.out_proj = _build_output_layers(
            opts, self.embedding.get_word_emb_out_proj_weight(), self.embedding.get_word_emb_weight())

        if opts.use_logits_mask:
            self.logits_mask_layer = LogitsMaskLayer(
                self.embedding.get_word2syllable_buffer(),
                opts.SEP_token, opts.UNK_token, opts.KEYWORD_token)

        self._reset_parameters()
Exemple #5
0
class TransformerCLM(ModelBase):
    def __init__(self, opts, word2syllable, pretrained_emb_weights):
        super(TransformerCLM, self).__init__()

        self.opts = opts
        self.PAD_token = opts.PAD_token
        self.SOS_token = opts.SOS_token
        self.EOS_token = opts.EOS_token
        self.UNK_token = opts.UNK_token
        self.SEP_token = opts.SEP_token
        self.KEYWORD_token = opts.KEYWORD_token
        self.keyword_approaches = opts.keyword_approaches

        self.embedding = _build_embeddings(opts, word2syllable, pretrained_emb_weights, use_pos_emb=True)

        self.decoder = _build_stacked_transformer_layers(opts, opts.num_layers, opts.num_self_attn_per_layer)

        self.hidden2emb, self.out_proj = _build_output_layers(
            opts, self.embedding.get_word_emb_out_proj_weight(), self.embedding.get_word_emb_weight())

        if opts.use_logits_mask:
            self.logits_mask_layer = LogitsMaskLayer(
                self.embedding.get_word2syllable_buffer(),
                opts.SEP_token, opts.UNK_token, opts.KEYWORD_token)

        self._reset_parameters()

    def _reset_parameters(self):
        if not self.opts.weight_typing:
            xavier_uniform_fan_in_(self.out_proj.weight)

    def _forward_all(self, input, keyword_ids, segment_ids=None, remain_syllables=None, use_cache=False):
        padding_mask = self._get_padding_mask(input,
                                              revise_for_khead="khead" in self.opts.keyword_approaches,
                                              revise_for_cls=False)
        attn_masks = self._get_attn_masks(segment_ids, input.size(0), input.device,
                                          triangle=True,
                                          revise_for_khead="khead" in self.opts.keyword_approaches,
                                          revise_for_cls=False)
        embedded = self.embedding(input, (0, 0), segment_ids=segment_ids)
        if "khead" in self.keyword_approaches:
            keyword_embs = self.embedding.forward_word_emb(keyword_ids)
            embedded = torch.cat([keyword_embs.unsqueeze(0), embedded], dim=0)
        if self.opts.hierarchical_model:
            dec_out = self._forward_layers_hierarchical("decoder", embedded, attn_masks, padding_mask)
        else:
            dec_out = self._forward_layers("decoder", embedded, attn_masks, padding_mask)
        if "khead" in self.keyword_approaches:
            dec_out = dec_out[1:]
        logits = self.out_proj(self.hidden2emb(dec_out))
        if self.opts.use_logits_mask:
            logits = self.logits_mask_layer(
                logits,
                use_cache=use_cache,
                remain_syllables=remain_syllables,
                decoder_input=input,
                solve_ktoken="ktoken" in self.keyword_approaches,
                keyword_ids=keyword_ids,
                sample_n_to_check=1)
        return logits

    # src(without sos): [seq_len, batch_size]
    # keyword_ids: [batch_size]
    def forward(self, src, keyword_ids, segment_ids=None, remain_syllables=None, mode="train"):
        assert mode in ("train", "valid", "test")
        if mode != "train":
            assert not self.training
        else:
            assert self.training

        sos = torch.full((src.size(1),), fill_value=self.SOS_token, dtype=torch.long, device=src.device)
        input = torch.cat([sos.unsqueeze(0), src[:-1]], dim=0)
        if self.opts.need_segment_ids:
            if segment_ids is not None:
                segment_ids = torch.cat([torch.zeros_like(segment_ids[:1]), segment_ids[:-1]], dim=0)
            else:
                segment_ids = self.embedding.get_segment_ids(input)
        logits = self._forward_all(input, keyword_ids, segment_ids, remain_syllables)

        return (logits,)

    # keyword_ids: [batch_size]
    def generate(self, keyword_ids, approach, gen_options):
        assert not self.training
        assert keyword_ids is not None
        assert approach in ("beam", "greedy")
        return getattr(self, "_gen_{}".format(approach))(keyword_ids, **gen_options)

    # input: [seq_len, batch_size], the first token of each sequence should be <SOS>
    # keyword_ids: [batch_size]
    def _gen_forward_step(self, input, keyword_ids, use_cache=False):
        segment_ids = None
        if self.opts.need_segment_ids:
            segment_ids = self.embedding.get_segment_ids(input, use_cache=use_cache, restrict=False)
        return self._forward_all(input, keyword_ids, segment_ids, use_cache=use_cache)

    def _gen_greedy(self, keyword_ids, **kwargs):
        batch_size = keyword_ids.size(0)
        device = keyword_ids.device
        max_seq_len = self.opts.gen_max_seq_len

        input = torch.full((1, batch_size), self.SOS_token, dtype=torch.long, device=device)
        lens = torch.full((batch_size,), max_seq_len, dtype=torch.long, device=device)
        output_steps = []
        for step in range(max_seq_len):
            logits = self._gen_forward_step(input, keyword_ids, use_cache=True)
            out_step = logits[-1].argmax(dim=-1, keepdim=False)
            output_steps.append(out_step.clone())
            lens[(out_step == self.EOS_token) & (lens == max_seq_len)] = logits.size(0)
            if step == max_seq_len - 1 or (lens < max_seq_len).all():
                break
            if "ktoken" in self.keyword_approaches:
                mask = out_step == self.KEYWORD_token
                out_step[mask] = keyword_ids[mask]
            input = torch.cat([input, out_step.unsqueeze(0)], dim=0)
        output = torch.stack(output_steps, dim=0)

        if self.opts.need_segment_ids:
            self.embedding.clear_segment_emb_cache()
        if self.opts.need_remain_syllables:
            self.logits_mask_layer.clear_cache()

        return output

    def _gen_beam(self, keyword_ids, **kwargs):
        device = keyword_ids.device
        batch_size = keyword_ids.size(0)
        max_seq_len = self.opts.gen_max_seq_len
        beam_width = kwargs["beam_width"]
        length_norm = kwargs["length_norm"]
        n_best = kwargs["n_best"]

        input = torch.full((1, batch_size), fill_value=self.SOS_token, dtype=torch.long, device=device)
        output_step = torch.zeros(batch_size * beam_width, dtype=torch.long, device=device)
        back_pointers = torch.zeros(batch_size * beam_width, dtype=torch.long, device=device)
        batch_beams = [Beam(beam_width, length_norm, self.EOS_token, n_best) for _ in range(batch_size)]

        # first step
        logits_step = self._gen_forward_step(input, keyword_ids, use_cache=False)[-1]
        step_batch_beams(batch_beams, logits_step, output_step, func="init_beams")
        if keyword_ids is not None:
            keyword_ids = keyword_ids.repeat_interleave(beam_width, dim=0)
        if "ktoken" in self.keyword_approaches:
            mask = output_step == self.KEYWORD_token
            output_step[mask] = keyword_ids[mask]

        # remain steps
        input = input.repeat_interleave(beam_width, dim=1)
        input = torch.cat([input, output_step.unsqueeze(0)], dim=0)
        for _ in range(1, max_seq_len):
            logits = self._gen_forward_step(input, keyword_ids, use_cache=False)
            logits_step = logits[-1].view(batch_size, beam_width, -1)
            step_batch_beams(batch_beams, logits_step, output_step, back_pointers, func="update_beams")
            if all(b.done for b in batch_beams):
                break
            if "ktoken" in self.keyword_approaches:
                mask = output_step == self.KEYWORD_token
                output_step[mask] = keyword_ids[mask]
            input = input.index_select(dim=1, index=back_pointers)
            input = torch.cat([input, output_step.unsqueeze(0)], dim=0)

        output = list(chain(*(beam.get_best_results()[0] for beam in batch_beams)))
        output = bidirectional_padding(output, self.PAD_token, 0, device=device)[0]

        return output
Exemple #6
0
class TransformerCVAE(ModelBase):
    def __init__(self, opts, word2syllable, pretrained_emb_weights):
        super(TransformerCVAE, self).__init__()

        self.opts = opts
        self.PAD_token = opts.PAD_token
        self.SOS_token = opts.SOS_token
        self.EOS_token = opts.EOS_token
        self.UNK_token = opts.UNK_token
        self.SEP_token = opts.SEP_token
        self.CLS_token = opts.CLS_token
        self.KEYWORD_token = opts.KEYWORD_token
        self.keyword_approaches = opts.keyword_approaches

        self.embedding = _build_embeddings(opts, word2syllable, pretrained_emb_weights, use_pos_emb=True)

        self.encoder = _build_stacked_transformer_layers(
            opts,
            opts.num_layers_before_latent,
            opts.num_self_attn_per_layer_before_latent)

        self.latent_module = TLatentModule(
            opts.d_model, opts.d_model, opts.latent_dim, opts.latent_use_tanh, "klatent" in opts.keyword_approaches)

        if opts.use_bow_loss:
            bow_inp_dim = opts.latent_dim
            if "klatent" in opts.keyword_approaches:
                bow_inp_dim += opts.d_model
            self.bow_proj_layer = nn.Linear(bow_inp_dim, opts.vocab_size)

        self.decoder = _build_stacked_transformer_layers(
            opts,
            opts.num_layers - opts.num_layers_before_latent,
            opts.num_self_attn_per_layer_after_latent)

        self.hidden2emb, self.out_proj = _build_output_layers(
            opts, self.embedding.get_word_emb_out_proj_weight(), self.embedding.get_word_emb_weight())

        if opts.use_logits_mask:
            self.logits_mask_layer = LogitsMaskLayer(
                self.embedding.get_word2syllable_buffer(),
                opts.SEP_token, opts.UNK_token, opts.KEYWORD_token)

        self._reset_parameters()

    def _reset_parameters(self):
        if hasattr(self, "bow_proj_layer"):
            xavier_uniform_fan_in_(self.bow_proj_layer.weight)
            nn.init.zeros_(self.bow_proj_layer.bias)
        if not self.opts.weight_typing:
            xavier_uniform_fan_in_(self.out_proj.weight)

    # src(without sos): [seq_len, batch_size]
    # keyword_ids: [batch_size]
    def forward(self, src, keyword_ids, segment_ids=None, remain_syllables=None, mode="train"):
        assert mode in ("train", "valid", "test")
        if mode != "train":
            assert not self.training
        else:
            assert self.training

        if self.opts.need_segment_ids and segment_ids is None:
            segment_ids = self.embedding.get_segment_ids(src)
        padding_mask = self._get_padding_mask(src, revise_for_khead=False, revise_for_cls=True)
        attn_masks = self._get_attn_masks(segment_ids, src.size(0), src.device,
                                          triangle=False,
                                          revise_for_khead=False,
                                          revise_for_cls=True)
        embedded = self.embedding(src, (0, 0), segment_ids=segment_ids)
        cls = torch.full((src.size(1),), fill_value=self.CLS_token, dtype=torch.long, device=src.device)
        cls_embs = self.embedding.forward_word_emb(cls)
        embedded = torch.cat([cls_embs.unsqueeze(0), embedded], dim=0)
        if self.opts.hierarchical_before_latent:
            enc_out = self._forward_layers_hierarchical("encoder", embedded, attn_masks, padding_mask)
        elif self.opts.hierarchical_model:
            enc_out = self._forward_layers("encoder", embedded, attn_masks[-1], padding_mask)
        else:
            enc_out = self._forward_layers("encoder", embedded, attn_masks, padding_mask)
        enc_hidden = enc_out[0]

        sample_n = self.opts.train_sample_n if mode == "train" else self.opts.test_sample_n
        keyword_embs = None
        if keyword_ids is not None:
            keyword_embs = self.embedding.forward_word_emb(keyword_ids)
        mu_p, log_var_p, mu_r, log_var_r, latent_vector, latent_out = self.latent_module.forward_train_path(
            enc_hidden, keyword_embs, sample_n=sample_n)
        # mu/log_var: [batch_size, latent_dim]; latent_out/latent_vec: [sample_n, batch_size, embedding_dim/latent_dim]
        bow_logits = None
        if self.opts.use_bow_loss:
            if keyword_embs is not None:
                keyword_embs_expanded = keyword_embs.unsqueeze(0).expand(sample_n, -1, -1)
                bow_inp = torch.cat([latent_vector, keyword_embs_expanded], dim=-1)
            else:
                bow_inp = latent_vector
            bow_inp = bow_inp.view(sample_n * bow_inp.size(1), -1)
            bow_logits = self.bow_proj_layer(bow_inp)
            bow_logits = bow_logits.unsqueeze(0).expand(src.size(0), -1, -1)

        sos = torch.full((src.size(1),), fill_value=self.SOS_token, dtype=torch.long, device=src.device)
        dec_input = torch.cat([sos.unsqueeze(0), src[:-1]], dim=0)
        if self.opts.need_segment_ids:
            segment_ids = torch.cat([torch.zeros_like(segment_ids[:1]), segment_ids[:-1]], dim=0)
        padding_mask = self._get_padding_mask(dec_input,
                                              revise_for_khead="khead" in self.opts.keyword_approaches,
                                              revise_for_cls=False)
        attn_masks = self._get_attn_masks(segment_ids, dec_input.size(0), dec_input.device,
                                          triangle=True,
                                          revise_for_khead="khead" in self.opts.keyword_approaches,
                                          revise_for_cls=False)
        embedded = self.embedding(dec_input, (0, 0), segment_ids=segment_ids)
        if sample_n > 1:
            padding_mask = self._expand_padding_mask(padding_mask, sample_n)
            attn_masks = self._expand_attn_masks(attn_masks, sample_n)
            embedded = embedded.repeat(1, sample_n, 1)
            if "khead" in self.keyword_approaches:
                keyword_embs = keyword_embs.repeat(sample_n, 1)
        embedded[0] = embedded[0] + latent_out.view(sample_n * latent_out.size(1), latent_out.size(-1))
        if "khead" in self.keyword_approaches:
            embedded = torch.cat([keyword_embs.unsqueeze(0), embedded], dim=0)
        if self.opts.hierarchical_after_latent:
            dec_out = self._forward_layers_hierarchical("decoder", embedded, attn_masks, padding_mask)
        elif self.opts.hierarchical_model:
            dec_out = self._forward_layers("decoder", embedded, attn_masks[-1], padding_mask)
        else:
            dec_out = self._forward_layers("decoder", embedded, attn_masks, padding_mask)
        if "khead" in self.keyword_approaches:
            dec_out = dec_out[1:]

        logits = self.out_proj(self.hidden2emb(dec_out))
        if self.opts.use_logits_mask:
            logits = self.logits_mask_layer(
                logits,
                remain_syllables=remain_syllables,
                decoder_input=dec_input,
                solve_ktoken="ktoken" in self.keyword_approaches,
                keyword_ids=keyword_ids,
                sample_n_to_check=sample_n)

        return logits, mu_p, log_var_p, mu_r, log_var_r, bow_logits

    # normal_vector ~ N(0,1): [batch_size, latent_dim]
    # keyword_ids: [batch_size]
    def generate(self, keyword_ids, normal_vector, approach, gen_options):
        assert not self.training
        assert approach in ("beam", "greedy")
        return getattr(self, "_gen_{}".format(approach))(keyword_ids, normal_vector, **gen_options)

    # input: [seq_len, batch_size], the first token of each sequence should be <SOS>
    # keyword_ids: [batch_size]
    # latent_out: [batch_size, embedding_dim]
    def _gen_forward_step(self, input, keyword_ids, keyword_embs, latent_out, use_cache=False):
        segment_ids = None
        if self.opts.need_segment_ids:
            segment_ids = self.embedding.get_segment_ids(input, use_cache=use_cache, restrict=False)
        padding_mask = self._get_padding_mask(input,
                                              revise_for_khead="khead" in self.opts.keyword_approaches,
                                              revise_for_cls=False)
        attn_masks = self._get_attn_masks(segment_ids, input.size(0), input.device,
                                          triangle=True,
                                          revise_for_khead="khead" in self.opts.keyword_approaches,
                                          revise_for_cls=False)
        embedded = self.embedding(input, (0, 0), segment_ids=segment_ids)
        embedded[0] = embedded[0] + latent_out
        if "khead" in self.keyword_approaches:
            embedded = torch.cat([keyword_embs.unsqueeze(0), embedded], dim=0)
        if self.opts.hierarchical_after_latent:
            dec_out = self._forward_layers_hierarchical("decoder", embedded, attn_masks, padding_mask)
        elif self.opts.hierarchical_model:
            dec_out = self._forward_layers("decoder", embedded, attn_masks[-1], padding_mask)
        else:
            dec_out = self._forward_layers("decoder", embedded, attn_masks, padding_mask)
        if "khead" in self.keyword_approaches:
            dec_out = dec_out[1:]

        logits = self.out_proj(self.hidden2emb(dec_out))
        if self.opts.use_logits_mask:
            logits = self.logits_mask_layer(
                logits,
                use_cache=use_cache,
                decoder_input=input,
                solve_ktoken="ktoken" in self.keyword_approaches,
                keyword_ids=keyword_ids,
                sample_n_to_check=1)

        return logits

    def _gen_greedy(self, keyword_ids, normal_vector, **kwargs):
        batch_size = normal_vector.size(0)
        dtype = normal_vector.dtype
        device = normal_vector.device
        max_seq_len = self.opts.gen_max_seq_len

        keyword_embs = None
        if keyword_ids is not None:
            keyword_embs = self.embedding.forward_word_emb(keyword_ids)
        latent_out = self.latent_module.forward_gen_path(keyword_embs, normal_vector,
                                                         head_dims=[], batch_size=batch_size,
                                                         dtype=dtype, device=device)[1].squeeze(0)

        input = torch.full((1, batch_size), self.SOS_token, dtype=torch.long, device=device)
        lens = torch.full((batch_size,), max_seq_len, dtype=torch.long, device=device)
        output_steps = []
        for step in range(max_seq_len):
            logits = self._gen_forward_step(input, keyword_ids, keyword_embs, latent_out, use_cache=True)
            out_step = logits[-1].argmax(dim=-1, keepdim=False)
            output_steps.append(out_step.clone())
            lens[(out_step == self.EOS_token) & (lens == max_seq_len)] = logits.size(0)
            if step == max_seq_len - 1 or (lens < max_seq_len).all():
                break
            if "ktoken" in self.keyword_approaches:
                mask = out_step == self.KEYWORD_token
                out_step[mask] = keyword_ids[mask]
            input = torch.cat([input, out_step.unsqueeze(0)], dim=0)
        output = torch.stack(output_steps, dim=0)

        if self.opts.need_segment_ids:
            self.embedding.clear_segment_emb_cache()
        if self.opts.need_remain_syllables:
            self.logits_mask_layer.clear_cache()

        return output

    def _gen_beam(self, keyword_ids, normal_vector, **kwargs):
        dtype = normal_vector.dtype
        device = normal_vector.device
        batch_size, latent_dim = normal_vector.size()
        max_seq_len = self.opts.gen_max_seq_len
        beam_width = kwargs["beam_width"]
        length_norm = kwargs["length_norm"]
        n_best = kwargs["n_best"]

        keyword_embs = None
        if keyword_ids is not None:
            keyword_embs = self.embedding.forward_word_emb(keyword_ids)
        latent_out = self.latent_module.forward_gen_path(keyword_embs, normal_vector,
                                                         head_dims=[], batch_size=batch_size,
                                                         dtype=dtype, device=device)[1].squeeze(0)

        input = torch.full((1, batch_size), fill_value=self.SOS_token, dtype=torch.long, device=device)
        output_step = torch.zeros(batch_size * beam_width, dtype=torch.long, device=device)
        back_pointers = torch.zeros(batch_size * beam_width, dtype=torch.long, device=device)
        batch_beams = [Beam(beam_width, length_norm, self.EOS_token, n_best) for _ in range(batch_size)]

        # first step
        logits_step = self._gen_forward_step(input, keyword_ids, keyword_embs, latent_out, use_cache=False)[-1]
        step_batch_beams(batch_beams, logits_step, output_step, func="init_beams")
        if keyword_ids is not None:
            keyword_ids = keyword_ids.repeat_interleave(beam_width, dim=0)
        if "ktoken" in self.keyword_approaches:
            mask = output_step == self.KEYWORD_token
            output_step[mask] = keyword_ids[mask]

        # remain steps
        input = input.repeat_interleave(beam_width, dim=1)
        input = torch.cat([input, output_step.unsqueeze(0)], dim=0)
        latent_out = latent_out.repeat_interleave(beam_width, dim=0)
        if keyword_embs is not None:
            keyword_embs = keyword_embs.repeat_interleave(beam_width, dim=0)
        for _ in range(1, max_seq_len):
            logits = self._gen_forward_step(input, keyword_ids, keyword_embs, latent_out, use_cache=False)
            logits_step = logits[-1].view(batch_size, beam_width, -1)
            step_batch_beams(batch_beams, logits_step, output_step, back_pointers, func="update_beams")
            if all(b.done for b in batch_beams):
                break
            if "ktoken" in self.keyword_approaches:
                mask = output_step == self.KEYWORD_token
                output_step[mask] = keyword_ids[mask]
            input = input.index_select(dim=1, index=back_pointers)
            input = torch.cat([input, output_step.unsqueeze(0)], dim=0)

        output = list(chain(*(beam.get_best_results()[0] for beam in batch_beams)))
        output = bidirectional_padding(output, self.PAD_token, 0, device=device)[0]

        return output
Exemple #7
0
class CVAD(ModelBase):
    def __init__(self, opts, word2syllable, pretrained_emb_weights):
        super(ModelBase, self).__init__()

        self.opts = opts
        self.PAD_token = opts.PAD_token
        self.SOS_token = opts.SOS_token
        self.EOS_token = opts.EOS_token
        self.UNK_token = opts.UNK_token
        self.SEP_token = opts.SEP_token

        self.embedding = _build_embeddings(opts, word2syllable,
                                           pretrained_emb_weights)

        self.latent_module = VADLatentModule(opts.dec_hidden_size,
                                             opts.latent_dim,
                                             opts.latent_use_tanh)
        if opts.use_bow_loss:
            self.bow_proj_layer = nn.Linear(
                opts.latent_dim + opts.dec_hidden_size, opts.vocab_size)

        self.emb2states_fwd = nn.Linear(
            opts.emb_out_dim, opts.dec_num_layers * opts.dec_hidden_size)
        self.emb2states_bwd = nn.Linear(
            opts.emb_out_dim, opts.dec_num_layers * opts.dec_hidden_size)

        self.fwd_decoder = StackedGRUCell(opts.emb_out_dim + opts.latent_dim,
                                          opts.dec_hidden_size,
                                          opts.dec_num_layers, opts.dropout,
                                          opts.use_layer_norm,
                                          opts.layer_norm_trainable)
        self.bwd_decoder = StackedGRUCell(opts.emb_out_dim,
                                          opts.dec_hidden_size,
                                          opts.dec_num_layers, opts.dropout,
                                          opts.use_layer_norm,
                                          opts.layer_norm_trainable)

        out_dim = opts.dec_hidden_size
        if opts.latent_out_attach:
            out_dim += opts.latent_dim
        self.fwd_out_proj = nn.Linear(out_dim, opts.vocab_size, bias=False)
        if opts.need_bwd_out_proj_layer:
            self.bwd_out_proj = nn.Linear(opts.dec_hidden_size,
                                          opts.vocab_size,
                                          bias=False)

        if opts.fwd_use_logits_mask:
            self.fwd_logits_mask_layer = LogitsMaskLayer(
                self.embedding.get_word2syllable_buffer(), opts.SEP_token,
                opts.UNK_token)
        if opts.bwd_use_logits_mask:
            self.bwd_logits_mask_layer = LogitsMaskLayer(
                self.embedding.get_word2syllable_buffer(), opts.SEP_token,
                opts.UNK_token)

        self._reset_parameters()

    def _reset_parameters(self):
        if hasattr(self, "bow_proj_layer"):
            xavier_uniform_fan_in_(self.bow_proj_layer.weight)
            nn.init.zeros_(self.bow_proj_layer.bias)
        xavier_uniform_fan_in_(self.emb2states_fwd.weight)
        xavier_uniform_fan_in_(self.emb2states_bwd.weight)
        nn.init.zeros_(self.emb2states_fwd.bias)
        nn.init.zeros_(self.emb2states_bwd.bias)
        if hasattr(self, "fwd_out_proj"):
            xavier_uniform_fan_in_(self.fwd_out_proj.weight)
        if hasattr(self, "bwd_out_proj"):
            xavier_uniform_fan_in_(self.bwd_out_proj.weight)

    def _init_states(self, keyword_embs, direction):
        assert direction in ("fwd", "bwd")
        batch_size = keyword_embs.size(0)
        if direction == "fwd":
            hidden = self.emb2states_fwd(keyword_embs)
        else:
            if self.opts.detach_bwd_decoder_from_embedding:
                keyword_embs = keyword_embs.detach()
            hidden = self.emb2states_bwd(keyword_embs)
        hidden = hidden.view(batch_size, self.opts.dec_num_layers,
                             self.opts.dec_hidden_size)
        return hidden.transpose(0, 1).contiguous()

    def _forward_bwd_decoder(self, bwd_input, bwd_segment_ids,
                             bwd_remain_syllables, initial_hidden):
        seq_len, batch_size = bwd_input.size()
        device = bwd_input.device
        embedded = self.embedding(bwd_input, 0, segment_ids=bwd_segment_ids)
        if self.opts.detach_bwd_decoder_from_embedding:
            embedded = embedded.detach()
        bwd_last_layer_states = []
        prev_hidden = initial_hidden
        for step in range(seq_len):
            cur_hidden = self.bwd_decoder(embedded[step], prev_hidden)
            pad_indexes = torch.arange(
                batch_size,
                device=device)[bwd_input[step] == self.opts.PAD_token]
            cur_hidden = cur_hidden.index_copy(1, pad_indexes,
                                               prev_hidden[:, pad_indexes])
            bwd_last_layer_states.append(cur_hidden[-1])
            prev_hidden = cur_hidden
        bwd_last_layer_hidden = torch.stack(bwd_last_layer_states, dim=0)
        logits = None
        if hasattr(self, "bwd_out_proj"):
            logits = self.bwd_out_proj(bwd_last_layer_hidden)
            if self.opts.bwd_use_logits_mask:
                logits = self.bwd_logits_mask_layer(
                    logits,
                    remain_syllables=bwd_remain_syllables,
                    decoder_input=bwd_input,
                    sample_n_to_check=1)
        return logits, bwd_last_layer_hidden

    def _forward_fwd_decoder(self, fwd_input, fwd_segment_ids,
                             fwd_remain_syllables, initial_hidden, bwd_hidden,
                             sample_n):
        fwd_last_layer_states = []
        mu_p_list = []
        log_var_p_list = []
        mu_r_list = []
        log_var_r_list = []
        latent_vector_list = []

        embedded = self.embedding(fwd_input, 0, segment_ids=fwd_segment_ids)
        prev_hidden = initial_hidden
        if sample_n > 1:
            embedded = embedded.repeat(1, sample_n, 1)
            prev_hidden = prev_hidden.repeat(1, sample_n, 1)
            bwd_hidden = bwd_hidden.repeat(1, sample_n, 1)

        for step in range(fwd_input.size(0)):
            mu_p, log_var_p, mu_r, log_var_r, z = self.latent_module.forward_train_path(
                prev_hidden[-1], bwd_hidden[-(step + 1)])
            cur_hidden = self.fwd_decoder(
                torch.cat([embedded[step], z], dim=-1), prev_hidden)
            fwd_last_layer_states.append(cur_hidden[-1])
            prev_hidden = cur_hidden

            mu_p_list.append(mu_p)
            log_var_p_list.append(log_var_p)
            mu_r_list.append(mu_r)
            log_var_r_list.append(log_var_r)
            latent_vector_list.append(z)

        fwd_last_layer_hidden = torch.stack(fwd_last_layer_states, dim=0)
        latent_vector = torch.stack(latent_vector_list, dim=0)
        out_proj_inp = fwd_last_layer_hidden
        if self.opts.latent_out_attach:
            out_proj_inp = torch.cat([fwd_last_layer_hidden, latent_vector],
                                     dim=-1)
        logits = self.fwd_out_proj(out_proj_inp)
        if self.opts.fwd_use_logits_mask:
            logits = self.fwd_logits_mask_layer(
                logits,
                remain_syllables=fwd_remain_syllables,
                decoder_input=fwd_input,
                sample_n_to_check=sample_n)

        mu_p = torch.stack(mu_p_list, dim=0)
        log_var_p = torch.stack(log_var_p_list, dim=0)
        mu_r = torch.stack(mu_r_list, dim=0)
        log_var_r = torch.stack(log_var_r_list, dim=0)

        return logits, fwd_last_layer_hidden, latent_vector, mu_p, log_var_p, mu_r, log_var_r

    def forward(self,
                inputs,
                keyword_ids,
                segment_ids=None,
                remain_syllables=None,
                mode="train"):
        assert mode in ("train", "valid", "test")
        if mode != "train":
            assert not self.training
        else:
            assert self.training

        fwd_tgt, bwd_inp = inputs
        sos = torch.full((fwd_tgt.size(1), ),
                         fill_value=self.SOS_token,
                         dtype=torch.long,
                         device=fwd_tgt.device)
        fwd_inp = torch.cat([sos.unsqueeze(0), fwd_tgt[:-1]], dim=0)
        bwd_tgt = fwd_inp.flip(0)

        if self.opts.need_segment_ids:
            if segment_ids is None:
                fwd_seg_ids = self.embedding.get_segment_ids(fwd_inp)
                bwd_seg_ids = self.embedding.get_segment_ids(bwd_inp)
            else:
                fwd_seg_ids, bwd_seg_ids = segment_ids
                fwd_seg_ids = torch.cat(
                    [torch.zeros_like(fwd_seg_ids[:1]), fwd_seg_ids[:-1]],
                    dim=0)
        else:
            fwd_seg_ids = bwd_seg_ids = None
        fwd_rem_syls = bwd_rem_syls = None
        if self.opts.fwd_need_remain_syllables or self.opts.bwd_need_remain_syllables:
            if remain_syllables is None:
                if self.opts.fwd_need_remain_syllables:
                    fwd_rem_syls = self.logits_mask_layer.get_remain_syllables(
                        fwd_inp)
                if self.opts.bwd_need_remain_syllables:
                    bwd_rem_syls = self.logits_mask_layer.get_remain_syllables(
                        bwd_inp)
            else:
                fwd_rem_syls, bwd_rem_syls = remain_syllables

        keyword_embs = self.embedding.forward_word_emb(keyword_ids)
        fwd_initial_states = self._init_states(keyword_embs, "fwd")
        bwd_initial_states = self._init_states(keyword_embs, "bwd")

        bwd_logits, bwd_hidden = self._forward_bwd_decoder(
            bwd_inp, bwd_seg_ids, bwd_rem_syls, bwd_initial_states)

        sample_n = self.opts.train_sample_n if mode == "train" else self.opts.test_sample_n
        fwd_logits, fwd_last_layer_hidden, latent_vector, mu_p, log_var_p, mu_r, log_var_r = self._forward_fwd_decoder(
            fwd_inp, fwd_seg_ids, fwd_rem_syls, fwd_initial_states, bwd_hidden,
            sample_n)

        bow_logits = None
        if self.opts.use_bow_loss:
            bow_inp = torch.cat([latent_vector, fwd_last_layer_hidden], dim=-1)
            bow_logits = self.bow_proj_layer(bow_inp)
            expand_dim = bow_logits.size(
                0) if self.opts.bow_window is None else self.opts.bow_window
            bow_logits = bow_logits.unsqueeze(0).expand(expand_dim, -1, -1, -1)

        fwd_tgt = self.expand_tgt(fwd_tgt, sample_n)

        return (fwd_logits, bwd_logits), (
            fwd_tgt, bwd_tgt), bow_logits, mu_p, log_var_p, mu_r, log_var_r

    # normal_vector ~ N(0,1): [seq_len, batch_size, latent_dim]
    # keyword_ids: [batch_size]
    def generate(self, keyword_ids, normal_vector, approach, gen_options):
        assert not self.training
        assert approach in ("beam", "greedy")
        return getattr(self,
                       "_gen_{}".format(approach))(keyword_ids, normal_vector,
                                                   **gen_options)

    # input: [seq_len, batch_size], the first token of each sequence should be <SOS>
    # hidden_step: [num_layers, batch_size, hidden_size]
    # normal_vector_step ~ N(0,1): [batch_size, latent_dim]
    def _gen_forward_step(self,
                          input,
                          hidden_step,
                          normal_vector_step,
                          use_cache=False):
        segment_ids = None
        if self.opts.need_segment_ids:
            segment_ids = self.embedding.get_segment_ids(input,
                                                         use_cache=use_cache,
                                                         restrict=False)
        embedded = self.embedding(input,
                                  0,
                                  segment_ids=segment_ids,
                                  segment_emb_restrict=False)
        z = self.latent_module.forward_gen_path(hidden_step[-1],
                                                normal_vector_step)
        hidden_step = self.fwd_decoder(torch.cat([embedded[-1], z], dim=-1),
                                       hidden_step)
        out_proj_inp = hidden_step[-1]
        if self.opts.latent_out_attach:
            out_proj_inp = torch.cat([hidden_step[-1], z], dim=-1)
        logits_step = self.fwd_out_proj(out_proj_inp)
        if self.opts.fwd_use_logits_mask:
            logits_step = self.fwd_logits_mask_layer(logits_step,
                                                     use_cache=use_cache,
                                                     decoder_input=input,
                                                     sample_n_to_check=1,
                                                     only_last_step=True)
        return logits_step, hidden_step

    # keyword_ids: [batch_size]
    # normal_vector: [seq_len, batch_size, latent_dim]
    def _gen_greedy(self, keyword_ids, normal_vector, **kwargs):
        batch_size = normal_vector.size(1)
        device = normal_vector.device
        max_seq_len = self.opts.gen_max_seq_len

        keyword_embs = self.embedding.forward_word_emb(keyword_ids)
        hidden = self._init_states(keyword_embs, "fwd")
        input = torch.full((1, batch_size),
                           self.SOS_token,
                           dtype=torch.long,
                           device=device)
        lens = torch.full((batch_size, ),
                          max_seq_len,
                          dtype=torch.long,
                          device=device)

        output_steps = []
        for step in range(max_seq_len):
            logits_step, hidden = self._gen_forward_step(input,
                                                         hidden,
                                                         normal_vector[step],
                                                         use_cache=True)
            out_step = logits_step.argmax(dim=-1, keepdim=False)
            output_steps.append(out_step)
            lens[(out_step == self.EOS_token)
                 & (lens == max_seq_len)] = step + 1
            if step == max_seq_len - 1 or (lens < max_seq_len).all():
                break
            input = torch.cat([input, out_step.unsqueeze(0)], dim=0)
        output = torch.stack(output_steps, dim=0)

        if self.opts.need_segment_ids:
            self.embedding.clear_segment_emb_cache()
        if self.opts.fwd_need_remain_syllables:
            self.fwd_logits_mask_layer.clear_cache()

        return output

    def _gen_beam(self, keyword_ids, normal_vector, **kwargs):
        device = normal_vector.device
        _, batch_size, latent_dim = normal_vector.size()
        max_seq_len = self.opts.gen_max_seq_len
        beam_width = kwargs["beam_width"]
        length_norm = kwargs["length_norm"]
        n_best = kwargs["n_best"]

        input = torch.full((1, batch_size),
                           fill_value=self.SOS_token,
                           dtype=torch.long,
                           device=device)
        keyword_embs = self.embedding.forward_word_emb(keyword_ids)
        hidden = self._init_states(keyword_embs, "fwd")
        output_step = torch.zeros(batch_size * beam_width,
                                  dtype=torch.long,
                                  device=device)
        back_pointers = torch.zeros(batch_size * beam_width,
                                    dtype=torch.long,
                                    device=device)
        batch_beams = [
            Beam(beam_width, length_norm, self.EOS_token, n_best)
            for _ in range(batch_size)
        ]

        # first step
        logits_step, hidden = self._gen_forward_step(input,
                                                     hidden,
                                                     normal_vector[0],
                                                     use_cache=False)
        step_batch_beams(batch_beams,
                         logits_step,
                         output_step,
                         func="init_beams")

        # remain steps
        input = input.repeat_interleave(beam_width, dim=1)
        normal_vector = normal_vector.repeat_interleave(beam_width, dim=1)
        input = torch.cat([input, output_step.unsqueeze(0)], dim=0)
        hidden = hidden.repeat_interleave(beam_width, dim=1)
        for step in range(1, max_seq_len):
            logits_step, hidden = self._gen_forward_step(input,
                                                         hidden,
                                                         normal_vector[step],
                                                         use_cache=False)
            logits_step = logits_step.view(batch_size, beam_width, -1)
            step_batch_beams(batch_beams,
                             logits_step,
                             output_step,
                             back_pointers,
                             func="update_beams")
            if all(b.done for b in batch_beams):
                break
            input = input.index_select(dim=1, index=back_pointers)
            input = torch.cat([input, output_step.unsqueeze(0)], dim=0)
            hidden = hidden.index_select(dim=1, index=back_pointers)

        output = list(
            chain(*(beam.get_best_results()[0] for beam in batch_beams)))
        output = bidirectional_padding(output,
                                       self.PAD_token,
                                       0,
                                       device=device)[0]

        return output