def _swap_in_parameter(self, aio_handle, parameter, dest_buffers):
        swap_info = self._get_param_swap_info(parameter)
        if swap_info is None:
            return

        assert len(swap_info.tensors) <= len(dest_buffers)

        swap_lengths = [self._io_aligned_numel(swap_info.numel())] * len(
            swap_info.tensors)
        swap_buffers = get_sized_buffers(dest_buffers, swap_lengths)

        READ_TIMER = 'swap_submit_read_param'
        WAIT_TIMER = 'swap_wait_read_param'

        self._start_timer(READ_TIMER)
        swap_in_tensors(aio_handle, swap_buffers, swap_info.swap_paths)
        self._stop_timer(READ_TIMER)

        swap_bytes = sum([
            buffer.numel() * buffer.element_size() for buffer in swap_buffers
        ])

        self._start_timer(WAIT_TIMER)
        aio_handle.wait()
        self._stop_timer(WAIT_TIMER)

        compute_lengths = [swap_info.numel()] * len(swap_info.tensors)
        compute_buffers = get_sized_buffers(dest_buffers, compute_lengths)
        for t, buffer in zip(swap_info.tensors, compute_buffers):
            t.data = buffer.data

        self._log_timers([READ_TIMER, WAIT_TIMER])
        if DEBUG_MODE and torch.distributed.get_rank() == 0:
            logger.info(
                f'optimizer_param_swap_in: {(swap_bytes/(1024**3)):5.2f} GB')
 def _report_statistics(self, message):
     if torch.distributed.get_rank() == 0:
         element_size = torch.tensor([], dtype=self.dtype).element_size()
         swapped_GB = (self.num_elements_swapped * element_size) / (1024**3)
         logger.info(
             f'{message} num_elems = {self.num_elements_swapped}, {swapped_GB:5.2f} GB'
         )
    def _initialize_parameters(self, parameters, src_tensors, aio_handle):
        assert len(parameters) == len(src_tensors)

        swap_paths = self._get_swap_paths(parameters=parameters,
                                          num_elems=[src.numel() for src in src_tensors])

        SWAP_INIT_TIMER = "swap_init_write"
        self._start_timer(SWAP_INIT_TIMER)

        pinned_buffers = self.swap_buffer_manager.allocate_all(
            num_elems=self.largest_numel,
            dtype=self.dtype)
        assert pinned_buffers is not None

        self._swap_out_unpinned_tensors(aio_handle=aio_handle,
                                        unpinned_tensors=src_tensors,
                                        dest_paths=swap_paths,
                                        pinned_buffers=pinned_buffers)

        if torch.distributed.get_rank() == 0 and SWAPPER_DEBUG_MODE:
            for i, tensor in enumerate(src_tensors):
                logger.info(
                    f'copy_in_fp16_param: fp32_id = {id(parameters[i])} index = {i}, swap_num_elem = {src_tensors[i].numel()}'
                )

        self.swap_buffer_manager.free(pinned_buffers)

        self._stop_timer(SWAP_INIT_TIMER)
        self._log_timers([SWAP_INIT_TIMER])
def main():
    start = time.time()
    args = construct_arguments()
    model, optimizer = prepare_model_optimizer(args)
    start_epoch = 0
    if not None in [args.load_training_checkpoint, args.load_checkpoint_id]:
        start_epoch = load_checkpoint(args, model)
    run(args, model, optimizer, start_epoch)
    elapsed = time.time() - start
    logger = args.logger
    logger.info(f"Elapsed time: {elapsed} seconds")
    def _initialize_from_swapped_fp16_params(self,
                                             aio_handle,
                                             fp16_partitions_info,
                                             fp16_num_elems,
                                             fp16_pinned_buffers,
                                             fp32_parameters):
        assert len(fp32_parameters) == len(fp16_partitions_info)
        assert len(fp32_parameters) == len(fp16_num_elems)
        assert all([buffer.is_pinned() for buffer in fp16_pinned_buffers])

        fp32_swap_paths = self._get_swap_paths(parameters=fp32_parameters,
                                               num_elems=fp16_num_elems)

        fp32_pinned_buffers = self.swap_buffer_manager.allocate_all(
            num_elems=self.largest_numel,
            dtype=self.dtype)

        fp16_buffer_numel = [buf.numel() for buf in fp16_pinned_buffers]
        assert all([numel >= self.largest_numel for numel in fp16_buffer_numel]), \
        f"numel of fp16 buffers {fp16_buffer_numel} is too small for initializing fp32 params {self.largest_numel}"

        fp32_swap_buffers = SwapBufferPool(fp32_pinned_buffers)
        fp16_swap_buffers = SwapBufferPool(fp16_pinned_buffers)

        curr_index = 0
        while curr_index < len(fp32_parameters):
            fp16_pinned_tensors = self._swap_in_fp16_params(
                aio_handle=aio_handle,
                fp16_num_elems=fp16_num_elems[curr_index:],
                fp16_partitions_info=fp16_partitions_info[curr_index:],
                fp16_swap_buffers=fp16_swap_buffers)

            if torch.distributed.get_rank() == 0 and SWAPPER_DEBUG_MODE:
                for i, tensor in enumerate(fp16_pinned_tensors):
                    true_index = curr_index + i
                    logger.info(
                        f'swap_in_fp16_param: fp32_id = {id(fp32_parameters[true_index])} index = {true_index} orig_num_elem = {fp16_num_elems[true_index]}, swap_num_elem = {fp16_pinned_tensors[i].numel()}'
                    )

            swap_out_count = self._swap_out_fp16_params(
                aio_handle=aio_handle,
                fp32_swap_paths=fp32_swap_paths[curr_index:],
                fp32_swap_buffers=fp32_swap_buffers,
                fp16_pinned_tensors=fp16_pinned_tensors)
            assert swap_out_count == len(fp16_pinned_tensors), \
            f"{swap_out_count} does not match {len(fp16_pinned_tensors)}"

            fp16_swap_buffers.reset()
            fp32_swap_buffers.reset()
            curr_index += swap_out_count

        self.swap_buffer_manager.free(fp32_pinned_buffers)
    def _retrieve_unswapped_grad_partitions(self, swap_info, dest_buffer):
        UNSWAPPED_READ_GRADIENTS = 'unswapped_read_gradients'
        self._start_timer(UNSWAPPED_READ_GRADIENTS)
        tensor_count = len(swap_info.unswapped_gradients)
        num_elem_count = swap_info.read_unswapped_gradients(dest_buffer)
        self._stop_timer(UNSWAPPED_READ_GRADIENTS)
        self._log_timers([UNSWAPPED_READ_GRADIENTS])

        # It shoud be safe to discard unswapped gradient partitions
        swap_info.release_unswapped_gradients()

        if SWAPPER_DEBUG_MODE:
            logger.info(
                f'optimizer_retreive_unswapped_radients: param={swap_info.param_id} tensor_count={tensor_count} elem_count={num_elem_count}'
            )
    def swap_out_optimizer_state(self, parameter, async_swap=False):
        swap_info = self._get_param_swap_info(parameter=parameter)

        if swap_info is None:
            return

        self._start_timer(SWAP_OUT_PARAM_TIMER)
        pinned_tensors, pinned_paths, unpinned_tensors, unpinned_paths = self._separate_pinned_tensors(
            swap_info)
        swap_bytes = sum([
            self._io_aligned_numel(t.numel()) * t.element_size()
            for t in swap_info.tensors
        ])

        WRITE_TIMER = 'swap_submit_write'
        self._start_timer(WRITE_TIMER)

        swap_out_tensors(self.aio_handle, pinned_tensors, pinned_paths)
        assert self.aio_handle.wait() == len(pinned_tensors)
        for t in pinned_tensors:
            t.data = torch.Tensor()

        if len(unpinned_tensors) > 0:
            pinned_buffers = self.swap_buffer_manager.allocate_all(
                num_elems=self.largest_numel, dtype=self.dtype)
            self._swap_out_unpinned_tensors(aio_handle=self.aio_handle,
                                            unpinned_tensors=unpinned_tensors,
                                            dest_paths=unpinned_paths,
                                            pinned_buffers=pinned_buffers)
            self.allocated_swap_buffers += pinned_buffers

            for t in unpinned_tensors:
                t.data = torch.Tensor()
        self._stop_timer(WRITE_TIMER)

        self.swap_buffer_manager.free(self.allocated_swap_buffers)
        self.allocated_swap_buffers = []

        self._stop_timer(SWAP_OUT_PARAM_TIMER)
        self.timer_names.add(SWAP_OUT_PARAM_TIMER)

        self._log_timers([WRITE_TIMER])

        if DEBUG_MODE and torch.distributed.get_rank() == 0:
            logger.info(
                f'optimizer_param_swap_out: {(swap_bytes/(1024**3)):5.2f} GB')
def load_checkpoint(args, model):
    global global_step
    global global_data_samples
    global last_global_step_from_restore

    config = args.config
    logger = args.logger

    logger.info(
        f"Restoring previous training checkpoint from PATH={args.load_training_checkpoint}, CKPT_ID={args.load_checkpoint_id}"
    )
    start_epoch, global_step, global_data_samples = load_training_checkpoint(
        args=args,
        model=model,
        PATH=args.load_training_checkpoint,
        ckpt_id=args.load_checkpoint_id)
    logger.info(
        f"The model is loaded from last checkpoint at epoch {start_epoch} when the global steps were at {global_step} and global data samples at {global_data_samples}"
    )

    if args.rewarmup:
        logger.info(
            f"Rewarmup learning rate with last_global_step_from_restore = {global_step}"
        )
        last_global_step_from_restore = global_step

    lr_this_step = config["training"][
        "learning_rate"] * warmup_linear_decay_exp(
            global_step, config["training"]["decay_rate"],
            config["training"]["decay_step"],
            config["training"]["total_training_steps"],
            config["training"]["warmup_proportion"])
    logger.info(f"Restart training with lr = {lr_this_step}")

    # Run validation for checkpoint before training
    if not args.finetune and args.max_seq_length == 512:
        logger.info(
            f"Validation Loss of Checkpoint {start_epoch} before pretraining")
        index = start_epoch - 1 if start_epoch > 0 else start_epoch
        pretrain_validation(args, index, model)

    return start_epoch
def run(args, model, optimizer, start_epoch):
    global global_step
    global global_data_samples
    global last_global_step_from_restore

    config = args.config
    logger = args.logger

    if args.use_nvidia_dataset:
        pretrain_dataset_provider = NvidiaBertDatasetProvider(args)
    else:
        pretrain_dataset_provider = BingBertDatasetProvider(args)

    for index in range(start_epoch, config["training"]["num_epochs"]):
        logger.info(f"Training Epoch: {index + 1}")
        pre = time.time()
        train(args, index, model, optimizer, pretrain_dataset_provider)

        # Save ckpts according to "--ckpt_to_save" option,
        # e.g. "--ckpt_to_save 160 161" to save epoch 160 and 161.
        if args.ckpt_to_save is None or (index + 1) in args.ckpt_to_save:
            logger.info(
                f"Saving a checkpointing of the model for epoch: {index+1}")

            checkpoint_model(PATH=args.saved_model_path,
                             ckpt_id='epoch{}_step{}'.format(
                                 index + 1, global_step),
                             model=model,
                             epoch=index + 1,
                             last_global_step=global_step,
                             last_global_data_samples=global_data_samples)

        post = time.time()
        logger.info(f"Time for shard {index + 1}: {post-pre} seconds")

        current_global_step = global_step - last_global_step_from_restore
        if is_time_to_exit(args=args, global_steps=current_global_step):
            print(
                f'Warning: Early training termination due to max steps limit, epoch={index+1}, global_step={current_global_step}'
            )
            break
Exemple #10
0
    def __init__(self, *super_args, **super_kwargs):
        super().__init__(*super_args, **super_kwargs)
        assert isinstance(self.module,
                          PipelineModule), "model must base PipelineModule"

        # We schedule the all-reduces, so disable it in super().backward()
        self.enable_backward_allreduce = False
        assert not self.elasticity_enabled(), "Elasticity is not currently supported" \
            " with pipeline parallelism."

        # pipeline step for logging
        self.log_batch_step_id = -1

        self.micro_batch_size = self.train_micro_batch_size_per_gpu()
        self.micro_batches = self.gradient_accumulation_steps()

        # Set Grid and Communication Groups
        self.grid = self.module._grid
        if self.grid.get_global_rank() == 0:
            logger.info(f'CONFIG: micro_batches={self.micro_batches} '
                        f'micro_batch_size={self.micro_batch_size}')

        self.global_rank = self.grid.get_global_rank()

        assert self.dp_world_size == self.grid.data_parallel_size
        assert self.train_batch_size() == \
            self.micro_batch_size * self.micro_batches * self.grid.data_parallel_size

        #  Set Stage Inf
        self.num_stages = self.grid.pipe_parallel_size
        self.stage_id = self.grid.get_stage_id()
        self.prev_stage = self.stage_id - 1
        self.next_stage = self.stage_id + 1

        self.data_iterator = None
        self.batch_fn = None

        self._force_grad_boundary = False

        self.batch_timer = ThroughputTimer(
            batch_size=self.micro_batch_size * self.micro_batches,
            num_workers=self.dp_world_size,
            logging_fn=self.tput_log,
            monitor_memory=False,
            steps_per_output=self.steps_per_print())

        # PipelineEngine needs to handle data loading specially due to only the first
        # and last stages loading inputs/labels. We construct a sampler that uses
        if self.training_data:
            self._build_data_iter(self.training_data)

        self.is_pipe_parallel = self.grid.pipe_parallel_size > 1
        self.is_data_parallel = self.grid.data_parallel_size > 1
        self.is_model_parallel = self.grid.model_parallel_size > 1

        # Partition input/output buffers
        self.is_pipe_partitioned = self.is_model_parallel
        self.is_grad_partitioned = False

        model_parameters = filter(lambda p: p.requires_grad,
                                  self.module.parameters())
        num_params = sum([p.numel() for p in model_parameters])
        unique_params = num_params
        # Subtract tied parameters if we don't own them
        if self.module.tied_comms:
            tied_params = 0
            for key, d in self.module.tied_comms.items():
                if self.global_rank != min(d['ranks']):
                    tied_params += sum(p.numel()
                                       for p in d['module'].parameters())
            unique_params -= tied_params
        params_tensor = torch.LongTensor(data=[num_params, unique_params]).to(
            self.device)
        dist.all_reduce(params_tensor,
                        group=self.grid.get_model_parallel_group())
        params_tensor = params_tensor.tolist()
        total_params = params_tensor[0]
        unique_params = params_tensor[1]
        if self.grid.data_parallel_id == 0:
            logger.info(
                f'RANK={self.global_rank} '
                f'STAGE={self.stage_id} '
                f'LAYERS={self.module._local_stop - self.module._local_start} '
                f'[{self.module._local_start}, {self.module._local_stop}) '
                f'STAGE_PARAMS={num_params} ({num_params/1e6:0.3f}M) '
                f'TOTAL_PARAMS={total_params} ({total_params/1e6:0.3f}M) '
                f'UNIQUE_PARAMS={unique_params} ({unique_params/1e6:0.3f}M)')

        #intialize peer-2-peer communication and allreduce groups
        if self.is_pipe_parallel:
            p2p.init_process_groups(self.grid)

        # Pipeline buffers
        self.num_pipe_buffers = 0
        self.pipe_buffers = {
            'inputs': [],  # batch input and received activations
            'labels': [],  # labels from batch input
            'outputs': [],  # activations
            'output_tensors': [],  # tensor object to preserve backward graph
        }
        self.pipe_recv_buf = None
        self.grad_layer = None

        self.meta_buffer = None

        self.first_output_send = True
        self.first_gradient_send = True

        #stores the loss for the current micro batch being processed
        self.loss = torch.tensor(0.0).to(self.device)

        #stores the loss for the entire batch
        self.total_loss = None
        self.agg_loss = torch.tensor(0.0, requires_grad=False).to(self.device)
        self.dp_group_loss = torch.tensor(0.0,
                                          requires_grad=False).to(self.device)

        if self._config.pipeline['activation_checkpoint_interval'] > 0:
            self.module.activation_checkpoint_interval = self._config.pipeline[
                'activation_checkpoint_interval']

        if self.is_last_stage():
            self.loss_model = self.module.loss_fn

        # Initialize pipeline communicators. Just send a 0.
        if is_even(self.stage_id):
            if not self.is_last_stage():
                p2p.send(self.loss, self.next_stage)
            if not self.is_first_stage():
                p2p.recv(self.loss, self.prev_stage)
        else:
            if not self.is_first_stage():
                p2p.recv(self.loss, self.prev_stage)
            if not self.is_last_stage():
                p2p.send(self.loss, self.next_stage)

        # XXX look into timer reporting timing
        # Initialize some timers because of early weirdness.
        if self.wall_clock_breakdown():
            self.timers('forward_microstep').start()
            self.timers('forward_microstep').stop()
            self.timers('backward_microstep').start()
            self.timers('backward_microstep').stop()
            self.timers('backward_inner_microstep').start()
            self.timers('backward_inner_microstep').stop()
            self.timers('backward_allreduce_microstep').start()
            self.timers('backward_allreduce_microstep').stop()
            self.timers('backward_allreduce').start()
            self.timers('backward_allreduce').stop()
            self.timers('step_microstep').start()
            self.timers('step_microstep').stop()
def train(args,
          index,
          model,
          optimizer,
          pretrain_dataset_provider,
          finetune=False):
    global global_step
    global global_data_samples
    global last_global_step_from_restore

    dataset_iterator, total_length = pretrain_dataset_provider.get_shard(index)
    current_data_sample_count = global_data_samples

    config = args.config
    logger = args.logger
    logger.info(
        f'worker-{dist.get_rank()}: begin epoch {index+1} current_sample_count {current_data_sample_count} shard_length {total_length} global_data_samples {global_data_samples}'
    )

    pretrain_dataset_provider.prefetch_shard(index + 1)

    model.train()

    epoch_step = 0
    rounds = 20
    all_step_time = 0.0
    step_counts = 0

    for _, batch_index in enumerate(tqdm(dataset_iterator, smoothing=1)):
        try:
            step_start = time.time()
            batch = pretrain_dataset_provider.get_batch(batch_index)
            batch = tuple(t.to(args.device) for t in batch)  # Move to GPU

            # Calculate forward pass
            loss = model.network(batch)
            unscaled_loss = loss.item()
            current_data_sample_count += (args.train_micro_batch_size_per_gpu *
                                          dist.get_world_size())

            # Prefetch training data
            pretrain_dataset_provider.prefetch_batch()

            model.network.backward(loss)

            loss = None

            if model.network.is_gradient_accumulation_boundary():
                if args.fp16:
                    # modify learning rate with special warm up BERT uses
                    # if args.fp16 is False, BertAdam is used that handles this automatically
                    lr_this_step = update_learning_rate(
                        args, config, global_step, optimizer)

                report_step_metrics(args, lr_this_step, unscaled_loss,
                                    global_step, current_data_sample_count)

                model.network.step()

                report_lamb_coefficients(args, optimizer)
                global_step += 1
                epoch_step += 1
            else:
                # Call DeepSpeed engine step on micro steps
                model.network.step()

        except StopIteration:
            continue

        current_global_step = global_step - last_global_step_from_restore
        if is_time_to_exit(args=args,
                           epoch_steps=epoch_step,
                           global_steps=current_global_step):
            print(
                f'Warning: Early epoch termination due to max steps limit, epoch step ={epoch_step}, global step = {current_global_step}, epoch = {index+1}'
            )
            break
        step_time = time.time() - step_start
        all_step_time += step_time
        if global_step % rounds == 0 and global_step != 0 and model.network.is_gradient_accumulation_boundary(
        ) and dist.get_rank() == 0:
            one_step_bs = args.train_micro_batch_size_per_gpu * args.gradient_accumulation_steps * dist.get_world_size(
            ) * rounds
            print(' At step {}, the throughput is {:2f} Samples/s'.format(
                global_step * args.gradient_accumulation_steps,
                one_step_bs / all_step_time),
                  flush=True)
            all_step_time = 0.0

    pretrain_dataset_provider.release_shard(index)

    global_data_samples = current_data_sample_count

    # Run Validation Loss
    if not finetune and args.max_seq_length == 512:
        pretrain_validation(args, index, model)
Exemple #12
0
def print_object(obj, name, exclude_list=[]):
    logger.info('{}:'.format(name))
    for arg in sorted(vars(obj)):
        if not arg in exclude_list:
            dots = '.' * (29 - len(arg))
            logger.info('  {} {} {}'.format(arg, dots, getattr(obj, arg)))