def save(
    output_save_file,
    model,
    optimizer,
    lr_scheduler,
    model_config,
    num_params,
    total_steps,
    curr_train_path_index,
    args,
    partial=True,
    translate_to_hf=False,
    seq_length=1024,
    batch_idx=0,
):
    save_fn = save_fp16_optimizer_megatron if args.megatron else save_fp16_optimizer
    save_dict = {
        "cli_args": args.__dict__,
        "num_params": num_params,
        "total_steps": total_steps,
        "curr_train_path_index": curr_train_path_index,
        "model_config": model_config,
        "batch_idx": batch_idx,
    }

    if lr_scheduler is not None:
        save_dict["lr_scheduler"] = lr_scheduler.state_dict()
    if partial:
        save_dict["model"] = model.local_state_dict()
    else:
        model_state_dict = model.state_dict(gather_to_rank0=True)
        if smp.rank() == 0:
            save_dict["model"] = (
                translate_state_dict_to_hf_gpt2(model_state_dict, seq_length)
                if translate_to_hf
                else model_state_dict
            )

    if args.fp16:
        if not partial and args.skip_full_optimizer:
            print("Skipping saving the final optimizer state")
        else:
            if args.shard_optimizer_state == 0 or partial:
                save_dict["optimizer"] = save_fn(args, model, optimizer, partial=partial)
            else:
                print("Saving the full optimizer state does not work with shard_optimizer_state > 0! Skipping...")
    else:
        # fp32
        if partial:
            save_dict["optimizer"] = optimizer.local_state_dict()
        else:
            if not args.skip_full_optimizer:
                save_dict["optimizer"] = optimizer.state_dict()
            else:
                print("Skipping saving of full optimizer state")

    if (smp.rdp_rank() == 0 and partial) or smp.rank() == 0:
        smp.save(save_dict, output_save_file, partial=partial)

    print(f"Finished checkpointing after {total_steps} steps: {output_save_file}")
Esempio n. 2
0
def save_fp16_optimizer(args, model, optimizer, partial=True):
    optimizer_state_dict = {}
    loss_scaler = optimizer.loss_scaler
    _model = loss_scaler.model
    loss_scaler.model = None
    _loss_scaler = copy.deepcopy(loss_scaler)
    loss_scaler.model = _model
    optimizer_state_dict["loss_scaler"] = _loss_scaler
    optimizer_state_dict["dynamic_loss_scale"] = optimizer.dynamic_loss_scale
    optimizer_state_dict["overflow"] = optimizer.overflow
    optimizer_state_dict[
        "first_closure_call_this_step"] = optimizer.first_closure_call_this_step
    cpu_fp32_from_fp16_groups = [[param.cpu() for param in group]
                                 for group in optimizer.fp32_from_fp16_groups]
    if optimizer.master_params_created:
        register_optimizer_hooks(model)
    if partial:
        optimizer_state_dict[
            "optimizer_state_dict"] = optimizer.local_state_dict()
        if args.shard_optimizer_state:
            if smp.rdp_rank() == 0:
                print(
                    "With shard_optimizer_state=True, gather full fp32_from_fp16_groups for the rdp_group on rdp rank 0"
                )
                gathered_cpu_fp32_from_fp16_groups = [
                    cpu_fp32_from_fp16_groups
                ]
                for src in range(1, smp.rdp_size()):
                    gathered_cpu_fp32_from_fp16_groups.append(
                        smp.recv_from(src, smp.RankType.RDP_RANK))
                optimizer_state_dict[
                    "fp32_from_fp16"] = gathered_cpu_fp32_from_fp16_groups
            else:
                smp.send(cpu_fp32_from_fp16_groups, 0, smp.RankType.RDP_RANK)
                optimizer_state_dict[
                    "fp32_from_fp16"] = cpu_fp32_from_fp16_groups
        else:
            optimizer_state_dict["fp32_from_fp16"] = cpu_fp32_from_fp16_groups
        if smp.pp_size() > 1:
            print(
                "WARNING: Ensure that partition decision doesnt change between runs (you can ensure this by setting use_times=False in smp config)."
                "If you want to save and load with partition decision changing between runs, use full save and load instead."
            )
    else:
        optimizer_state_dict["optimizer_state_dict"] = optimizer.state_dict()
        if smp.tp_size() > 1 and not args.shard_optimizer_state:
            tp_merged_fp32_from_fp16_groups, param_name_groups = get_tp_merged_fp32_from_fp16_param_groups(
                optimizer, cpu_fp32_from_fp16_groups)
            pp_merged_fp32_from_fp16_groups, param_name_groups = get_pp_merged_fp32_from_fp16_param_groups(
                optimizer, tp_merged_fp32_from_fp16_groups, param_name_groups)
        else:
            raise ValueError(
                "Loading full optimizer state is not supported, when TP is not enabled or shard_optimizer_state is enabled"
            )
        optimizer_state_dict[
            "fp32_from_fp16"] = pp_merged_fp32_from_fp16_groups
        optimizer_state_dict["param_name_groups"] = param_name_groups
    return optimizer_state_dict
def save_ckptsum(args, model, optimizer, filename):
    results = collections.defaultdict(dict)
    model_result = collections.defaultdict(dict)

    if args.fp16:
        from fp16.fp16util import register_optimizer_hooks

        register_optimizer_hooks(model)

    def _get_optimizer_result(optimizer_states):
        _optimizer_result = collections.defaultdict(dict)
        for param_idx, state in optimizer_states.items():
            for key, val in state.items():
                if isinstance(val, torch.Tensor):
                    _optimizer_result["tensors"][
                        f"{param_idx}_{key}"] = torch.sum(val)
                else:
                    _optimizer_result["scalars"][f"{param_idx}_{key}"] = val
        return _optimizer_result

    if not args.shard_optimizer_state:
        optimizer_result = _get_optimizer_result(
            optimizer.local_state_dict()["state"])
    else:
        local_state_dict = optimizer.local_state_dict()["state"]
        if smp.rdp_rank() == 0:
            optimizer_result = []
            for partial_local_state_dict in local_state_dict:
                optimizer_result.append(
                    _get_optimizer_result(partial_local_state_dict))

    for param_name, param in model.local_state_dict().items():
        if isinstance(param, torch.Tensor):
            model_result["tensors"][param_name] = torch.sum(param)
        else:
            model_result["scalars"][param_name] = param

    if smp.rdp_rank() == 0:
        results["optimizer"] = optimizer_result
        results["model"] = model_result
        smp.save(results, filename)
def memory_status_cpu(msg=""):
    import gc
    global last_mem_usage
    global base_mem_usage
    rdp_rank = smp.rdp_rank()
    gc.collect()
    gc.collect()
    gc.collect()
    objects = gc.get_objects()
    tensors = [
        obj for obj in objects
        if isinstance(obj, torch.Tensor) and not obj.is_cuda
    ]
    torch_usage = 0
    for t in tensors:
        torch_usage += t.numel() * dtype_to_bit[t.dtype]
    #total_usage = psutil.virtual_memory()[3] # This will get the total usage for all processes
    current_usage = process.memory_info().data
    total_usage = current_usage - base_mem_usage
    usage_change = current_usage - last_mem_usage
    last_mem_usage = current_usage

    torch_usage /= 1024**3
    total_usage /= 1024**3
    usage_change /= 1024**3
    base_usage = base_mem_usage / 1024**3

    rank = smp.rank()
    tp_rank = smp.tp_rank()
    pp_rank = smp.pp_rank()
    rdp_rank = smp.rdp_rank()
    local_rank = smp.local_rank()
    if rdp_rank != 0:
        return

    print(
        f'[{msg}] rank {rank} tp_rank {tp_rank} pp_rank {pp_rank} TORCH {torch.__version__}',
        f'device={local_rank} '
        f'torch cpu tensor usage {torch_usage:0.4f} cpu mem usage {total_usage:0.4f} change since last measurement {usage_change:0.4f} base cpu mem usage {base_usage:0.4f}'
    )
def memory_status(msg="", reset_max=True, sync=True):

    rank = smp.rank()
    tp_rank = smp.tp_rank()
    pp_rank = smp.pp_rank()
    rdp_rank = smp.rdp_rank()
    local_rank = smp.local_rank()

    if sync:
        torch.cuda.synchronize()

    if rdp_rank != 0:
        return

    if py3nvml != None:
        py3nvml.nvmlInit()
        handle = py3nvml.nvmlDeviceGetHandleByIndex(local_rank)
        info = py3nvml.nvmlDeviceGetMemoryInfo(handle)
        total_used = info.used / 1024**3
        total_used_str = f"Totally used GPU memory: {total_used}"
    else:
        total_used_str = ""

    alloced = torch.cuda.memory_allocated(device=local_rank)
    max_alloced = torch.cuda.max_memory_allocated(device=local_rank)
    cached = torch.cuda.memory_reserved(device=local_rank)
    max_cached = torch.cuda.max_memory_reserved(device=local_rank)

    # convert to GB for printing
    alloced /= 1024**3
    cached /= 1024**3
    max_alloced /= 1024**3
    max_cached /= 1024**3

    print(
        f'[{msg}] rank {rank} tp_rank {tp_rank} pp_rank {pp_rank} TORCH {torch.__version__}',
        f'device={local_rank} '
        f'alloc {alloced:0.4f} max_alloced {max_alloced:0.4f} '
        f'cache {cached:0.4f} max_cached {max_cached:0.4f} '
        f'{total_used_str}')
    if reset_max:
        torch.cuda.reset_max_memory_cached()
        torch.cuda.reset_max_memory_allocated()
    if py3nvml != None:
        py3nvml.nvmlShutdown()
Esempio n. 6
0
    def hook_fn(model, optimizer):
        optimizer.load_state_dict(opt_state_dict["optimizer_state_dict"])
        if partial:
            if args.shard_optimizer_state:
                assert isinstance(
                    opt_state_dict["fp32_from_fp16"], list
                ), "Loading with shard_optimizer_state=True must use the checkpoint that was trained with shard_optimizer_state=True!"
                optimizer.fp32_from_fp16 = opt_state_dict["fp32_from_fp16"][
                    smp.rdp_rank()]
            else:
                optimizer.fp32_from_fp16 = opt_state_dict["fp32_from_fp16"]

            for current_group, saved_group in zip(
                    optimizer.fp32_from_float16_groups,
                    optimizer.fp32_from_fp16):
                for current, saved in zip(current_group, saved_group):
                    current.data.copy_(saved.data)

        else:
            optimizer.fp32_from_fp16 = opt_state_dict["fp32_from_fp16"]
            param_name_groups = opt_state_dict["param_name_groups"]
            param_id_to_index = optimizer._param_id_to_index()
            param_index_to_name_tp_group = smp_state.param_index_to_name_tp_group
            param_index_to_name = param_index_to_name_tp_group[smp.tp_rank()]
            for group_idx, (current_group, saved_group) in enumerate(
                    zip(optimizer.fp32_from_float16_groups,
                        optimizer.fp32_from_fp16)):
                for current in current_group:
                    param_id = id(current)
                    param_index = param_id_to_index[param_id]
                    param_name = param_index_to_name[param_index]
                    arr_index = param_name_groups[group_idx][param_name]
                    saved = saved_group[arr_index]
                    if optimizer.master_distribution_axis[
                            param_id] is not None:
                        axis = optimizer.master_distribution_axis[param_id]
                        slice_size = saved.size(axis) // smp.tp_size()
                        saved = torch.narrow(saved.data, axis,
                                             slice_size * smp.tp_rank(),
                                             slice_size).contiguous()
                    else:
                        saved = saved.data
                    current.data.copy_(saved)

        optimizer.grad_scaler.load_state_dict(opt_state_dict["grad_scaler"])
def load_and_verify_ckptsum(args, model, optimizer, filename):
    results = smp.load(filename)
    optimizer_result = (
        results["optimizer"]
        if not args.shard_optimizer_state
        else results["optimizer"][smp.rdp_rank()]
    )
    model_result = results["model"]

    def opt_check_fn(mod, opt):
        loaded_opt_states = (
            opt.orig_state_dict()["state"]
            if args.shard_optimizer_state
            else opt.local_state_dict()["state"]
        )
        for param_idx, state in loaded_opt_states.items():
            for key, val in state.items():
                if isinstance(val, torch.Tensor):
                    assert torch.isclose(
                        torch.sum(val), optimizer_result["tensors"][f"{param_idx}_{key}"]
                    ), f"mismatch for param_idx: {param_idx}, key is {key}"
                else:
                    assert (
                        val == optimizer_result["scalars"][f"{param_idx}_{key}"]
                    ), f"mismatch for param_idx: {param_idx}, key is {key}"
        print("Optimizer save/load check passed successfully")

    def model_check_fn(mod, opt):
        for param_name, param in mod.local_state_dict().items():
            if isinstance(param, torch.Tensor):
                assert torch.isclose(
                    torch.sum(param), model_result["tensors"][param_name]
                ), f"mismatch for param_name: {param_name}"
            else:
                assert (
                    param == model_result["scalars"][param_name]
                ), f"mismatch for param_name: {param_name}"
        print("Model save/load check passed successfully")

    model.register_post_partition_hook(model_check_fn)
    model.register_post_step_hook(opt_check_fn)
Esempio n. 8
0
    def train_smp(
        self,
        model,
        optimizer,
        lr_scheduler,
        start_train_path_index,
        start_batch_index,
        num_params,
        total_steps,
        args,
        prescaled_batch,
    ):

        model.train()

        dp_rank = smp.dp_rank() if not prescaled_batch else smp.rdp_rank()
        dp_size = smp.dp_size() if not prescaled_batch else smp.rdp_size()

        start = time.time()
        throughput = None
        to_save = {"loss": [], "val_loss": []}
        loss_metric = 0

        def should_record():

            # only record the ranks that in the tp group that contains global rank 0
            if smp.tp_size() > 1:
                tp_group = smp.get_tp_group()
                return 0 in tp_group
            else:
                return smp.rank() == 0

        # Set the same seed for computation
        set_seed(args.seed)

        sampler = torch.utils.data.DistributedSampler(
            self.train_dataset,
            shuffle=True,
            seed=args.seed,
            rank=dp_rank,
            num_replicas=dp_size,
            drop_last=True,
        )

        train_dataloader = torch.utils.data.DataLoader(
            self.train_dataset,
            sampler=sampler,
            batch_size=args.per_device_train_batch_size,
            collate_fn=self.data_collator,
            num_workers=0,
            pin_memory=True,
            drop_last=True,
        )

        total_steps = 0
        for batch_idx, input_data in enumerate(train_dataloader):

            step_start = time.time()
            optimizer.zero_grad(set_to_none=True)

            input_ids = input_data["input_ids"]
            attention_mask = input_data["attention_mask"]

            loss_mb = self.train_step(model, optimizer, input_ids,
                                      attention_mask, args)

            loss = loss_mb.reduce_mean()

            lr_scheduler.step()

            total_steps += 1

            total_steps += 1
            time_elapsed = time.time() - start
            step_time = time.time() - step_start

            if smp.rank() == 0 and not total_steps % 10:
                print(
                    f"({int(time_elapsed)}s), Batch {total_steps - 1} Loss: {loss.item()}, Speed: {''} samples/sec"
                )
            if total_steps == args.max_steps:
                break
Esempio n. 9
0
def main():
    args = parse_args()

    if args.shard_optimizer_state > 0 and not args.skip_full_optimizer:
        raise ValueError(
            "If shard_optimizer_state is enabled, skip_full_optimizer must also be enabled. Full optimizer saving is currently not supported under optimizer state sharding."
        )

    if args.partition_assignment != "" and args.manual_partition == 0:
        print("[Warning] partition_assignment is set, enable manual_partition")
        args.manual_partition = 1

    # any value here is overriden by the config set in notebook when launching the sagemaker job
    smp_config = {
        "ddp": True,
        "tensor_parallel_degree": args.tensor_parallel_degree,
        "pipeline_parallel_degree": args.pipeline_parallel_degree,
        "microbatches": args.microbatches,
        # if activation_checkpointing true checkpoints transformer layers below
        "checkpoint_attentions":
        False if args.activation_checkpointing else True,
        "shard_optimizer_state": args.shard_optimizer_state > 0,
        "prescaled_batch": args.prescaled_batch > 0,
        "offload_activations": args.offload_activations > 0,
        "optimize": args.optimize,
        "auto_partition": False if args.manual_partition else True,
        "default_partition": 0,
        "static_mode": args.static_mode > 0,
        "fast_mode": args.fast_mode > 0,
    }

    if args.smp_version < 110:
        smp_config["fp16_params"] = args.fp16 > 0
    else:
        smp_config["fp16"] = args.fp16 > 0
        smp_config["delayed_parameter_initialization"] = args.delayed_param > 0
        smp_config["placement_strategy"] = args.placement_strategy
        smp_config[
            "activation_loading_horizon"] = args.activation_loading_horizon
        smp_config["skip_tracing"] = args.skip_tracing > 0

    if args.active_microbatches is not None:
        smp_config["active_microbatches"] = args.active_microbatches

    smp.init(smp_config)

    if smp.rank() == 0:
        print("Arguments:", args.__dict__)
        print(f"Transformers version: {transformers.__version__}")
        print(
            f"smdistributed.modelparallel version: {smdistributed.modelparallel.__version__}"
        )
        print(f"smdistributed config: {smp_config}")

    if args.save_final_full_model and smp.rank() == 0:
        print(
            f"[Warning] Note that save_final_full_model only saves the final model at the end of all steps. It does not save optimizer state. Optimizer state is only saved with partial models which are saved at checkpointing_freq during training. If you want to restart training you need partial checkpoints."
        )

    if args.partition_assignment != "":
        partition_assignment = args.partition_assignment.split(",")
        assert (
            len(partition_assignment) == smp.pp_size()
        ), f"partition_assignment must have the same size as pipeline parallel degree, but getting {len(partition_assignment)} vs {smp.pp_size()}"

    if smp.rank() == 0 or (smp.local_rank() == 0 and args.use_fsx == 0):
        for path in [args.model_dir, args.checkpoint_dir]:
            if not os.path.exists(path):
                os.makedirs(path, exist_ok=True)

    model_config = GPT2Config(
        vocab_size=args.vocab_size,
        n_positions=args.max_context_width,
        n_embd=args.hidden_width,
        n_layer=args.num_layers,
        n_head=args.num_heads,
        n_inner=None,
        activation_function="gelu_new",
        resid_pdrop=args.resid_pdrop,
        embd_pdrop=args.embd_pdrop,
        attn_pdrop=args.attn_pdrop,
        layer_norm_epsilon=1e-05,
        initializer_range=0.02,
        summary_type="cls_index",
        summary_use_proj=True,
        summary_activation=None,
        summary_proj_to_labels=True,
        summary_first_dropout=args.summary_first_pdrop,
        # gradient_checkpointing=args.gradient_checkpointing > 0,
        use_cache=False,
        bos_token_id=50256,
        eos_token_id=50256,
        return_dict=True,
    )

    # the following improves start-up time by skipping proper initialization
    # of weights in the original model. this is not a problem because DistributedModel
    # will override those weights anyway when tensor_parallel_degree > 1.
    if smp.tp_size() > 1:
        from transformers.modeling_utils import PreTrainedModel

        PreTrainedModel.init_weights = lambda x: None

    set_seed(args.seed)

    if args.enable_memory_profiling > 0:
        memory_status_cpu(msg="before model creation")

    if args.smp_version < 110:
        if args.fp16:
            torch.set_default_dtype(torch.float16)
        with smp.tensor_parallelism(
                enabled=smp.tp_size() > 1,
                attention_in_fp32=args.attention_in_fp32 > 0):
            with smp.delay_param_initialization(
                    enabled=(smp.tp_size() > 1 and args.delayed_param > 0)):
                model = AutoModelForCausalLM.from_config(model_config)
    else:
        with smp.model_creation(
                tensor_parallelism=smp.tp_size() > 1,
                attention_in_fp32=args.attention_in_fp32 > 0,
                query_key_layer_scaling=args.query_key_layer_scaling > 0,
                fused_softmax=args.fused_softmax > 0,
                fused_bias_gelu=args.fused_bias_gelu > 0,
                dtype=torch.float16
                if args.fp16 else torch.get_default_dtype(),
        ):
            model = AutoModelForCausalLM.from_config(model_config)

    if args.smp_version < 110 and args.fp16:
        model = FP16_Module(model)

    if args.enable_memory_profiling > 0:
        memory_status_cpu(msg="after model creation")

    num_params = sum([np.prod(p.size()) for p in model.parameters()])
    if smp.rank() == 0:
        print(f"# total parameters: {num_params}")

    # smdistributed: Set the device to the GPU ID used by the current process.
    # Input tensors should be transferred to this device.
    torch.cuda.set_device(smp.local_rank())
    device = torch.device("cuda")

    if not args.same_seed:
        # Set seed by tp_rank to prevent weights from being the same on different tp_ranks
        set_seed(args.seed + smp.tp_rank())

    # smdistributed: Use the DistributedModel container to provide the model
    # to be partitioned across different ranks. For the rest of the script,
    # the returned DistributedModel object should be used in place of
    # the model provided for DistributedModel class instantiation.
    if args.smp_version < 110 and args.fp16:
        torch.set_default_dtype(torch.float16)
    if args.enable_memory_profiling > 0:
        memory_status_cpu(msg="before dist model creation")
    model = smp.DistributedModel(model, trace_device="gpu")
    if args.enable_memory_profiling > 0:
        memory_status_cpu(msg="after dist model creation")

    if args.smp_version < 110:
        if smp.tp_size() > 1:
            transformer_layers = model.module.module.module.transformer.seq_layers
        else:
            transformer_layers = model.module.module.module.transformer.h
    else:
        m = model.get_module()
        if smp.tp_size() > 1:
            transformer_layers = m.transformer.seq_layers
        else:
            transformer_layers = m.transformer.h

    if args.manual_partition:
        print(f"Manual partition enabled")
        if args.partition_assignment != "":
            get_num_layers = lambda x: int(partition_assignment[x])
            total_layers = sum(
                [get_num_layers(pp_rank) for pp_rank in range(smp.pp_size())])
            assert (
                total_layers == args.num_layers
            ), f"partition_assignment must have the same total transformer layers as model, but getting {total_layers} vs {args.num_layers}"
        else:
            # evenly distribute layers across all partitions
            div, rem = divmod(args.num_layers, smp.pp_size())
            get_num_layers = lambda x: (div + 1
                                        if x >= smp.pp_size() - rem else div)
        assignments = []
        # (TODO) This is required for 175B otherwise a hang for partition "8,17,17,18,18,18"
        # Need further investigation
        # for pp_rank in reversed(range(smp.pp_size())):
        for pp_rank in range(smp.pp_size()):
            nl = get_num_layers(pp_rank)
            print(f"{nl} layers assigned to partition {pp_rank}")
            assignments += [pp_rank for _ in range(nl)]

        for i, c in enumerate(transformer_layers.children()):
            smp.set_partition(c, assignments[i])
    if args.smp_version < 110:
        iter_model = model
        # Build parameter groups (weight decay and non-decay).
        while isinstance(iter_model, (DistributedDataParallel, FP16_Module)):
            iter_model = iter_model.module
    else:
        iter_model = m
    param_groups = get_param_groups_by_weight_decay(iter_model)

    if args.use_adamw > 0:
        optimizer = optim.AdamW(param_groups,
                                betas=(args.beta1, args.beta2),
                                lr=args.lr,
                                weight_decay=args.weight_decay)
    else:
        optimizer = optim.Adam(param_groups,
                               betas=(args.beta1, args.beta2),
                               lr=args.lr,
                               weight_decay=args.weight_decay)

    if args.activation_checkpointing:
        kwargs = {}
        if isinstance(transformer_layers, nn.Sequential):
            kwargs["pack_args_as_tuple"] = True
            kwargs["strategy"] = args.activation_strategy
        smp.set_activation_checkpointing(transformer_layers, **kwargs)

    if args.smp_version < 110:
        optimizer = FP16_Optimizer(
            model,
            optimizer,
            static_loss_scale=None,
            dynamic_loss_scale=True,
            use_smp=True,
            dynamic_loss_args={
                "scale_window": 1000,
                "min_scale": 1,
                "delayed_shift": 2
            },
            params_have_main_grad=False,
            shard_optimizer_state=args.shard_optimizer_state > 0,
        )

        optimizer = smp.DistributedOptimizer(optimizer)
        model.register_post_step_hook(
            lambda model, optimizer: optimizer.init_master_params())
    else:
        optimizer = smp.DistributedOptimizer(
            optimizer,
            static_loss_scale=None,
            dynamic_loss_scale=True,
            dynamic_loss_args={
                "scale_window": 1000,
                "min_scale": 1,
                "delayed_shift": 2
            },
        )
    lr_scheduler = get_learning_rate_scheduler(optimizer, args)

    if args.enable_memory_profiling > 0:
        model.register_post_partition_hook(
            lambda model, optimizer: memory_status(msg="After_partition"))

    # load after wrapping model and optimizer with smp Distributed...
    if args.load_full or args.load_partial:
        if args.load_partial and args.load_full:
            print(
                "Since both --load_partial and --load_full set, will try to load from full checkpoint."
                "If the intention is to load from partial checkpoint, please don't set --load_full"
            )
        partial = not args.load_full
        path = args.checkpoint_dir if partial else args.model_dir
        translate_from_hf = not partial
        model, optimizer, total_steps, start_train_path_index, start_batch_index = load_model_and_optimizer(
            path,
            model,
            optimizer,
            lr_scheduler,
            partial,
            args,
            translate_from_hf=translate_from_hf,
            seq_length=args.max_context_width,
            load_model=True,
            load_optimizer=args.load_partial > 0,
            num_params=num_params,
        )
    else:
        total_steps = 0
        start_train_path_index = 0
        start_batch_index = 0

    start = time.time()
    total_steps, throughput, loss = train(
        model,
        optimizer,
        lr_scheduler,
        model_config,
        start_train_path_index,
        start_batch_index,
        num_params,
        total_steps,
        args,
    )
    time_to_train = time.time() - start
    if args.ci:
        print(f"[SMP_METRIC]__GPT2__Time_to_train__{time_to_train}")
        print(f"[SMP_METRIC]__GPT2__samples/second__{throughput}")
        print(f"[SMP_METRIC]__GPT2__Loss__{loss}")
        if not args.load_partial and not args.load_full:
            assert time_to_train < args.time_to_train
            assert throughput > args.throughput
            if args.loss:
                assert loss < args.loss

    if args.save_final_full_model:
        # saves full model at the end

        base_path = f"trained_gpt_nparams-{num_params}_steps-{total_steps}.pt"
        out_path = os.path.join(args.model_dir, base_path)

        if smp.rdp_rank() == 0:
            save(
                out_path,
                model,
                optimizer,
                lr_scheduler,
                model_config,
                num_params,
                total_steps,
                -1,
                args,
                partial=False,
                translate_to_hf=smp.tp_size() > 1,
                seq_length=args.max_context_width,
            )

    smp.barrier()
    if smp.rank() == 0:
        print("SMP training finished successfully")
Esempio n. 10
0
def train(
    model,
    optimizer,
    lr_scheduler,
    model_config,
    start_train_path_index,
    start_batch_index,
    num_params,
    total_steps,
    args,
):
    if args.enable_memory_profiling > 0:
        memory_status_cpu(msg="before train step")
    model.train()
    if args.parallel_proc_data_processing:
        pool = ProcessPoolExecutor(1)

    dp_rank = smp.dp_rank() if not args.prescaled_batch else smp.rdp_rank()
    dp_size = smp.dp_size() if not args.prescaled_batch else smp.rdp_size()
    data_type = "BERT" if args.use_bert_data else "GPT"

    if args.use_bert_data:
        train_paths = sorted([
            os.path.join(args.training_dir, p)
            for p in os.listdir(args.training_dir)
            if os.path.isfile(os.path.join(args.training_dir, p))
            and "training" in p
        ])
    else:
        if args.zipped_data > 0:
            file_extension = ".json.gz"
        else:
            file_extension = ".json"
        train_paths = sorted([
            os.path.join(args.training_dir, p)
            for p in os.listdir(args.training_dir)
            if p.endswith(file_extension)
        ])

    train_dataloader = create_pretraining_dataloader(
        [train_paths[start_train_path_index]],
        args.train_batch_size,
        args.max_context_width,
        seed=args.seed,
        dp_rank=dp_rank,
        dp_size=dp_size,
        shuffle=args.same_seed < 1,
        zipped=args.zipped_data > 0,
        use_last_file_only=args.fast_validation > 0,
        data_type=data_type,
    )

    if args.validation_freq is not None:
        # load all validation examples
        if smp.rank() == 0:
            print("Creating val dataloader")
        if args.use_bert_data:
            val_paths = sorted([
                os.path.join(args.test_dir, p)
                for p in os.listdir(args.test_dir)
                if os.path.isfile(os.path.join(args.test_dir, p))
                and "testing" in p
            ])

        else:
            if args.zipped_data > 0:
                file_extension = ".json.gz"
            else:
                file_extension = ".json"
            val_paths = sorted([
                os.path.join(args.test_dir, p)
                for p in os.listdir(args.test_dir)
                if p.endswith(file_extension)
            ])
        val_dataloader = create_pretraining_dataloader(
            val_paths,
            args.val_batch_size,
            args.max_context_width,
            seed=args.seed,
            dp_rank=dp_rank,
            dp_size=dp_size,
            shuffle=True,
            zipped=args.zipped_data > 0,
            use_last_file_only=args.fast_validation > 0,
            data_type=data_type,
        )
        if smp.rank() == 0:
            print("Created val dataloader")

    start = time.time()
    throughput = None
    to_save = {"loss": [], "val_loss": []}
    loss_metric = 0

    def should_record():
        # only record the ranks that in the tp group that contains global rank 0
        if smp.tp_size() > 1:
            tp_group = smp.get_tp_group()
            return 0 in tp_group
        else:
            return smp.rank() == 0

    # Set the same seed for computation
    set_seed(args.seed)

    for index in range(start_train_path_index, args.epochs * len(train_paths)):
        next_train_path_index = (index + 1) % len(train_paths)
        curr_train_path_index = index % len(train_paths)

        if total_steps >= args.max_steps:
            break

        if args.parallel_proc_data_processing:
            dataset_future = pool.submit(
                create_pretraining_dataloader,
                [train_paths[next_train_path_index]],
                args.train_batch_size,
                args.max_context_width,
                seed=args.seed,
                dp_rank=dp_rank,
                dp_size=dp_size,
                shuffle=args.same_seed < 1,
                zipped=args.zipped_data > 0,
                use_last_file_only=args.fast_validation > 0,
                data_type=data_type,
            )

        if smp.rank() == 0:
            if args.use_bert_data:
                print(
                    f"Reading data from training path {train_dataloader.dataset.input_file}"
                )
            else:
                print(
                    f"Reading data from training path {train_dataloader.dataset.input_paths}"
                )

        for batch_idx, input_data in enumerate(train_dataloader):
            if batch_idx < start_batch_index:
                if smp.rank() == 0:
                    print(
                        f"Resuming from saved batch index {start_batch_index}, skipping batch {batch_idx}..."
                    )
                if start_batch_index == len(train_dataloader):
                    # If saving at the last batch of the file, read from the next file
                    start_batch_index = 0
                    break
                continue
            else:
                start_batch_index = 0

            if args.use_bert_data:
                input_ids, _, attention_mask, _, _ = input_data
            else:
                input_ids, attention_mask = input_data

            if total_steps >= args.max_steps:
                break

            step_start = time.time()

            if args.smp_version < 110:
                optimizer.zero_grad(set_grads_to_None=True)
            else:
                optimizer.zero_grad(set_to_none=True)

            if args.logits_output:
                train_output = train_step(model, optimizer, input_ids,
                                          attention_mask, args)
                loss_mb = train_output["loss"]
                logits_mb = train_output["logits"]
                if smp.tp_size() > 1:
                    logits = torch.cat(tuple(logits_mb.outputs), dim=1)
                else:
                    logits = torch.cat(tuple(logits_mb.outputs), dim=0)
            else:
                # Return value, loss_mb is a StepOutput object
                loss_mb = train_step(model, optimizer, input_ids,
                                     attention_mask, args)

            # smdistributed: Average the loss across microbatches.
            loss = loss_mb.reduce_mean()
            if not args.validation_freq:
                loss_metric = loss.item()

            if args.enable_memory_profiling > 0:
                memory_status_cpu("After_train_step_cpu")
                memory_status(msg="After_train_step")

            if args.clean_cache > 0:
                # empty the cache to avoid OOM
                torch.cuda.empty_cache()

            if args.fp16:
                if args.smp_version < 110:
                    optimizer.update_master_grads()
                optimizer.clip_master_grads(args.grad_clip)

            optimizer.step()
            if not (args.fp16 and optimizer.overflow):
                lr_scheduler.step()

            if args.enable_memory_profiling > 0:
                memory_status(msg="After_opt_step")

            total_steps += 1
            time_elapsed = time.time() - start
            step_time = time.time() - step_start
            sample_processed = input_ids.shape[0] * dp_size
            throughput = sample_processed / step_time
            if smp.rank() == 0 and not total_steps % args.logging_freq:
                print(
                    f"({int(time_elapsed)}s), Batch {total_steps - 1} Loss: {loss.item()}, Speed: {throughput} samples/sec"
                )

            # evaluate on validation
            if args.validation_freq and not (total_steps %
                                             args.validation_freq):
                cur_state = np.random.get_state()
                model = model.eval()
                val_loss, val_ppl = eval_model(model, val_dataloader,
                                               args.validation_batches,
                                               args.use_bert_data)
                if is_main_process(smp.rank()):
                    print(
                        f"({int(time.time()-start)}s) Batch {total_steps - 1} Validation loss: {val_loss}"
                    )
                    print(
                        f"({int(time.time()-start)}s) Batch {total_steps - 1} Validation perplexity: {val_ppl}"
                    )
                loss_metric = val_loss
                if args.logits_output:
                    to_save["val_loss"].append(val_loss)
                model = model.train()
                if args.preserve_np_state > 0:
                    np.random.set_state(cur_state)

            # checkpoint
            if not (total_steps % args.checkpoint_freq):
                base_path = f"trained_gpt_nparams-{num_params}_steps-{total_steps}.pt"
                out_path = os.path.join(args.checkpoint_dir, base_path)
                total_ckpts = total_steps // args.checkpoint_freq

                delete_oldest_ckpt(args, delete_on_rank0_only=args.use_fsx > 0)

                save(
                    out_path,
                    model,
                    optimizer,
                    lr_scheduler,
                    model_config,
                    num_params,
                    total_steps,
                    curr_train_path_index,
                    args,
                    partial=True,
                    batch_idx=batch_idx + 1,
                )

            if args.logits_output:
                to_save["loss"].append(loss.item())

        if total_steps >= args.max_steps:
            if should_record() and args.logits_output:
                to_save["logits"] = logits.detach().cpu()
                output_file = f"rank_{smp.rank()}_" + args.logits_output
                torch.save(to_save, os.path.join(args.model_dir, output_file))
                print(
                    f"logits and loss saved at {os.path.join(args.model_dir, output_file)}"
                )
            break

        del train_dataloader

        if args.parallel_proc_data_processing:
            s = time.time()
            train_dataloader = dataset_future.result(timeout=None)
            wait_time = time.time() - s
            if wait_time > 1:
                # TODO if this happens, we should try num_workers>1 in dataloader
                print(
                    f"[{smp.rank()}] Waited {wait_time} for data loader to be ready. Please check if dataloader performance can be improved to avoid these waits."
                )
        else:
            train_dataloader = create_pretraining_dataloader(
                [train_paths[next_train_path_index]],
                args.train_batch_size,
                args.max_context_width,
                seed=args.seed,
                dp_rank=dp_rank,
                dp_size=dp_size,
                shuffle=args.same_seed < 1,
                zipped=args.zipped_data > 0,
                use_last_file_only=args.fast_validation > 0,
                data_type=data_type,
            )

    return total_steps, throughput, loss_metric
Esempio n. 11
0
def load_model_and_optimizer(
    output_dir,
    model,
    optimizer,
    lr_scheduler,
    partial,
    args,
    translate_from_hf=False,
    seq_length=1024,
    load_model=True,
    load_optimizer=True,
    num_params=0,
):
    # Find longest-trained checkpoint
    re_pattern = f"trained_gpt_nparams-{num_params}_steps-(?P<total_steps>\d+)\.pt"
    if partial:
        re_pattern += "_(?P<rank>\d+)"
    else:
        re_pattern += "$"

    ckpt_paths = sorted(
        [(int(re.match(re_pattern,
                       p).group("total_steps")), os.path.join(output_dir, p))
         for p in os.listdir(output_dir) if re.match(re_pattern, p)],
        reverse=True,
    )
    if not ckpt_paths:
        raise Exception(
            f'No checkpoints could be found in "{output_dir}".  Candidates: {os.listdir(output_dir)}'
        )

    local_ckpt_path = ckpt_paths[0][1]

    if partial:
        # need to pass prefix without ranks to smp
        local_ckpt_path = local_ckpt_path.split(".pt")[0] + ".pt"

    if args.gather_if_shard > 0:
        # Should expect v2 checkpoint here
        checkpoint = smp.load(local_ckpt_path, partial=partial)
    else:
        # Loading separately for model and opt
        checkpoint = torch.load(
            f"{local_ckpt_path}_{smp.pp_rank()}_{smp.tp_rank()}_0")
        if smp.rdp_rank() != 0:
            opt_checkpoint = torch.load(
                f"{local_ckpt_path}_{smp.pp_rank()}_{smp.tp_rank()}_{smp.rdp_rank()}"
            )

    if load_model:
        checkpointed_model = (translate_hf_state_dict_to_smdistributed(
            checkpoint["model"], seq_length)
                              if translate_from_hf else checkpoint["model"])
        model.load_state_dict(checkpointed_model,
                              same_partition_load=args.same_partition_load > 0)
        if lr_scheduler is not None:
            lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])

    if load_optimizer:

        def opt_load_hook(mod, opt):
            load_fn = load_fp16_optimizer

        checkpoint = (checkpoint if args.gather_if_shard > 0
                      or smp.rdp_rank() == 0 else opt_checkpoint)
        if args.smp_version < 110:
            if not partial and args.skip_full_optimizer:
                print(
                    "Skipping loading the final optimizer state, and reloading master_params from model_params"
                )
                opt.reload_model_params()
            else:
                load_fn(args, mod, opt, checkpoint, partial=partial)
            model.register_post_step_hook(opt_load_hook)
        elif not partial and args.skip_full_optimizer:
            print(
                "Skipping loading the final optimizer state, and reloading master_params from model_params for fp16"
            )
            if args.fp16:
                model.register_post_step_hook(opt.reload_model_params)
        else:
            optimizer.load_optimizer_backcompat(checkpoint["optimizer"],
                                                args.gather_if_shard)

    print(f'Loaded model from "{local_ckpt_path}"')

    batch_idx = 0
    if "batch_idx" in checkpoint:
        batch_idx = checkpoint["batch_idx"]

    return (
        model,
        optimizer,
        checkpoint["total_steps"],
        checkpoint["curr_train_path_index"],
        batch_idx,
    )