Пример #1
0
def fit(model, loss, opt, train_dataset, epochs, train_batch_size, max_steps=None):
    pbar = tqdm(train_dataset)
    for i, batch in enumerate(pbar):
        with tf.GradientTape() as tape:
            inputs, targets = batch
            outputs = model(batch)
            loss_value = loss(targets, outputs.logits)

        if SDP_ENABLED:
            tape = sdp.DistributedGradientTape(tape, sparse_as_dense=True)

        grads = tape.gradient(loss_value, model.trainable_variables)
        opt.apply_gradients(zip(grads, model.trainable_variables))

        pbar.set_description(f"Loss: {loss_value:.4f}")

        if SDP_ENABLED:
            if i == 0:
                sdp.broadcast_variables(model.variables, root_rank=0)
                sdp.broadcast_variables(opt.variables(), root_rank=0)
                first_batch = False

        if max_steps and i >= max_steps:
            break

    train_results = {"loss": loss_value.numpy()}
    return train_results
Пример #2
0
def training_step(images, labels, batch):
    """
    Runs one training step, with considerations for the distributed nature of the job.
    :param images: images, sorted
    :param labels: labels for images, same order
    :param batch: batch number. variables must be broadcasted on first batch
    :return: loss for this step across all workers
    """
    with tf.GradientTape() as tape:
        # record losses for automatic differentiation (learning)
        probs = model(images, training=True)
        loss_value = loss(labels, probs)

    # need to wrap tape with more tape, because it is distributed
    tape = dist.DistributedGradientTape(tape)

    grads = tape.gradient(loss_value, model.trainable_variables)
    optimizer.apply_gradients(zip(
        grads, model.trainable_variables))  # learn on this worker

    if batch == 0:
        # workers share learning broadcast by broadcasting variables. one set of variables across workers.
        dist.broadcast_variables(model.variables, root_rank=0)
        dist.broadcast_variables(optimizer.variables(), root_rank=0)

    # all loss values from all workers reduced to one loss value
    loss_value = dist.oob_allreduce(
        loss_value)  # Average the loss across workers
    return loss_value
Пример #3
0
def training_step(mnist_model, loss, opt, images, labels, batch):
    with tf.GradientTape() as tape:
        probs = mnist_model(images, training=True)
        loss_value = loss(labels, probs)

    ########################################################
    ####### 4. SageMaker Distributed Data Parallel  ########
    #######  - Optimize AllReduce operation         ########
    ########################################################

    # SMDataParallel: Wrap tf.GradientTape with SMDataParallel's DistributedGradientTape
    tape = smdp.DistributedGradientTape(tape)

    #######################################################

    grads = tape.gradient(loss_value, mnist_model.trainable_variables)
    opt.apply_gradients(zip(grads, mnist_model.trainable_variables))

    ########################################################
    ####### 5. SageMaker Distributed Data Parallel   #######
    #######  - Broadcast the initial model variables #######
    #######    from rank 0 to ranks 1 ~ n            #######
    ########################################################

    if batch == 0:
        # SMDataParallel: Broadcast model and optimizer variables
        smdp.broadcast_variables(mnist_model.variables, root_rank=0)
        smdp.broadcast_variables(opt.variables(), root_rank=0)

    #######################################################

    return loss_value
def training_step(images, labels, first_batch):
    with tf.GradientTape() as tape:
        probs = mnist_model(images, training=True)
        loss_value = loss(labels, probs)

    tape = dist.DistributedGradientTape(tape)

    grads = tape.gradient(loss_value, mnist_model.trainable_variables)
    opt.apply_gradients(zip(grads, mnist_model.trainable_variables))

    if first_batch:
        dist.broadcast_variables(mnist_model.variables, root_rank=0)
        dist.broadcast_variables(opt.variables(), root_rank=0)

    loss_value = dist.oob_allreduce(
        loss_value)  # Average the loss across workers
    return loss_value
Пример #5
0
def training_step(images, labels, first_batch):
    with tf.GradientTape() as tape:
        probs = mnist_model(images, training=True)
        loss_value = loss(labels, probs)

    # Create a new DistributedGradientTape, which uses TensorFlow’s GradientTape under the hood,
    # using an AllReduce to combine gradient values before applying gradients to model weights.
    tape = smdataparallel.DistributedGradientTape(tape)

    grads = tape.gradient(loss_value, mnist_model.trainable_variables)
    opt.apply_gradients(zip(grads, mnist_model.trainable_variables))

    # Broadcast model and optimizer variable are first forward pass for sync
    if first_batch:
        smdataparallel.broadcast_variables(mnist_model.variables, root_rank=0)
        smdataparallel.broadcast_variables(opt.variables(), root_rank=0)

    return loss_value
Пример #6
0
def training_step(images, labels, first_batch):
    with tf.GradientTape() as tape:
        probs = mnist_model(images, training=True)
        loss_value = loss(labels, probs)

    # SMDataParallel: Wrap tf.GradientTape with SMDataParallel's DistributedGradientTape
    tape = dist.DistributedGradientTape(tape)

    grads = tape.gradient(loss_value, mnist_model.trainable_variables)
    opt.apply_gradients(zip(grads, mnist_model.trainable_variables))

    if first_batch:
        # SMDataParallel: Broadcast model and optimizer variables
        dist.broadcast_variables(mnist_model.variables, root_rank=0)
        dist.broadcast_variables(opt.variables(), root_rank=0)

    # SMDataParallel: all_reduce call
    loss_value = dist.oob_allreduce(
        loss_value)  # Average the loss across workers
    return loss_value
Пример #7
0
    def training_step(images, labels, first_batch):
        with tf.GradientTape() as tape:
            train_pred = model(images, training=True)
            loss_value = loss(labels, train_pred)
        # Change: Wrap tf.GradientTape with SMDataParallel's DistributedGradientTape
        tape = smdp.DistributedGradientTape(tape)

        grads = tape.gradient(loss_value, model.trainable_variables)
        opt.apply_gradients(zip(grads, model.trainable_variables))

        if first_batch:
            # Change: Broadcast model and optimizer variables
            smdp.broadcast_variables(model.variables, root_rank=0)
            smdp.broadcast_variables(opt.variables(), root_rank=0)

        # Change: all_reduce call
        train_loss_value = smdp.oob_allreduce(
            loss_value)  # Average the loss across workers

        train_loss(train_loss_value)
        train_accuracy(labels, train_pred)
        return
Пример #8
0
def main():
    parser = HfArgumentParser(
        (ModelArguments, DataTrainingArguments, TrainingArguments, LoggingArguments, PathArguments)
    )
    (
        model_args,
        data_args,
        train_args,
        log_args,
        path_args,
        remaining_strings,
    ) = parser.parse_args_into_dataclasses(return_remaining_strings=True)
    # SageMaker may have some extra strings. TODO: Test this on SM.
    assert len(remaining_strings) == 0, f"The args {remaining_strings} could not be parsed."

    tf.random.set_seed(train_args.seed)
    tf.autograph.set_verbosity(0)

    # Settings init
    parse_bool = lambda arg: arg == "true"
    do_gradient_accumulation = train_args.gradient_accumulation_steps > 1
    do_xla = not parse_bool(train_args.skip_xla)
    do_eager = parse_bool(train_args.eager)
    skip_sop = parse_bool(train_args.skip_sop)
    skip_mlm = parse_bool(train_args.skip_mlm)
    pre_layer_norm = parse_bool(model_args.pre_layer_norm)
    fast_squad = parse_bool(log_args.fast_squad)
    dummy_eval = parse_bool(log_args.dummy_eval)
    is_sagemaker = path_args.filesystem_prefix.startswith("/opt/ml")
    disable_tqdm = is_sagemaker
    global max_grad_norm
    max_grad_norm = train_args.max_grad_norm

    # TODO : Change to obfuscate smddpcommon. This code does not use GradientTape, so need to pass it like this.
    if train_args.bucket_cap_mb:
        bucket_cap_bytes = int(train_args.bucket_cap_mb * 1024 * 1024)
    else:
        bucket_cap_bytes = int(64 * 1024 * 1024)
    hc.setBucketSize(bucket_cap_bytes)

    gpus = tf.config.list_physical_devices("GPU")
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)
    if gpus:
        tf.config.set_visible_devices(gpus[smddp.local_rank()], "GPU")
    # XLA, AutoGraph
    tf.config.optimizer.set_jit(do_xla)
    tf.config.experimental_run_functions_eagerly(do_eager)

    if smddp.rank() == 0:
        # Run name should only be used on one process to avoid race conditions
        current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
        platform = "sm" if is_sagemaker else "eks"
        if skip_sop:
            loss_str = "-skipsop"
        elif skip_mlm:
            loss_str = "-skipmlm"
        else:
            loss_str = ""

        if log_args.run_name is None:
            metadata = (
                f"{model_args.model_type}"
                f"-{model_args.model_size}"
                f"-{model_args.load_from}"
                f"-{smddp.size()}gpus"
                f"-{train_args.per_gpu_batch_size * smddp.size() * train_args.gradient_accumulation_steps}globalbatch"
                f"-{train_args.learning_rate}maxlr"
                f"-{train_args.learning_rate_decay_power}power"
                f"-{train_args.optimizer}opt"
                f"-{train_args.total_steps}steps"
                f"-{'preln' if pre_layer_norm else 'postln'}"
                f"{loss_str}"
                f"-{model_args.hidden_dropout_prob}dropout"
            )
            run_name = f"{current_time}-{platform}-{metadata}-{train_args.name if train_args.name else 'unnamed'}"
        else:
            run_name = log_args.run_name

        # Logging should only happen on a single process
        # https://stackoverflow.com/questions/9321741/printing-to-screen-and-writing-to-a-file-at-the-same-time
        level = logging.INFO
        format = "%(asctime)-15s %(name)-12s: %(levelname)-8s %(message)s"
        if not os.path.exists(path_args.log_dir):
            os.makedirs(path_args.log_dir)
        handlers = [
            logging.FileHandler(
                os.path.join(path_args.filesystem_prefix, path_args.log_dir, f"{run_name}.log")
            ),
            TqdmLoggingHandler(),
        ]
        logging.basicConfig(level=level, format=format, handlers=handlers)

        # Check that arguments passed in properly, only after registering the alert_func and logging
        assert not (skip_sop and skip_mlm), "Cannot use --skip_sop and --skip_mlm"

    wrap_global_functions(do_gradient_accumulation)

    # Create optimizer and enable AMP loss scaling.
    if train_args.optimizer == "lamb":
        optimizer = get_lamb_optimizer(train_args)
    elif train_args.optimizer == "adamw":
        optimizer = get_adamw_optimizer(train_args)

    if _PRE_TF_2_4_0:
        optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
            optimizer, loss_scale="dynamic"
        )
    else:
        optimizer = tf.keras.mixed_precision.LossScaleOptimizer(optimizer)

    gradient_accumulator = GradientAccumulator()

    loaded_optimizer_weights = None

    model = create_model(model_class=TFAutoModelForPreTraining, model_args=model_args)
    tokenizer = create_tokenizer(model_args.model_type)
    if model_args.load_from == "checkpoint":
        checkpoint_path = os.path.join(path_args.filesystem_prefix, model_args.checkpoint_path)
        model_ckpt, optimizer_ckpt = get_checkpoint_paths_from_prefix(checkpoint_path)
        if smddp.rank() == 0:
            model.load_weights(model_ckpt)
            if model_args.load_optimizer_state == "true":
                loaded_optimizer_weights = np.load(optimizer_ckpt, allow_pickle=True)
            # We do not set the weights yet, we have to do a first step to initialize the optimizer.

    # Train filenames are [1, 2047], Val filenames are [0]. Note the different subdirectories
    # Move to same folder structure and remove if/else
    train_glob = os.path.join(path_args.filesystem_prefix, path_args.train_dir, "*.tfrecord")
    validation_glob = os.path.join(path_args.filesystem_prefix, path_args.val_dir, "*.tfrecord")

    train_filenames = glob.glob(train_glob)
    validation_filenames = glob.glob(validation_glob)

    train_dataset = get_dataset_from_tfrecords(
        model_type=model_args.model_type,
        filenames=train_filenames,
        max_seq_length=data_args.max_seq_length,
        max_predictions_per_seq=data_args.max_predictions_per_seq,
        per_gpu_batch_size=train_args.per_gpu_batch_size,
    )  # Of shape [per_gpu_batch_size, ...]
    # Batch of batches, helpful for gradient accumulation. Shape [grad_steps, per_gpu_batch_size, ...]
    train_dataset = train_dataset.batch(train_args.gradient_accumulation_steps)
    # One iteration with 10 dupes, 8 nodes seems to be 60-70k steps.
    train_dataset = train_dataset.prefetch(buffer_size=8)

    # Validation should only be done on one node, since Horovod doesn't allow allreduce on a subset of ranks
    if smddp.rank() == 0:
        validation_dataset = get_dataset_from_tfrecords(
            model_type=model_args.model_type,
            filenames=validation_filenames,
            max_seq_length=data_args.max_seq_length,
            max_predictions_per_seq=data_args.max_predictions_per_seq,
            per_gpu_batch_size=train_args.per_gpu_batch_size,
        )
        # validation_dataset = validation_dataset.batch(1)
        validation_dataset = validation_dataset.prefetch(buffer_size=8)

        pbar = tqdm.tqdm(total=train_args.total_steps, disable=disable_tqdm)
        summary_writer = None  # Only create a writer if we make it through a successful step
        logger.info(f"Starting training, job name {run_name}")

    i = 1
    start_time = time.perf_counter()
    train_start_time = time.perf_counter()
    for batch in train_dataset:
        learning_rate = optimizer.learning_rate(step=tf.constant(i, dtype=tf.float32))
        # weight_decay = wd_schedule(step=tf.constant(i, dtype=tf.float32))
        loss_scale = optimizer.loss_scale() if _PRE_TF_2_4_0 else optimizer.loss_scale
        loss, mlm_loss, mlm_acc, sop_loss, sop_acc, grad_norm, weight_norm = train_step(
            model=model,
            optimizer=optimizer,
            gradient_accumulator=gradient_accumulator,
            batch=batch,
            gradient_accumulation_steps=train_args.gradient_accumulation_steps,
            skip_sop=skip_sop,
            skip_mlm=skip_mlm,
        )

        # Don't want to wrap broadcast_variables() in a tf.function, can lead to asynchronous errors
        if i == 1:
            if smddp.rank() == 0 and loaded_optimizer_weights is not None:
                optimizer.set_weights(loaded_optimizer_weights)
            print (" RANK {} is broadcasting".format(smddp.rank()))
            #smddp.broadcast_variables(model.variables + optimizer.variables(), root_rank=0)
            smddp.broadcast_variables(model.variables, root_rank=0)
            smddp.broadcast_variables(optimizer.variables(), root_rank=0)
            print(" RANK {} is done broadcasting".format(smddp.rank()))
            # smddp.broadcast_variables(optimizer.variables(), root_rank=0)
            i = optimizer.get_weights()[0]

        is_final_step = i >= train_args.total_steps
        do_squad = (log_args.squad_frequency != 0) and (
            (i % log_args.squad_frequency == 0) or is_final_step
        )
        # Squad requires all the ranks to train, but results are only returned on rank 0
        if do_squad:
            from albert.run_squad import get_squad_results_while_pretraining
            squad_results = get_squad_results_while_pretraining(
                model=model,
                tokenizer=tokenizer,
                model_size=model_args.model_size,
                filesystem_prefix=path_args.filesystem_prefix,
                step=i,
                dataset=data_args.squad_version,
                fast=log_args.fast_squad,
                dummy_eval=log_args.dummy_eval,
            )
            if smddp.rank() == 0:
                squad_exact, squad_f1 = squad_results["exact"], squad_results["f1"]
                logger.info(f"SQuAD step {i} -- F1: {squad_f1:.3f}, Exact: {squad_exact:.3f}")
            # Re-wrap autograph so it doesn't get arg mismatches
            wrap_global_functions(do_gradient_accumulation)
            gc.collect()

        if smddp.rank() == 0:
            do_log = i % log_args.log_frequency == 0
            do_checkpoint = (log_args.checkpoint_frequency != 0) and (
                (i % log_args.checkpoint_frequency == 0) or is_final_step
            )
            do_validation = (log_args.validation_frequency != 0) and (
                (i % log_args.validation_frequency == 0) or is_final_step
            )

            pbar.update(1)
            description = f"Loss: {loss:.3f}, MLM: {mlm_loss:.3f}, SOP: {sop_loss:.3f}, MLM_acc: {mlm_acc:.3f}, SOP_acc: {sop_acc:.3f}"
            pbar.set_description(description)
            if do_log:
                elapsed_time = time.perf_counter() - start_time
                if i == 1:
                    logger.info(f"First step: {elapsed_time:.3f} secs")
                elif is_final_step:
                    total_time = time.perf_counter() - train_start_time
                    seq_per_sec = i * train_args.per_gpu_batch_size * smddp.size() * train_args.gradient_accumulation_steps / total_time
                    logger.info(f"Final step {i}: {description} -- Average seq_per_sec: {seq_per_sec:.2f} -- Total Time: {total_time}")
                else:
                    it_per_sec = log_args.log_frequency / elapsed_time
                    logger.info(f"Train step {i} -- {description} -- It/s: {it_per_sec:.2f}")
                    start_time = time.perf_counter()

            if do_checkpoint:
                checkpoint_prefix = os.path.join(
                    path_args.filesystem_prefix, path_args.checkpoint_dir, f"{run_name}-step{i}"
                )
                model_ckpt = f"{checkpoint_prefix}.ckpt"
                optimizer_ckpt = f"{checkpoint_prefix}-optimizer.npy"
                logger.info(f"Saving model at {model_ckpt}, optimizer at {optimizer_ckpt}")
                model.save_weights(model_ckpt)
                # model.load_weights(model_ckpt)

                optimizer_weights = optimizer.get_weights()
                np.save(optimizer_ckpt, optimizer_weights)
                # optimizer.set_weights(optimizer_weights)

            if do_validation:
                val_loss, val_mlm_loss, val_mlm_acc, val_sop_loss, val_sop_acc = run_validation(
                    model=model,
                    validation_dataset=validation_dataset,
                    skip_sop=skip_sop,
                    skip_mlm=skip_mlm,
                )
                description = f"Loss: {val_loss:.3f}, MLM: {val_mlm_loss:.3f}, SOP: {val_sop_loss:.3f}, MLM_acc: {val_mlm_acc:.3f}, SOP_acc: {val_sop_acc:.3f}"
                logger.info(f"Validation step {i} -- {description}")

            # Create summary_writer after the first step
            if summary_writer is None:
                summary_writer = tf.summary.create_file_writer(
                    os.path.join(path_args.filesystem_prefix, path_args.log_dir, run_name)
                )
                config = {
                    **asdict(model_args),
                    **asdict(data_args),
                    **asdict(train_args),
                    **asdict(log_args),
                    "global_batch_size": train_args.per_gpu_batch_size * smddp.size(),
                }
                if is_wandb_available():
                    wandb.init(config=config, project=model_args.model_type)
                    wandb.run.save()
                    wandb_run_name = wandb.run.name

            train_metrics = {
                "weight_norm": weight_norm,
                "grad_norm": grad_norm,
                "loss_scale": loss_scale,
                "learning_rate": learning_rate,
                "train/loss": loss,
                "train/mlm_loss": mlm_loss,
                "train/mlm_acc": mlm_acc,
                "train/sop_loss": sop_loss,
                "train/sop_acc": sop_acc,
            }
            all_metrics = {**train_metrics}
            if do_validation:
                val_metrics = {
                    "val/loss": val_loss,
                    "val/mlm_loss": val_mlm_loss,
                    "val/mlm_acc": val_mlm_acc,
                    "val/sop_loss": val_sop_loss,
                    "val/sop_acc": val_sop_acc,
                }
                all_metrics = {**all_metrics, **val_metrics}
            if do_squad:
                squad_metrics = {
                    "squad/f1": squad_f1,
                    "squad/exact": squad_exact,
                }
                all_metrics = {**all_metrics, **squad_metrics}

            # Log to TensorBoard
            with summary_writer.as_default():
                for name, val in all_metrics.items():
                    tf.summary.scalar(name, val, step=i)
            # Log to Weights & Biases
            if is_wandb_available():
                wandb.log({"step": i, **all_metrics})

        i += 1
        if is_final_step:
            break

    if smddp.rank() == 0:
        pbar.close()
        logger.info(f"Finished pretraining, job name {run_name}")