コード例 #1
0
def tune_model_weights():
    parser = generate.get_parser_with_args()
    parser = add_tune_args(parser)
    args = options.parse_args_and_arch(parser)
    print(args.model_weights)
    n_models = len(args.path.split(CHECKPOINT_PATHS_DELIMITER))
    print(n_models)

    weight_grid = np.linspace(
        args.weight_lower_bound, args.weight_upper_bound, args.n_grid + 1
    )
    weight_vec_aux = list(itertools.product(weight_grid, weight_grid))
    weight_vec = []
    for w1, w2 in weight_vec_aux:
        weight_sum = w1 + w2
        if weight_sum <= 1:
            w3 = 1 - weight_sum
            weight_vec.append(str(w1) + "," + str(w2) + "," + str(w3))

    print(len(weight_vec))
    output = pd.DataFrame()
    for weight in weight_vec:
        args.model_weights = weight
        print(args.model_weights)
        generate.validate_args(args)
        score = generate.generate(args)
        print(score)
        output = output.append(
            {"weight": args.model_weights, "bleu_score": score}, ignore_index=True
        )
        output.to_csv(args.output_file_name)
    return output
コード例 #2
0
def generate_main(data_dir, extra_flags=None):
    parser = generate.get_parser_with_args()
    args = options.parse_args_and_arch(
        parser,
        [
            "--source-vocab-file",
            os.path.join(data_dir, "dictionary-in.txt"),
            "--target-vocab-file",
            os.path.join(data_dir, "dictionary-out.txt"),
            "--source-text-file",
            os.path.join(data_dir, "test.in"),
            "--target-text-file",
            os.path.join(data_dir, "test.out"),
            "--path",
            os.path.join(data_dir, "checkpoint_last.pt"),
            "--beam",
            "3",
            "--length-penalty",
            "0.0",
            "--batch-size",
            "64",
            "--max-len-b",
            "5",
            "--no-progress-bar",
        ] + (extra_flags or []),
    )
    generate.validate_args(args)
    generate.generate(args)
コード例 #3
0
 def evaluation_function(parameterization):
     w1 = parameterization.get("w1")
     w2 = parameterization.get("w2")
     w3 = parameterization.get("w3")
     weight = str(w1) + "," + str(w2) + "," + str(w3)
     args.model_weights = weight
     generate.validate_args(args)
     score = generate.generate(args)
     return {"bleu_score": (score, 0.0)}