Esempio n. 1
0
def get_bert_dataset(model,
                     args,
                     inputs,
                     embedding_dict=None,
                     positional_dict=None,
                     merge_both_embeddings=False):
    config = model.config
    shapeOf = model.builder.getTensorShape
    # The inputs after the first three (ind, pos, seg) are always lists
    inputs = reduce(chain, inputs[3:], inputs[:3])
    tensor_shapes = [(tensorId, shapeOf(tensorId)) for tensorId in inputs]

    if config.task == "PRETRAINING":
        return get_pretraining_dataset(
            tensor_shapes,
            input_files=args.input_files,
            sequence_length=config.sequence_length,
            mask_tokens=config.mask_tokens,
            vocab_length=config.vocab_length,
            batch_size=config.batch_size,
            batches_per_step=args.batches_per_step,
            accumulation_factor=args.gradient_accumulation_factor,
            replication_factor=args.replication_factor,
            duplication_factor=args.duplication_factor,
            shuffle=args.shuffle,
            synthetic=args.synthetic_data,
            epochs_to_cache=args.epochs_to_cache,
            start_data_at_epoch=args.continue_training_from_epoch)

    if config.task == "SQUAD":
        ds = get_squad_dataset(
            tensor_shapes,
            input_file=args.input_files[0],
            output_dir=args.squad_results_dir,
            sequence_length=config.sequence_length,
            vocab_file=args.vocab_file,
            vocab_length=config.vocab_length,
            batch_size=config.batch_size,
            batches_per_step=args.batches_per_step,
            embedding_dict=embedding_dict,
            positional_dict=positional_dict,
            merge_both_embeddings=merge_both_embeddings,
            accumulation_factor=args.gradient_accumulation_factor,
            replication_factor=args.replication_factor,
            shuffle=args.shuffle,
            is_training=not args.inference,
            overwrite_cache=args.overwrite_cache,
            no_drop_remainder=args.no_drop_remainder,
            evaluate_script=args.squad_evaluate_script,
            synthetic=args.synthetic_data,
            do_lower_case=args.do_lower_case,
            max_pipeline_stage=model.total_pipeline_stages
            if args.execution_mode == "PIPELINE" else 1,
            seed=args.seed,
            mpi_size=args.mpi_size,
            mpi_rank=args.mpi_rank,
            is_distributed=args.mpi_size > 1)

        return ds
Esempio n. 2
0
def get_bert_dataset(model, args, inputs):
    shapeOf = model.builder.getTensorShape
    # The inputs after the first three (ind, pos, seg) are always lists
    inputs = reduce(chain, inputs[3:], inputs[:3])
    tensor_shapes = [(tensorId, shapeOf(tensorId)) for tensorId in inputs]

    if args.task == "PRETRAINING":
        ds = get_pretraining_dataset(args, tensor_shapes)
    elif args.task == "SQUAD":
        ds = get_squad_dataset(args,
                               tensor_shapes,
                               host_embeddings=model.get_model_embeddings())
    else:
        raise RuntimeError(f"Unsupported Task {args.task} in get_bert_dataset")

    return ds
Esempio n. 3
0
def get_bert_dataset(model, args, inputs):
    config = model.config
    shapeOf = model.builder.getTensorShape
    # The inputs after the first three (ind, pos, seg) are always lists
    inputs = reduce(chain, inputs[3:], inputs[:3])
    tensor_shapes = [(tensorId, shapeOf(tensorId)) for tensorId in inputs]
    if config.task == "PRETRAINING":
        return get_pretraining_dataset(
            tensor_shapes,
            input_files=args.input_files,
            sequence_length=config.sequence_length,
            mask_tokens=config.mask_tokens,
            vocab_length=config.vocab_length,
            batch_size=config.batch_size,
            batches_per_step=args.batches_per_step,
            accumulation_factor=args.gradient_accumulation_factor,
            replication_factor=args.replication_factor,
            duplication_factor=args.duplication_factor,
            shuffle=args.shuffle,
            synthetic=args.synthetic_data,
            epochs_to_cache=args.epochs_to_cache)
    if config.task == "SQUAD":
        return get_squad_dataset(
            tensor_shapes,
            input_file=args.input_files[0],
            output_dir=args.squad_results_dir,
            sequence_length=config.sequence_length,
            vocab_file=args.vocab_file,
            vocab_length=config.vocab_length,
            batch_size=config.batch_size,
            batches_per_step=args.batches_per_step,
            accumulation_factor=args.gradient_accumulation_factor,
            replication_factor=args.replication_factor,
            shuffle=args.shuffle,
            is_training=not args.inference,
            overwrite_cache=args.overwrite_cache,
            no_drop_remainder=args.no_drop_remainder,
            evaluate_script=args.squad_evaluate_script,
            synthetic=args.synthetic_data,
            do_lower_case=args.do_lower_case)
Esempio n. 4
0
def main(args):
    set_library_seeds(args.seed)

    config = bert_config_from_args(args)

    initializers = bert_pretrained_initialisers(config, args)

    logger.info("Building Model")
    model = Bert(config, pipeline=args.pipeline, initializers=initializers)

    if not config.use_packed_sequence_format:
        # If config.host_embedding is enabled, indices and positions will have the matrices instead of the index vector.
        indices, positions, segments, masks, labels = bert_add_inputs(
            args, model)
        logits = model.build_graph(indices, positions, segments, masks)
        outputs, accuracies, losses, final_loss, writer = bert_add_outputs(
            args, model, logits, labels)
        dataset = get_bert_dataset(
            model, args, [indices, positions, segments, masks, labels])

    else:  # use_packed_sequence_format
        if args.task != "PRETRAINING":
            raise RuntimeError(
                "Packed sequence format currently only supported for pretraining."
            )
        input_tensor_shapes = packed_bert_utils.add_inputs(model)
        logits = packed_bert_utils.logits_graph(model)
        losses, accuracies, final_loss, outputs = packed_bert_utils.pretraining_loss_and_accuracy(
            model, logits)
        writer = bert_writer(args) if not args.inference else None
        dataset = get_pretraining_dataset(args, input_tensor_shapes)

    device = acquire_device(args, bert_required_ipus(args, model))

    logger.info(f"Dataset length: {len(dataset)}")

    data_flow = popart.DataFlow(args.batches_per_step, outputs)

    iteration = bert_iteration(args, dataset, writer)

    if args.inference:
        session, anchors = bert_inference_session(model, args, data_flow,
                                                  device)
        logger.info("Inference Started")
        inputs = [indices, positions, segments, *masks, *labels]
        bert_infer_loop(args, session, dataset, inputs, logits, anchors,
                        accuracies, losses, iteration)
        device.detach()
    else:
        if not args.no_training:
            optimizer_factory = bert_optimizer_factory(args, model, iteration)
            if args.save_initializers_externally:
                save_dir = Path(args.checkpoint_dir,
                                f'model_{args.continue_training_from_epoch}')
                save_dir.mkdir(parents=True, exist_ok=True)
                weight_tensors = [
                    item for sublist in model.tensors.values()
                    for item in sublist
                ]
                vars_path = f'vars_{args.continue_training_from_epoch}.onnx'
                vars_path = os.path.join(save_dir, vars_path)
                model.builder.saveInitializersExternally(
                    weight_tensors, vars_path)

            session, anchors = bert_training_session(model, args, data_flow,
                                                     final_loss, device,
                                                     optimizer_factory)
            logger.info("Training Started")
            bert_train_loop(args, session, writer, dataset, accuracies, losses,
                            anchors, iteration, optimizer_factory)

            save_model(args, session, iteration.count)
            if args.wandb_save_checkpoints:
                artifact = wandb.Artifact(name=args.wandb_save_checkpoints,
                                          type="model")
                artifact.add_dir(args.checkpoint_dir)
                wandb.log_artifact(artifact)

            device.detach()
            logger.info("Training Finished")

    return session, iteration