コード例 #1
0
    def __init__(self, params: Params, vocab: Vocabulary) -> None:
        super().__init__(vocab=vocab)

        disc_hidden_dim = params.pop_int('disc_hidden_dim', 1200)
        disc_num_layers = params.pop_int('disc_num_layers', 1)
        code_dist_type = params.pop_choice('code_dist_type',
                                           ['gaussian', 'vmf'],
                                           default_to_first_choice=True)
        code_dim = params.pop_int('code_dim', 500)

        emb_dropout = params.pop_float('emb_dropout', 0.0)
        disc_dropout = params.pop_float('disc_dropout', 0.0)
        latent_dropout = params.pop_float('latent_dropout', 0.0)
        l2_weight = params.pop_float('l2_weight', 0.0)

        self.emb_dropout = nn.Dropout(emb_dropout)
        self.disc_dropout = nn.Dropout(disc_dropout)
        self.latent_dropout = nn.Dropout(latent_dropout)
        self._l2_weight = l2_weight

        self._token_embedder = Embedding.from_params(
            vocab=vocab, params=params.pop('token_embedder'))
        self._encoder = nn.Sequential(
            nn.Conv1d(in_channels=300,
                      out_channels=300,
                      kernel_size=5,
                      stride=2),
            nn.Conv1d(in_channels=300,
                      out_channels=600,
                      kernel_size=5,
                      stride=2),
            nn.Conv1d(in_channels=600,
                      out_channels=500,
                      kernel_size=5,
                      stride=2))
        self._generator = nn.Sequential(
            nn.ConvTranspose1d(in_channels=500,
                               out_channels=600,
                               kernel_size=5,
                               stride=2), nn.ReLU(),
            nn.ConvTranspose1d(in_channels=600,
                               out_channels=300,
                               kernel_size=5,
                               stride=2), nn.ReLU(),
            nn.ConvTranspose1d(in_channels=300,
                               out_channels=300,
                               kernel_size=5,
                               stride=2), nn.ReLU())
        self._generator_projector = nn.Linear(
            in_features=300, out_features=vocab.get_vocab_size(), bias=False)
        self._generator_projector.weight = self._token_embedder.weight

        if code_dist_type == 'vmf':
            vmf_kappa = params.pop_int('vmf_kappa', 150)
            self._code_generator = VmfCodeGenerator(input_dim=500,
                                                    code_dim=code_dim,
                                                    kappa=vmf_kappa)
        elif code_dist_type == 'gaussian':
            self._code_generator = GaussianCodeGenerator(input_dim=500,
                                                         code_dim=code_dim)
        else:
            raise ValueError('Unknown code_dist_type')

        self._discriminator = FeedForward(
            input_dim=4 * self._code_generator.get_output_dim(),
            hidden_dims=[disc_hidden_dim] * disc_num_layers +
            [self._NUM_LABELS],
            num_layers=disc_num_layers + 1,
            activations=[Activation.by_name('relu')()] * disc_num_layers +
            [Activation.by_name('linear')()],
            dropout=disc_dropout)

        self._kl_weight = 1.0
        self._discriminator_weight = params.pop_float('discriminator_weight',
                                                      0.1)
        self._gumbel_temperature = 1.0

        # Metrics
        self._metrics = {
            'generator_loss': ScalarMetric(),
            'kl_divergence': ScalarMetric(),
            'discriminator_accuracy': CategoricalAccuracy(),
            'discriminator_loss': ScalarMetric(),
            'loss': ScalarMetric()
        }
コード例 #2
0
class DeconvSNLIModel(Model):
    """NOTE: THE INPUT MUST BE OF LENGTH 29."""

    _NUM_LABELS = 3

    def __init__(self, params: Params, vocab: Vocabulary) -> None:
        super().__init__(vocab=vocab)

        disc_hidden_dim = params.pop_int('disc_hidden_dim', 1200)
        disc_num_layers = params.pop_int('disc_num_layers', 1)
        code_dist_type = params.pop_choice('code_dist_type',
                                           ['gaussian', 'vmf'],
                                           default_to_first_choice=True)
        code_dim = params.pop_int('code_dim', 500)

        emb_dropout = params.pop_float('emb_dropout', 0.0)
        disc_dropout = params.pop_float('disc_dropout', 0.0)
        latent_dropout = params.pop_float('latent_dropout', 0.0)
        l2_weight = params.pop_float('l2_weight', 0.0)

        self.emb_dropout = nn.Dropout(emb_dropout)
        self.disc_dropout = nn.Dropout(disc_dropout)
        self.latent_dropout = nn.Dropout(latent_dropout)
        self._l2_weight = l2_weight

        self._token_embedder = Embedding.from_params(
            vocab=vocab, params=params.pop('token_embedder'))
        self._encoder = nn.Sequential(
            nn.Conv1d(in_channels=300,
                      out_channels=300,
                      kernel_size=5,
                      stride=2),
            nn.Conv1d(in_channels=300,
                      out_channels=600,
                      kernel_size=5,
                      stride=2),
            nn.Conv1d(in_channels=600,
                      out_channels=500,
                      kernel_size=5,
                      stride=2))
        self._generator = nn.Sequential(
            nn.ConvTranspose1d(in_channels=500,
                               out_channels=600,
                               kernel_size=5,
                               stride=2), nn.ReLU(),
            nn.ConvTranspose1d(in_channels=600,
                               out_channels=300,
                               kernel_size=5,
                               stride=2), nn.ReLU(),
            nn.ConvTranspose1d(in_channels=300,
                               out_channels=300,
                               kernel_size=5,
                               stride=2), nn.ReLU())
        self._generator_projector = nn.Linear(
            in_features=300, out_features=vocab.get_vocab_size(), bias=False)
        self._generator_projector.weight = self._token_embedder.weight

        if code_dist_type == 'vmf':
            vmf_kappa = params.pop_int('vmf_kappa', 150)
            self._code_generator = VmfCodeGenerator(input_dim=500,
                                                    code_dim=code_dim,
                                                    kappa=vmf_kappa)
        elif code_dist_type == 'gaussian':
            self._code_generator = GaussianCodeGenerator(input_dim=500,
                                                         code_dim=code_dim)
        else:
            raise ValueError('Unknown code_dist_type')

        self._discriminator = FeedForward(
            input_dim=4 * self._code_generator.get_output_dim(),
            hidden_dims=[disc_hidden_dim] * disc_num_layers +
            [self._NUM_LABELS],
            num_layers=disc_num_layers + 1,
            activations=[Activation.by_name('relu')()] * disc_num_layers +
            [Activation.by_name('linear')()],
            dropout=disc_dropout)

        self._kl_weight = 1.0
        self._discriminator_weight = params.pop_float('discriminator_weight',
                                                      0.1)
        self._gumbel_temperature = 1.0

        # Metrics
        self._metrics = {
            'generator_loss': ScalarMetric(),
            'kl_divergence': ScalarMetric(),
            'discriminator_accuracy': CategoricalAccuracy(),
            'discriminator_loss': ScalarMetric(),
            'loss': ScalarMetric()
        }

    def get_regularization_penalty(self):
        sum_sq = sum(p.pow(2).sum() for p in self.parameters())
        l2_norm = sum_sq.sqrt()
        return self.l2_weight * l2_norm

    @property
    def l2_weight(self):
        return self._l2_weight

    @property
    def kl_weight(self):
        return self._kl_weight

    @kl_weight.setter
    def kl_weight(self, value):
        self._kl_weight = value

    @property
    def discriminator_weight(self):
        return self._discriminator_weight

    @discriminator_weight.setter
    def discriminator_weight(self, value):
        self._discriminator_weight = value

    def embed(self, tokens: torch.Tensor) -> torch.Tensor:
        return self._token_embedder(tokens)

    def encode(self, inputs: torch.Tensor) -> torch.Tensor:
        inputs = inputs.transpose(1, 2)  # (B, H, L)
        enc_hidden = self._encoder(inputs.contiguous()).squeeze(2)  # (B, H)
        return enc_hidden

    def sample_code_and_compute_kld(self,
                                    hidden: torch.Tensor) -> torch.Tensor:
        return self._code_generator(hidden)

    def discriminate(self, premise_hidden: torch.Tensor,
                     hypothesis_hidden: torch.Tensor) -> torch.Tensor:
        disc_input = torch.cat([
            premise_hidden, hypothesis_hidden,
            premise_hidden * hypothesis_hidden,
            (premise_hidden - hypothesis_hidden).abs()
        ],
                               dim=-1)
        disc_input = self.disc_dropout(disc_input)
        disc_logits = self._discriminator(disc_input)
        return disc_logits

    def generate_project(self, hiddens):
        # hiddens: (B, L, H)
        gen_proj_weights = self._generator_projector.weight  # (V, H)
        gen_proj_weights_norm = gen_proj_weights.norm(dim=1, keepdim=True)
        gen_proj_weights = gen_proj_weights / gen_proj_weights_norm
        hiddens_norm = hiddens.norm(dim=2, keepdim=True)
        hiddens = hiddens / hiddens_norm
        logits = functional.linear(input=hiddens,
                                   weight=gen_proj_weights,
                                   bias=None)
        return logits * 100.0  # divide by temperature (0.01).

    def generate(self, code: torch.Tensor) -> torch.Tensor:
        gen_hiddens = self._generator(code.unsqueeze(2)).transpose(1, 2)
        logits = self.generate_project(gen_hiddens)
        generated = logits.argmax(dim=2)
        return generated

    def convert_to_readable_text(self,
                                 generated: torch.Tensor) -> List[List[str]]:
        sequences = [seq.cpu().tolist() for seq in generated.unbind(0)]
        readable_sequences = []
        for seq in sequences:
            readable_seq = []
            for word_index in seq:
                if word_index != 0:
                    word = self.vocab.get_token_from_index(word_index)
                    readable_seq.append(word)
            readable_sequences.append(readable_seq)
        return readable_sequences

    def compute_generator_loss(self, code: torch.Tensor,
                               targets: torch.Tensor) -> torch.Tensor:
        hiddens = self._generator(code.unsqueeze(2))
        hiddens = hiddens.transpose(1, 2).contiguous()
        logits = self._generator_projector(hiddens)
        loss = sequence_cross_entropy_with_logits(
            logits=logits,
            targets=targets.contiguous(),
            weights=torch.ones_like(targets),
            average=None)
        loss = loss * logits.shape[1]
        return loss

    def forward(self,
                premise: Dict[str, torch.Tensor],
                hypothesis: Dict[str, torch.Tensor],
                label: Optional[torch.Tensor] = None) -> Dict[str, Any]:
        pre_tokens = premise['tokens']
        hyp_tokens = hypothesis['tokens']
        pre_token_embs = self.embed(pre_tokens)
        hyp_token_embs = self.embed(hyp_tokens)
        pre_token_embs = self.emb_dropout(pre_token_embs)
        hyp_token_embs = self.emb_dropout(hyp_token_embs)

        output_dict = {}

        pre_hidden = self.encode(inputs=pre_token_embs)
        hyp_hidden = self.encode(inputs=hyp_token_embs)
        pre_code, pre_kld = self.sample_code_and_compute_kld(pre_hidden)
        hyp_code, hyp_kld = self.sample_code_and_compute_kld(hyp_hidden)
        pre_code = self.latent_dropout(pre_code)
        hyp_code = self.latent_dropout(hyp_code)
        pre_kld = pre_kld.mean()
        hyp_kld = hyp_kld.mean()

        pre_gen_loss = self.compute_generator_loss(code=pre_code,
                                                   targets=pre_tokens)
        hyp_gen_loss = self.compute_generator_loss(code=hyp_code,
                                                   targets=hyp_tokens)
        pre_gen_loss = pre_gen_loss.mean()
        hyp_gen_loss = hyp_gen_loss.mean()

        gen_loss = pre_gen_loss + hyp_gen_loss
        kld = pre_kld + hyp_kld
        loss = gen_loss + self.kl_weight * kld

        if label is not None:
            disc_logits = self.discriminate(premise_hidden=pre_code,
                                            hypothesis_hidden=hyp_code)
            disc_loss = functional.cross_entropy(input=disc_logits,
                                                 target=label)
            loss = loss + self.discriminator_weight * disc_loss
            output_dict['discriminator_loss'] = disc_loss
            self._metrics['discriminator_loss'](disc_loss)
            self._metrics['discriminator_accuracy'](predictions=disc_logits,
                                                    gold_labels=label)

        output_dict['generator_loss'] = gen_loss
        output_dict['kl_divergence'] = kld
        output_dict['loss'] = loss
        self._metrics['generator_loss'](gen_loss)
        self._metrics['kl_divergence'](kld)
        self._metrics['loss'](loss)

        return output_dict

    def get_metrics(
            self,
            reset: bool = False) -> Dict[str, Union[float, Dict[str, float]]]:
        metrics = {
            k: v.get_metric(reset=reset)
            for k, v in self._metrics.items()
        }
        metrics['kl_weight'] = self.kl_weight
        metrics['discriminator_weight'] = self.discriminator_weight
        return metrics
コード例 #3
0
class SNLIModel(Model):

    _NUM_LABELS = 3

    def __init__(self, params: Params, vocab: Vocabulary) -> None:
        super().__init__(vocab=vocab)

        enc_hidden_dim = params.pop_int('enc_hidden_dim', 300)
        gen_hidden_dim = params.pop_int('gen_hidden_dim', 300)
        disc_hidden_dim = params.pop_int('disc_hidden_dim', 1200)
        disc_num_layers = params.pop_int('disc_num_layers', 1)
        code_dist_type = params.pop_choice('code_dist_type',
                                           ['gaussian', 'vmf'],
                                           default_to_first_choice=True)
        code_dim = params.pop_int('code_dim', 300)
        label_emb_dim = params.pop_int('label_emb_dim', 50)
        shared_encoder = params.pop_bool('shared_encoder', True)
        tie_embedding = params.pop_bool('tie_embedding', False)
        auto_weighting = params.pop_bool('auto_weighting', False)

        emb_dropout = params.pop_float('emb_dropout', 0.0)
        disc_dropout = params.pop_float('disc_dropout', 0.0)
        l2_weight = params.pop_float('l2_weight', 0.0)

        self.emb_dropout = nn.Dropout(emb_dropout)
        self.disc_dropout = nn.Dropout(disc_dropout)
        self._l2_weight = l2_weight
        self.auto_weighting = auto_weighting

        self._token_embedder = Embedding.from_params(
            vocab=vocab, params=params.pop('token_embedder'))
        self._label_embedder = Embedding(num_embeddings=self._NUM_LABELS,
                                         embedding_dim=label_emb_dim)
        self._encoder = PytorchSeq2VecWrapper(
            nn.LSTM(input_size=self._token_embedder.get_output_dim(),
                    hidden_size=enc_hidden_dim,
                    batch_first=True))
        self._generator = PytorchSeq2SeqWrapper(
            nn.LSTM(input_size=(self._token_embedder.get_output_dim() +
                                code_dim + label_emb_dim),
                    hidden_size=gen_hidden_dim,
                    batch_first=True))
        self._generator_projector = nn.Linear(
            in_features=self._generator.get_output_dim(),
            out_features=vocab.get_vocab_size())
        self._discriminator_encoder = PytorchSeq2VecWrapper(
            nn.LSTM(input_size=self._token_embedder.get_output_dim(),
                    hidden_size=enc_hidden_dim,
                    batch_first=True))
        if shared_encoder:
            self._discriminator_encoder = self._encoder
        if tie_embedding:
            self._generator_projector.weight = self._token_embedder.weight

        self._discriminator = FeedForward(
            input_dim=4 * self._discriminator_encoder.get_output_dim(),
            hidden_dims=[disc_hidden_dim] * disc_num_layers +
            [self._NUM_LABELS],
            num_layers=disc_num_layers + 1,
            activations=[Activation.by_name('relu')()] * disc_num_layers +
            [Activation.by_name('linear')()],
            dropout=disc_dropout)
        if code_dist_type == 'vmf':
            vmf_kappa = params.pop_int('vmf_kappa', 150)
            self._code_generator = VmfCodeGenerator(
                input_dim=self._encoder.get_output_dim(),
                code_dim=code_dim,
                kappa=vmf_kappa)
        elif code_dist_type == 'gaussian':
            self._code_generator = GaussianCodeGenerator(
                input_dim=self._encoder.get_output_dim(), code_dim=code_dim)
        else:
            raise ValueError('Unknown z_dist')

        self._kl_weight = 1.0
        self._discriminator_weight = params.pop_float('discriminator_weight',
                                                      0.1)
        self._gumbel_temperature = 1.0

        self._use_sampling = params.pop_bool('use_sampling', False)

        if auto_weighting:
            self.num_tasks = num_tasks = 3
            self.task_weights = nn.Parameter(torch.zeros(num_tasks))

        # Metrics
        self._metrics = {
            'labeled': {
                'generator_loss': ScalarMetric(),
                'kl_divergence': ScalarMetric(),
                'discriminator_entropy': ScalarMetric(),
                'discriminator_accuracy': CategoricalAccuracy(),
                'discriminator_loss': ScalarMetric(),
                'loss': ScalarMetric()
            },
            'unlabeled': {
                'generator_loss': ScalarMetric(),
                'kl_divergence': ScalarMetric(),
                'discriminator_entropy': ScalarMetric(),
                'loss': ScalarMetric()
            },
            'aux': {
                'discriminator_entropy': ScalarMetric(),
                'discriminator_accuracy': CategoricalAccuracy(),
                'discriminator_loss': ScalarMetric(),
                'gumbel_temperature': ScalarMetric(),
                'loss': ScalarMetric(),
                'code_log_prob': ScalarMetric(),
                'cosine_dist': ScalarMetric()
            }
        }

    def add_finetune_parameters(self,
                                con_autoweight=False,
                                con_y_weight=None,
                                con_z_weight=None,
                                con_z2_weight=None):
        self.con_autoweight = con_autoweight
        self.con_y_weight = con_y_weight
        self.con_z_weight = con_z_weight
        self.con_z2_weight = con_z2_weight
        if con_autoweight:
            self.con_y_weight_p = nn.Parameter(torch.zeros(1))
            self.con_z_weight_p = nn.Parameter(torch.zeros(1))
            self.con_z2_weight_p = nn.Parameter(torch.zeros(1))

    def finetune_main_parameters(self, exclude_generator=False):
        params = []
        for name, param in self.named_parameters():
            if exclude_generator:
                if 'generator' in name:
                    continue
            params.append(param)
        return params

    def finetune_aux_parameters(self):
        gen_params = list(self._generator.parameters())
        gen_proj_params = list(self._generator_projector.parameters())
        emb_params = list(self._token_embedder.parameters())
        # enc_params = list(self._encoder.parameters())
        # code_gen_params = list(self._code_generator.parameters())
        con_params = []
        if self.con_autoweight:
            con_params = [
                self.con_y_weight_p, self.con_z_weight_p, self.con_z2_weight_p
            ]
        return gen_params + gen_proj_params + emb_params + con_params

    def get_regularization_penalty(self):
        sum_sq = sum(p.pow(2).sum() for p in self.parameters())
        l2_norm = sum_sq.sqrt()
        return self.l2_weight * l2_norm

    @property
    def gumbel_temperature(self):
        return self._gumbel_temperature

    @gumbel_temperature.setter
    def gumbel_temperature(self, value):
        self._gumbel_temperature = value

    @property
    def l2_weight(self):
        return self._l2_weight

    @property
    def kl_weight(self):
        return self._kl_weight

    @kl_weight.setter
    def kl_weight(self, value):
        self._kl_weight = value

    @property
    def discriminator_weight(self):
        return self._discriminator_weight

    @discriminator_weight.setter
    def discriminator_weight(self, value):
        self._discriminator_weight = value

    def embed(self, tokens: torch.Tensor) -> torch.Tensor:
        return self._token_embedder(tokens)

    def encode(self,
               inputs: torch.Tensor,
               mask: torch.Tensor,
               drop_start_token: bool = True) -> torch.Tensor:
        if drop_start_token:
            inputs = inputs[:, 1:]
            mask = mask[:, 1:]
        enc_hidden = self._encoder(inputs.contiguous(), mask)
        return enc_hidden

    def sample_code_and_compute_kld(self,
                                    hidden: torch.Tensor) -> torch.Tensor:
        return self._code_generator(hidden)

    def discriminator_encode(self,
                             inputs: torch.Tensor,
                             mask: torch.Tensor,
                             drop_start_token: bool = True) -> torch.Tensor:
        if drop_start_token:
            inputs = inputs[:, 1:]
            mask = mask[:, 1:]
        enc_hidden = self._discriminator_encoder(inputs.contiguous(), mask)
        return enc_hidden

    def discriminate(self, premise_hidden: torch.Tensor,
                     hypothesis_hidden: torch.Tensor) -> torch.Tensor:
        disc_input = torch.cat([
            premise_hidden, hypothesis_hidden,
            premise_hidden * hypothesis_hidden,
            (premise_hidden - hypothesis_hidden).abs()
        ],
                               dim=-1)
        disc_input = self.disc_dropout(disc_input)
        disc_logits = self._discriminator(disc_input)
        return disc_logits

    def construct_generator_inputs(self, embeddings: torch.Tensor,
                                   code: torch.Tensor,
                                   label: torch.Tensor) -> torch.Tensor:
        batch_size, max_length, _ = embeddings.shape
        code_expand = code.unsqueeze(1).expand(batch_size, max_length, -1)
        label_emb = self._label_embedder(label)
        label_emb_expand = label_emb.unsqueeze(1).expand(
            batch_size, max_length, -1)
        inputs = torch.cat([embeddings, code_expand, label_emb_expand], dim=-1)
        return inputs

    def beam_search_step(
        self, prev_predicted: torch.Tensor, prev_state: Dict[str, torch.Tensor]
    ) -> (Tuple[torch.Tensor, Dict[str, torch.Tensor]]):
        """
        Args:
            prev_predicted: (group_size,)
            prev_state: {'h': (group_size, hidden_dim),
                         'c': (group_size, hidden_dim),
                         'code': (group_size, code_dim),
                         'label': (group_size,)
                         'length': (group_size,),
                         'length_alpha': (group_size,)}

        Returns:
            log_probs: (group_size, vocab_size)
            state: {'h': (group_size, hidden_dim),
                    'c': (group_size, hidden_dim),
                    'code': (group_size, code_dim),
                    'label': (group_size,),
                    'length': (group_size,),
                    'length_alpha': (group_size,)}
        """
        # prev_word_emb: (group_size, 1, word_dim)
        prev_word_emb = self.embed(prev_predicted).unsqueeze(1)
        lstm_prev_state = (prev_state['h'].unsqueeze(0),
                           prev_state['c'].unsqueeze(0))
        code = prev_state['code']
        label = prev_state['label']
        length = prev_state['length']
        length_alpha = prev_state['length_alpha']
        # input_t: (group_size, 1, word_dim + code_dim + label_dim)
        input_t = self.construct_generator_inputs(embeddings=prev_word_emb,
                                                  code=code,
                                                  label=label)
        hidden_t, state_t = self._generator._module(input=input_t,
                                                    hx=lstm_prev_state)
        # log_probs: (group_size, vocab_size)
        length_penalty = ((5.0 + length.float()) / 6)**length_alpha.float()
        log_probs = functional.log_softmax(
            self._generator_projector(hidden_t).squeeze(1), dim=-1)
        log_probs = log_probs / length_penalty.unsqueeze(1)
        state = {
            'h': state_t[0].squeeze(0),
            'c': state_t[1].squeeze(0),
            'code': code,
            'label': label,
            'length': length + 1,
            'length_alpha': length_alpha
        }
        return log_probs, state

    def generate(self, code: torch.Tensor, label: torch.Tensor,
                 max_length: int, beam_size: int,
                 lp_alpha: float) -> torch.Tensor:
        start_index = self.vocab.get_token_index('<s>')
        end_index = self.vocab.get_token_index('</s>')
        beam_search = BeamSearch(end_index=end_index,
                                 max_steps=max_length,
                                 beam_size=beam_size,
                                 per_node_beam_size=3)
        batch_size = code.shape[0]
        start_predictions = (
            torch.empty(batch_size).to(label).fill_(start_index))
        zero_state = code.new_zeros(batch_size,
                                    self._generator._module.hidden_size)
        start_state = {
            'h': zero_state,
            'c': zero_state,
            'code': code,
            'label': label,
            'length': label.new_ones(batch_size),
            'length_alpha': (code.new_empty(batch_size).fill_(lp_alpha))
        }

        all_predictions, last_log_probs = beam_search.search(
            start_predictions=start_predictions,
            start_state=start_state,
            step=self.beam_search_step)
        return all_predictions

    def gumbel_softmax(self, logits):
        u = torch.rand_like(logits)
        g = -torch.log(-torch.log(u + 1e-20) + 1e-20)
        new_logits = (logits + g) / self.gumbel_temperature
        probs = functional.softmax(new_logits, dim=-1)
        return probs

    def generate_soft(self, code: torch.Tensor, label: torch.Tensor,
                      length: torch.Tensor) -> torch.Tensor:
        """
        Generate soft predictions using the Gumbel-Softmax
        reparameterization. Note that the generated sentence
        always has exactly the length of `length`.
        """
        start_index = self.vocab.get_token_index('<s>')
        batch_size = code.shape[0]
        prev_word = (torch.empty(
            batch_size,
            device=code.device).long().unsqueeze(1).fill_(start_index))
        prev_word_emb = self.embed(prev_word)
        max_length = length.max().item()
        generated_embs = []
        self._generator.stateful = True
        self._generator.reset_states()
        for t in range(max_length):
            input_t = self.construct_generator_inputs(embeddings=prev_word_emb,
                                                      code=code,
                                                      label=label)
            mask = length.gt(t).long().unsqueeze(1)
            hidden_t = self._generator(input_t, mask)
            logit_t = self._generator_projector(hidden_t)
            gumbel_probs_t = self.gumbel_softmax(logit_t)
            emb_t = torch.matmul(gumbel_probs_t, self._token_embedder.weight)
            generated_embs.append(emb_t)
            prev_word_emb = emb_t
        self._generator.stateful = False
        self._generator.reset_states()
        generated_embs = torch.cat(generated_embs, dim=1)
        return generated_embs

    def convert_to_readable_text(self,
                                 generated: torch.Tensor) -> List[List[str]]:
        sequences = [seq.cpu().tolist() for seq in generated.unbind(0)]
        readable_sequences = []
        for seq in sequences:
            readable_seq = []
            for word_index in seq:
                word = self.vocab.get_token_from_index(word_index)
                if word == '</s>':
                    break
                readable_seq.append(word)
            readable_sequences.append(readable_seq)
        return readable_sequences

    def compute_generator_loss(
        self,
        embeddings: torch.Tensor,
        code: torch.Tensor,
        label: torch.Tensor,
        targets: torch.Tensor,
        mask: torch.Tensor,
    ) -> torch.Tensor:
        inputs = self.construct_generator_inputs(embeddings=embeddings,
                                                 code=code,
                                                 label=label)
        hiddens = self._generator(inputs.contiguous(), mask)
        logits = self._generator_projector(hiddens)
        weights = mask.float()
        loss = sequence_cross_entropy_with_logits(logits=logits,
                                                  targets=targets.contiguous(),
                                                  weights=weights,
                                                  average=None)
        return loss

    def aux_forward(self, premise: Dict[str, torch.Tensor],
                    hypothesis: Dict[str, torch.Tensor]):
        """
        Generate the hypothesis dynamically given a premise
        and a sampled label, then compute the discriminator loss.
        This is intended to update only generator parameters,
        thus unnecessary gradients will not be accumulated
        for the reduced memory usage and faster training.
        """
        pre_mask = get_text_field_mask(premise)
        pre_tokens = premise['tokens']
        hyp_mask = get_text_field_mask(hypothesis)
        hyp_tokens = hypothesis['tokens']

        # with torch.no_grad():
        pre_token_embs = self.embed(pre_tokens)
        pre_hidden = self.encode(inputs=pre_token_embs,
                                 mask=pre_mask,
                                 drop_start_token=True)
        code, kld = self.sample_code_and_compute_kld(pre_hidden)
        batch_size = code.shape[0]
        label_dist = Categorical(
            logits=torch.ones(self._NUM_LABELS, device=code.device))
        label = label_dist.sample((batch_size, ))

        gen_hyp_token_embs = self.generate_soft(code=code,
                                                label=label,
                                                length=pre_mask.sum(1))
        gen_hyp_hidden = self.encode(inputs=gen_hyp_token_embs,
                                     mask=pre_mask,
                                     drop_start_token=False)

        loss = 0
        output_dict = {}
        if self.con_y_weight > 0:
            disc_logits = self.discriminate(premise_hidden=pre_hidden,
                                            hypothesis_hidden=gen_hyp_hidden)
            disc_dist = Categorical(logits=disc_logits)
            disc_entropy = disc_dist.entropy().mean()
            disc_loss = functional.cross_entropy(input=disc_logits,
                                                 target=label)
            if self.con_autoweight:
                yw = self.con_y_weight_p.exp().reciprocal()
                reg = self.con_y_weight_p * 0.5
                loss = loss + yw * disc_loss + reg
            else:
                loss = loss + self.con_y_weight * disc_loss
            output_dict['discriminator_entropy'] = disc_entropy
            output_dict['discriminator_loss'] = disc_loss
            self._metrics['aux']['discriminator_entropy'](disc_entropy)
            self._metrics['aux']['discriminator_loss'](disc_loss)
            self._metrics['aux']['discriminator_accuracy'](
                predictions=disc_logits, gold_labels=label)

        if self.con_z_weight > 0:
            hyp_token_embs = self.embed(hyp_tokens)
            hyp_hidden = self.encode(inputs=hyp_token_embs,
                                     mask=hyp_mask,
                                     drop_start_token=True)
            hyp_code, hyp_kld = self.sample_code_and_compute_kld(hyp_hidden)
            gen_hyp_dist = self._code_generator.get_distribution(
                gen_hyp_hidden)
            code_log_prob = -gen_hyp_dist.log_prob(hyp_code).mean()
            code_loss = -code_log_prob
            if self.con_autoweight:
                zw = self.con_z_weight_p.exp().reciprocal()
                reg = self.con_z_weight_p * 0.5
                loss = loss + zw * code_loss + reg
            else:
                loss = loss + self.con_z_weight * code_loss

            output_dict['code_loss'] = code_loss
            self._metrics['aux']['code_log_prob'](code_log_prob)

        if self.con_z2_weight > 0:
            gen_hyp_dist = self._code_generator.get_distribution(
                gen_hyp_hidden)
            mu = gen_hyp_dist.loc
            mu_bar = mu.mean(dim=0, keepdim=True)
            mu_bar = mu_bar / mu_bar.norm(dim=1, keepdim=True)
            cosine_dist = 1 - (mu * mu_bar).sum(dim=1)
            z2_loss = -cosine_dist.mean(dim=0)  # Scalar
            if self.con_autoweight:
                z2w = self.con_z2_weight_p.exp().reciprocal()
                reg = self.con_z2_weight_p * 0.5
                loss = loss + z2w * z2_loss + reg
            else:
                loss = loss + self.con_z2_weight * z2_loss
            output_dict['cosine_dist_mean'] = z2_loss
            self._metrics['aux']['cosine_dist'](-z2_loss)

        output_dict['loss'] = loss
        self._metrics['aux']['gumbel_temperature'](self.gumbel_temperature)
        self._metrics['aux']['loss'](loss)
        return output_dict

    def forward(self,
                premise: Dict[str, torch.Tensor],
                hypothesis: Dict[str, torch.Tensor],
                label: Optional[torch.Tensor] = None) -> Dict[str, Any]:
        """
        premise and hypothesis are padded with
        the BOS and the EOS token.
        """
        pre_mask = get_text_field_mask(premise)
        hyp_mask = get_text_field_mask(hypothesis)
        pre_tokens = premise['tokens']
        hyp_tokens = hypothesis['tokens']
        pre_token_embs = self.embed(pre_tokens)
        hyp_token_embs = self.embed(hyp_tokens)
        pre_token_embs = self.emb_dropout(pre_token_embs)
        hyp_token_embs = self.emb_dropout(hyp_token_embs)

        output_dict = {}

        if label is not None:  # Labeled
            pre_hidden = self.encode(inputs=pre_token_embs,
                                     mask=pre_mask,
                                     drop_start_token=True)
            # hyp_hidden = self.encode(
            #     inputs=hyp_token_embs, mask=hyp_mask, drop_start_token=True)
            pre_disc_hidden = self.discriminator_encode(inputs=pre_token_embs,
                                                        mask=pre_mask,
                                                        drop_start_token=True)
            hyp_disc_hidden = self.discriminator_encode(inputs=hyp_token_embs,
                                                        mask=hyp_mask,
                                                        drop_start_token=True)
            disc_logits = self.discriminate(premise_hidden=pre_disc_hidden,
                                            hypothesis_hidden=hyp_disc_hidden)
            disc_dist = Categorical(logits=disc_logits)
            disc_entropy = disc_dist.entropy().mean()
            disc_loss = functional.cross_entropy(input=disc_logits,
                                                 target=label)

            code, kld = self.sample_code_and_compute_kld(pre_hidden)
            kld = kld.mean()
            gen_mask = hyp_mask[:, 1:]
            gen_loss = self.compute_generator_loss(
                embeddings=hyp_token_embs[:, :-1],
                code=code,
                label=label,
                targets=hyp_tokens[:, 1:],
                mask=gen_mask)
            gen_loss = gen_loss.mean()
            if not self.auto_weighting:
                loss = (gen_loss + self.kl_weight * kld +
                        self.discriminator_weight * disc_loss)
            else:
                tw0_weight = self.task_weights[0].exp().reciprocal()
                tw1_weight = self.task_weights[1].exp().reciprocal()
                tw0_reg = 0.5 * self.task_weights[0]
                tw1_reg = 0.5 * self.task_weights[1]
                loss = (tw0_weight * (gen_loss + kld) +
                        tw1_weight * disc_loss) + tw0_reg + tw1_reg

            output_dict['discriminator_entropy'] = disc_entropy
            output_dict['discriminator_loss'] = disc_loss
            output_dict['generator_loss'] = gen_loss
            output_dict['kl_divergence'] = kld
            output_dict['loss'] = loss

            self._metrics['labeled']['discriminator_entropy'](disc_entropy)
            self._metrics['labeled']['discriminator_loss'](disc_loss)
            self._metrics['labeled']['discriminator_accuracy'](
                predictions=disc_logits, gold_labels=label)
            self._metrics['labeled']['generator_loss'](gen_loss)
            self._metrics['labeled']['kl_divergence'](kld)
            self._metrics['labeled']['loss'](loss)
        else:  # Unlabeled
            pre_hidden = self.encode(inputs=pre_token_embs,
                                     mask=pre_mask,
                                     drop_start_token=True)
            # hyp_hidden = self.encode(
            #     inputs=hyp_token_embs, mask=hyp_mask, drop_start_token=True)
            pre_disc_hidden = self.discriminator_encode(inputs=pre_token_embs,
                                                        mask=pre_mask,
                                                        drop_start_token=True)
            hyp_disc_hidden = self.discriminator_encode(inputs=hyp_token_embs,
                                                        mask=hyp_mask,
                                                        drop_start_token=True)
            disc_logits = self.discriminate(premise_hidden=pre_disc_hidden,
                                            hypothesis_hidden=hyp_disc_hidden)
            disc_dist = Categorical(logits=disc_logits)
            disc_entropy = disc_dist.entropy().mean()

            code, kld = self.sample_code_and_compute_kld(pre_hidden)
            kld = kld.mean()

            batch_size = pre_hidden.shape[0]
            if not self._use_sampling:
                label = torch.arange(self._NUM_LABELS,
                                     dtype=torch.long,
                                     device=pre_hidden.device)
                label_repeat = label.unsqueeze(1).repeat(1,
                                                         batch_size).view(-1)
                targets_repeat = hyp_tokens[:, 1:].repeat(self._NUM_LABELS, 1)
                gen_mask_repeat = hyp_mask[:, 1:].repeat(self._NUM_LABELS, 1)
                hyp_token_embs_repeat = hyp_token_embs[:, :-1].repeat(
                    self._NUM_LABELS, 1, 1)
                code_repeat = code.repeat(self._NUM_LABELS, 1)
                gen_loss = self.compute_generator_loss(
                    embeddings=hyp_token_embs_repeat,
                    code=code_repeat,
                    label=label_repeat,
                    targets=targets_repeat,
                    mask=gen_mask_repeat)
                gen_loss = (gen_loss.contiguous().view(-1, self._NUM_LABELS) *
                            disc_dist.probs)
                gen_loss = gen_loss.sum(1).mean()
                loss = gen_loss + self.kl_weight * kld - disc_entropy
            else:
                label = disc_dist.sample()
                gen_loss = self.compute_generator_loss(
                    embeddings=hyp_token_embs,
                    code=code,
                    label=label,
                    targets=hyp_tokens[:, 1:],
                    mask=hyp_mask[:, 1:])
                gen_loss = gen_loss.mean()
                loss = gen_loss + self.kl_weight * kld - disc_entropy
            if self.auto_weighting:
                tw2_weight = self.task_weights[2].exp().reciprocal()
                tw2_reg = 0.5 * self.task_weights[2]
                loss = tw2_weight * loss + tw2_reg

            output_dict['discriminator_entropy'] = disc_entropy
            output_dict['generator_loss'] = gen_loss
            output_dict['kl_divergence'] = kld
            output_dict['loss'] = loss

            self._metrics['unlabeled']['discriminator_entropy'](disc_entropy)
            self._metrics['unlabeled']['generator_loss'](gen_loss)
            self._metrics['unlabeled']['kl_divergence'](kld)
            self._metrics['unlabeled']['loss'](loss)

        return output_dict

    def get_metrics(
            self,
            reset: bool = False) -> Dict[str, Union[float, Dict[str, float]]]:
        metrics = {
            label_type:
            {k: v.get_metric(reset=reset)
             for k, v in label_metrics.items()}
            for label_type, label_metrics in self._metrics.items()
        }
        metrics['kl_weight'] = self.kl_weight
        metrics['discriminator_weight'] = self.discriminator_weight
        return metrics
コード例 #4
0
    def __init__(self, params: Params, vocab: Vocabulary) -> None:
        super().__init__(vocab=vocab)

        enc_hidden_dim = params.pop_int('enc_hidden_dim', 300)
        gen_hidden_dim = params.pop_int('gen_hidden_dim', 300)
        disc_hidden_dim = params.pop_int('disc_hidden_dim', 1200)
        disc_num_layers = params.pop_int('disc_num_layers', 1)
        code_dist_type = params.pop_choice('code_dist_type',
                                           ['gaussian', 'vmf'],
                                           default_to_first_choice=True)
        code_dim = params.pop_int('code_dim', 300)
        label_emb_dim = params.pop_int('label_emb_dim', 50)
        shared_encoder = params.pop_bool('shared_encoder', True)
        tie_embedding = params.pop_bool('tie_embedding', False)
        auto_weighting = params.pop_bool('auto_weighting', False)

        emb_dropout = params.pop_float('emb_dropout', 0.0)
        disc_dropout = params.pop_float('disc_dropout', 0.0)
        l2_weight = params.pop_float('l2_weight', 0.0)

        self.emb_dropout = nn.Dropout(emb_dropout)
        self.disc_dropout = nn.Dropout(disc_dropout)
        self._l2_weight = l2_weight
        self.auto_weighting = auto_weighting

        self._token_embedder = Embedding.from_params(
            vocab=vocab, params=params.pop('token_embedder'))
        self._label_embedder = Embedding(num_embeddings=self._NUM_LABELS,
                                         embedding_dim=label_emb_dim)
        self._encoder = PytorchSeq2VecWrapper(
            nn.LSTM(input_size=self._token_embedder.get_output_dim(),
                    hidden_size=enc_hidden_dim,
                    batch_first=True))
        self._generator = PytorchSeq2SeqWrapper(
            nn.LSTM(input_size=(self._token_embedder.get_output_dim() +
                                code_dim + label_emb_dim),
                    hidden_size=gen_hidden_dim,
                    batch_first=True))
        self._generator_projector = nn.Linear(
            in_features=self._generator.get_output_dim(),
            out_features=vocab.get_vocab_size())
        self._discriminator_encoder = PytorchSeq2VecWrapper(
            nn.LSTM(input_size=self._token_embedder.get_output_dim(),
                    hidden_size=enc_hidden_dim,
                    batch_first=True))
        if shared_encoder:
            self._discriminator_encoder = self._encoder
        if tie_embedding:
            self._generator_projector.weight = self._token_embedder.weight

        self._discriminator = FeedForward(
            input_dim=4 * self._discriminator_encoder.get_output_dim(),
            hidden_dims=[disc_hidden_dim] * disc_num_layers +
            [self._NUM_LABELS],
            num_layers=disc_num_layers + 1,
            activations=[Activation.by_name('relu')()] * disc_num_layers +
            [Activation.by_name('linear')()],
            dropout=disc_dropout)
        if code_dist_type == 'vmf':
            vmf_kappa = params.pop_int('vmf_kappa', 150)
            self._code_generator = VmfCodeGenerator(
                input_dim=self._encoder.get_output_dim(),
                code_dim=code_dim,
                kappa=vmf_kappa)
        elif code_dist_type == 'gaussian':
            self._code_generator = GaussianCodeGenerator(
                input_dim=self._encoder.get_output_dim(), code_dim=code_dim)
        else:
            raise ValueError('Unknown z_dist')

        self._kl_weight = 1.0
        self._discriminator_weight = params.pop_float('discriminator_weight',
                                                      0.1)
        self._gumbel_temperature = 1.0

        self._use_sampling = params.pop_bool('use_sampling', False)

        if auto_weighting:
            self.num_tasks = num_tasks = 3
            self.task_weights = nn.Parameter(torch.zeros(num_tasks))

        # Metrics
        self._metrics = {
            'labeled': {
                'generator_loss': ScalarMetric(),
                'kl_divergence': ScalarMetric(),
                'discriminator_entropy': ScalarMetric(),
                'discriminator_accuracy': CategoricalAccuracy(),
                'discriminator_loss': ScalarMetric(),
                'loss': ScalarMetric()
            },
            'unlabeled': {
                'generator_loss': ScalarMetric(),
                'kl_divergence': ScalarMetric(),
                'discriminator_entropy': ScalarMetric(),
                'loss': ScalarMetric()
            },
            'aux': {
                'discriminator_entropy': ScalarMetric(),
                'discriminator_accuracy': CategoricalAccuracy(),
                'discriminator_loss': ScalarMetric(),
                'gumbel_temperature': ScalarMetric(),
                'loss': ScalarMetric(),
                'code_log_prob': ScalarMetric(),
                'cosine_dist': ScalarMetric()
            }
        }
コード例 #5
0
class SeparatedSNLIModel(Model):

    _NUM_LABELS = 3

    def __init__(self, params: Params, vocab: Vocabulary) -> None:
        super().__init__(vocab=vocab)

        enc_hidden_dim = params.pop_int('enc_hidden_dim', 300)
        gen_hidden_dim = params.pop_int('gen_hidden_dim', 300)
        disc_hidden_dim = params.pop_int('disc_hidden_dim', 1200)
        disc_num_layers = params.pop_int('disc_num_layers', 1)
        code_dist_type = params.pop_choice('code_dist_type',
                                           ['gaussian', 'vmf'],
                                           default_to_first_choice=True)
        code_dim = params.pop_int('code_dim', 300)
        tie_embedding = params.pop_bool('tie_embedding', False)

        emb_dropout = params.pop_float('emb_dropout', 0.0)
        disc_dropout = params.pop_float('disc_dropout', 0.0)
        l2_weight = params.pop_float('l2_weight', 0.0)

        self.emb_dropout = nn.Dropout(emb_dropout)
        self.disc_dropout = nn.Dropout(disc_dropout)
        self._l2_weight = l2_weight

        self._token_embedder = Embedding.from_params(
            vocab=vocab, params=params.pop('token_embedder'))
        self._encoder = PytorchSeq2VecWrapper(
            nn.LSTM(input_size=self._token_embedder.get_output_dim(),
                    hidden_size=enc_hidden_dim,
                    batch_first=True))
        self._generator = PytorchSeq2SeqWrapper(
            nn.LSTM(input_size=(self._token_embedder.get_output_dim() +
                                code_dim),
                    hidden_size=gen_hidden_dim,
                    batch_first=True))
        self._generator_projector = nn.Linear(
            in_features=self._generator.get_output_dim(),
            out_features=vocab.get_vocab_size())
        if tie_embedding:
            self._generator_projector.weight = self._token_embedder.weight

        if code_dist_type == 'vmf':
            vmf_kappa = params.pop_int('vmf_kappa', 150)
            self._code_generator = VmfCodeGenerator(
                input_dim=self._encoder.get_output_dim(),
                code_dim=code_dim,
                kappa=vmf_kappa)
        elif code_dist_type == 'gaussian':
            self._code_generator = GaussianCodeGenerator(
                input_dim=self._encoder.get_output_dim(), code_dim=code_dim)
        else:
            raise ValueError('Unknown code_dist_type')

        self._discriminator = FeedForward(
            input_dim=4 * self._code_generator.get_output_dim(),
            hidden_dims=[disc_hidden_dim] * disc_num_layers +
            [self._NUM_LABELS],
            num_layers=disc_num_layers + 1,
            activations=[Activation.by_name('relu')()] * disc_num_layers +
            [Activation.by_name('linear')()],
            dropout=disc_dropout)

        self._kl_weight = 1.0
        self._discriminator_weight = params.pop_float('discriminator_weight',
                                                      0.1)
        self._gumbel_temperature = 1.0

        # Metrics
        self._metrics = {
            'generator_loss': ScalarMetric(),
            'kl_divergence': ScalarMetric(),
            'discriminator_accuracy': CategoricalAccuracy(),
            'discriminator_loss': ScalarMetric(),
            'loss': ScalarMetric()
        }

    def get_regularization_penalty(self):
        sum_sq = sum(p.pow(2).sum() for p in self.parameters())
        l2_norm = sum_sq.sqrt()
        return self.l2_weight * l2_norm

    @property
    def l2_weight(self):
        return self._l2_weight

    @property
    def kl_weight(self):
        return self._kl_weight

    @kl_weight.setter
    def kl_weight(self, value):
        self._kl_weight = value

    @property
    def discriminator_weight(self):
        return self._discriminator_weight

    @discriminator_weight.setter
    def discriminator_weight(self, value):
        self._discriminator_weight = value

    def embed(self, tokens: torch.Tensor) -> torch.Tensor:
        return self._token_embedder(tokens)

    def encode(self,
               inputs: torch.Tensor,
               mask: torch.Tensor,
               drop_start_token: bool = True) -> torch.Tensor:
        if drop_start_token:
            inputs = inputs[:, 1:]
            mask = mask[:, 1:]
        enc_hidden = self._encoder(inputs.contiguous(), mask)
        return enc_hidden

    def sample_code_and_compute_kld(self,
                                    hidden: torch.Tensor) -> torch.Tensor:
        return self._code_generator(hidden)

    def discriminate(self, premise_hidden: torch.Tensor,
                     hypothesis_hidden: torch.Tensor) -> torch.Tensor:
        disc_input = torch.cat([
            premise_hidden, hypothesis_hidden,
            premise_hidden * hypothesis_hidden,
            (premise_hidden - hypothesis_hidden).abs()
        ],
                               dim=-1)
        disc_input = self.disc_dropout(disc_input)
        disc_logits = self._discriminator(disc_input)
        return disc_logits

    def construct_generator_inputs(self, embeddings: torch.Tensor,
                                   code: torch.Tensor) -> torch.Tensor:
        batch_size, max_length, _ = embeddings.shape
        code_expand = code.unsqueeze(1).expand(batch_size, max_length, -1)
        inputs = torch.cat([embeddings, code_expand], dim=-1)
        return inputs

    def generate(self, code: torch.Tensor,
                 max_length: torch.Tensor) -> torch.Tensor:
        start_index = self.vocab.get_token_index('<s>')
        end_index = self.vocab.get_token_index('</s>')
        pad_index = 0
        done = torch.zeros_like(max_length).long()
        max_max_length = max_length.max().item()

        prev_word = (
            torch.empty_like(done).long().unsqueeze(1).fill_(start_index))
        generated = []
        self._generator.stateful = True
        self._generator.reset_states()
        for t in range(max_max_length):
            if done.byte().all():
                break
            prev_word_emb = self.embed(prev_word)
            input_t = self.construct_generator_inputs(embeddings=prev_word_emb,
                                                      code=code)
            hidden_t = self._generator(input_t, 1 - done.unsqueeze(1))
            pred_t = self._generator_projector(hidden_t).argmax(2)
            pred_t.masked_fill_(done.byte(), pad_index)
            generated.append(pred_t)
            done.masked_fill_(pred_t.eq(end_index).squeeze(1), 1)
            done.masked_fill_(max_length.le(t + 1), 1)
            prev_word = pred_t
        self._generator.stateful = False
        generated = torch.cat(generated, dim=1)
        return generated

    def convert_to_readable_text(self,
                                 generated: torch.Tensor) -> List[List[str]]:
        sequences = [seq.cpu().tolist() for seq in generated.unbind(0)]
        readable_sequences = []
        for seq in sequences:
            readable_seq = []
            for word_index in seq:
                if word_index != 0:
                    word = self.vocab.get_token_from_index(word_index)
                    readable_seq.append(word)
            readable_sequences.append(readable_seq)
        return readable_sequences

    def compute_generator_loss(self, embeddings: torch.Tensor,
                               code: torch.Tensor, targets: torch.Tensor,
                               mask: torch.Tensor) -> torch.Tensor:
        inputs = self.construct_generator_inputs(embeddings=embeddings,
                                                 code=code)
        hiddens = self._generator(inputs.contiguous(), mask)
        logits = self._generator_projector(hiddens)
        weights = mask.float()
        loss = sequence_cross_entropy_with_logits(logits=logits,
                                                  targets=targets.contiguous(),
                                                  weights=weights,
                                                  average=None)
        return loss

    def forward(self,
                premise: Dict[str, torch.Tensor],
                hypothesis: Dict[str, torch.Tensor],
                label: Optional[torch.Tensor] = None) -> Dict[str, Any]:
        """
        premise and hypothesis are padded with
        the BOS and the EOS token.
        """
        pre_mask = get_text_field_mask(premise)
        hyp_mask = get_text_field_mask(hypothesis)
        pre_tokens = premise['tokens']
        hyp_tokens = hypothesis['tokens']
        pre_token_embs = self.embed(pre_tokens)
        hyp_token_embs = self.embed(hyp_tokens)
        pre_token_embs = self.emb_dropout(pre_token_embs)
        hyp_token_embs = self.emb_dropout(hyp_token_embs)

        output_dict = {}

        pre_hidden = self.encode(inputs=pre_token_embs,
                                 mask=pre_mask,
                                 drop_start_token=True)
        hyp_hidden = self.encode(inputs=hyp_token_embs,
                                 mask=hyp_mask,
                                 drop_start_token=True)
        pre_code, pre_kld = self.sample_code_and_compute_kld(pre_hidden)
        hyp_code, hyp_kld = self.sample_code_and_compute_kld(hyp_hidden)
        pre_kld = pre_kld.mean()
        hyp_kld = hyp_kld.mean()

        pre_gen_mask = pre_mask[:, 1:]
        hyp_gen_mask = hyp_mask[:, 1:]
        pre_gen_loss = self.compute_generator_loss(
            embeddings=pre_token_embs[:, :-1],
            code=pre_code,
            targets=pre_tokens[:, 1:],
            mask=pre_gen_mask)
        hyp_gen_loss = self.compute_generator_loss(
            embeddings=hyp_token_embs[:, :-1],
            code=hyp_code,
            targets=hyp_tokens[:, 1:],
            mask=hyp_gen_mask)
        pre_gen_loss = pre_gen_loss.mean()
        hyp_gen_loss = hyp_gen_loss.mean()

        gen_loss = pre_gen_loss + hyp_gen_loss
        kld = pre_kld + hyp_kld
        loss = gen_loss + self.kl_weight * kld

        if label is not None:
            disc_logits = self.discriminate(premise_hidden=pre_code,
                                            hypothesis_hidden=hyp_code)
            disc_loss = functional.cross_entropy(input=disc_logits,
                                                 target=label)
            loss = loss + self.discriminator_weight * disc_loss
            output_dict['discriminator_loss'] = disc_loss
            self._metrics['discriminator_loss'](disc_loss)
            self._metrics['discriminator_accuracy'](predictions=disc_logits,
                                                    gold_labels=label)

        output_dict['generator_loss'] = gen_loss
        output_dict['kl_divergence'] = kld
        output_dict['loss'] = loss
        self._metrics['generator_loss'](gen_loss)
        self._metrics['kl_divergence'](kld)
        self._metrics['loss'](loss)

        return output_dict

    def get_metrics(
            self,
            reset: bool = False) -> Dict[str, Union[float, Dict[str, float]]]:
        metrics = {
            k: v.get_metric(reset=reset)
            for k, v in self._metrics.items()
        }
        metrics['kl_weight'] = self.kl_weight
        metrics['discriminator_weight'] = self.discriminator_weight
        return metrics