Пример #1
0
                    help="model working platform.")


def get_config(config):
    config = TransformerConfig.from_json_file(config)
    config.compute_type = mstype.float32
    config.dtype = mstype.float32
    return config


if __name__ == '__main__':
    args, _ = parser.parse_known_args()
    if args.vocab.endswith("bin"):
        vocab = Dictionary.load_from_persisted_dict(args.vocab)
    else:
        vocab = Dictionary.load_from_text([args.vocab])
    _config = get_config(args.config)

    device_id = os.getenv('DEVICE_ID', None)
    if device_id is None:
        device_id = 0
    device_id = int(device_id)
    context.set_context(
        #mode=context.GRAPH_MODE,
        mode=context.PYNATIVE_MODE,
        device_target=args.platform,
        reserve_class_name_in_scope=False,
        device_id=device_id)

    if args.metric == 'rouge':
        result = infer(_config)
Пример #2
0
        raise TypeError("`--processes` must be an integer.")

    available_dict = []
    args_groups = []
    for file in os.listdir(source_folder):
        if args.prefix and not file.startswith(args.prefix):
            continue
        if file.endswith(".txt"):
            output_path = os.path.join(output_folder,
                                       file.replace(".txt", "_bpe.txt"))
            dict_path = os.path.join(output_folder,
                                     file.replace(".txt", ".dict"))
            available_dict.append(dict_path)
            args_groups.append(
                (codes, os.path.join(source_folder,
                                     file), output_path, dict_path))

    kernel_size = 1 if args.processes <= 0 else args.processes
    kernel_size = min(kernel_size, cpu_count())
    pool = Pool(kernel_size)
    for arg in args_groups:
        pool.apply_async(bpe_encode, args=arg)
    pool.close()
    pool.join()

    vocab = Dictionary.load_from_text(available_dict)
    if args.threshold is not None:
        vocab = vocab.shrink(args.threshold)
    vocab.persistence(args.vocab_path)
    print(f" | Vocabulary Size: {len(vocab)}")