コード例 #1
0
def save(
    output_save_file,
    model,
    optimizer,
    lr_scheduler,
    model_config,
    num_params,
    total_steps,
    curr_train_path_index,
    args,
    partial=True,
    translate_to_hf=False,
    seq_length=1024,
    batch_idx=0,
):
    save_fn = save_fp16_optimizer_megatron if args.megatron else save_fp16_optimizer
    save_dict = {
        "cli_args": args.__dict__,
        "num_params": num_params,
        "total_steps": total_steps,
        "curr_train_path_index": curr_train_path_index,
        "model_config": model_config,
        "batch_idx": batch_idx,
    }

    if lr_scheduler is not None:
        save_dict["lr_scheduler"] = lr_scheduler.state_dict()
    if partial:
        save_dict["model"] = model.local_state_dict()
    else:
        model_state_dict = model.state_dict(gather_to_rank0=True)
        if smp.rank() == 0:
            save_dict["model"] = (
                translate_state_dict_to_hf_gpt2(model_state_dict, seq_length)
                if translate_to_hf
                else model_state_dict
            )

    if args.fp16:
        if not partial and args.skip_full_optimizer:
            print("Skipping saving the final optimizer state")
        else:
            if args.shard_optimizer_state == 0 or partial:
                save_dict["optimizer"] = save_fn(args, model, optimizer, partial=partial)
            else:
                print("Saving the full optimizer state does not work with shard_optimizer_state > 0! Skipping...")
    else:
        # fp32
        if partial:
            save_dict["optimizer"] = optimizer.local_state_dict()
        else:
            if not args.skip_full_optimizer:
                save_dict["optimizer"] = optimizer.state_dict()
            else:
                print("Skipping saving of full optimizer state")

    if (smp.rdp_rank() == 0 and partial) or smp.rank() == 0:
        smp.save(save_dict, output_save_file, partial=partial)

    print(f"Finished checkpointing after {total_steps} steps: {output_save_file}")
コード例 #2
0
    def _check_and_set(self, cls_value, sd_value, name):
        """Auxiliary function for checking the values in the checkpoint and
        setting them."""
        if self.override_lr_scheduler:
            if smp.rank() == 0:
                print('Overriding {} value to {}'.format(name, cls_value))
            return cls_value

        if not self.use_checkpoint_lr_scheduler:
            assert cls_value == sd_value, 'AnnealingLR: class input value' \
                'and checkpoint values for {} do not match'.format(name)
        if smp.rank() == 0:
            print(' > using checkpoint value {} for {}'.format(sd_value, name))
        return sd_value
コード例 #3
0
def dist_setting(args):
    #     args.data_parallel = False

    print("args.data_parallel : {}".format(args.data_parallel))
    print("args.model_parallel : {}".format(args.model_parallel))
    print("args.apex : {}".format(args.apex))

    args.world_size = 1
    args.host_num = args.hosts.index(args.current_host)

    if args.data_parallel:
        args.world_size = sdp.get_world_size()
        args.rank = sdp.get_rank()  # total rank in all hosts
        args.local_rank = sdp.get_local_rank()  # rank per host
    elif args.model_parallel:
        args.world_size = smp.size()
        args.local_rank = smp.local_rank()  # rank per host
        args.rank = smp.rank()
        args.dp_size = smp.dp_size()
        args.dp_rank = smp.dp_rank()
        print(
            "smp.rank() : {}, smp.size() : {}, smp.mp_rank() : {}, smp.local_size() : {}, smp.get_mp_group() : {}, smp.get_dp_group() : {}, smp.local_rank() : {}, smp.dp_size() : {}, smp.dp_rank() : {}"
            .format(smp.rank(), smp.size(), smp.mp_rank(), smp.local_size(),
                    smp.get_mp_group(), smp.get_dp_group(), smp.local_rank(),
                    smp.dp_size(), smp.dp_rank()))
    else:
        args.world_size = len(args.hosts) * args.num_gpus
        if args.local_rank is not None:
            args.rank = args.num_gpus * args.host_num + \
                args.local_rank  # total rank in all hosts

        dist.init_process_group(backend=args.backend,
                                rank=args.rank,
                                world_size=args.world_size)
        logger.info(
            'Initialized the distributed environment: \'{}\' backend on {} nodes. '
            .format(args.backend, dist.get_world_size()) +
            'Current host rank is {}. Number of gpus: {}'.format(
                dist.get_rank(), args.num_gpus))

    print("**** [dist_setting] args.rank : {}".format(args.rank))
    print("args.world_size : {}".format(args.world_size))
    print("Use GPU: {} for training".format(args.local_rank))

    args.lr = args.lr * float(args.world_size)

    args.batch_size //= args.world_size // args.num_gpus
    args.batch_size = max(args.batch_size, 1)

    return args
コード例 #4
0
def train(args, model, scaler, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        # SM Distributed: Move input tensors to the GPU ID used by the current process,
        # based on the set_device call.
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        # Return value, loss_mb is a StepOutput object
        _, loss_mb = train_step(args, model, scaler, data, target)

        # SM Distributed: Average the loss across microbatches.
        loss = loss_mb.reduce_mean()

        if args.amp:
            scaler.step(optimizer)
            scaler.update()
        else:
            optimizer.step()

        if smp.rank() == 0 and batch_idx % args.log_interval == 0:
            print("Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                epoch,
                batch_idx * len(data),
                len(train_loader.dataset),
                100.0 * batch_idx / len(train_loader),
                loss.item(),
            ))
            if args.dry_run:
                break
        if args.num_batches and batch_idx + 1 == args.num_batches:
            break
コード例 #5
0
def initialize_smp(smp_args, training_args):
    smp_config = {
        "ddp": smp_args.ddp,
        "pipeline_parallel_degree": smp_args.pipeline_parallel_degree,
        "microbatches": smp_args.microbatches,
        "shard_optimizer_state": smp_args.shard_optimizer_state > 0,
        "prescaled_batch": smp_args.prescaled_batch > 0,
        "_match_weights": smp_args.match_weights > 0,
        "offload_activations": smp_args.offload_activations > 0,
        "optimize": smp_args.optimize,
        "auto_partition": True,
        "default_partition": 0,
        "static_mode": smp_args.static_mode > 0,
        "fast_mode": smp_args.fast_mode > 0,
    }

    if smp_args.active_microbatches is not None:
        smp_config["active_microbatches"] = smp_args.active_microbatches

    smp.init(smp_config)

    if smp.rank() == 0:
        print("Arguments:", smp_args.__dict__)
        print(f"Transformers version: {transformers.__version__}")
        print(
            f"smdistributed.modelparallel version: {smdistributed.modelparallel.__version__}"
        )
        print(f"smdistributed config: {smp_config}")

    set_seed(training_args.seed)
コード例 #6
0
 def should_record():
     # only record the ranks that in the tp group that contains global rank 0
     if smp.tp_size() > 1:
         tp_group = smp.get_tp_group()
         return 0 in tp_group
     else:
         return smp.rank() == 0
def setup_training(args):

    assert torch.cuda.is_available()

    if args.smp > 0:
        # Initialize SMP. The configuration is obtained from the parameters passed to
        # the Sagemaker PyTorch estimator.
        smp.init()

    # SMP: Set the device to the GPU ID used by the current process.
    # Input tensors should be transferred to this device.
    torch.cuda.set_device(smp.local_rank())
    device = torch.device("cuda", smp.local_rank())
    args.n_gpu = 1

    # if args.local_rank == -1:
    #    device = torch.device("cuda")
    #    args.n_gpu = torch.cuda.device_count()
    #    args.allreduce_post_accumulation = False
    #    args.allreduce_post_accumulation_fp16 = False
    # else:
    #    torch.cuda.set_device(args.local_rank)
    #    device = torch.device("cuda", args.local_rank)
    #    # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
    #    torch.distributed.init_process_group(backend='nccl', init_method='env://')
    #    args.n_gpu = 1

    if args.gradient_accumulation_steps == 1:
        args.allreduce_post_accumulation = False
        args.allreduce_post_accumulation_fp16 = False

    print(
        "device: {} n_gpu: {}, mp_rank: {}, rank: {}, distributed training: {}, 16-bits training: {}"
        .format(device, args.n_gpu, smp.mp_rank(), smp.rank(),
                bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))
    if args.train_batch_size % args.gradient_accumulation_steps != 0:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, batch size {} should be divisible"
            .format(args.gradient_accumulation_steps, args.train_batch_size))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    if (not args.resume_from_checkpoint and os.path.exists(args.output_dir) and
        (os.listdir(args.output_dir)
         and any([i.startswith("ckpt")
                  for i in os.listdir(args.output_dir)]))):
        raise ValueError(
            "Output directory ({}) already exists and is not empty.".format(
                args.output_dir))

    if (not args.resume_from_checkpoint
            or not os.path.exists(args.output_dir)) and is_main_process():
        os.makedirs(args.output_dir, exist_ok=True)

    return device, args
コード例 #8
0
def delete_oldest_ckpt(args, delete_on_rank0_only=False):
    to_delete = smp.rank() == 0 if delete_on_rank0_only else smp.local_rank(
    ) == 0
    if to_delete:
        re_pattern = "trained_gpt_nparams-(?P<num_params>\d+)_steps-(?P<total_steps>\d+)\.pt"

        # partial
        re_pattern += "_(?P<pp_rank>\d+)_(?P<tp_rank>\d+)"

        paths_per_step = collections.defaultdict(list)

        for p in os.listdir(args.checkpoint_dir):
            if re.match(re_pattern, p):
                step = int(re.match(re_pattern, p).group("total_steps"))
                path = os.path.join(args.checkpoint_dir, p)
                paths_per_step[step].append(path)

        if paths_per_step:
            oldest_step = sorted(paths_per_step.keys())[0]
            num_parts = len(paths_per_step[oldest_step])
            if len(paths_per_step) >= args.num_kept_checkpoints:
                # delete oldest step to save the new one
                for p in paths_per_step[oldest_step]:
                    os.remove(p)
        # else We still haven't reached maximum number of checkpoints -- no need to delete older ones
    return None
コード例 #9
0
    def __init__(self, optimizer, start_lr,
                 warmup_iter, plateau_iter, total_iters,
                 decay_style, last_iter, min_lr=0.0,
                 use_checkpoint_lr_scheduler=True,
                 override_lr_scheduler=False):

        # Class values.
        self.optimizer = optimizer
        self.start_lr = start_lr
        self.min_lr = min_lr
        self.warmup_iter = warmup_iter
        self.plateau_iter = plateau_iter
        self.num_iters = last_iter
        self.end_iter = total_iters
        assert self.end_iter > 0
        self.decay_style = decay_style
        self.override_lr_scheduler = override_lr_scheduler
        self.use_checkpoint_lr_scheduler = use_checkpoint_lr_scheduler
        if self.override_lr_scheduler:
            assert not self.use_checkpoint_lr_scheduler, 'both override and '\
                'use-checkpoint are set.'
        # Set the learning rate
        self.step(self.num_iters)

        if smp.rank() == 0:
            print('Learning rate decay style: {}'.format(self.decay_style))
コード例 #10
0
def get_rank():
    if not dist.is_available():
        return 0
    if not dist.is_initialized():
        return 0
    if smp.is_initialized():
        return smp.rank()
    return dist.get_rank()
コード例 #11
0
 def is_world_process_zero(self) -> bool:
     """
     Whether or not this process is the global main process (when training in a distributed fashion on several
     machines, this is only going to be :obj:`True` for one process).
     """
     if self.is_model_parallel_enabled:
         return smp.rank() == 0 and smp.local_rank() == 0 and smp.mp_rank() == 0 and smp.dp_rank() == 0
     else:
         return super().is_world_process_zero()
コード例 #12
0
 def should_log(self):
     """
     Whether or not the current process should produce log.
     """
     if self.log_on_each_node:
         return self.local_process_index == 0
     else:
         if is_sagemaker_mp_enabled():
             return smp.rank() == 0
         else:
             return self.process_index == 0
コード例 #13
0
def print_num_parameters(model):
    seen = set()
    num_params = 0
    for p in model.parameters():
        if p not in seen:
            seen.add(p)
            num_params += np.prod(p.size())

    if smp.rank() == 0:
        print(f"# total parameters: {num_params}")

    return num_params
コード例 #14
0
def smp_init(args):
    args.deepspeed = False
    cfg = {
        "microbatches": args.num_microbatches,
        "placement_strategy": args.placement_strategy,
        "pipeline": args.pipeline,
        "optimize": args.optimize,
        "partitions": args.num_partitions,
        "horovod": args.horovod,
        "ddp": args.ddp,
    }

    smp.init(cfg)

    args.rank = smp.dp_rank()
    args.global_rank = smp.rank()
    args.world_size = smp.size()
    
    
    os.environ['RANK'] = str(args.rank)
    os.environ['WORLD_SIZE'] = str(args.world_size)
    os.environ['LOCAL_RANK'] = str(smp.local_rank())
    
#     ## SMP_SKIP_GRAPH_VALIDATION=1
    os.environ['SMP_SKIP_GRAPH_VALIDATION'] = "0"
    
#     args.bpe_path = "/opt/ml/code/dalle_pytorch/data/bpe_simple_vocab_16e6.txt"
    
    torch.cuda.set_device(smp.local_rank())
    args.local_rank = smp.local_rank()
    
#     if args.seed is not None:
#         random.seed(args.seed)
#         torch.manual_seed(args.seed+args.rank)
#         np.random.seed(args.seed)
#         torch.cuda.manual_seed_all(args.seed)
        
#     cudnn.deterministic = True

#     if cudnn.deterministic:
#         warnings.warn('You have chosen to seed training. '
#                       'This will turn on the CUDNN deterministic setting, '
#                       'which can slow down your training considerably! '
#                       'You may see unexpected behavior when restarting '
#                       'from checkpoints.')
    return args
コード例 #15
0
def memory_status(msg="", reset_max=True, sync=True):

    rank = smp.rank()
    tp_rank = smp.tp_rank()
    pp_rank = smp.pp_rank()
    rdp_rank = smp.rdp_rank()
    local_rank = smp.local_rank()

    if sync:
        torch.cuda.synchronize()

    if rdp_rank != 0:
        return

    if py3nvml != None:
        py3nvml.nvmlInit()
        handle = py3nvml.nvmlDeviceGetHandleByIndex(local_rank)
        info = py3nvml.nvmlDeviceGetMemoryInfo(handle)
        total_used = info.used / 1024**3
        total_used_str = f"Totally used GPU memory: {total_used}"
    else:
        total_used_str = ""

    alloced = torch.cuda.memory_allocated(device=local_rank)
    max_alloced = torch.cuda.max_memory_allocated(device=local_rank)
    cached = torch.cuda.memory_reserved(device=local_rank)
    max_cached = torch.cuda.max_memory_reserved(device=local_rank)

    # convert to GB for printing
    alloced /= 1024**3
    cached /= 1024**3
    max_alloced /= 1024**3
    max_cached /= 1024**3

    print(
        f'[{msg}] rank {rank} tp_rank {tp_rank} pp_rank {pp_rank} TORCH {torch.__version__}',
        f'device={local_rank} '
        f'alloc {alloced:0.4f} max_alloced {max_alloced:0.4f} '
        f'cache {cached:0.4f} max_cached {max_cached:0.4f} '
        f'{total_used_str}')
    if reset_max:
        torch.cuda.reset_max_memory_cached()
        torch.cuda.reset_max_memory_allocated()
    if py3nvml != None:
        py3nvml.nvmlShutdown()
コード例 #16
0
def dist_setting(args):
    #     args.data_parallel = False
    print(f"args.data_parallel : {args.data_parallel}, args.model_parallel : {args.model_parallel}, args.apex : {args.apex}")


    args.world_size = 1
    args.host_num = args.hosts.index(args.current_host)

    if args.data_parallel:
        sdp, DDP = _sdp_import(args)
        
        args.world_size = sdp.get_world_size()
        args.rank = sdp.get_rank()  # total rank in all hosts
        args.local_rank = sdp.get_local_rank()  # rank per host
    elif args.model_parallel:
        args.world_size = smp.size()
        args.world_size = args.num_gpus * len(args.hosts)
        args.local_rank = smp.local_rank()  # rank per host
        args.rank = smp.rank()
        args.dp_size = smp.dp_size()
        args.dp_rank = smp.dp_rank()
    else:
        args.world_size = len(args.hosts) * args.num_gpus
        if args.local_rank is not None:
            args.rank = args.num_gpus * args.host_num + \
                args.local_rank  # total rank in all hosts

        dist.init_process_group(backend=args.backend,
                                rank=args.rank,
                                world_size=args.world_size)
        logger.info(
            'Initialized the distributed environment: \'{}\' backend on {} nodes. '
            .format(args.backend, dist.get_world_size()) +
            'Current host rank is {}. Number of gpus: {}'.format(
                dist.get_rank(), args.num_gpus))
    
#     if not args.model_parallel:
    args.lr = args.lr * float(args.world_size)
    args.batch_size //= args.world_size // args.num_gpus
    args.batch_size = max(args.batch_size, 1)

    return args
コード例 #17
0
def memory_status_cpu(msg=""):
    import gc
    global last_mem_usage
    global base_mem_usage
    rdp_rank = smp.rdp_rank()
    gc.collect()
    gc.collect()
    gc.collect()
    objects = gc.get_objects()
    tensors = [
        obj for obj in objects
        if isinstance(obj, torch.Tensor) and not obj.is_cuda
    ]
    torch_usage = 0
    for t in tensors:
        torch_usage += t.numel() * dtype_to_bit[t.dtype]
    #total_usage = psutil.virtual_memory()[3] # This will get the total usage for all processes
    current_usage = process.memory_info().data
    total_usage = current_usage - base_mem_usage
    usage_change = current_usage - last_mem_usage
    last_mem_usage = current_usage

    torch_usage /= 1024**3
    total_usage /= 1024**3
    usage_change /= 1024**3
    base_usage = base_mem_usage / 1024**3

    rank = smp.rank()
    tp_rank = smp.tp_rank()
    pp_rank = smp.pp_rank()
    rdp_rank = smp.rdp_rank()
    local_rank = smp.local_rank()
    if rdp_rank != 0:
        return

    print(
        f'[{msg}] rank {rank} tp_rank {tp_rank} pp_rank {pp_rank} TORCH {torch.__version__}',
        f'device={local_rank} '
        f'torch cpu tensor usage {torch_usage:0.4f} cpu mem usage {total_usage:0.4f} change since last measurement {usage_change:0.4f} base cpu mem usage {base_usage:0.4f}'
    )
コード例 #18
0
    def train_smp(
        self,
        model,
        optimizer,
        lr_scheduler,
        start_train_path_index,
        start_batch_index,
        num_params,
        total_steps,
        args,
        prescaled_batch,
    ):

        model.train()

        dp_rank = smp.dp_rank() if not prescaled_batch else smp.rdp_rank()
        dp_size = smp.dp_size() if not prescaled_batch else smp.rdp_size()

        start = time.time()
        throughput = None
        to_save = {"loss": [], "val_loss": []}
        loss_metric = 0

        def should_record():

            # only record the ranks that in the tp group that contains global rank 0
            if smp.tp_size() > 1:
                tp_group = smp.get_tp_group()
                return 0 in tp_group
            else:
                return smp.rank() == 0

        # Set the same seed for computation
        set_seed(args.seed)

        sampler = torch.utils.data.DistributedSampler(
            self.train_dataset,
            shuffle=True,
            seed=args.seed,
            rank=dp_rank,
            num_replicas=dp_size,
            drop_last=True,
        )

        train_dataloader = torch.utils.data.DataLoader(
            self.train_dataset,
            sampler=sampler,
            batch_size=args.per_device_train_batch_size,
            collate_fn=self.data_collator,
            num_workers=0,
            pin_memory=True,
            drop_last=True,
        )

        total_steps = 0
        for batch_idx, input_data in enumerate(train_dataloader):

            step_start = time.time()
            optimizer.zero_grad(set_to_none=True)

            input_ids = input_data["input_ids"]
            attention_mask = input_data["attention_mask"]

            loss_mb = self.train_step(model, optimizer, input_ids,
                                      attention_mask, args)

            loss = loss_mb.reduce_mean()

            lr_scheduler.step()

            total_steps += 1

            total_steps += 1
            time_elapsed = time.time() - start
            step_time = time.time() - step_start

            if smp.rank() == 0 and not total_steps % 10:
                print(
                    f"({int(time_elapsed)}s), Batch {total_steps - 1} Loss: {loss.item()}, Speed: {''} samples/sec"
                )
            if total_steps == args.max_steps:
                break
コード例 #19
0
    args.is_distributed = len(args.hosts) > 1 and args.backend is not None
    args.is_multigpus = args.num_gpus > 1
    args.multigpus_distributed = (args.is_distributed or args.is_multigpus)        
    
    logger.debug(f"args.image_folder : {args.image_folder}")
    

    args.world_size = 1
    args.local_rank = 0
    args.rank = 0
    
    if args.model_parallel:
        args.world_size = smp.size()
        args.local_rank = smp.local_rank()  # rank per host
        args.rank = smp.rank()
        args.dp_size = smp.dp_size()
        args.dp_rank = smp.dp_rank()
        logger.debug(f"args.world_size : {args.world_size}, args.local_rank : {args.local_rank}, args.rank : {args.rank}, \
                    args.dp_size : {args.dp_size}, args.dp_rank : {args.dp_rank}")
    else:
        # initialize deepspeed
        print(f"args.deepspeed : {args.deepspeed}")
        deepspeed_utils.init_deepspeed(args.deepspeed)
#     args.LEARNING_RATE = args.LEARNING_RATE * float(args.world_size)


    ## SageMaker
    try:
        if os.environ.get('SM_CHANNEL_TRAINING') is not None:
            args.model_dir = os.environ.get('SM_MODEL_DIR')
コード例 #20
0
def main():
    if not torch.cuda.is_available():
        raise ValueError(
            "The script requires CUDA support, but CUDA not available")
    use_ddp = True
    use_horovod = False

    # Fix seeds in order to get the same losses across runs
    random.seed(1)
    np.random.seed(1)
    torch.manual_seed(1)
    torch.cuda.manual_seed(1)

    smp.init()

    # SM Distributed: Set the device to the GPU ID used by the current process.
    # Input tensors should be transferred to this device.
    torch.cuda.set_device(smp.local_rank())
    device = torch.device("cuda")
    kwargs = {"batch_size": 64}
    kwargs.update({"num_workers": 1, "pin_memory": True, "shuffle": False})

    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.1307, ), (0.3081, ))])

    # SM Distributed: Download only on a single process per instance.
    # When this is not present, the file is corrupted by multiple processes trying
    # to download and extract at the same time
    if smp.local_rank() == 0:
        dataset1 = datasets.MNIST("../data",
                                  train=True,
                                  download=True,
                                  transform=transform)
    smp.barrier()
    dataset1 = datasets.MNIST("../data",
                              train=True,
                              download=False,
                              transform=transform)

    if (use_ddp or use_horovod) and smp.dp_size() > 1:
        partitions_dict = {
            f"{i}": 1 / smp.dp_size()
            for i in range(smp.dp_size())
        }
        dataset1 = SplitDataset(dataset1, partitions=partitions_dict)
        dataset1.select(f"{smp.dp_rank()}")

    # Download and create dataloaders for train and test dataset
    dataset2 = datasets.MNIST("../data", train=False, transform=transform)

    train_loader = torch.utils.data.DataLoader(dataset1, **kwargs)
    test_loader = torch.utils.data.DataLoader(dataset2, **kwargs)

    model = GroupedNet()

    # SMP handles the transfer of parameters to the right device
    # and the user doesn't need to call 'model.to' explicitly.
    # model.to(device)
    optimizer = optim.Adadelta(model.parameters(), lr=4.0)

    # SM Distributed: Use the DistributedModel container to provide the model
    # to be partitioned across different ranks. For the rest of the script,
    # the returned DistributedModel object should be used in place of
    # the model provided for DistributedModel class instantiation.
    model = smp.DistributedModel(model)
    scaler = smp.amp.GradScaler()
    optimizer = smp.DistributedOptimizer(optimizer)

    scheduler = StepLR(optimizer, step_size=1, gamma=0.7)
    for epoch in range(1, 2):
        train(model, scaler, device, train_loader, optimizer, epoch)
        test_loss = test(model, device, test_loader)
        scheduler.step()

    if smp.rank() == 0:
        if os.path.exists("/opt/ml/local_checkpoints"):
            print("-INFO- PATH DO EXIST")
        else:
            os.makedirs("/opt/ml/local_checkpoints")
            print("-INFO- PATH DO NOT EXIST")

    # Waiting the save checkpoint to be finished before run another allgather_object
    smp.barrier()

    if smp.dp_rank() == 0:
        model_dict = model.local_state_dict()
        opt_dict = optimizer.local_state_dict()
        smp.save(
            {
                "model_state_dict": model_dict,
                "optimizer_state_dict": opt_dict
            },
            f"/opt/ml/local_checkpoints/pt_mnist_checkpoint.pt",
            partial=True,
        )
    smp.barrier()

    if smp.local_rank() == 0:
        print("Start syncing")
        base_s3_path = os.path.dirname(
            os.path.dirname(os.getenv("SM_MODULE_DIR", "")))
        curr_host = os.getenv("SM_CURRENT_HOST")
        full_s3_path = f"{base_s3_path}/checkpoints/{curr_host}/"
        sync_local_checkpoints_to_s3(local_path="/opt/ml/local_checkpoints",
                                     s3_path=full_s3_path)
        print("Finished syncing")
コード例 #21
0
def main():
    global timeout_sent

    args = parse_arguments()

    random.seed(args.seed + args.local_rank)
    np.random.seed(args.seed + args.local_rank)
    torch.manual_seed(args.seed + args.local_rank)
    torch.cuda.manual_seed(args.seed + args.local_rank)
    worker_init = WorkerInitObj(args.seed + args.local_rank)

    device, args = setup_training(args)

    # Prepare optimizer
    (
        model,
        optimizer,
        lr_scheduler,
        checkpoint,
        global_step,
        criterion,
    ) = prepare_model_and_optimizer(args, device)

    raw_train_start = None
    most_recent_ckpts_paths = []
    average_loss = 0.0  # averaged loss every args.log_freq steps
    epoch = 0
    training_steps = 0
    test_losses = []

    pool = ProcessPoolExecutor(1)

    # Note: We loop infinitely over epochs, termination is handled via iteration count
    while True:
        thread = None
        restored_data_loader = None
        if (not args.resume_from_checkpoint or epoch > 0
                or (args.phase2 and global_step < 1) or args.init_checkpoint):
            files = [
                os.path.join(args.input_dir, f)
                for f in os.listdir(args.input_dir)
                if os.path.isfile(os.path.join(args.input_dir, f))
                and "training" in f
            ]
            files.sort()
            num_files = len(files)
            random.Random(args.seed + epoch).shuffle(files)
            f_start_id = 0
        else:
            f_start_id = checkpoint["files"][0]
            files = checkpoint["files"][1:]
            args.resume_from_checkpoint = False
            num_files = len(files)
            # may not exist in all checkpoints
            epoch = checkpoint.get("epoch", 0)
            restored_dataloader = checkpoint.get("data_loader", None)

        shared_file_list = {}

        if smp.is_initialized():
            dpsize = smp.dp_size()
            dprank = smp.dp_rank()
        elif torch.distributed.is_initialized():
            dpsize = get_world_size()
            dprank = get_rank()
        else:
            dpsize = 1
            dprank = 0
        dparallel = dpsize > 1
        if dparallel and dpsize > num_files:
            remainder = dpsize % num_files
            data_file = files[(f_start_id * dpsize + dprank +
                               remainder * f_start_id) % num_files]
        else:
            data_file = files[(f_start_id * dpsize + dprank) % num_files]

        previous_file = data_file

        if restored_data_loader is None:
            train_data = pretraining_dataset(data_file,
                                             args.max_predictions_per_seq)
            train_sampler = RandomSampler(train_data)
            train_dataloader = DataLoader(
                train_data,
                sampler=train_sampler,
                batch_size=args.train_batch_size * args.n_gpu,
                num_workers=4,
                worker_init_fn=worker_init,
                pin_memory=True,
                drop_last=True,
            )
            # shared_file_list["0"] = (train_dataloader, data_file)
        else:
            train_dataloader = restored_data_loader
            restored_data_loader = None

        overflow_buf = None
        if args.allreduce_post_accumulation:
            overflow_buf = torch.cuda.IntTensor([0])

        for f_id in range(f_start_id + 1, len(files)):
            if get_world_size() > num_files:
                data_file = files[(f_id * get_world_size() + get_rank() +
                                   remainder * f_id) % num_files]
            else:
                data_file = files[(f_id * get_world_size() + get_rank()) %
                                  num_files]

            previous_file = data_file

            dataset_future = pool.submit(
                create_pretraining_dataset,
                data_file,
                args.max_predictions_per_seq,
                shared_file_list,
                args,
                worker_init,
            )

            train_iter = (tqdm(train_dataloader,
                               desc="Iteration",
                               disable=args.disable_progress_bar)
                          if is_main_process() else train_dataloader)

            if raw_train_start is None:
                raw_train_start = time.time()

            for step, batch in enumerate(train_iter):
                training_steps += 1
                batch = [t.to(device) for t in batch]
                input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels = batch
                if args.do_train:
                    from smdistributed.modelparallel.test.torch.utils import dump_model, verify

                    model.train()
                    if args.smp > 0:
                        loss_mbs = smp_step(
                            args,
                            device,
                            input_ids,
                            segment_ids,
                            input_mask,
                            masked_lm_labels,
                            next_sentence_labels,
                            model,
                            optimizer,
                            criterion,
                            step,
                        )
                        loss = loss_mbs.reduce_mean()
                        if smp.rank() == 0:
                            print("Loss:", loss.item())
                    else:
                        loss = train_step(
                            args,
                            device,
                            input_ids,
                            segment_ids,
                            input_mask,
                            masked_lm_labels,
                            next_sentence_labels,
                            model,
                            optimizer,
                            criterion,
                            step,
                        )
                    divisor = 1
                    average_loss += loss.item()

                    if training_steps % args.gradient_accumulation_steps == 0:
                        lr_scheduler.step()  # learning rate warmup
                        global_step = take_optimizer_step(
                            args, optimizer, model, overflow_buf, global_step)

                    if global_step >= args.steps_this_run or timeout_sent:
                        train_time_raw = time.time() - raw_train_start
                        last_num_steps = (int(
                            training_steps / args.gradient_accumulation_steps)
                                          % args.log_freq)
                        last_num_steps = args.log_freq if last_num_steps == 0 else last_num_steps
                        average_loss = torch.tensor(
                            average_loss, dtype=torch.float32).cuda()
                        average_loss = average_loss / (last_num_steps *
                                                       divisor)
                        if torch.distributed.is_initialized():
                            average_loss /= get_world_size()
                            torch.distributed.all_reduce(average_loss)
                        final_loss = loss.item()
                    elif training_steps % (
                            args.log_freq *
                            args.gradient_accumulation_steps) == 0:
                        average_loss = 0

                    if (global_step >= args.steps_this_run or training_steps %
                        (args.num_steps_per_checkpoint *
                         args.gradient_accumulation_steps) == 0
                            or timeout_sent):
                        if smp.dp_rank() == 0 and not args.skip_checkpoint:
                            if args.resume_step < 0 or not args.phase2:
                                output_save_file = os.path.join(
                                    args.output_dir,
                                    "ckpt_{}.pt".format(global_step))
                            else:
                                output_save_file = os.path.join(
                                    args.output_dir,
                                    "ckpt_{}.pt".format(global_step +
                                                        args.phase1_end_step),
                                )
                            if args.do_train:
                                save_dict = {
                                    "model":
                                    model.local_state_dict(),
                                    "optimizer":
                                    optimizer.local_state_dict(),
                                    "files": [f_id] + files,
                                    "epoch":
                                    epoch,
                                    "data_loader":
                                    None if global_step >= args.steps_this_run
                                    else train_dataloader,
                                }
                                if args.fp16:
                                    save_dict["master params"] = list(
                                        amp.master_params(optimizer))
                                # SMP: Checkpoint mp_rank specific state
                                smp.save(save_dict,
                                         output_save_file,
                                         partial=True)

                                most_recent_ckpts_paths.append(
                                    output_save_file)
                                if len(most_recent_ckpts_paths) > 3 and (
                                        args.smp == 0 or smp.dp_rank() == 0):
                                    ckpt_to_be_removed = most_recent_ckpts_paths.pop(
                                        0)
                                    os.remove(ckpt_to_be_removed +
                                              f"_{smp.mp_rank()}")

                        # Exiting the training due to hitting max steps, or being sent a
                        # timeout from the cluster scheduler
                        if global_step >= args.steps_this_run or timeout_sent:
                            del train_dataloader
                            # thread.join()
                            if smp.dp_rank() == 0 and args.save_full:
                                output_save_file = os.path.join(
                                    args.output_dir,
                                    "ckpt_{}.pt".format(global_step))
                                save_dict = {
                                    "model":
                                    model.local_state_dict(),
                                    "optimizer":
                                    optimizer.local_state_dict(),
                                    "files": [f_id] + files,
                                    "epoch":
                                    epoch,
                                    "data_loader":
                                    None if global_step >= args.steps_this_run
                                    else train_dataloader,
                                }
                                if args.fp16:
                                    save_dict["master params"] = list(
                                        amp.master_params(optimizer))
                                # SMP: Save a single checkpoint containing entire model parameters
                                smp.save(save_dict,
                                         output_save_file,
                                         partial=False)
                            smp.barrier()
                            if smp.local_rank() == 0:
                                print(f"Start syncing model checkpoints to s3")
                                base_s3_path = os.path.dirname(
                                    os.path.dirname(
                                        os.getenv("SM_MODULE_DIR", "")))
                                curr_host = os.getenv("SM_CURRENT_HOST")
                                full_s3_path = f"{base_s3_path}/checkpoints/{curr_host}/"
                                sync_local_checkpoints_to_s3(
                                    local_path=args.output_dir,
                                    s3_path=full_s3_path)
                                print(
                                    f"Finished syncing model checkpoints to s3"
                                )
                            return args, final_loss, train_time_raw, global_step
                else:
                    model.eval()
                    with torch.no_grad():
                        loss = test_step(
                            args,
                            device,
                            input_ids,
                            segment_ids,
                            input_mask,
                            masked_lm_labels,
                            next_sentence_labels,
                            model,
                            criterion,
                            step,
                        )
                        print(f"global_step {global_step} Test Loss:", loss)
                        test_losses.append(loss)
                    global_step += 1
                    if global_step >= args.steps_this_run:
                        return sum(test_losses) / len(test_losses)

            del train_dataloader
            # thread.join()
            # Make sure pool has finished and switch train_dataloader
            # NOTE: Will block until complete
            train_dataloader, data_file = dataset_future.result(timeout=None)
        epoch += 1
コード例 #22
0
def train(
    model,
    optimizer,
    lr_scheduler,
    model_config,
    start_train_path_index,
    start_batch_index,
    num_params,
    total_steps,
    args,
):
    if args.enable_memory_profiling > 0:
        memory_status_cpu(msg="before train step")
    model.train()
    if args.parallel_proc_data_processing:
        pool = ProcessPoolExecutor(1)

    dp_rank = smp.dp_rank() if not args.prescaled_batch else smp.rdp_rank()
    dp_size = smp.dp_size() if not args.prescaled_batch else smp.rdp_size()
    data_type = "BERT" if args.use_bert_data else "GPT"

    if args.use_bert_data:
        train_paths = sorted([
            os.path.join(args.training_dir, p)
            for p in os.listdir(args.training_dir)
            if os.path.isfile(os.path.join(args.training_dir, p))
            and "training" in p
        ])
    else:
        if args.zipped_data > 0:
            file_extension = ".json.gz"
        else:
            file_extension = ".json"
        train_paths = sorted([
            os.path.join(args.training_dir, p)
            for p in os.listdir(args.training_dir)
            if p.endswith(file_extension)
        ])

    train_dataloader = create_pretraining_dataloader(
        [train_paths[start_train_path_index]],
        args.train_batch_size,
        args.max_context_width,
        seed=args.seed,
        dp_rank=dp_rank,
        dp_size=dp_size,
        shuffle=args.same_seed < 1,
        zipped=args.zipped_data > 0,
        use_last_file_only=args.fast_validation > 0,
        data_type=data_type,
    )

    if args.validation_freq is not None:
        # load all validation examples
        if smp.rank() == 0:
            print("Creating val dataloader")
        if args.use_bert_data:
            val_paths = sorted([
                os.path.join(args.test_dir, p)
                for p in os.listdir(args.test_dir)
                if os.path.isfile(os.path.join(args.test_dir, p))
                and "testing" in p
            ])

        else:
            if args.zipped_data > 0:
                file_extension = ".json.gz"
            else:
                file_extension = ".json"
            val_paths = sorted([
                os.path.join(args.test_dir, p)
                for p in os.listdir(args.test_dir)
                if p.endswith(file_extension)
            ])
        val_dataloader = create_pretraining_dataloader(
            val_paths,
            args.val_batch_size,
            args.max_context_width,
            seed=args.seed,
            dp_rank=dp_rank,
            dp_size=dp_size,
            shuffle=True,
            zipped=args.zipped_data > 0,
            use_last_file_only=args.fast_validation > 0,
            data_type=data_type,
        )
        if smp.rank() == 0:
            print("Created val dataloader")

    start = time.time()
    throughput = None
    to_save = {"loss": [], "val_loss": []}
    loss_metric = 0

    def should_record():
        # only record the ranks that in the tp group that contains global rank 0
        if smp.tp_size() > 1:
            tp_group = smp.get_tp_group()
            return 0 in tp_group
        else:
            return smp.rank() == 0

    # Set the same seed for computation
    set_seed(args.seed)

    for index in range(start_train_path_index, args.epochs * len(train_paths)):
        next_train_path_index = (index + 1) % len(train_paths)
        curr_train_path_index = index % len(train_paths)

        if total_steps >= args.max_steps:
            break

        if args.parallel_proc_data_processing:
            dataset_future = pool.submit(
                create_pretraining_dataloader,
                [train_paths[next_train_path_index]],
                args.train_batch_size,
                args.max_context_width,
                seed=args.seed,
                dp_rank=dp_rank,
                dp_size=dp_size,
                shuffle=args.same_seed < 1,
                zipped=args.zipped_data > 0,
                use_last_file_only=args.fast_validation > 0,
                data_type=data_type,
            )

        if smp.rank() == 0:
            if args.use_bert_data:
                print(
                    f"Reading data from training path {train_dataloader.dataset.input_file}"
                )
            else:
                print(
                    f"Reading data from training path {train_dataloader.dataset.input_paths}"
                )

        for batch_idx, input_data in enumerate(train_dataloader):
            if batch_idx < start_batch_index:
                if smp.rank() == 0:
                    print(
                        f"Resuming from saved batch index {start_batch_index}, skipping batch {batch_idx}..."
                    )
                if start_batch_index == len(train_dataloader):
                    # If saving at the last batch of the file, read from the next file
                    start_batch_index = 0
                    break
                continue
            else:
                start_batch_index = 0

            if args.use_bert_data:
                input_ids, _, attention_mask, _, _ = input_data
            else:
                input_ids, attention_mask = input_data

            if total_steps >= args.max_steps:
                break

            step_start = time.time()

            if args.smp_version < 110:
                optimizer.zero_grad(set_grads_to_None=True)
            else:
                optimizer.zero_grad(set_to_none=True)

            if args.logits_output:
                train_output = train_step(model, optimizer, input_ids,
                                          attention_mask, args)
                loss_mb = train_output["loss"]
                logits_mb = train_output["logits"]
                if smp.tp_size() > 1:
                    logits = torch.cat(tuple(logits_mb.outputs), dim=1)
                else:
                    logits = torch.cat(tuple(logits_mb.outputs), dim=0)
            else:
                # Return value, loss_mb is a StepOutput object
                loss_mb = train_step(model, optimizer, input_ids,
                                     attention_mask, args)

            # smdistributed: Average the loss across microbatches.
            loss = loss_mb.reduce_mean()
            if not args.validation_freq:
                loss_metric = loss.item()

            if args.enable_memory_profiling > 0:
                memory_status_cpu("After_train_step_cpu")
                memory_status(msg="After_train_step")

            if args.clean_cache > 0:
                # empty the cache to avoid OOM
                torch.cuda.empty_cache()

            if args.fp16:
                if args.smp_version < 110:
                    optimizer.update_master_grads()
                optimizer.clip_master_grads(args.grad_clip)

            optimizer.step()
            if not (args.fp16 and optimizer.overflow):
                lr_scheduler.step()

            if args.enable_memory_profiling > 0:
                memory_status(msg="After_opt_step")

            total_steps += 1
            time_elapsed = time.time() - start
            step_time = time.time() - step_start
            sample_processed = input_ids.shape[0] * dp_size
            throughput = sample_processed / step_time
            if smp.rank() == 0 and not total_steps % args.logging_freq:
                print(
                    f"({int(time_elapsed)}s), Batch {total_steps - 1} Loss: {loss.item()}, Speed: {throughput} samples/sec"
                )

            # evaluate on validation
            if args.validation_freq and not (total_steps %
                                             args.validation_freq):
                cur_state = np.random.get_state()
                model = model.eval()
                val_loss, val_ppl = eval_model(model, val_dataloader,
                                               args.validation_batches,
                                               args.use_bert_data)
                if is_main_process(smp.rank()):
                    print(
                        f"({int(time.time()-start)}s) Batch {total_steps - 1} Validation loss: {val_loss}"
                    )
                    print(
                        f"({int(time.time()-start)}s) Batch {total_steps - 1} Validation perplexity: {val_ppl}"
                    )
                loss_metric = val_loss
                if args.logits_output:
                    to_save["val_loss"].append(val_loss)
                model = model.train()
                if args.preserve_np_state > 0:
                    np.random.set_state(cur_state)

            # checkpoint
            if not (total_steps % args.checkpoint_freq):
                base_path = f"trained_gpt_nparams-{num_params}_steps-{total_steps}.pt"
                out_path = os.path.join(args.checkpoint_dir, base_path)
                total_ckpts = total_steps // args.checkpoint_freq

                delete_oldest_ckpt(args, delete_on_rank0_only=args.use_fsx > 0)

                save(
                    out_path,
                    model,
                    optimizer,
                    lr_scheduler,
                    model_config,
                    num_params,
                    total_steps,
                    curr_train_path_index,
                    args,
                    partial=True,
                    batch_idx=batch_idx + 1,
                )

            if args.logits_output:
                to_save["loss"].append(loss.item())

        if total_steps >= args.max_steps:
            if should_record() and args.logits_output:
                to_save["logits"] = logits.detach().cpu()
                output_file = f"rank_{smp.rank()}_" + args.logits_output
                torch.save(to_save, os.path.join(args.model_dir, output_file))
                print(
                    f"logits and loss saved at {os.path.join(args.model_dir, output_file)}"
                )
            break

        del train_dataloader

        if args.parallel_proc_data_processing:
            s = time.time()
            train_dataloader = dataset_future.result(timeout=None)
            wait_time = time.time() - s
            if wait_time > 1:
                # TODO if this happens, we should try num_workers>1 in dataloader
                print(
                    f"[{smp.rank()}] Waited {wait_time} for data loader to be ready. Please check if dataloader performance can be improved to avoid these waits."
                )
        else:
            train_dataloader = create_pretraining_dataloader(
                [train_paths[next_train_path_index]],
                args.train_batch_size,
                args.max_context_width,
                seed=args.seed,
                dp_rank=dp_rank,
                dp_size=dp_size,
                shuffle=args.same_seed < 1,
                zipped=args.zipped_data > 0,
                use_last_file_only=args.fast_validation > 0,
                data_type=data_type,
            )

    return total_steps, throughput, loss_metric
コード例 #23
0
def main():
    args = parse_args()

    if args.shard_optimizer_state > 0 and not args.skip_full_optimizer:
        raise ValueError(
            "If shard_optimizer_state is enabled, skip_full_optimizer must also be enabled. Full optimizer saving is currently not supported under optimizer state sharding."
        )

    if args.partition_assignment != "" and args.manual_partition == 0:
        print("[Warning] partition_assignment is set, enable manual_partition")
        args.manual_partition = 1

    # any value here is overriden by the config set in notebook when launching the sagemaker job
    smp_config = {
        "ddp": True,
        "tensor_parallel_degree": args.tensor_parallel_degree,
        "pipeline_parallel_degree": args.pipeline_parallel_degree,
        "microbatches": args.microbatches,
        # if activation_checkpointing true checkpoints transformer layers below
        "checkpoint_attentions":
        False if args.activation_checkpointing else True,
        "shard_optimizer_state": args.shard_optimizer_state > 0,
        "prescaled_batch": args.prescaled_batch > 0,
        "offload_activations": args.offload_activations > 0,
        "optimize": args.optimize,
        "auto_partition": False if args.manual_partition else True,
        "default_partition": 0,
        "static_mode": args.static_mode > 0,
        "fast_mode": args.fast_mode > 0,
    }

    if args.smp_version < 110:
        smp_config["fp16_params"] = args.fp16 > 0
    else:
        smp_config["fp16"] = args.fp16 > 0
        smp_config["delayed_parameter_initialization"] = args.delayed_param > 0
        smp_config["placement_strategy"] = args.placement_strategy
        smp_config[
            "activation_loading_horizon"] = args.activation_loading_horizon
        smp_config["skip_tracing"] = args.skip_tracing > 0

    if args.active_microbatches is not None:
        smp_config["active_microbatches"] = args.active_microbatches

    smp.init(smp_config)

    if smp.rank() == 0:
        print("Arguments:", args.__dict__)
        print(f"Transformers version: {transformers.__version__}")
        print(
            f"smdistributed.modelparallel version: {smdistributed.modelparallel.__version__}"
        )
        print(f"smdistributed config: {smp_config}")

    if args.save_final_full_model and smp.rank() == 0:
        print(
            f"[Warning] Note that save_final_full_model only saves the final model at the end of all steps. It does not save optimizer state. Optimizer state is only saved with partial models which are saved at checkpointing_freq during training. If you want to restart training you need partial checkpoints."
        )

    if args.partition_assignment != "":
        partition_assignment = args.partition_assignment.split(",")
        assert (
            len(partition_assignment) == smp.pp_size()
        ), f"partition_assignment must have the same size as pipeline parallel degree, but getting {len(partition_assignment)} vs {smp.pp_size()}"

    if smp.rank() == 0 or (smp.local_rank() == 0 and args.use_fsx == 0):
        for path in [args.model_dir, args.checkpoint_dir]:
            if not os.path.exists(path):
                os.makedirs(path, exist_ok=True)

    model_config = GPT2Config(
        vocab_size=args.vocab_size,
        n_positions=args.max_context_width,
        n_embd=args.hidden_width,
        n_layer=args.num_layers,
        n_head=args.num_heads,
        n_inner=None,
        activation_function="gelu_new",
        resid_pdrop=args.resid_pdrop,
        embd_pdrop=args.embd_pdrop,
        attn_pdrop=args.attn_pdrop,
        layer_norm_epsilon=1e-05,
        initializer_range=0.02,
        summary_type="cls_index",
        summary_use_proj=True,
        summary_activation=None,
        summary_proj_to_labels=True,
        summary_first_dropout=args.summary_first_pdrop,
        # gradient_checkpointing=args.gradient_checkpointing > 0,
        use_cache=False,
        bos_token_id=50256,
        eos_token_id=50256,
        return_dict=True,
    )

    # the following improves start-up time by skipping proper initialization
    # of weights in the original model. this is not a problem because DistributedModel
    # will override those weights anyway when tensor_parallel_degree > 1.
    if smp.tp_size() > 1:
        from transformers.modeling_utils import PreTrainedModel

        PreTrainedModel.init_weights = lambda x: None

    set_seed(args.seed)

    if args.enable_memory_profiling > 0:
        memory_status_cpu(msg="before model creation")

    if args.smp_version < 110:
        if args.fp16:
            torch.set_default_dtype(torch.float16)
        with smp.tensor_parallelism(
                enabled=smp.tp_size() > 1,
                attention_in_fp32=args.attention_in_fp32 > 0):
            with smp.delay_param_initialization(
                    enabled=(smp.tp_size() > 1 and args.delayed_param > 0)):
                model = AutoModelForCausalLM.from_config(model_config)
    else:
        with smp.model_creation(
                tensor_parallelism=smp.tp_size() > 1,
                attention_in_fp32=args.attention_in_fp32 > 0,
                query_key_layer_scaling=args.query_key_layer_scaling > 0,
                fused_softmax=args.fused_softmax > 0,
                fused_bias_gelu=args.fused_bias_gelu > 0,
                dtype=torch.float16
                if args.fp16 else torch.get_default_dtype(),
        ):
            model = AutoModelForCausalLM.from_config(model_config)

    if args.smp_version < 110 and args.fp16:
        model = FP16_Module(model)

    if args.enable_memory_profiling > 0:
        memory_status_cpu(msg="after model creation")

    num_params = sum([np.prod(p.size()) for p in model.parameters()])
    if smp.rank() == 0:
        print(f"# total parameters: {num_params}")

    # smdistributed: Set the device to the GPU ID used by the current process.
    # Input tensors should be transferred to this device.
    torch.cuda.set_device(smp.local_rank())
    device = torch.device("cuda")

    if not args.same_seed:
        # Set seed by tp_rank to prevent weights from being the same on different tp_ranks
        set_seed(args.seed + smp.tp_rank())

    # smdistributed: Use the DistributedModel container to provide the model
    # to be partitioned across different ranks. For the rest of the script,
    # the returned DistributedModel object should be used in place of
    # the model provided for DistributedModel class instantiation.
    if args.smp_version < 110 and args.fp16:
        torch.set_default_dtype(torch.float16)
    if args.enable_memory_profiling > 0:
        memory_status_cpu(msg="before dist model creation")
    model = smp.DistributedModel(model, trace_device="gpu")
    if args.enable_memory_profiling > 0:
        memory_status_cpu(msg="after dist model creation")

    if args.smp_version < 110:
        if smp.tp_size() > 1:
            transformer_layers = model.module.module.module.transformer.seq_layers
        else:
            transformer_layers = model.module.module.module.transformer.h
    else:
        m = model.get_module()
        if smp.tp_size() > 1:
            transformer_layers = m.transformer.seq_layers
        else:
            transformer_layers = m.transformer.h

    if args.manual_partition:
        print(f"Manual partition enabled")
        if args.partition_assignment != "":
            get_num_layers = lambda x: int(partition_assignment[x])
            total_layers = sum(
                [get_num_layers(pp_rank) for pp_rank in range(smp.pp_size())])
            assert (
                total_layers == args.num_layers
            ), f"partition_assignment must have the same total transformer layers as model, but getting {total_layers} vs {args.num_layers}"
        else:
            # evenly distribute layers across all partitions
            div, rem = divmod(args.num_layers, smp.pp_size())
            get_num_layers = lambda x: (div + 1
                                        if x >= smp.pp_size() - rem else div)
        assignments = []
        # (TODO) This is required for 175B otherwise a hang for partition "8,17,17,18,18,18"
        # Need further investigation
        # for pp_rank in reversed(range(smp.pp_size())):
        for pp_rank in range(smp.pp_size()):
            nl = get_num_layers(pp_rank)
            print(f"{nl} layers assigned to partition {pp_rank}")
            assignments += [pp_rank for _ in range(nl)]

        for i, c in enumerate(transformer_layers.children()):
            smp.set_partition(c, assignments[i])
    if args.smp_version < 110:
        iter_model = model
        # Build parameter groups (weight decay and non-decay).
        while isinstance(iter_model, (DistributedDataParallel, FP16_Module)):
            iter_model = iter_model.module
    else:
        iter_model = m
    param_groups = get_param_groups_by_weight_decay(iter_model)

    if args.use_adamw > 0:
        optimizer = optim.AdamW(param_groups,
                                betas=(args.beta1, args.beta2),
                                lr=args.lr,
                                weight_decay=args.weight_decay)
    else:
        optimizer = optim.Adam(param_groups,
                               betas=(args.beta1, args.beta2),
                               lr=args.lr,
                               weight_decay=args.weight_decay)

    if args.activation_checkpointing:
        kwargs = {}
        if isinstance(transformer_layers, nn.Sequential):
            kwargs["pack_args_as_tuple"] = True
            kwargs["strategy"] = args.activation_strategy
        smp.set_activation_checkpointing(transformer_layers, **kwargs)

    if args.smp_version < 110:
        optimizer = FP16_Optimizer(
            model,
            optimizer,
            static_loss_scale=None,
            dynamic_loss_scale=True,
            use_smp=True,
            dynamic_loss_args={
                "scale_window": 1000,
                "min_scale": 1,
                "delayed_shift": 2
            },
            params_have_main_grad=False,
            shard_optimizer_state=args.shard_optimizer_state > 0,
        )

        optimizer = smp.DistributedOptimizer(optimizer)
        model.register_post_step_hook(
            lambda model, optimizer: optimizer.init_master_params())
    else:
        optimizer = smp.DistributedOptimizer(
            optimizer,
            static_loss_scale=None,
            dynamic_loss_scale=True,
            dynamic_loss_args={
                "scale_window": 1000,
                "min_scale": 1,
                "delayed_shift": 2
            },
        )
    lr_scheduler = get_learning_rate_scheduler(optimizer, args)

    if args.enable_memory_profiling > 0:
        model.register_post_partition_hook(
            lambda model, optimizer: memory_status(msg="After_partition"))

    # load after wrapping model and optimizer with smp Distributed...
    if args.load_full or args.load_partial:
        if args.load_partial and args.load_full:
            print(
                "Since both --load_partial and --load_full set, will try to load from full checkpoint."
                "If the intention is to load from partial checkpoint, please don't set --load_full"
            )
        partial = not args.load_full
        path = args.checkpoint_dir if partial else args.model_dir
        translate_from_hf = not partial
        model, optimizer, total_steps, start_train_path_index, start_batch_index = load_model_and_optimizer(
            path,
            model,
            optimizer,
            lr_scheduler,
            partial,
            args,
            translate_from_hf=translate_from_hf,
            seq_length=args.max_context_width,
            load_model=True,
            load_optimizer=args.load_partial > 0,
            num_params=num_params,
        )
    else:
        total_steps = 0
        start_train_path_index = 0
        start_batch_index = 0

    start = time.time()
    total_steps, throughput, loss = train(
        model,
        optimizer,
        lr_scheduler,
        model_config,
        start_train_path_index,
        start_batch_index,
        num_params,
        total_steps,
        args,
    )
    time_to_train = time.time() - start
    if args.ci:
        print(f"[SMP_METRIC]__GPT2__Time_to_train__{time_to_train}")
        print(f"[SMP_METRIC]__GPT2__samples/second__{throughput}")
        print(f"[SMP_METRIC]__GPT2__Loss__{loss}")
        if not args.load_partial and not args.load_full:
            assert time_to_train < args.time_to_train
            assert throughput > args.throughput
            if args.loss:
                assert loss < args.loss

    if args.save_final_full_model:
        # saves full model at the end

        base_path = f"trained_gpt_nparams-{num_params}_steps-{total_steps}.pt"
        out_path = os.path.join(args.model_dir, base_path)

        if smp.rdp_rank() == 0:
            save(
                out_path,
                model,
                optimizer,
                lr_scheduler,
                model_config,
                num_params,
                total_steps,
                -1,
                args,
                partial=False,
                translate_to_hf=smp.tp_size() > 1,
                seq_length=args.max_context_width,
            )

    smp.barrier()
    if smp.rank() == 0:
        print("SMP training finished successfully")
コード例 #24
0
def main():
    global timeout_sent

    args = parse_arguments()

    random.seed(args.seed + args.local_rank)
    np.random.seed(args.seed + args.local_rank)
    torch.manual_seed(args.seed + args.local_rank)
    torch.cuda.manual_seed(args.seed + args.local_rank)
    worker_init = WorkerInitObj(args.seed + args.local_rank)

    device, args = setup_training(args)

    if is_main_process():
        dllogger.log(step="PARAMETER", data={"Config": [str(args)]})

    # Prepare optimizer
    model, optimizer, lr_scheduler, checkpoint, global_step, criterion = prepare_model_and_optimizer(args, device)

    if is_main_process():
        dllogger.log(step="PARAMETER", data={"SEED": args.seed})

    raw_train_start = None
    if is_main_process():
        dllogger.log(step="PARAMETER", data={"train_start": True})
        dllogger.log(step="PARAMETER", data={"batch_size_per_gpu": args.train_batch_size})
        dllogger.log(step="PARAMETER", data={"learning_rate": args.learning_rate})

    most_recent_ckpts_paths = []
    average_loss = 0.0  # averaged loss every args.log_freq steps
    epoch = 0
    training_steps = 0
    test_losses = []

    pool = ProcessPoolExecutor(1)

    # Note: We loop infinitely over epochs, termination is handled via iteration count
    while True:
        thread = None
        restored_data_loader = None
        if not args.resume_from_checkpoint or epoch > 0 or (args.phase2 and global_step < 1) or args.init_checkpoint:
            files = [os.path.join(args.input_dir, f) for f in os.listdir(args.input_dir) if
                    os.path.isfile(os.path.join(args.input_dir, f)) and 'training' in f]
            files.sort()
            num_files = len(files)
            random.Random(args.seed + epoch).shuffle(files)
            f_start_id = 0
        else:
            f_start_id = checkpoint['files'][0]
            files = checkpoint['files'][1:]
            args.resume_from_checkpoint = False
            num_files = len(files)
            # may not exist in all checkpoints
            epoch = checkpoint.get('epoch', 0)
            restored_dataloader = checkpoint.get('data_loader', None)

        shared_file_list = {}

        if torch.distributed.is_initialized() and get_world_size() > num_files:
            remainder = get_world_size() % num_files
            data_file = files[(f_start_id*get_world_size()+get_rank() + remainder*f_start_id)%num_files]
        else:
            data_file = files[(f_start_id*get_world_size()+get_rank())%num_files]

        previous_file = data_file

        if restored_data_loader is None:
            train_data = pretraining_dataset(data_file, args.max_predictions_per_seq)
            train_sampler = RandomSampler(train_data)
            train_dataloader = DataLoader(train_data, sampler=train_sampler,
                                        batch_size=args.train_batch_size * args.n_gpu,
                                        num_workers=4, worker_init_fn=worker_init,
                                        pin_memory=True)
            # shared_file_list["0"] = (train_dataloader, data_file)
        else:
            train_dataloader = restored_data_loader
            restored_data_loader = None

        overflow_buf = None
        if args.allreduce_post_accumulation:
            overflow_buf = torch.cuda.IntTensor([0])

        for f_id in range(f_start_id + 1 , len(files)):
            if get_world_size() > num_files:
                data_file = files[(f_id*get_world_size()+get_rank() + remainder*f_id)%num_files]
            else:
                data_file = files[(f_id*get_world_size()+get_rank())%num_files]

            previous_file = data_file

            dataset_future = pool.submit(create_pretraining_dataset, data_file, args.max_predictions_per_seq, shared_file_list, args, worker_init)

            train_iter = tqdm(train_dataloader, desc="Iteration", disable=args.disable_progress_bar) if is_main_process() else train_dataloader

            if raw_train_start is None:
                raw_train_start = time.time()

            for step, batch in enumerate(train_iter):
                training_steps += 1
                batch = [t.to(device) for t in batch]
                input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels = batch
                if args.do_train:
                    from smdistributed.modelparallel.test.torch.utils import verify, dump_model
                    model.train()
                    if args.smp > 0:
                        loss_mbs = smp_step(args, device, input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels, model, optimizer, criterion, step)
                        loss = loss_mbs.reduce_mean()
                        if smp.rank() == 0:
                            print("Loss:", loss.item())
                    else:
                        loss = train_step(args, device, input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels, model, optimizer, criterion, step)
                    divisor=1
                    average_loss += loss.item()

                    if training_steps % args.gradient_accumulation_steps == 0:
                        lr_scheduler.step()  # learning rate warmup
                        global_step = take_optimizer_step(args, optimizer, model, overflow_buf, global_step)

                    if global_step >= args.steps_this_run or timeout_sent:
                        train_time_raw = time.time() - raw_train_start
                        last_num_steps = int(training_steps / args.gradient_accumulation_steps) % args.log_freq
                        last_num_steps = args.log_freq if last_num_steps == 0 else last_num_steps
                        average_loss = torch.tensor(average_loss, dtype=torch.float32).cuda()
                        average_loss = average_loss / (last_num_steps * divisor)
                        if (torch.distributed.is_initialized()):
                            average_loss /= get_world_size()
                            torch.distributed.all_reduce(average_loss)
                        final_loss = loss.item()
                        if is_main_process():
                            dllogger.log(step=(epoch, global_step, ), data={"final_loss": final_loss})
                    elif training_steps % (args.log_freq * args.gradient_accumulation_steps) == 0:
                        if is_main_process():
                            dllogger.log(step=(epoch, global_step, ), data={"average_loss": average_loss / (args.log_freq * divisor),
                                                                            "step_loss": loss.item() * args.gradient_accumulation_steps / divisor,
                                                                            "learning_rate": optimizer.param_groups[0]['lr']})
                        average_loss = 0


                    if global_step >= args.steps_this_run or training_steps % (
                            args.num_steps_per_checkpoint * args.gradient_accumulation_steps) == 0 or timeout_sent:
                        if smp.dp_rank() == 0 and not args.skip_checkpoint:
                            # Save a trained model
                            dllogger.log(step="PARAMETER", data={"checkpoint_step": global_step})
                            # model_to_save = model.module if hasattr(model,
                            #                                         'module') else model  # Only save the model it-self
                            if args.resume_step < 0 or not args.phase2:
                                output_save_file = os.path.join(args.output_dir, "ckpt_{}.pt".format(global_step))
                            else:
                                output_save_file = os.path.join(args.output_dir, "ckpt_{}.pt".format(global_step + args.phase1_end_step))
                            if args.do_train:
                                save_dict = {
                                            'model': model.local_state_dict(),
                                            'optimizer': optimizer.local_state_dict(),
                                            'files': [f_id] + files,
                                            'epoch': epoch,
                                            'data_loader': None if global_step >= args.steps_this_run else train_dataloader}
                                if args.fp16:
                                    save_dict['master params'] = list(amp.master_params(optimizer))
                                # SMP: Checkpoint mp_rank specific state
                                smp.save(save_dict, output_save_file, partial=True)

                                most_recent_ckpts_paths.append(output_save_file)
                                if len(most_recent_ckpts_paths) > 3 and (args.smp == 0 or smp.dp_rank() == 0):
                                    ckpt_to_be_removed = most_recent_ckpts_paths.pop(0)
                                    os.remove(ckpt_to_be_removed+f"_{smp.mp_rank()}")

                        # Exiting the training due to hitting max steps, or being sent a
                        # timeout from the cluster scheduler
                        if global_step >= args.steps_this_run or timeout_sent:
                            del train_dataloader
                            # thread.join()
                            if smp.dp_rank() == 0 and args.save_full:
                                output_save_file = os.path.join(args.output_dir, "ckpt_{}.pt".format(global_step))
                                save_dict = {
                                            'model': model.local_state_dict(),
                                            'optimizer': optimizer.local_state_dict(),
                                            'files': [f_id] + files,
                                            'epoch': epoch,
                                            'data_loader': None if global_step >= args.steps_this_run else train_dataloader}
                                if args.fp16:
                                    save_dict['master params'] = list(amp.master_params(optimizer))
                                # SMP: Save a single checkpoint containing entire model parameters
                                smp.save(save_dict, output_save_file, partial=False)
                            smp.barrier()
                            return args, final_loss, train_time_raw, global_step
                else:
                    model.eval()
                    with torch.no_grad():
                        loss = test_step(args, device, input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels, model, criterion, step)
                        print(f"global_step {global_step} Test Loss:", loss)
                        test_losses.append(loss)
                    global_step += 1
                    if global_step >= args.steps_this_run:
                        return sum(test_losses) / len(test_losses)

            del train_dataloader
            # thread.join()
            # Make sure pool has finished and switch train_dataloader
            # NOTE: Will block until complete
            train_dataloader, data_file = dataset_future.result(timeout=None)
        epoch += 1