shard_size=OPTS.shard, seed=OPTS.seed) lanmt_options = basic_options.copy() lanmt_options.update( dict(encoder_layers=5, prior_layers=OPTS.priorl, q_layers=OPTS.priorl, decoder_layers=OPTS.decoderl, latent_dim=OPTS.latentdim, KL_budget=0. if OPTS.finetune else OPTS.klbudget, budget_annealing=OPTS.annealbudget, max_train_steps=training_maxsteps, fp16=OPTS.fp16)) nmt = LANMTModel2(**lanmt_options) if OPTS.scorenet: OPTS.shard = 0 lanmt_model_path = envswitch.load("lanmt_path") assert os.path.exists(lanmt_model_path) nmt.load(lanmt_model_path) if is_root_node(): print("Successfully loaded LANMT: {}".format(lanmt_model_path)) if torch.cuda.is_available(): nmt.cuda() from lib_score_matching6 import LatentScoreNetwork6 ScoreNet = LatentScoreNetwork6 # Force to use a specified network if OPTS.modelclass == "shunet5": from lib_score_matching5_shu import LatentScoreNetwork5 ScoreNet = LatentScoreNetwork5
def mean_t(tensor, mask): # Average across the Time dimension (given a binary mask) assert tensor.shape == mask.shape return (tensor * mask).sum(1) / mask.sum(1) def cosine_loss(x1, x2, mask): # Average across the Time dimension (given a binary mask) sim = F.cosine_similarity(x1, x2, dim=2) sim = mean_t(sim, mask) # [bsz] return 1 - sim if __name__ == '__main__': import sys sys.path.append(".") # Testing lanmt = LANMTModel2(src_vocab_size=1000, tgt_vocab_size=1000, prior_layers=1, decoder_layers=1) snet = LatentScoreNetwork4(lanmt) x = torch.tensor([[1, 2, 3, 4, 5]]) y = torch.tensor([[1, 2, 3]]) if torch.cuda.is_available(): lanmt.cuda() snet.cuda() x = x.cuda() y = y.cuda() snet(x, y)