示例#1
0
    def _generate(
        model: nn.Module,
        start_tokens: torch.LongTensor,
        end_token: int,
        filename: Optional[str] = None,
        device="cpu",
    ):

        # load model from ckpt and do inference
        # also quantize if requested

        ckpt = torch.load(args.model)
        model.load_state_dict(ckpt['model'])
        if args.torchquant:
            model = quantize_dynamic_torch(model)
        model.eval()

        batch_size = train_data.batch_size

        dst = MultivariateNormalDiag(loc=torch.zeros(batch_size,
                                                     config.latent_dims),
                                     scale_diag=torch.ones(
                                         batch_size, config.latent_dims))

        latent_z = dst.rsample().to(device)

        helper = model.decoder.create_helper(decoding_strategy='infer_sample',
                                             start_tokens=start_tokens,
                                             end_token=end_token)
        outputs = model.decode(helper=helper,
                               latent_z=latent_z,
                               max_decoding_length=100)

        if config.decoder_type == "transformer":
            outputs = outputs[0]

        sample_tokens = vocab.map_ids_to_tokens_py(outputs.sample_id.cpu())

        if filename is None:
            fh = sys.stdout
        else:
            fh = open(filename, 'w', encoding='utf-8')

        for sent in sample_tokens:
            sent = tx.utils.compat_as_text(list(sent))
            end_id = len(sent)
            if vocab.eos_token in sent:
                end_id = sent.index(vocab.eos_token)
            fh.write(' '.join(sent[:end_id + 1]) + '\n')

        print('Output done')
        fh.close()
示例#2
0
    def forward(
            self,  # type: ignore
            data_batch: tx.data.Batch,
            kl_weight: float,
            start_tokens: torch.LongTensor,
            end_token: int) -> Dict[str, Tensor]:
        # encoder -> connector -> decoder
        text_ids = data_batch["text_ids"]
        input_embed = self.encoder_w_embedder(text_ids)
        _, encoder_states = self.encoder(input_embed,
                                         sequence_length=data_batch["length"])

        # print(encoder_states[-1][-1].shape, self.connector_mlp)
        mean_logvar = self.connector_mlp(encoder_states)
        mean, logvar = torch.chunk(mean_logvar, 2, 1)
        kl_loss = kl_divergence(mean, logvar)
        dst = MultivariateNormalDiag(loc=mean,
                                     scale_diag=torch.exp(0.5 * logvar))

        latent_z = dst.rsample()
        helper = None
        if self._config.decoder_type == "lstm":
            helper = self.lstm_decoder.create_helper(
                decoding_strategy="train_greedy",
                start_tokens=start_tokens,
                end_token=end_token)

        # decode
        seq_lengths = data_batch["length"] - 1
        outputs = self.decode(helper=helper,
                              latent_z=latent_z,
                              text_ids=text_ids[:, :-1],
                              seq_lengths=seq_lengths)

        logits = outputs.logits

        # Losses & train ops
        rc_loss = tx.losses.sequence_sparse_softmax_cross_entropy(
            labels=text_ids[:, 1:], logits=logits, sequence_length=seq_lengths)

        nll = rc_loss + kl_weight * kl_loss

        ret = {
            "nll": nll,
            "kl_loss": kl_loss,
            "rc_loss": rc_loss,
            "lengths": seq_lengths,
        }

        return ret