def train_batch(self, data_iterator, epoch_idx, batch_idx):
        if self.neox_args.is_pipe_parallel:
            reduced_loss = megatron_train.train_step_pipe(
                neox_args=self.neox_args,
                timers=self.timers,
                model=self.model,
                data_iterator=data_iterator,
            )
        else:
            losses = []
            for _ in range(self.neox_args.gradient_accumulation_steps):
                self.timers("forward").start()
                loss = megatron_train.forward_step(
                    neox_args=self.neox_args,
                    timers=self.timers,
                    data_iterator=data_iterator,
                    model=self.model,
                )
                self.timers("forward").stop()
                losses.append(loss)
                # Calculate gradients, reduce across processes, and clip.
                self.timers("backward").start()
                megatron_train.backward_step(
                    neox_args=self.neox_args,
                    timers=self.timers,
                    optimizer=self.optimizer,
                    model=self.model,
                    loss=loss,
                )
                self.timers("backward").stop()
                # Update parameters.
                self.timers("optimizer").start()
                if self.neox_args.deepspeed:
                    self.model.step()
                else:
                    raise ValueError("Must be using deepspeed to run neox")
                self.timers("optimizer").stop()
            reduced_loss = {
                "lm_loss": megatron_utils.reduce_losses(losses).mean()
            }

        if self.neox_args.precision == "fp16" and self.model.optimizer.overflow:
            skipped_iter = 1
        else:
            skipped_iter = 0
        self.neox_args.iteration += 1

        self.overflow_monitor.check(
            skipped_iter)  # check for repeated overflow
        if self.neox_args.log_gradient_noise_scale:  # log noise scale if applicable
            self.noise_scale_logger.update()

        # get learning rate (if present) - if doing soft prompt tuning + pipe parallel, you
        # may have no tunable parameters on a specific rank
        if self.optimizer.param_groups:
            lr = self.optimizer.param_groups[0].get("lr", 0)
        else:
            lr = 0

        # Logging.
        self.report_memory_flag, additional_metrics = megatron_train.training_log(
            neox_args=self.neox_args,
            timers=self.timers,
            loss_dict=reduced_loss,
            total_loss_dict=self.total_train_loss_dict,
            learning_rate=lr,
            iteration=self.neox_args.iteration,
            loss_scale=self.optimizer.cur_scale
            if self.neox_args.precision == "fp16" else None,
            report_memory_flag=self.report_memory_flag,
            skipped_iter=skipped_iter,
            model=self.model,
            optimizer=self.optimizer,
            noise_scale_logger=self.noise_scale_logger,
            return_metrics=True,
        )
        if (additional_metrics is not None
                and additional_metrics["num_nans"] == 0
                and additional_metrics["num_skipped"] == 0):
            self.tflops = additional_metrics["flops_per_sec_per_gpu"] / 10**12

        if (self.neox_args.exit_interval and
                self.neox_args.iteration % self.neox_args.exit_interval == 0):
            torch.distributed.barrier()
            time_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
            rank = torch.distributed.get_rank()
            megatron_utils.print_rank_0(
                "time: {} | exiting the program at iteration {}".format(
                    time_str, self.neox_args.iteration))
            self.context.set_stop_requested(True)
        return reduced_loss
def _train(model, optimizer, lr_scheduler, forward_step, train_dataloader,
           valid_dataloader, end_of_epoch_callback):
    """Train the model."""
    args = get_args()
    timers = get_timers()

    # Turn on training mode which enables dropout.
    model.train()

    # Tracking loss.
    losses_dict_sum = {}

    # Starting epoch and iteration
    start_epoch = args.iteration // args.train_iters_per_epoch
    start_iteration = args.iteration % args.train_iters_per_epoch
    iteration = args.iteration

    # Memory reporting flag.
    report_memory_flag = True

    # For each remaining epoch
    timers('interval time').start()
    for epoch in range(start_epoch, args.epochs):
        print_rank_0('working on epoch {} ...'.format(epoch + 1))

        # Set the data loader epoch to shuffle the index iterator.
        train_dataloader.sampler.set_epoch(args.seed + epoch)

        # For all the batches in the dataset.
        for iteration_, batch in enumerate(train_dataloader):

            # Ignore the iterations before starting value
            if iteration_ < start_iteration:
                continue
            # Set to zero so the next epoch does not skip any batches.
            start_iteration = 0

            # Train for one step.
            losses_dict, _ = train_step(forward_step, batch, model, optimizer,
                                        lr_scheduler)
            iteration += 1

            # Logging.
            report_memory_flag = training_log(losses_dict, losses_dict_sum,
                                              optimizer.param_groups[0]['lr'],
                                              iteration, optimizer.loss_scale,
                                              report_memory_flag)

            # Autoresume
            if args.adlr_autoresume and \
               (iteration % args.adlr_autoresume_interval == 0):
                check_adlr_autoresume_termination(iteration, model, optimizer,
                                                  lr_scheduler)

            # Checkpointing
            if args.save and args.save_interval and \
               iteration % args.save_interval == 0:
                save_checkpoint(iteration, model, optimizer, lr_scheduler)

            # Evaluation
            if args.eval_interval and iteration % args.eval_interval == 0:
                prefix = 'iteration {}'.format(iteration)
                evaluate_and_print_results(prefix, forward_step,
                                           valid_dataloader, model, iteration,
                                           False)

        # Checkpointing at the end of each epoch.
        if args.save:
            save_checkpoint(iteration, model, optimizer, lr_scheduler)

        # Callback at the end of each epoch.
        if end_of_epoch_callback is not None:
            end_of_epoch_callback(model, epoch)
Esempio n. 3
0
def _train(model, optimizer, lr_scheduler, forward_step, train_dataloader,
           valid_dataloader, end_of_epoch_callback):
    """Train the model."""
    args = get_args()
    timers = get_timers()

    assert get_num_microbatches(
    ) == 1, "finetuning with gradient accumulation doesn't currently work"

    # Turn on training mode which enables dropout.
    for m in model:
        m.train()

    # Tracking loss.
    losses_dict_sum = {}

    # Starting epoch and iteration
    start_epoch = args.iteration // args.train_iters_per_epoch
    start_iteration = args.iteration % args.train_iters_per_epoch
    iteration = args.iteration

    # Memory reporting flag.
    report_memory_flag = True

    # For each remaining epoch
    timers('interval-time').start()
    for epoch in range(start_epoch, args.epochs):
        print_rank_0('working on epoch {} ...'.format(epoch + 1))

        # Set the data loader epoch to shuffle the index iterator.
        train_dataloader.sampler.set_epoch(args.seed + epoch)

        # For all the batches in the dataset.
        for iteration_, batch in enumerate(train_dataloader):

            # Ignore the iterations before starting value
            if iteration_ < start_iteration:
                continue
            # Set to zero so the next epoch does not skip any batches.
            start_iteration = 0

            # Train for one step.
            out = train_step(forward_step, batch, model, optimizer,
                             lr_scheduler)

            losses_dict, skipped_iter, grad_norm, num_zeros_in_grad = out
            iteration += 1

            # Logging.
            params_norm = None
            if args.log_params_norm:
                params_norm = calc_params_l2_norm(model)
            report_memory_flag = training_log(
                losses_dict, losses_dict_sum, optimizer.param_groups[0]['lr'],
                iteration,
                optimizer.get_loss_scale().item(), report_memory_flag,
                skipped_iter, grad_norm, params_norm, num_zeros_in_grad)

            # Autoresume
            if args.adlr_autoresume and \
               (iteration % args.adlr_autoresume_interval == 0):
                check_adlr_autoresume_termination(iteration, model, optimizer,
                                                  lr_scheduler)

            # Checkpointing
            saved_checkpoint = False
            if args.save and args.save_interval and \
               iteration % args.save_interval == 0:
                save_checkpoint(iteration, model, optimizer, lr_scheduler)
                saved_checkpoint = True

            # Evaluation
            if args.eval_interval and iteration % args.eval_interval == 0:
                prefix = 'iteration {}'.format(iteration)
                evaluate_and_print_results(prefix, forward_step,
                                           valid_dataloader, model, iteration,
                                           False)

            # Exiting based on iterations
            if args.exit_interval and iteration % args.exit_interval == 0:
                if not saved_checkpoint:
                    save_checkpoint(iteration, model, optimizer, lr_scheduler)
                torch.distributed.barrier()
                print_rank_0(
                    'exiting program at iteration {}'.format(iteration))
                sys.exit()

        # Checkpointing at the end of each epoch.
        if args.save:
            save_checkpoint(iteration, model, optimizer, lr_scheduler)

        # Callback at the end of each epoch.
        if end_of_epoch_callback is not None:
            end_of_epoch_callback(model, epoch)