コード例 #1
0
ファイル: nmtmodel.py プロジェクト: zkaden/InTrainingMF4NMT
    def load(model_path: str):

        dict_path = model_path + ".dict.pt"
        model = NMTModel()
        print("Loading whole model")
        load_partial_state_dict(model, torch.load(dict_path))

        return model
コード例 #2
0
    def load(model_path: str):
        dict_path = model_path+".dict.pt"
        model = PrunedModel()
        print("Loading whole model")
        load_partial_state_dict(model, torch.load(dict_path))
        if ".pruned" in model_path:
            model.currently_pruned = True
        else:
            model.currently_pruned = False

        return model
コード例 #3
0
    def load(model_path: str):

        enc_path = model_path+".enc.pt"
        dec_path = model_path+".dec.pt"
        model = NMTModel()
        print("Loading encoder")
        load_partial_state_dict(model.encoder, torch.load(enc_path))
        print("Loading decoder")
        load_partial_state_dict(model.decoder, torch.load(dec_path))

        return model
コード例 #4
0
    def load(model_path: str):
        dict_path = model_path+".dict.pt"
        model = MixedPrecisionModel()
        print("Loading whole model")
        if ".quantized" in model_path:
            model.quantize()
        else:
            model.unquantize()
        load_partial_state_dict(model, torch.load(dict_path))
        if ".pruned" in model_path:
            model.currently_pruned = True
        else:
            model.currently_pruned = False

        return model
コード例 #5
0
    def load(model_path: str):
        dict_path = model_path+".dict.pt"
        print("Loading whole model")
        if ".postfactorized" in model_path:
            print("loading factorized model")
            tconfig.embedding_factorization = True
            tconfig.ffward_factorization = True
            tconfig.inner_factorization = True
            tconfig.embedding_rank = 256
            tconfig.ffward_rank = 256
            tconfig.inner_rank = 256
            model = PostFactorizedModel(embedding_rank=256, ffward_rank=256, inner_rank=256)
            model.currently_factorized = ["embeddings", "ffward", "attention"]
        else:
            print("loading standard model")
            model = PostFactorizedModel()
        load_partial_state_dict(model, torch.load(dict_path))

        return model
コード例 #6
0
    def load(model_path: str):

        model = MultiWayModel()
        print("Loading decoder")
        dec_path = model_path + ".dec.pt"
        load_partial_state_dict(model.decoder, torch.load(dec_path))
        print("Loading encoders")
        for key in model.keys:
            enc_path = model_path + "." + key + ".enc.pt"
            load_partial_state_dict(model.encoder[key], torch.load(enc_path))
        if model.use_discriminator:
            print("Loading discriminator")
            try:
                disc_path = model_path + ".disc.pt"
                load_partial_state_dict(model.discriminator,
                                        torch.load(disc_path))
            except:
                print("Failed")

        return model