Beispiel #1
0
def count_zeros_fp32(parameters):

    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]

    # Filter parameters based on:
    #   - grad should not be none
    #   - parameter should not be shared
    #   - should not be a replica due to tensor model parallelism
    total_num_zeros = 0.0
    for param in parameters:
        grad_not_none = param.grad is not None
        is_not_shared = not hasattr(param, "shared") or not param.shared
        is_not_tp_duplicate = smp.tp_rank() == 0 or (
            param in param_is_distributed and param_is_distributed[param])
        if grad_not_none and is_not_shared and is_not_tp_duplicate:
            grad = param.grad.detach()
            num_zeros = grad.numel() - torch.count_nonzero(grad)
            total_num_zeros = num_zeros + total_num_zeros

    # Sum across all model-parallel GPUs.
    torch.distributed.all_reduce(total_num_zeros,
                                 op=torch.distributed.ReduceOp.SUM,
                                 group=smp.get_mp_process_group())
    total_num_zeros = total_num_zeros.item()

    return total_num_zeros
Beispiel #2
0
    def hook_fn(model, optimizer):
        optimizer.load_state_dict(opt_state_dict["optimizer_state_dict"])
        if partial:
            if args.shard_optimizer_state:
                assert isinstance(
                    opt_state_dict["fp32_from_fp16"], list
                ), "Loading with shard_optimizer_state=True must use the checkpoint that was trained with shard_optimizer_state=True!"
                optimizer.fp32_from_fp16 = opt_state_dict["fp32_from_fp16"][
                    smp.rdp_rank()]
            else:
                optimizer.fp32_from_fp16 = opt_state_dict["fp32_from_fp16"]

            for current_group, saved_group in zip(
                    optimizer.fp32_from_float16_groups,
                    optimizer.fp32_from_fp16):
                for current, saved in zip(current_group, saved_group):
                    current.data.copy_(saved.data)

        else:
            optimizer.fp32_from_fp16 = opt_state_dict["fp32_from_fp16"]
            param_name_groups = opt_state_dict["param_name_groups"]
            param_id_to_index = optimizer._param_id_to_index()
            param_index_to_name_tp_group = smp_state.param_index_to_name_tp_group
            param_index_to_name = param_index_to_name_tp_group[smp.tp_rank()]
            for group_idx, (current_group, saved_group) in enumerate(
                    zip(optimizer.fp32_from_float16_groups,
                        optimizer.fp32_from_fp16)):
                for current in current_group:
                    param_id = id(current)
                    param_index = param_id_to_index[param_id]
                    param_name = param_index_to_name[param_index]
                    arr_index = param_name_groups[group_idx][param_name]
                    saved = saved_group[arr_index]
                    if optimizer.master_distribution_axis[
                            param_id] is not None:
                        axis = optimizer.master_distribution_axis[param_id]
                        slice_size = saved.size(axis) // smp.tp_size()
                        saved = torch.narrow(saved.data, axis,
                                             slice_size * smp.tp_rank(),
                                             slice_size).contiguous()
                    else:
                        saved = saved.data
                    current.data.copy_(saved)

        optimizer.grad_scaler.load_state_dict(opt_state_dict["grad_scaler"])
def memory_status(msg="", reset_max=True, sync=True):

    rank = smp.rank()
    tp_rank = smp.tp_rank()
    pp_rank = smp.pp_rank()
    rdp_rank = smp.rdp_rank()
    local_rank = smp.local_rank()

    if sync:
        torch.cuda.synchronize()

    if rdp_rank != 0:
        return

    if py3nvml != None:
        py3nvml.nvmlInit()
        handle = py3nvml.nvmlDeviceGetHandleByIndex(local_rank)
        info = py3nvml.nvmlDeviceGetMemoryInfo(handle)
        total_used = info.used / 1024**3
        total_used_str = f"Totally used GPU memory: {total_used}"
    else:
        total_used_str = ""

    alloced = torch.cuda.memory_allocated(device=local_rank)
    max_alloced = torch.cuda.max_memory_allocated(device=local_rank)
    cached = torch.cuda.memory_reserved(device=local_rank)
    max_cached = torch.cuda.max_memory_reserved(device=local_rank)

    # convert to GB for printing
    alloced /= 1024**3
    cached /= 1024**3
    max_alloced /= 1024**3
    max_cached /= 1024**3

    print(
        f'[{msg}] rank {rank} tp_rank {tp_rank} pp_rank {pp_rank} TORCH {torch.__version__}',
        f'device={local_rank} '
        f'alloc {alloced:0.4f} max_alloced {max_alloced:0.4f} '
        f'cache {cached:0.4f} max_cached {max_cached:0.4f} '
        f'{total_used_str}')
    if reset_max:
        torch.cuda.reset_max_memory_cached()
        torch.cuda.reset_max_memory_allocated()
    if py3nvml != None:
        py3nvml.nvmlShutdown()
def memory_status_cpu(msg=""):
    import gc
    global last_mem_usage
    global base_mem_usage
    rdp_rank = smp.rdp_rank()
    gc.collect()
    gc.collect()
    gc.collect()
    objects = gc.get_objects()
    tensors = [
        obj for obj in objects
        if isinstance(obj, torch.Tensor) and not obj.is_cuda
    ]
    torch_usage = 0
    for t in tensors:
        torch_usage += t.numel() * dtype_to_bit[t.dtype]
    #total_usage = psutil.virtual_memory()[3] # This will get the total usage for all processes
    current_usage = process.memory_info().data
    total_usage = current_usage - base_mem_usage
    usage_change = current_usage - last_mem_usage
    last_mem_usage = current_usage

    torch_usage /= 1024**3
    total_usage /= 1024**3
    usage_change /= 1024**3
    base_usage = base_mem_usage / 1024**3

    rank = smp.rank()
    tp_rank = smp.tp_rank()
    pp_rank = smp.pp_rank()
    rdp_rank = smp.rdp_rank()
    local_rank = smp.local_rank()
    if rdp_rank != 0:
        return

    print(
        f'[{msg}] rank {rank} tp_rank {tp_rank} pp_rank {pp_rank} TORCH {torch.__version__}',
        f'device={local_rank} '
        f'torch cpu tensor usage {torch_usage:0.4f} cpu mem usage {total_usage:0.4f} change since last measurement {usage_change:0.4f} base cpu mem usage {base_usage:0.4f}'
    )
Beispiel #5
0
def get_tp_merged_fp32_from_fp16_param_groups(optimizer,
                                              cpu_fp32_from_fp16_groups):
    def _merge_param_group_tp_group(group_idx, param_group):
        result_fp32_from_fp16_param_group = []
        param_name_group = {}
        for i, param in enumerate(param_group):
            # for each param, obtain param_name from param using two dicts above for tp_rank 0
            param_index = param_id_to_index_tp_group[rank_0][
                fp32_from_fp16_paramid_groups_tp_group[rank_0][group_idx][i]]
            param_name = param_index_to_name_tp_group[rank_0][param_index]
            # obtain distribution axis for the param and check if its distributed
            # axis = master_distribution_axis_tp_rank_0[fp32_from_fp16_paramid_groups_tp_group[rank_0][group_idx][i]]
            axis = master_distribution_axis_tp_rank_0.get(
                fp32_from_fp16_paramid_groups_tp_group[rank_0][group_idx][i],
                None)
            if axis is not None:
                tensors = []
                for r in range(smp.tp_size()):
                    # if distributed, for each rank, obtain param id from index using above two dicts
                    param_index_r = param_name_to_index_tp_group[r][param_name]
                    param_id_r = param_index_to_id_tp_group[r][param_index_r]

                    # search param id in fp32_from_fp16_groups_param_ids and find the index.
                    group_param_idx = fp32_from_fp16_paramid_groups_tp_group[
                        r][group_idx].index(param_id_r)
                    # use the param corresponding to the index from fp32_from_fp16_groups for concatenation along axis
                    tensors.append(fp32_from_fp16_param_groups_tp_group[r]
                                   [group_idx][group_param_idx])
                result_fp32_from_fp16_param_group.append(
                    torch.cat(tensors, axis))
            else:
                # if not distributed set tp_rank 0 param as the param
                result_fp32_from_fp16_param_group.append(param)
            param_name_group[param_name] = i
        return result_fp32_from_fp16_param_group, param_name_group

    # get param_index_to_name all and param_name_to_index_all
    param_index_to_name_tp_group = smp_state.param_index_to_name_tp_group
    param_name_to_index_tp_group = smp_state.param_name_to_index_tp_group
    # get mapping of param_id_to_index_all and param_index_to_id_all
    param_id_to_index = optimizer._param_id_to_index()
    param_id_to_index_tp_group = smp.allgather(param_id_to_index, smp.TP_GROUP)
    param_index_to_id_tp_group = _get_param_index_to_id(
        param_id_to_index_tp_group)
    # allgather all param ids and all params for fp32_from_fp16_groups
    fp32_from_fp16_paramid_groups = optimizer.fp32_from_fp16_paramid_groups
    fp32_from_fp16_paramid_groups_tp_group = smp.allgather(
        fp32_from_fp16_paramid_groups, smp.TP_GROUP)
    fp32_from_fp16_param_groups_tp_group = smp.allgather(
        cpu_fp32_from_fp16_groups, smp.TP_GROUP)
    # broadcast distribution axis from tp_rank 0 to all tp_ranks
    master_distribution_axis_tp_rank_0 = None
    if smp.tp_rank() == 0:
        master_distribution_axis_tp_rank_0 = optimizer.master_distribution_axis
        smp.broadcast(master_distribution_axis_tp_rank_0, smp.TP_GROUP)
    else:
        master_distribution_axis_tp_rank_0 = smp.recv_from(
            0, smp.RankType.TP_RANK)

    result_fp32_from_fp16_param_groups = []
    param_name_groups = []
    rank_0 = 0
    # iterate through all the params for tp_group_fp32_from_fp16_groups[rank_0]
    for group_idx, param_group in enumerate(
            fp32_from_fp16_param_groups_tp_group[rank_0]):
        result_fp32_from_fp16_param_group, param_name_group = _merge_param_group_tp_group(
            group_idx, param_group)
        result_fp32_from_fp16_param_groups.append(
            result_fp32_from_fp16_param_group)
        param_name_groups.append(param_name_group)
    return result_fp32_from_fp16_param_groups, param_name_groups
Beispiel #6
0
def clip_grad_norm_fp32(parameters,
                        param_is_distributed,
                        max_norm,
                        norm_type=2):
    """Clips gradient norm of an iterable of parameters whose gradients
       are in fp32.

    This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
    added functionality to handle model parallel parameters. Note that
    the gradients are modified in place.

    Arguments:
        parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
            single Tensor that will have gradients normalized
        max_norm (float or int): max norm of the gradients
        norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
            infinity norm.

    Returns:
        Total norm of the parameters (viewed as a single vector).
    """

    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]

    # Filter parameters based on:
    #   - grad should not be none
    #   - parameter should not be shared
    #   - should not be a replica due to tensor model parallelism
    torch.cuda.set_device(smp.local_rank())
    grads = []
    grads_for_norm = []
    for param in parameters:
        grad_not_none = param.grad is not None
        is_not_shared = not hasattr(param, "shared") or not param.shared
        is_not_tp_duplicate = smp.tp_rank() == 0 or (
            param in param_is_distributed and param_is_distributed[param])
        if grad_not_none:
            grad = param.grad.detach()
            # Make sure the grads are in fp32
            assert param.grad.type() == 'torch.cuda.FloatTensor'
            grads.append(grad)
        if grad_not_none and is_not_shared and is_not_tp_duplicate:
            grads_for_norm.append(grad)

    # Norm parameters.
    max_norm = float(max_norm)
    norm_type = float(norm_type)
    total_norm = torch.tensor(0.0, device=torch.device("cuda"))

    # Calculate norm.
    if norm_type == inf:
        if len(grads_for_norm) > 0:
            total_norm = max(grad.abs().max() for grad in grads_for_norm)
        total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
        # Take max across all model-parallel GPUs.
        torch.distributed.all_reduce(total_norm_cuda,
                                     op=torch.distributed.ReduceOp.MAX,
                                     group=smp.get_mp_process_group())
        total_norm = total_norm_cuda[0].item()

    else:
        if norm_type == 2.0:
            dummy_overflow_buf = torch.cuda.IntTensor([0],
                                                      device=torch.device(
                                                          "cuda",
                                                          smp.local_rank()))
            # Use apex's multi-tensor applier for efficiency reasons.
            # Multi-tensor applier takes a function and a list of list
            # and performs the operation on that list all in one kernel.
            if len(grads_for_norm) > 0:
                grad_norm, _ = multi_tensor_applier(
                    amp_C.multi_tensor_l2norm,
                    dummy_overflow_buf,
                    [grads_for_norm],
                    False  # no per-parameter norm
                )
                # Since we will be summing across data parallel groups,
                # we need the pow(norm-type).
                total_norm = grad_norm**norm_type
        else:
            for grad in grads_for_norm:
                grad_norm = torch.norm(grad, norm_type)
                total_norm += grad_norm**norm_type

        # Sum across all model-parallel GPUs.
        torch.distributed.all_reduce(total_norm,
                                     op=torch.distributed.ReduceOp.SUM,
                                     group=smp.get_mp_process_group())
        total_norm = total_norm.item()**(1.0 / norm_type)

    # Scale.
    if len(grads) > 0:
        clip_coeff = max_norm / (total_norm + 1.0e-6)
        if clip_coeff < 1.0:
            dummy_overflow_buf = torch.cuda.IntTensor([0],
                                                      device=torch.device(
                                                          "cuda",
                                                          smp.local_rank()))
            multi_tensor_applier(amp_C.multi_tensor_scale, dummy_overflow_buf,
                                 [grads, grads], clip_coeff)

    return total_norm
Beispiel #7
0
def main():
    args = parse_args()

    if args.shard_optimizer_state > 0 and not args.skip_full_optimizer:
        raise ValueError(
            "If shard_optimizer_state is enabled, skip_full_optimizer must also be enabled. Full optimizer saving is currently not supported under optimizer state sharding."
        )

    if args.partition_assignment != "" and args.manual_partition == 0:
        print("[Warning] partition_assignment is set, enable manual_partition")
        args.manual_partition = 1

    # any value here is overriden by the config set in notebook when launching the sagemaker job
    smp_config = {
        "ddp": True,
        "tensor_parallel_degree": args.tensor_parallel_degree,
        "pipeline_parallel_degree": args.pipeline_parallel_degree,
        "microbatches": args.microbatches,
        # if activation_checkpointing true checkpoints transformer layers below
        "checkpoint_attentions":
        False if args.activation_checkpointing else True,
        "shard_optimizer_state": args.shard_optimizer_state > 0,
        "prescaled_batch": args.prescaled_batch > 0,
        "offload_activations": args.offload_activations > 0,
        "optimize": args.optimize,
        "auto_partition": False if args.manual_partition else True,
        "default_partition": 0,
        "static_mode": args.static_mode > 0,
        "fast_mode": args.fast_mode > 0,
    }

    if args.smp_version < 110:
        smp_config["fp16_params"] = args.fp16 > 0
    else:
        smp_config["fp16"] = args.fp16 > 0
        smp_config["delayed_parameter_initialization"] = args.delayed_param > 0
        smp_config["placement_strategy"] = args.placement_strategy
        smp_config[
            "activation_loading_horizon"] = args.activation_loading_horizon
        smp_config["skip_tracing"] = args.skip_tracing > 0

    if args.active_microbatches is not None:
        smp_config["active_microbatches"] = args.active_microbatches

    smp.init(smp_config)

    if smp.rank() == 0:
        print("Arguments:", args.__dict__)
        print(f"Transformers version: {transformers.__version__}")
        print(
            f"smdistributed.modelparallel version: {smdistributed.modelparallel.__version__}"
        )
        print(f"smdistributed config: {smp_config}")

    if args.save_final_full_model and smp.rank() == 0:
        print(
            f"[Warning] Note that save_final_full_model only saves the final model at the end of all steps. It does not save optimizer state. Optimizer state is only saved with partial models which are saved at checkpointing_freq during training. If you want to restart training you need partial checkpoints."
        )

    if args.partition_assignment != "":
        partition_assignment = args.partition_assignment.split(",")
        assert (
            len(partition_assignment) == smp.pp_size()
        ), f"partition_assignment must have the same size as pipeline parallel degree, but getting {len(partition_assignment)} vs {smp.pp_size()}"

    if smp.rank() == 0 or (smp.local_rank() == 0 and args.use_fsx == 0):
        for path in [args.model_dir, args.checkpoint_dir]:
            if not os.path.exists(path):
                os.makedirs(path, exist_ok=True)

    model_config = GPT2Config(
        vocab_size=args.vocab_size,
        n_positions=args.max_context_width,
        n_embd=args.hidden_width,
        n_layer=args.num_layers,
        n_head=args.num_heads,
        n_inner=None,
        activation_function="gelu_new",
        resid_pdrop=args.resid_pdrop,
        embd_pdrop=args.embd_pdrop,
        attn_pdrop=args.attn_pdrop,
        layer_norm_epsilon=1e-05,
        initializer_range=0.02,
        summary_type="cls_index",
        summary_use_proj=True,
        summary_activation=None,
        summary_proj_to_labels=True,
        summary_first_dropout=args.summary_first_pdrop,
        # gradient_checkpointing=args.gradient_checkpointing > 0,
        use_cache=False,
        bos_token_id=50256,
        eos_token_id=50256,
        return_dict=True,
    )

    # the following improves start-up time by skipping proper initialization
    # of weights in the original model. this is not a problem because DistributedModel
    # will override those weights anyway when tensor_parallel_degree > 1.
    if smp.tp_size() > 1:
        from transformers.modeling_utils import PreTrainedModel

        PreTrainedModel.init_weights = lambda x: None

    set_seed(args.seed)

    if args.enable_memory_profiling > 0:
        memory_status_cpu(msg="before model creation")

    if args.smp_version < 110:
        if args.fp16:
            torch.set_default_dtype(torch.float16)
        with smp.tensor_parallelism(
                enabled=smp.tp_size() > 1,
                attention_in_fp32=args.attention_in_fp32 > 0):
            with smp.delay_param_initialization(
                    enabled=(smp.tp_size() > 1 and args.delayed_param > 0)):
                model = AutoModelForCausalLM.from_config(model_config)
    else:
        with smp.model_creation(
                tensor_parallelism=smp.tp_size() > 1,
                attention_in_fp32=args.attention_in_fp32 > 0,
                query_key_layer_scaling=args.query_key_layer_scaling > 0,
                fused_softmax=args.fused_softmax > 0,
                fused_bias_gelu=args.fused_bias_gelu > 0,
                dtype=torch.float16
                if args.fp16 else torch.get_default_dtype(),
        ):
            model = AutoModelForCausalLM.from_config(model_config)

    if args.smp_version < 110 and args.fp16:
        model = FP16_Module(model)

    if args.enable_memory_profiling > 0:
        memory_status_cpu(msg="after model creation")

    num_params = sum([np.prod(p.size()) for p in model.parameters()])
    if smp.rank() == 0:
        print(f"# total parameters: {num_params}")

    # smdistributed: Set the device to the GPU ID used by the current process.
    # Input tensors should be transferred to this device.
    torch.cuda.set_device(smp.local_rank())
    device = torch.device("cuda")

    if not args.same_seed:
        # Set seed by tp_rank to prevent weights from being the same on different tp_ranks
        set_seed(args.seed + smp.tp_rank())

    # smdistributed: Use the DistributedModel container to provide the model
    # to be partitioned across different ranks. For the rest of the script,
    # the returned DistributedModel object should be used in place of
    # the model provided for DistributedModel class instantiation.
    if args.smp_version < 110 and args.fp16:
        torch.set_default_dtype(torch.float16)
    if args.enable_memory_profiling > 0:
        memory_status_cpu(msg="before dist model creation")
    model = smp.DistributedModel(model, trace_device="gpu")
    if args.enable_memory_profiling > 0:
        memory_status_cpu(msg="after dist model creation")

    if args.smp_version < 110:
        if smp.tp_size() > 1:
            transformer_layers = model.module.module.module.transformer.seq_layers
        else:
            transformer_layers = model.module.module.module.transformer.h
    else:
        m = model.get_module()
        if smp.tp_size() > 1:
            transformer_layers = m.transformer.seq_layers
        else:
            transformer_layers = m.transformer.h

    if args.manual_partition:
        print(f"Manual partition enabled")
        if args.partition_assignment != "":
            get_num_layers = lambda x: int(partition_assignment[x])
            total_layers = sum(
                [get_num_layers(pp_rank) for pp_rank in range(smp.pp_size())])
            assert (
                total_layers == args.num_layers
            ), f"partition_assignment must have the same total transformer layers as model, but getting {total_layers} vs {args.num_layers}"
        else:
            # evenly distribute layers across all partitions
            div, rem = divmod(args.num_layers, smp.pp_size())
            get_num_layers = lambda x: (div + 1
                                        if x >= smp.pp_size() - rem else div)
        assignments = []
        # (TODO) This is required for 175B otherwise a hang for partition "8,17,17,18,18,18"
        # Need further investigation
        # for pp_rank in reversed(range(smp.pp_size())):
        for pp_rank in range(smp.pp_size()):
            nl = get_num_layers(pp_rank)
            print(f"{nl} layers assigned to partition {pp_rank}")
            assignments += [pp_rank for _ in range(nl)]

        for i, c in enumerate(transformer_layers.children()):
            smp.set_partition(c, assignments[i])
    if args.smp_version < 110:
        iter_model = model
        # Build parameter groups (weight decay and non-decay).
        while isinstance(iter_model, (DistributedDataParallel, FP16_Module)):
            iter_model = iter_model.module
    else:
        iter_model = m
    param_groups = get_param_groups_by_weight_decay(iter_model)

    if args.use_adamw > 0:
        optimizer = optim.AdamW(param_groups,
                                betas=(args.beta1, args.beta2),
                                lr=args.lr,
                                weight_decay=args.weight_decay)
    else:
        optimizer = optim.Adam(param_groups,
                               betas=(args.beta1, args.beta2),
                               lr=args.lr,
                               weight_decay=args.weight_decay)

    if args.activation_checkpointing:
        kwargs = {}
        if isinstance(transformer_layers, nn.Sequential):
            kwargs["pack_args_as_tuple"] = True
            kwargs["strategy"] = args.activation_strategy
        smp.set_activation_checkpointing(transformer_layers, **kwargs)

    if args.smp_version < 110:
        optimizer = FP16_Optimizer(
            model,
            optimizer,
            static_loss_scale=None,
            dynamic_loss_scale=True,
            use_smp=True,
            dynamic_loss_args={
                "scale_window": 1000,
                "min_scale": 1,
                "delayed_shift": 2
            },
            params_have_main_grad=False,
            shard_optimizer_state=args.shard_optimizer_state > 0,
        )

        optimizer = smp.DistributedOptimizer(optimizer)
        model.register_post_step_hook(
            lambda model, optimizer: optimizer.init_master_params())
    else:
        optimizer = smp.DistributedOptimizer(
            optimizer,
            static_loss_scale=None,
            dynamic_loss_scale=True,
            dynamic_loss_args={
                "scale_window": 1000,
                "min_scale": 1,
                "delayed_shift": 2
            },
        )
    lr_scheduler = get_learning_rate_scheduler(optimizer, args)

    if args.enable_memory_profiling > 0:
        model.register_post_partition_hook(
            lambda model, optimizer: memory_status(msg="After_partition"))

    # load after wrapping model and optimizer with smp Distributed...
    if args.load_full or args.load_partial:
        if args.load_partial and args.load_full:
            print(
                "Since both --load_partial and --load_full set, will try to load from full checkpoint."
                "If the intention is to load from partial checkpoint, please don't set --load_full"
            )
        partial = not args.load_full
        path = args.checkpoint_dir if partial else args.model_dir
        translate_from_hf = not partial
        model, optimizer, total_steps, start_train_path_index, start_batch_index = load_model_and_optimizer(
            path,
            model,
            optimizer,
            lr_scheduler,
            partial,
            args,
            translate_from_hf=translate_from_hf,
            seq_length=args.max_context_width,
            load_model=True,
            load_optimizer=args.load_partial > 0,
            num_params=num_params,
        )
    else:
        total_steps = 0
        start_train_path_index = 0
        start_batch_index = 0

    start = time.time()
    total_steps, throughput, loss = train(
        model,
        optimizer,
        lr_scheduler,
        model_config,
        start_train_path_index,
        start_batch_index,
        num_params,
        total_steps,
        args,
    )
    time_to_train = time.time() - start
    if args.ci:
        print(f"[SMP_METRIC]__GPT2__Time_to_train__{time_to_train}")
        print(f"[SMP_METRIC]__GPT2__samples/second__{throughput}")
        print(f"[SMP_METRIC]__GPT2__Loss__{loss}")
        if not args.load_partial and not args.load_full:
            assert time_to_train < args.time_to_train
            assert throughput > args.throughput
            if args.loss:
                assert loss < args.loss

    if args.save_final_full_model:
        # saves full model at the end

        base_path = f"trained_gpt_nparams-{num_params}_steps-{total_steps}.pt"
        out_path = os.path.join(args.model_dir, base_path)

        if smp.rdp_rank() == 0:
            save(
                out_path,
                model,
                optimizer,
                lr_scheduler,
                model_config,
                num_params,
                total_steps,
                -1,
                args,
                partial=False,
                translate_to_hf=smp.tp_size() > 1,
                seq_length=args.max_context_width,
            )

    smp.barrier()
    if smp.rank() == 0:
        print("SMP training finished successfully")
def main():

    model_args, data_args, training_args, smp_args = parse_args()
    model, tokenizer = initialize_model_and_tokenizer(model_args)

    # Get datasets
    train_dataset, eval_dataset = Preprocess.datasets(model_args, data_args,
                                                      training_args)

    if is_sagemaker_mp_enabled():
        initialize_smp(smp_args, training_args)

        torch.set_default_dtype(torch.float32)

        num_params = print_num_parameters(model)

        # smdistributed: Set the device to the GPU ID used by the current process.
        # Input tensors should be transferred to this device.
        torch.cuda.set_device(smp.local_rank())
        device = torch.device("cuda")

        if not training_args.same_seed:
            # Set seed by tp_rank to prevent weights from being the same on different tp_ranks
            set_seed(training_args.seed + smp.tp_rank())

        model = smp.DistributedModel(model,
                                     trace_device=smp_args.trace_device,
                                     gradient_as_bucket_view=True)

        torch.set_default_dtype(torch.float32)

        iter_model = model
        # Build parameter groups (weight decay and non-decay).
        while isinstance(iter_model, (DistributedDataParallel, FP16_Module)):
            iter_model = iter_model.module

        param_groups = get_param_groups_by_weight_decay(iter_model)

        if training_args.use_adamw > 0:
            optimizer = training_args.AdamW(
                param_groups,
                betas=(training_args.beta1, training_args.beta2),
                lr=training_args.lr,
                weight_decay=training_args.weight_decay,
            )
        else:
            optimizer = optim.Adam(
                param_groups,
                betas=(training_args.beta1, training_args.beta2),
                lr=training_args.lr,
                weight_decay=training_args.weight_decay,
            )

        optimizer = smp.DistributedOptimizer(optimizer)
        lr_scheduler = get_learning_rate_scheduler(optimizer, training_args)

        total_steps = 0
        start_train_path_index = 0
        start_batch_index = 0

        # Initialize Trainer instance

        trainer = SMPTrainer(
            model=model,
            args=training_args,
            train_dataset=train_dataset if training_args.do_train else None,
            eval_dataset=eval_dataset if training_args.do_eval else None,
            tokenizer=tokenizer,
            data_collator=default_data_collator,
        )

        trainer.train_smp(
            model,
            optimizer,
            lr_scheduler,
            start_train_path_index,
            start_batch_index,
            num_params,
            total_steps,
            training_args,
            prescaled_batch=smp_args.prescaled_batch,
        )