Пример #1
0
def load_data(args) -> dict:
    """Load all required data into a dictionay."""
    data = {
        'src':
        Vocabulary.from_embeddings(args.src_input,
                                   top_n_words=args.vocab_limit),
        'trg':
        Vocabulary.from_embeddings(args.trg_input,
                                   top_n_words=args.vocab_limit),
        'dico':
        Dictionary.from_txt(args.train_dico, delimiter=args.dico_delimiter)
    }
    logging.info("============ Data Summary")
    logging.info(f"Source language tokens: {len(data['src'].word2id)}")
    logging.info(f"Target language tokens: {len(data['trg'].word2id)}")
    # lower_case only is a fallback should normal case not be in dictionary
    data['dico'].vocabulary_check(data['src'], data['trg'], lower_case=True)
    if args.eval_dico is not None:
        data['eval_dico'] = Dictionary.from_txt(args.eval_dico,
                                                delimiter=args.dico_delimiter)
        data['eval_dico'].vocabulary_check(data['src'],
                                           data['trg'],
                                           lower_case=True)
        logging.info(f"Evaluation pairs: {len(data['eval_dico'].pairs)}")
    return data
Пример #2
0
                    help="test dataset reference file path.")
parser.add_argument("--noise_prob", type=float, default=0., required=False,
                    help="add noise prob.")
parser.add_argument("--existed_vocab", type=str, default="", required=False,
                    help="existed vocab path.")
parser.add_argument("--max_len", type=int, default=64, required=False,
                    help="max length of sentences.")
parser.add_argument("--output_folder", type=str, default="", required=True,
                    help="dataset output path.")
parser.add_argument("--format", type=str, default="tfrecord", required=False,
                    help="dataset format.")

if __name__ == '__main__':
    args, _ = parser.parse_known_args()

    vocab = Dictionary.load_from_persisted_dict(args.existed_vocab)

    if args.train_src and args.train_ref:
        train = BiLingualDataLoader(
            src_filepath=args.train_src,
            tgt_filepath=args.train_ref,
            src_dict=vocab, tgt_dict=vocab,
            src_lang="en", tgt_lang="en",
            language_model=NoiseChannelLanguageModel(add_noise_prob=args.noise_prob),
            max_sen_len=args.max_len
        )
        if "tf" in args.format.lower():
            train.write_to_tfrecord(
                path=os.path.join(args.output_folder, "gigaword_train_dataset.tfrecord")
            )
        else:
Пример #3
0
context.set_context(mode=context.GRAPH_MODE,
                    device_target=args.device_target,
                    device_id=args.device_id)


def get_config(config_file):
    tfm_config = TransformerConfig.from_json_file(config_file)
    tfm_config.compute_type = mstype.float16
    tfm_config.dtype = mstype.float32

    return tfm_config


if __name__ == '__main__':
    vocab = Dictionary.load_from_persisted_dict(args.vocab_file)
    config = get_config(args.gigaword_infer_config)
    dec_len = config.max_decode_length

    tfm_model = TransformerInferModel(config=config,
                                      use_one_hot_embeddings=False)
    tfm_model.init_parameters_data()

    params = tfm_model.trainable_params()
    weights = load_infer_weights(config)

    for param in params:
        value = param.data
        name = param.name

        if name not in weights:
Пример #4
0
                    type=str,
                    required=True,
                    help="model working platform.")


def get_config(config):
    config = TransformerConfig.from_json_file(config)
    config.compute_type = mstype.float32
    config.dtype = mstype.float32
    return config


if __name__ == '__main__':
    args, _ = parser.parse_known_args()
    if args.vocab.endswith("bin"):
        vocab = Dictionary.load_from_persisted_dict(args.vocab)
    else:
        vocab = Dictionary.load_from_text([args.vocab])
    _config = get_config(args.config)

    device_id = os.getenv('DEVICE_ID', None)
    if device_id is None:
        device_id = 0
    device_id = int(device_id)
    context.set_context(
        #mode=context.GRAPH_MODE,
        mode=context.PYNATIVE_MODE,
        device_target=args.platform,
        reserve_class_name_in_scope=False,
        device_id=device_id)
Пример #5
0
        raise TypeError("`--processes` must be an integer.")

    available_dict = []
    args_groups = []
    for file in os.listdir(source_folder):
        if args.prefix and not file.startswith(args.prefix):
            continue
        if file.endswith(".txt"):
            output_path = os.path.join(output_folder,
                                       file.replace(".txt", "_bpe.txt"))
            dict_path = os.path.join(output_folder,
                                     file.replace(".txt", ".dict"))
            available_dict.append(dict_path)
            args_groups.append(
                (codes, os.path.join(source_folder,
                                     file), output_path, dict_path))

    kernel_size = 1 if args.processes <= 0 else args.processes
    kernel_size = min(kernel_size, cpu_count())
    pool = Pool(kernel_size)
    for arg in args_groups:
        pool.apply_async(bpe_encode, args=arg)
    pool.close()
    pool.join()

    vocab = Dictionary.load_from_text(available_dict)
    if args.threshold is not None:
        vocab = vocab.shrink(args.threshold)
    vocab.persistence(args.vocab_path)
    print(f" | Vocabulary Size: {len(vocab)}")