def _generate( model: nn.Module, start_tokens: torch.LongTensor, end_token: int, filename: Optional[str] = None, device="cpu", ): # load model from ckpt and do inference # also quantize if requested ckpt = torch.load(args.model) model.load_state_dict(ckpt['model']) if args.torchquant: model = quantize_dynamic_torch(model) model.eval() batch_size = train_data.batch_size dst = MultivariateNormalDiag(loc=torch.zeros(batch_size, config.latent_dims), scale_diag=torch.ones( batch_size, config.latent_dims)) latent_z = dst.rsample().to(device) helper = model.decoder.create_helper(decoding_strategy='infer_sample', start_tokens=start_tokens, end_token=end_token) outputs = model.decode(helper=helper, latent_z=latent_z, max_decoding_length=100) if config.decoder_type == "transformer": outputs = outputs[0] sample_tokens = vocab.map_ids_to_tokens_py(outputs.sample_id.cpu()) if filename is None: fh = sys.stdout else: fh = open(filename, 'w', encoding='utf-8') for sent in sample_tokens: sent = tx.utils.compat_as_text(list(sent)) end_id = len(sent) if vocab.eos_token in sent: end_id = sent.index(vocab.eos_token) fh.write(' '.join(sent[:end_id + 1]) + '\n') print('Output done') fh.close()
def forward( self, # type: ignore data_batch: tx.data.Batch, kl_weight: float, start_tokens: torch.LongTensor, end_token: int) -> Dict[str, Tensor]: # encoder -> connector -> decoder text_ids = data_batch["text_ids"] input_embed = self.encoder_w_embedder(text_ids) _, encoder_states = self.encoder(input_embed, sequence_length=data_batch["length"]) # print(encoder_states[-1][-1].shape, self.connector_mlp) mean_logvar = self.connector_mlp(encoder_states) mean, logvar = torch.chunk(mean_logvar, 2, 1) kl_loss = kl_divergence(mean, logvar) dst = MultivariateNormalDiag(loc=mean, scale_diag=torch.exp(0.5 * logvar)) latent_z = dst.rsample() helper = None if self._config.decoder_type == "lstm": helper = self.lstm_decoder.create_helper( decoding_strategy="train_greedy", start_tokens=start_tokens, end_token=end_token) # decode seq_lengths = data_batch["length"] - 1 outputs = self.decode(helper=helper, latent_z=latent_z, text_ids=text_ids[:, :-1], seq_lengths=seq_lengths) logits = outputs.logits # Losses & train ops rc_loss = tx.losses.sequence_sparse_softmax_cross_entropy( labels=text_ids[:, 1:], logits=logits, sequence_length=seq_lengths) nll = rc_loss + kl_weight * kl_loss ret = { "nll": nll, "kl_loss": kl_loss, "rc_loss": rc_loss, "lengths": seq_lengths, } return ret