예제 #1
0
def run_train_test(yaml_list=None, param_dict=None):
    from megatron.training import train_step
    from megatron.utils import Timers

    max_steps = 64

    model, optimizer, lr_scheduler, args_loaded = model_setup(
        yaml_list, param_dict)

    model.train()

    timers = Timers(use_wandb=False, tensorboard_writer=None)

    # generate some random data on which we can overfit
    # context size of data is model seq_len + 1 in order to compute loss
    data_list = list()
    context_tokens_tensor = torch.randint(0, args_loaded.padded_vocab_size,
                                          (4, args_loaded.seq_length + 1)).to(
                                              torch.int64)
    for i in range(max_steps):
        data_list.append({"text": context_tokens_tensor.clone()})
    data_iterator = iter(data_list)

    # run train_step until the loss decreases
    losses = list()
    for i in range(max_steps):
        loss_dict, skipped_iter = train_step(
            neox_args=args_loaded,
            timers=timers,
            data_iterator=data_iterator,
            model=model,
            optimizer=optimizer,
            lr_scheduler=lr_scheduler,
        )
        losses.append(loss_dict["lm_loss"])
        if len(losses) >= 2:
            if torch.isnan(losses[-1]):
                continue
            if torch.isnan(losses[-2]):
                continue
            if losses[-1] < losses[-2]:
                return  # all good

    # loss should have decreased by now (otherwise increasing the max_steps parameter could have the testcase pass)
    assert losses[-1] < losses[-2], (
        "run_train_test() loss going down within " + str(max_steps) + " steps")

    if torch.distributed.get_world_size() == 1 or torch.distributed.get_rank(
    ) == 0:
        clear_test_dirs()
def _train(model, optimizer, lr_scheduler, forward_step, train_dataloader,
           valid_dataloader, end_of_epoch_callback):
    """Train the model."""
    args = get_args()
    timers = get_timers()

    # Turn on training mode which enables dropout.
    model.train()

    # Tracking loss.
    losses_dict_sum = {}

    # Starting epoch and iteration
    start_epoch = args.iteration // args.train_iters_per_epoch
    start_iteration = args.iteration % args.train_iters_per_epoch
    iteration = args.iteration

    # Memory reporting flag.
    report_memory_flag = True

    # For each remaining epoch
    timers('interval time').start()
    for epoch in range(start_epoch, args.epochs):
        print_rank_0('working on epoch {} ...'.format(epoch + 1))

        # Set the data loader epoch to shuffle the index iterator.
        train_dataloader.sampler.set_epoch(args.seed + epoch)

        # For all the batches in the dataset.
        for iteration_, batch in enumerate(train_dataloader):

            # Ignore the iterations before starting value
            if iteration_ < start_iteration:
                continue
            # Set to zero so the next epoch does not skip any batches.
            start_iteration = 0

            # Train for one step.
            losses_dict, _ = train_step(forward_step, batch, model, optimizer,
                                        lr_scheduler)
            iteration += 1

            # Logging.
            report_memory_flag = training_log(losses_dict, losses_dict_sum,
                                              optimizer.param_groups[0]['lr'],
                                              iteration, optimizer.loss_scale,
                                              report_memory_flag)

            # Autoresume
            if args.adlr_autoresume and \
               (iteration % args.adlr_autoresume_interval == 0):
                check_adlr_autoresume_termination(iteration, model, optimizer,
                                                  lr_scheduler)

            # Checkpointing
            if args.save and args.save_interval and \
               iteration % args.save_interval == 0:
                save_checkpoint(iteration, model, optimizer, lr_scheduler)

            # Evaluation
            if args.eval_interval and iteration % args.eval_interval == 0:
                prefix = 'iteration {}'.format(iteration)
                evaluate_and_print_results(prefix, forward_step,
                                           valid_dataloader, model, iteration,
                                           False)

        # Checkpointing at the end of each epoch.
        if args.save:
            save_checkpoint(iteration, model, optimizer, lr_scheduler)

        # Callback at the end of each epoch.
        if end_of_epoch_callback is not None:
            end_of_epoch_callback(model, epoch)
예제 #3
0
def _train(model, optimizer, lr_scheduler, forward_step, train_dataloader,
           valid_dataloader, end_of_epoch_callback):
    """Train the model."""
    args = get_args()
    timers = get_timers()

    assert get_num_microbatches(
    ) == 1, "finetuning with gradient accumulation doesn't currently work"

    # Turn on training mode which enables dropout.
    for m in model:
        m.train()

    # Tracking loss.
    losses_dict_sum = {}

    # Starting epoch and iteration
    start_epoch = args.iteration // args.train_iters_per_epoch
    start_iteration = args.iteration % args.train_iters_per_epoch
    iteration = args.iteration

    # Memory reporting flag.
    report_memory_flag = True

    # For each remaining epoch
    timers('interval-time').start()
    for epoch in range(start_epoch, args.epochs):
        print_rank_0('working on epoch {} ...'.format(epoch + 1))

        # Set the data loader epoch to shuffle the index iterator.
        train_dataloader.sampler.set_epoch(args.seed + epoch)

        # For all the batches in the dataset.
        for iteration_, batch in enumerate(train_dataloader):

            # Ignore the iterations before starting value
            if iteration_ < start_iteration:
                continue
            # Set to zero so the next epoch does not skip any batches.
            start_iteration = 0

            # Train for one step.
            out = train_step(forward_step, batch, model, optimizer,
                             lr_scheduler)

            losses_dict, skipped_iter, grad_norm, num_zeros_in_grad = out
            iteration += 1

            # Logging.
            params_norm = None
            if args.log_params_norm:
                params_norm = calc_params_l2_norm(model)
            report_memory_flag = training_log(
                losses_dict, losses_dict_sum, optimizer.param_groups[0]['lr'],
                iteration,
                optimizer.get_loss_scale().item(), report_memory_flag,
                skipped_iter, grad_norm, params_norm, num_zeros_in_grad)

            # Autoresume
            if args.adlr_autoresume and \
               (iteration % args.adlr_autoresume_interval == 0):
                check_adlr_autoresume_termination(iteration, model, optimizer,
                                                  lr_scheduler)

            # Checkpointing
            saved_checkpoint = False
            if args.save and args.save_interval and \
               iteration % args.save_interval == 0:
                save_checkpoint(iteration, model, optimizer, lr_scheduler)
                saved_checkpoint = True

            # Evaluation
            if args.eval_interval and iteration % args.eval_interval == 0:
                prefix = 'iteration {}'.format(iteration)
                evaluate_and_print_results(prefix, forward_step,
                                           valid_dataloader, model, iteration,
                                           False)

            # Exiting based on iterations
            if args.exit_interval and iteration % args.exit_interval == 0:
                if not saved_checkpoint:
                    save_checkpoint(iteration, model, optimizer, lr_scheduler)
                torch.distributed.barrier()
                print_rank_0(
                    'exiting program at iteration {}'.format(iteration))
                sys.exit()

        # Checkpointing at the end of each epoch.
        if args.save:
            save_checkpoint(iteration, model, optimizer, lr_scheduler)

        # Callback at the end of each epoch.
        if end_of_epoch_callback is not None:
            end_of_epoch_callback(model, epoch)