示例#1
0
                     shard_size=OPTS.shard,
                     seed=OPTS.seed)

lanmt_options = basic_options.copy()
lanmt_options.update(
    dict(encoder_layers=5,
         prior_layers=OPTS.priorl,
         q_layers=OPTS.priorl,
         decoder_layers=OPTS.decoderl,
         latent_dim=OPTS.latentdim,
         KL_budget=0. if OPTS.finetune else OPTS.klbudget,
         budget_annealing=OPTS.annealbudget,
         max_train_steps=training_maxsteps,
         fp16=OPTS.fp16))

nmt = LANMTModel2(**lanmt_options)
if OPTS.scorenet:
    OPTS.shard = 0
    lanmt_model_path = envswitch.load("lanmt_path")
    assert os.path.exists(lanmt_model_path)
    nmt.load(lanmt_model_path)
    if is_root_node():
        print("Successfully loaded LANMT: {}".format(lanmt_model_path))
    if torch.cuda.is_available():
        nmt.cuda()
    from lib_score_matching6 import LatentScoreNetwork6
    ScoreNet = LatentScoreNetwork6
    # Force to use a specified network
    if OPTS.modelclass == "shunet5":
        from lib_score_matching5_shu import LatentScoreNetwork5
        ScoreNet = LatentScoreNetwork5
def mean_t(tensor, mask):
    # Average across the Time dimension (given a binary mask)
    assert tensor.shape == mask.shape
    return (tensor * mask).sum(1) / mask.sum(1)


def cosine_loss(x1, x2, mask):
    # Average across the Time dimension (given a binary mask)
    sim = F.cosine_similarity(x1, x2, dim=2)
    sim = mean_t(sim, mask)  # [bsz]
    return 1 - sim


if __name__ == '__main__':
    import sys
    sys.path.append(".")
    # Testing
    lanmt = LANMTModel2(src_vocab_size=1000,
                        tgt_vocab_size=1000,
                        prior_layers=1,
                        decoder_layers=1)
    snet = LatentScoreNetwork4(lanmt)
    x = torch.tensor([[1, 2, 3, 4, 5]])
    y = torch.tensor([[1, 2, 3]])
    if torch.cuda.is_available():
        lanmt.cuda()
        snet.cuda()
        x = x.cuda()
        y = y.cuda()
    snet(x, y)