コード例 #1
0
def main():

    config = prepare_config(training=False)
    ckpt_config, vocabs, model = load_checkpoint(config.ckpt_path,
                                                 use_fields=False)
    exp = prepare_experiment(ckpt_config.exp, ckpt_config.save_dir)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    dataset = gloss2transdataset("test",
                                 config.src_file,
                                 config.tgt_file,
                                 ckpt_config,
                                 vocabs=vocabs)
    dataloader = dataloader(dataset,
                            collate_fn=dataset.collate_fn,
                            batch_size=config.batch_size,
                            shuffle=False)

    model = model.to(device)
    print(model)

    translator = prepare_translator(exp,
                                    config,
                                    ckpt_config,
                                    model,
                                    device,
                                    fields=None,
                                    vocabs=vocabs)

    translator.translate_using_dataloader(dataset, dataloader)
コード例 #2
0
ファイル: translate.py プロジェクト: ikarosgit/NSLT
def main():
    # Prepare Experiment
    config = prepare_config(training=False)
    ckpt_config, fields, model = load_checkpoint(config.ckpt_path) 
    exp = prepare_experiment(ckpt_config.exp, ckpt_config.save_dir)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = model.to(device)
    print(model)
    
    translator = prepare_translator(exp,
                                    config,
                                    ckpt_config,
                                    model,
                                    device,
                                    fields=fields,
                                    vocabs=None)

    src_data = load_text(config.src_file)
    tgt_data = load_text(config.tgt_file)

    translator.translate(src_data=src_data,
                         tgt_data=tgt_data,
                         batch_size=config.batch_size)
コード例 #3
0
def main():

    #save_dir = "/media/ikaros/HDPH-UT-1TB/nslt/save"
    #save_dir = "/home/ikaros/hdd/nslt/save"
    save_dir = "save"
    mode = "valid" # train | valid | test
    exp = "sign-opt-first"
    data_dir = "sign-data/inputs"
    min_epoch = 0

    # Prepare Experiment
    config = prepare_config(training=False)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    config.update("src_file", os.path.join(data_dir, f"src-{mode}.txt"))
    config.update("src_path_file", os.path.join(data_dir, f"src-path-{mode}.txt"))
    config.update("tgt_file", os.path.join(data_dir, f"tgt-{mode}.txt"))

    out_dir = os.path.join(save_dir, exp, "results")
    path = os.path.join(save_dir, exp, "checkpoints")
    
    print(f"Loading from '{path}'")
    ckpt_path_list = glob.glob(os.path.join(path, "*.pt"))
    sort_key = lambda x: int(os.path.basename(x).split(".")[0].split("_")[-1])
    ckpt_path_list = sorted(ckpt_path_list, key=sort_key)
    epoch_list = [int(os.path.basename(path).split(".")[0].split("_")[-1])
                  for path in ckpt_path_list]

    out_file_list = []
    for ckpt_path in ckpt_path_list:
        epoch = os.path.basename(ckpt_path).split(".")[0].split("_")[-1]
        print(f"# Translation using checkpoint '{ckpt_path}'.")

        # Update for all evaluation
        config.update("ckpt_path", ckpt_path)
        out_file = mode+"_pred_"+epoch+".txt"
        out_file_list.append(out_file)
        config.update("out_file", out_file)

        ckpt_config, vocabs, model, _ = load_checkpoint(config.ckpt_path, device, use_fields=False) 

        if ckpt_config.data_type == "text":
            dataset = Gloss2TransDataset("test",
                                         config.src_file,
                                         config.tgt_file,
                                         ckpt_config,
                                         vocabs=vocabs)
        elif ckpt_config.data_type == "video":
            transform = prepare_transform(ckpt_config.image_resize)
            dataset = Sign2TransDataset("test",
                                        config.src_file,
                                        config.src_path_file,
                                        config.tgt_file,
                                        ckpt_config,
                                        vocabs=vocabs,
                                        transform=transform)

        dataloader = DataLoader(dataset,
                                collate_fn=dataset.collate_fn,
                                batch_size=config.batch_size,
                                shuffle=False)
        

        exp = prepare_experiment(ckpt_config.exp, ckpt_config.save_dir)

        model = model.to(device)

        translator = prepare_translator(exp,
                                        config,
                                        ckpt_config,
                                        model,
                                        device,
                                        fields=None,
                                        vocabs=vocabs)

        translator.translate_using_dataloader(dataset, dataloader)


    trans_file = os.path.join(data_dir, f"tgt-{mode}.txt")
    trans_list = load_trans_txt(trans_file)
    score_list = [[] for _ in range(4)]
    for out_file in out_file_list:
        out_file = os.path.join(out_dir, out_file)
        ref_list = load_trans_txt(out_file)
        
        for ngram in range(1, 5):
            weights = [1/ngram for _ in range(ngram)]
            score = compute_nltk_bleu_score(ref_list, trans_list, weights=weights)
            score_list[ngram-1].append(100*float(score))
  
    c = ["b", "g", "r", "c", "m", "y", "k"]
    best = []
    for i in range(4):
        plt.plot(epoch_list, score_list[i], linestyle="-", marker="o", color=c[i%len(c)], label=f"BLEU{i+1}")
        best.append(max(score_list[i]))
    print(f"Best BLEU-1 = {best[0]}, BLEU-2 = {best[1]}, BLEU-3 = {best[2]}, BLEU-4 = {best[3]}")

    plt.xlabel("Epoch")
    plt.ylabel("BLEU")
    plt.ylim([0, 100])
    plt.legend()
    plt.show()
コード例 #4
0
ファイル: all_evaluation.py プロジェクト: ikarosgit/NSLT
def main():

    save_dir = "/media/ikaros/HDPH-UT-1TB/nslt/save"
    mode = "valid"  # train | valid | test
    exp = "adam"

    # Prepare Experiment
    config = prepare_config(training=False)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    config.update("src_file", f"sign-data/inputs/src-{mode}.txt")
    config.update("tgt_file", f"sign-data/inputs/tgt-{mode}.txt")
    src_data = load_text(config.src_file)
    tgt_data = load_text(config.tgt_file)

    out_dir = os.path.join(save_dir, exp, "results")
    path = os.path.join(save_dir, exp, "checkpoints")

    ckpt_path_list = glob.glob(os.path.join(path, "*.pt"))
    sort_key = lambda x: int(os.path.basename(x).split(".")[0].split("_")[-1])
    ckpt_path_list = sorted(ckpt_path_list, key=sort_key)
    epoch_list = [
        int(os.path.basename(path).split(".")[0].split("_")[-1])
        for path in ckpt_path_list
    ]

    out_file_list = []
    for ckpt_path in ckpt_path_list:
        epoch = os.path.basename(ckpt_path).split(".")[0].split("_")[-1]
        print(f"# Translation using checkpoint '{ckpt_path}'.")

        # Update for all evaluation
        config.update("ckpt_path", ckpt_path)
        out_file = mode + "_pred_" + epoch + "txt"
        out_file_list.append(out_file)
        config.update("out_file", out_file)

        ckpt_config, fields, model = load_checkpoint(config.ckpt_path)
        exp = prepare_experiment(ckpt_config.exp, ckpt_config.save_dir)

        model = model.to(device)

        translator = prepare_translator(exp, config, ckpt_config, fields,
                                        model, device)

        translator.translate(src_data=src_data,
                             tgt_data=tgt_data,
                             batch_size=config.batch_size)

    trans_file = f"sign-data/annotations/tgt-{mode}.txt"
    trans_list = load_trans_txt(trans_file)
    score_list = [[] for _ in range(4)]
    for out_file in out_file_list:
        out_file = os.path.join(out_dir, out_file)
        ref_list = load_trans_txt(out_file)

        for ngram in range(1, 5):
            weights = [1 / ngram for _ in range(ngram)]
            score = compute_nltk_bleu_score(ref_list,
                                            trans_list,
                                            weights=weights)
            score_list[ngram - 1].append(100 * float(score))

    c = ["b", "g", "r", "c", "m", "y", "k"]
    best = []
    for i in range(4):
        plt.plot(epoch_list,
                 score_list[i],
                 linestyle="-",
                 marker="o",
                 color=c[i % len(c)],
                 label=f"BLEU{i}")
        best.append(max(score_list[i]))
    print(
        f"Best BLEU-1 = {best[0]}, BLEU-2 = {best[1]}, BLEU-3 = {best[2]}, BLEU-4 = {best[3]}"
    )

    plt.xlabel("Epoch")
    plt.ylabel("BLEU")
    plt.ylim([0, 100])
    plt.legend()
    plt.show()
コード例 #5
0
ファイル: eval_pure.py プロジェクト: ikarosgit/NSLT
def main():

    ############ Need to change #########
    mode = "valid"  # train | valid | test | eval(debug)
    exp = "sign-opt-init"
    step = 5000
    #####################################

    data_dir = "sign-data/inputs"
    #save_dir = "/media/ikaros/HDPH-UT-1TB/nslt/save"
    #save_dir = "/home/ikaros/hdd/nslt/save"
    save_dir = "save"

    # Prepare Experiment
    config = prepare_config(training=False)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    config.update("src_file", os.path.join(data_dir, f"src-{mode}.txt"))
    config.update("src_path_file",
                  os.path.join(data_dir, f"src-path-{mode}.txt"))
    config.update("src_opt_file", os.path.join(data_dir,
                                               f"src-opt-{mode}.txt"))
    config.update("tgt_file", os.path.join(data_dir, f"tgt-{mode}.txt"))

    out_dir = os.path.join(save_dir, exp, "results")
    path = os.path.join(save_dir, exp, "checkpoints")
    #ckpt_path = os.path.join(f"ckpt_{step:08d}.pt")
    ckpt_path = os.path.join(path, f"ckpt_{step:08d}.pt")

    print(f"# Translation using checkpoint '{ckpt_path}'.")

    # Update for all evaluation
    config.update("ckpt_path", ckpt_path)
    out_file = mode + "_pred_" + str(step) + "txt"
    config.update("out_file", out_file)

    ckpt_config, vocabs, model, _ = load_checkpoint(config.ckpt_path,
                                                    device,
                                                    use_fields=False)
    model.eval()

    step = 70000
    ckpt_path = os.path.join(path, f"ckpt_{step:08d}.pt")
    config.update("ckpt_path", ckpt_path)
    ckpt_config, vocabs, model2, _ = load_checkpoint(config.ckpt_path,
                                                     device,
                                                     use_fields=False)

    for (name, p), (name, p2) in zip(model.named_parameters(),
                                     model2.named_parameters()):
        #if "encoder.embedding.cnn" in name:
        #print(name.split(".")[2:], p.mean().cpu().detach().numpy(), p.std().cpu().detach().numpy())
        #print(name.split(".")[2:], p2.mean().cpu().detach().numpy(), p2.std().cpu().detach().numpy())
        print(name, (torch.abs(p - p2).sum() /
                     p.view(-1).size(0)).cpu().detach().numpy())

    exit(0)

    print(ckpt_config.data_type)
    if ckpt_config.data_type == "text":
        dataset = Gloss2TransDataset("test",
                                     config.src_file,
                                     config.tgt_file,
                                     ckpt_config,
                                     vocabs=vocabs)

    elif ckpt_config.data_type == "video":
        transform = prepare_transform(ckpt_config.image_resize)
        dataset = Sign2TransDataset("test",
                                    ckpt_config.data_type,
                                    config.src_file,
                                    config.src_path_file,
                                    config.tgt_file,
                                    ckpt_config,
                                    vocabs=vocabs,
                                    transform=transform)

    elif ckpt_config.data_type == "opticalflow":
        transform = prepare_opticalflow_transform(ckpt_config.image_resize)
        dataset = Sign2TransDataset("test",
                                    ckpt_config.data_type,
                                    config.src_file,
                                    config.src_opt_file,
                                    config.tgt_file,
                                    ckpt_config,
                                    vocabs=vocabs,
                                    transform=transform)

    dataloader = DataLoader(dataset,
                            collate_fn=dataset.collate_fn,
                            batch_size=config.batch_size,
                            shuffle=False)

    exp = prepare_experiment(ckpt_config.exp, ckpt_config.save_dir)

    model = model.to(device)

    translator = prepare_translator(exp,
                                    config,
                                    ckpt_config,
                                    model,
                                    device,
                                    fields=None,
                                    vocabs=vocabs)

    translator.translate_using_dataloader(dataset, dataloader)

    trans_file = os.path.join(data_dir, f"tgt-{mode}.txt")
    trans_list = load_trans_txt(trans_file)
    out_file = os.path.join(out_dir, out_file)
    ref_list = load_trans_txt(out_file)

    score_list = []
    for ngram in range(1, 5):
        weights = [1 / ngram for _ in range(ngram)]
        score = compute_nltk_bleu_score(ref_list, trans_list, weights=weights)
        score_list.append(round(100 * float(score), 4))

    print(
        f"Best BLEU-1 = {score_list[0]}, BLEU-2 = {score_list[1]}, BLEU-3 = {score_list[2]}, BLEU-4 = {score_list[3]}"
    )