Beispiel #1
0
def get_bert_dataset(model,
                     args,
                     inputs,
                     embedding_dict=None,
                     positional_dict=None,
                     merge_both_embeddings=False):
    config = model.config
    shapeOf = model.builder.getTensorShape
    # The inputs after the first three (ind, pos, seg) are always lists
    inputs = reduce(chain, inputs[3:], inputs[:3])
    tensor_shapes = [(tensorId, shapeOf(tensorId)) for tensorId in inputs]

    if config.task == "PRETRAINING":
        return get_pretraining_dataset(
            tensor_shapes,
            input_files=args.input_files,
            sequence_length=config.sequence_length,
            mask_tokens=config.mask_tokens,
            vocab_length=config.vocab_length,
            batch_size=config.batch_size,
            batches_per_step=args.batches_per_step,
            accumulation_factor=args.gradient_accumulation_factor,
            replication_factor=args.replication_factor,
            duplication_factor=args.duplication_factor,
            shuffle=args.shuffle,
            synthetic=args.synthetic_data,
            epochs_to_cache=args.epochs_to_cache,
            start_data_at_epoch=args.continue_training_from_epoch)

    if config.task == "SQUAD":
        ds = get_squad_dataset(
            tensor_shapes,
            input_file=args.input_files[0],
            output_dir=args.squad_results_dir,
            sequence_length=config.sequence_length,
            vocab_file=args.vocab_file,
            vocab_length=config.vocab_length,
            batch_size=config.batch_size,
            batches_per_step=args.batches_per_step,
            embedding_dict=embedding_dict,
            positional_dict=positional_dict,
            merge_both_embeddings=merge_both_embeddings,
            accumulation_factor=args.gradient_accumulation_factor,
            replication_factor=args.replication_factor,
            shuffle=args.shuffle,
            is_training=not args.inference,
            overwrite_cache=args.overwrite_cache,
            no_drop_remainder=args.no_drop_remainder,
            evaluate_script=args.squad_evaluate_script,
            synthetic=args.synthetic_data,
            do_lower_case=args.do_lower_case,
            max_pipeline_stage=model.total_pipeline_stages
            if args.execution_mode == "PIPELINE" else 1,
            seed=args.seed,
            mpi_size=args.mpi_size,
            mpi_rank=args.mpi_rank,
            is_distributed=args.mpi_size > 1)

        return ds
Beispiel #2
0
def get_bert_dataset(model, args, inputs):
    shapeOf = model.builder.getTensorShape
    # The inputs after the first three (ind, pos, seg) are always lists
    inputs = reduce(chain, inputs[3:], inputs[:3])
    tensor_shapes = [(tensorId, shapeOf(tensorId)) for tensorId in inputs]

    if args.task == "PRETRAINING":
        ds = get_pretraining_dataset(args, tensor_shapes)
    elif args.task == "SQUAD":
        ds = get_squad_dataset(args,
                               tensor_shapes,
                               host_embeddings=model.get_model_embeddings())
    else:
        raise RuntimeError(f"Unsupported Task {args.task} in get_bert_dataset")

    return ds
Beispiel #3
0
def get_bert_dataset(model, args, inputs):
    config = model.config
    shapeOf = model.builder.getTensorShape
    # The inputs after the first three (ind, pos, seg) are always lists
    inputs = reduce(chain, inputs[3:], inputs[:3])
    tensor_shapes = [(tensorId, shapeOf(tensorId)) for tensorId in inputs]
    if config.task == "PRETRAINING":
        return get_pretraining_dataset(
            tensor_shapes,
            input_files=args.input_files,
            sequence_length=config.sequence_length,
            mask_tokens=config.mask_tokens,
            vocab_length=config.vocab_length,
            batch_size=config.batch_size,
            batches_per_step=args.batches_per_step,
            accumulation_factor=args.gradient_accumulation_factor,
            replication_factor=args.replication_factor,
            duplication_factor=args.duplication_factor,
            shuffle=args.shuffle,
            synthetic=args.synthetic_data,
            epochs_to_cache=args.epochs_to_cache)
    if config.task == "SQUAD":
        return get_squad_dataset(
            tensor_shapes,
            input_file=args.input_files[0],
            output_dir=args.squad_results_dir,
            sequence_length=config.sequence_length,
            vocab_file=args.vocab_file,
            vocab_length=config.vocab_length,
            batch_size=config.batch_size,
            batches_per_step=args.batches_per_step,
            accumulation_factor=args.gradient_accumulation_factor,
            replication_factor=args.replication_factor,
            shuffle=args.shuffle,
            is_training=not args.inference,
            overwrite_cache=args.overwrite_cache,
            no_drop_remainder=args.no_drop_remainder,
            evaluate_script=args.squad_evaluate_script,
            synthetic=args.synthetic_data,
            do_lower_case=args.do_lower_case)