def create_model(hparams, vocab_src, vocab_tgt): # Generative components src_embedder = torch.nn.Embedding(vocab_src.size(), hparams.emb.size, padding_idx=vocab_src[PAD_TOKEN]) tgt_embedder = torch.nn.Embedding(vocab_tgt.size(), hparams.emb.size, padding_idx=vocab_tgt[PAD_TOKEN]) language_model = create_language_model(vocab_src, src_embedder, hparams) # Auxiliary generative components aux_lms = create_aux_language_models(vocab_src, src_embedder, hparams) aux_tms = create_aux_translation_models(src_embedder, tgt_embedder, hparams) if aux_lms or aux_tms: raise NotImplementedError("Aux losses are not yet supported with the new loss functions. See Issue #17.") translation_model = create_translation_model(vocab_tgt, src_embedder, tgt_embedder, hparams) priors = [] n_priors = len(hparams.prior.family.split(";")) if hparams.prior.latent_sizes: latent_sizes = [int(size) for size in re.split('[ ;:,]+', hparams.prior.latent_sizes.strip())] if len(latent_sizes) != n_priors: raise ValueError("You need to specify a latent_size for each prior using --latent_sizes 'list'") if sum(latent_sizes) != hparams.prior.latent_size: raise ValueError("The sum of latent_sizes must equal latent_size") else: if hparams.prior.latent_size % n_priors != 0: raise ValueError("Use a latent size multiple of the number of priors") latent_sizes = [hparams.prior.latent_size // n_priors] * n_priors for prior_family, prior_params, latent_size in zip(hparams.prior.family.split(";"), hparams.prior.params.split(";"), latent_sizes): prior_params = [float(param) for param in prior_params.split()] priors.append(create_prior(prior_family, latent_size, prior_params)) inf_model = create_inference_model( DetachedEmbeddingLayer(src_embedder) if hparams.emb.shared else torch.nn.Embedding( src_embedder.num_embeddings, src_embedder.embedding_dim, padding_idx=src_embedder.padding_idx), DetachedEmbeddingLayer(tgt_embedder) if hparams.emb.shared else torch.nn.Embedding( tgt_embedder.num_embeddings, tgt_embedder.embedding_dim, padding_idx=tgt_embedder.padding_idx), latent_sizes, hparams) constraints = create_constraints(hparams) model = AEVNMT( latent_size=hparams.prior.latent_size, src_embedder=src_embedder, tgt_embedder=tgt_embedder, language_model=language_model, translation_model=translation_model, inference_model=inf_model, dropout=hparams.dropout, feed_z=None, tied_embeddings=None, prior=priors[0] if len(priors) == 1 else ProductOfPriorsLayer(priors), constraints=constraints, aux_lms=aux_lms, aux_tms=aux_tms, mixture_likelihood=hparams.likelihood.mixture, mixture_likelihood_dir_prior=hparams.likelihood.mixture_dir_prior) return model
def create_model(hparams, vocab_src, vocab_tgt): # Generative components src_embedder = torch.nn.Embedding(vocab_src.size(), hparams.emb_size, padding_idx=vocab_src[PAD_TOKEN]) tgt_embedder = torch.nn.Embedding(vocab_tgt.size(), hparams.emb_size, padding_idx=vocab_tgt[PAD_TOKEN]) language_model = CorrelatedCategoricalsLM( embedder=src_embedder, sos_idx=vocab_src[SOS_TOKEN], eos_idx=vocab_src[EOS_TOKEN], latent_size=hparams.latent_size, hidden_size=hparams.hidden_size, dropout=hparams.dropout, num_layers=hparams.num_dec_layers, cell_type=hparams.cell_type, tied_embeddings=hparams.tied_embeddings, feed_z=hparams.feed_z, gate_z=False # TODO implement ) # Auxiliary generative components aux_lms = create_aux_language_models(vocab_src, src_embedder, hparams) aux_tms = create_aux_translation_models(src_embedder, tgt_embedder, hparams) encoder = create_encoder(hparams) attention = create_attention(hparams) decoder = create_decoder(attention, hparams) translation_model = AttentionBasedTM( src_embedder=src_embedder, tgt_embedder=tgt_embedder, tgt_sos_idx=vocab_tgt[SOS_TOKEN], tgt_eos_idx=vocab_tgt[EOS_TOKEN], encoder=encoder, decoder=decoder, latent_size=hparams.latent_size, dropout=hparams.dropout, feed_z=hparams.feed_z, tied_embeddings=hparams.tied_embeddings ) priors = [] n_priors = len(hparams.prior.split(";")) if hparams.latent_sizes: latent_sizes = [int(size) for size in re.split('[ ;:,]+', hparams.latent_sizes.strip())] if len(latent_sizes) != n_priors: raise ValueError("You need to specify a latent_size for each prior using --latent_sizes 'list'") if sum(latent_sizes) != hparams.latent_size: raise ValueError("The sum of latent_sizes must equal latent_size") else: if hparams.latent_size % n_priors != 0: raise ValueError("Use a latent size multiple of the number of priors") latent_sizes = [hparams.latent_size // n_priors] * n_priors for prior_family, prior_params, latent_size in zip(hparams.prior.split(";"), hparams.prior_params.split(";"), latent_sizes): prior_params = [float(param) for param in prior_params.split()] priors.append(create_prior(prior_family, latent_size, prior_params)) inf_model = create_inference_model( DetachedEmbeddingLayer(src_embedder) if hparams.inf_share_embeddings else torch.nn.Embedding( src_embedder.num_embeddings, src_embedder.embedding_dim, padding_idx=src_embedder.padding_idx), DetachedEmbeddingLayer(tgt_embedder) if hparams.inf_share_embeddings else torch.nn.Embedding( tgt_embedder.num_embeddings, tgt_embedder.embedding_dim, padding_idx=tgt_embedder.padding_idx), latent_sizes, hparams) model = AEVNMT( latent_size=hparams.latent_size, src_embedder=src_embedder, tgt_embedder=tgt_embedder, language_model=language_model, translation_model=translation_model, inference_model=inf_model, dropout=hparams.dropout, feed_z=hparams.feed_z, tied_embeddings=hparams.tied_embeddings, prior=priors[0] if len(priors) == 1 else ProductOfPriorsLayer(priors), mdr=hparams.minimum_desired_rate, aux_lms=aux_lms, aux_tms=aux_tms, mixture_likelihood=hparams.mixture_likelihood, mixture_likelihood_dir_prior=hparams.mixture_likelihood_dir_prior) return model