コード例 #1
0
def train(model, optimizer, lr_scheduler, train_data_iterator,
          val_data_iterator, timers, args, writer):
    """Train the model."""

    # Turn on training mode which enables dropout.
    model.train()

    # Tracking loss.
    total_lm_loss = 0.0
    total_nsp_loss = 0.0

    # Iterations.
    iteration = args.iteration
    skipped_iters = 0

    timers('interval time').start()
    report_memory_flag = True
    while iteration < args.train_iters:

        lm_loss, nsp_loss, skipped_iter = train_step(train_data_iterator,
                                                     model, optimizer,
                                                     lr_scheduler, args,
                                                     timers)
        skipped_iters += skipped_iter
        iteration += 1

        # Update losses.
        current_lm_loss = lm_loss.data.detach().float()
        current_nsp_loss = nsp_loss.data.detach().float()
        total_lm_loss += current_lm_loss
        total_nsp_loss += current_nsp_loss

        # Logging.

        if args.DDP_impl == 'torch':
            timers_to_log = [
                'forward', 'backward', 'optimizer', 'batch generator',
                'data loader'
            ]
        else:
            timers_to_log = [
                'forward', 'backward', 'allreduce', 'optimizer',
                'batch generator', 'data loader'
            ]

        learning_rate = optimizer.param_groups[0]['lr']

        if writer and args.rank == 0:
            writer.add_scalar('learning_rate', learning_rate, iteration)
            writer.add_scalar('lm_loss', current_lm_loss, iteration)
            writer.add_scalar('nsp_loss', current_nsp_loss, iteration)
            if args.fp16:
                writer.add_scalar('loss_scale', optimizer.loss_scale,
                                  iteration)
            normalizer = iteration % args.log_interval
            if normalizer == 0:
                normalizer = args.log_interval
            timers.write(timers_to_log,
                         writer,
                         iteration,
                         normalizer=normalizer)

        if iteration % args.log_interval == 0:
            avg_nsp_loss = total_nsp_loss.item() / args.log_interval
            avg_lm_loss = total_lm_loss.item() / args.log_interval
            elapsed_time = timers('interval time').elapsed()
            if writer and args.rank == 0:
                writer.add_scalar('iteration_time',
                                  elapsed_time / args.log_interval, iteration)
            log_string = ' iteration {:8d}/{:8d} |'.format(
                iteration, args.train_iters)
            log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
                elapsed_time * 1000.0 / args.log_interval)
            log_string += ' learning rate {:.3E} |'.format(learning_rate)
            log_string += ' lm loss {:.6E} |'.format(avg_lm_loss)
            log_string += ' nsp loss {:.6E} |'.format(avg_nsp_loss)
            if args.fp16:
                log_string += ' loss scale {:.1f} |'.format(
                    optimizer.loss_scale)
            print_rank_0(log_string)
            total_nsp_loss = 0.0
            total_lm_loss = 0.0
            if report_memory_flag:
                report_memory('after {} iterations'.format(iteration))
                report_memory_flag = False
            timers.log(timers_to_log, normalizer=args.log_interval)

        # Autoresume
        if (iteration % args.adlr_autoresume_interval
                == 0) and args.adlr_autoresume:
            check_adlr_autoresume_termination(iteration, model, optimizer,
                                              lr_scheduler, args)

        # Checkpointing
        if args.save and args.save_interval and iteration % args.save_interval == 0:
            save_checkpoint(iteration, model, optimizer, lr_scheduler, args)

        # Evaluation
        if args.eval_interval and iteration % args.eval_interval == 0 and args.do_valid:
            prefix = 'iteration {}'.format(iteration)
            evaluate_and_print_results(prefix, val_data_iterator, model, args,
                                       writer, iteration, timers, False)

        if args.exit_interval and iteration % args.exit_interval == 0:
            torch.distributed.barrier()
            time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
            rank = torch.distributed.get_rank()
            print('rank: {} | time: {} | exiting the program at iteration {}'.
                  format(rank, time_str, iteration),
                  flush=True)
            exit()

    return iteration, skipped_iters
コード例 #2
0
def train(model,
          optimizer,
          lr_scheduler,
          train_data_iterator,
          val_data_iterator,
          timers,
          args,
          summary_writer=None):
    """Train the model."""

    # Turn on training mode which enables dropout.
    model.train()

    # Tracking loss.
    total_lm_loss = 0.0

    # Iterations.
    skipped_iters = 0

    timers('interval time').start()
    report_memory_flag = True
    mems = []
    while args.iteration < args.train_iters:

        lm_loss, skipped_iter, mems = train_step(
            train_data_iterator,
            model,
            optimizer,
            lr_scheduler,
            args,
            timers,
            mems=mems,
            forward_step_func=forward_step)
        skipped_iters += skipped_iter
        args.iteration += 1

        # Update losses.
        total_lm_loss += lm_loss.data.detach().float()

        # Logging.
        if args.iteration % args.log_interval == 0:
            learning_rate = optimizer.param_groups[0]['lr']
            avg_lm_loss = total_lm_loss.item() / args.log_interval
            elapsed_time = timers('interval time').elapsed()
            report_iteration_metrics(summary_writer, optimizer, learning_rate,
                                     avg_lm_loss,
                                     elapsed_time * 1000.0 / args.log_interval,
                                     args.iteration, args.train_iters, args)
            total_lm_loss = 0.0
            if report_memory_flag:
                report_memory('after {} iterations'.format(args.iteration))
                report_memory_flag = False
            # for i in range(torch.distributed.get_world_size()):
            #     if i == torch.distributed.get_rank():
            #         print(get_hostname())
            #         timers.log(['forward', 'backward', 'optimizer',
            #                     'batch generator', 'data loader'],
            #                    normalizer=args.log_interval, reset=False)
            #     torch.distributed.barrier()
            if args.deepspeed or args.DDP_impl == 'torch':
                timers.log([
                    'forward', 'backward', 'optimizer', 'batch generator',
                    'data loader'
                ],
                           normalizer=args.log_interval)
            else:
                timers.log([
                    'forward', 'backward', 'allreduce', 'optimizer',
                    'batch generator', 'data loader'
                ],
                           normalizer=args.log_interval)
        # Checkpointing
        if args.save and args.save_interval and args.iteration % args.save_interval == 0:
            save_checkpoint(args.iteration, model, optimizer, lr_scheduler,
                            args)

        # Evaluation
        if args.eval_interval and args.iteration % args.eval_interval == 0 and args.do_valid:
            prefix = 'iteration {}'.format(args.iteration)
            evaluate_and_print_results(prefix,
                                       val_data_iterator,
                                       model,
                                       args,
                                       timers,
                                       verbose=False,
                                       step=args.iteration,
                                       summary_writer=summary_writer,
                                       forward_step_func=forward_step)

    return args.iteration, skipped_iters
コード例 #3
0
def train(args, model, train_data, dev_data, device, tokenizer=None):
    args.train_batch_size = args.per_device_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(train_data) if args.local_rank == -1 else DistributedSampler(train_data)
    train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size, collate_fn=default_data_collator)
    # Optimizer
    # Split weights in two groups, one with weight decay and the other not.
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": args.weight_decay,
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
    # Scheduler and math around the number of training steps.
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
    if args.max_train_steps is None:
        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
    else:
        args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
    lr_scheduler = get_scheduler(
        name=args.lr_scheduler_type,
        optimizer=optimizer,
        num_warmup_steps=int(args.max_train_steps * args.warmup),
        num_training_steps=args.max_train_steps,
    )
    start_epoch = 0
    purge_step = None
    if os.path.isdir(args.model_name_or_path):
        start_epoch, completed_steps, args.max_train_steps = load_checkpoint_from_disk(args.model_name_or_path, oprimizer, lr_scheduler)
        purge_step = completed_steps
    if args.local_rank in [-1, 0]:
        if args.tensorboard_dir is not None:
            tb_writer = SummaryWriter(args.tensorboard_dir, purge_step=purge_step)
    # Distributed training (should be after apex fp16 initialization)
    model = set_model_distributed(args, model)
    accs = AverageMeter()
    # Train!
    world_size = (torch.distributed.get_world_size() if args.local_rank != -1 else 1)
    total_batch_size = args.per_device_train_batch_size * world_size * args.gradient_accumulation_steps

    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {len(train_data)}")
    logger.info(f"  Num Epochs = {args.num_train_epochs}")
    logger.info(f"  Instantaneous batch size per device = {args.per_device_train_batch_size}")
    logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
    logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
    logger.info(f"  Total optimization steps = {args.max_train_steps}")
    if os.path.isdir(args.model_name_or_path):
        logger.info(f"  Load checkpoint from {args.model_name_or_path}")
        logger.info(f"  Competed optimization steps = {completed_steps}")
        logger.info(f"  Start epoch = {start_epoch}")
    else:
        completed_steps = 0
    # Only show the progress bar once on each machine.
    # progress_bar = tqdm(range(args.max_train_steps), disable=args.local_rank not in [-1, 0])
    step_loss = 0
    log_loss = 0.0
    accumulate_step = 0
    best_acc = 0
    best_acc_step = 0
    global start
    start = time.time()  # log time
    see_memory = True
    model.zero_grad()
    for epoch in range(start_epoch, args.num_train_epochs):
        model.train()
        for step, batch in enumerate(train_dataloader):
            # if step == 0:
            #     print(tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(batch['input_ids'][0][0])))
            #     print(batch['labels'][0])
            #     print(batch['input_ids'].shape)
            # print(batch['attention_mask'][0][0])
            # print(batch['token_type_ids'][0][0])
            for key in batch:
                if torch.is_tensor(batch[key]):
                    batch[key] = batch[key].to(device)
            input_ids, attention_masks, token_type_ids, labels = batch['input_ids'], batch['attention_mask'], batch['token_type_ids'], batch['labels']
            batch_size = input_ids.shape[0]
            sequence_len = input_ids.shape[-1]
            inputs = {"input_ids": input_ids,
                      "attention_mask": attention_masks,
                      "token_type_ids": token_type_ids,
                      "labels": labels}
            outputs = model(**inputs)
            loss = outputs.loss
            loss = loss / args.gradient_accumulation_steps
            accumulate_step += 1
            loss.backward()
            step_loss += loss.item()
            logits = outputs.logits
            acc = (logits.argmax(1)==labels).sum().item() / batch_size
            accs.update(acc, batch_size)
            if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()
                # progress_bar.update(1)
                completed_steps += 1
                step_loss = step_loss * args.gradient_accumulation_steps / accumulate_step
                log_loss += step_loss
                if completed_steps % args.log_steps == 0:
                    log_loss = log_loss / args.log_steps
                    log_acc = accs.avg
                    accs.reset()
                    if args.local_rank != -1:
                        # reduce from all process
                        log_loss = torch.tensor([log_loss], device=device)
                        torch.distributed.all_reduce(log_loss)
                        log_loss = log_loss[0] / torch.distributed.get_world_size()
                        log_loss = log_loss.item()
                        log_acc = torch.tensor([log_acc], device=device)
                        torch.distributed.all_reduce(log_acc)
                        log_acc = (log_acc[0] / torch.distributed.get_world_size()).item()
                    consume_time = (time.time() - start) / args.log_steps
                    time_left = consume_time * (args.max_train_steps - completed_steps)
                    if args.local_rank in [-1, 0]:
                        if see_memory:
                            report_memory('(after {} steps)'.format(completed_steps))
                        # log information
                        tb_writer.add_scalar('lr', lr_scheduler.get_last_lr()[0], completed_steps)
                        tb_writer.add_scalar('loss/Train', log_loss, completed_steps)
                        tb_writer.add_scalar('accuracy/Train', log_acc, completed_steps)
                        logger.info("Epoch {} | Steps {:d} | loss {:.3f} | acc {:.3f} | Seconds per batch: {:.3f} | Time left {:.3f}".format(epoch, completed_steps, log_loss, log_acc, consume_time, time_left))
                        # progress_bar.set_description("Train loss: {:.4f}".format(log_loss))
                    start = time.time()  # reset time
                    see_memory = False
                    log_loss = 0
                step_loss = 0.0
                accumulate_step = 0
                if args.eval_steps is not None and completed_steps % args.eval_steps == 0:
                    # evaluation during train
                    loss, eval_metric = evaluation(args, model, dev_data, device)
                    if args.local_rank in [-1, 0]:
                        # only main process can log
                        tb_writer.add_scalar('loss/Eval', loss, completed_steps)
                        for key in eval_metric:
                            tb_writer.add_scalar(f'{key}/Eval', eval_metric[key], completed_steps)
                        assert "accuracy" in eval_metric
                        if eval_metric['accuracy'] > best_acc:
                            best_acc = eval_metric['accuracy']
                            best_acc_step = completed_steps
                            output_dir = os.path.join(args.output_dir, "best")
                            save_model(model, optimizer, lr_scheduler, output_dir, epoch, completed_steps, args.max_train_steps)
                        logger.info("Best accuracy {:.3f} on step {:d}".format(best_acc, best_acc_step))
                        start = time.time()  # reset time
                if args.local_rank in [-1, 0] and args.save_steps is not None and completed_steps % args.save_steps == 0:
                    # save the model on main process
                    output_dir = os.path.join(args.output_dir, "checkpoint-{:d}".format(completed_steps))
                    save_model(model, optimizer, lr_scheduler, output_dir, epoch, completed_steps, args.max_train_steps)
                    start = time.time() # reset time

            if completed_steps >= args.max_train_steps:
                break
    loss, eval_metric = evaluation(args, model, dev_data, device)
    if args.local_rank in [-1, 0]:
        # only main process can log
        tb_writer.add_scalar('loss/Eval', loss, completed_steps)
        for key in eval_metric:
            tb_writer.add_scalar(f'{key}/Eval', eval_metric[key], completed_steps)
    if args.local_rank in [-1, 0]:
        # save the model on main process
        output_dir = os.path.join(args.output_dir, "checkpoint-{:d}".format(completed_steps))
        if not os.path.exists(output_dir):
            os.mkdir(output_dir)
        save_model(model, optimizer, lr_scheduler, output_dir, epoch, completed_steps, args.max_train_steps)
    return completed_steps
コード例 #4
0
def train(model, optimizer, lr_scheduler,
          train_data_iterator, val_data_iterator, timers, args):
    """Train the model."""

    # Turn on training mode which enables dropout.
    model.train()

    # Tracking loss.
    total_lm_loss = 0.0

    # Iterations.
    iteration = args.iteration
    skipped_iters = 0

    timers('interval time').start()
    report_memory_flag = True
    while iteration < args.train_iters:

        lm_loss, skipped_iter = train_step(train_data_iterator,
                                           model,
                                           optimizer,
                                           lr_scheduler,
                                           args, timers)
        skipped_iters += skipped_iter
        iteration += 1

        # Update losses.
        total_lm_loss += lm_loss.data.detach().float()

        # Logging.
        if iteration % args.log_interval == 0:
            learning_rate = optimizer.param_groups[0]['lr']
            avg_lm_loss = total_lm_loss.item() / args.log_interval
            elapsed_time = timers('interval time').elapsed()
            log_string = ' iteration {:8d}/{:8d} |'.format(iteration,
                                                            args.train_iters)
            log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
                elapsed_time * 1000.0 / args.log_interval)
            log_string += ' learning rate {:.3E} |'.format(learning_rate)
            log_string += ' lm loss {:.6E} |'.format(avg_lm_loss)
            if args.fp16:
                log_string += ' loss scale {:.1f} |'.format(
                    optimizer.cur_scale if args.deepspeed else optimizer.loss_scale)
            print_rank_0(log_string)
            total_lm_loss = 0.0
            if report_memory_flag:
                report_memory('after {} iterations'.format(iteration))
                report_memory_flag = False
            if USE_TORCH_DDP:
                timers.log(['forward', 'backward', 'optimizer',
                            'batch generator', 'data loader'],
                           normalizer=args.log_interval)
            else:
                timers.log(['forward', 'backward', 'allreduce', 'optimizer',
                            'batch generator', 'data loader'],
                           normalizer=args.log_interval)
        # Checkpointing
        if args.save and args.save_interval and iteration % args.save_interval == 0:
            save_checkpoint(iteration, model, optimizer, lr_scheduler, args)

        # Evaluation
        if args.eval_interval and iteration % args.eval_interval == 0 and args.do_valid:
            prefix = 'iteration {}'.format(iteration)
            evaluate_and_print_results(
                prefix, val_data_iterator, model, args, timers, False)

        if args.exit_interval and iteration % args.exit_interval == 0:
            torch.distributed.barrier()
            time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
            rank = torch.distributed.get_rank()
            print('rank: {} | time: {} | exiting the program at iteration {}'.
                  format(rank, time_str, iteration), flush=True)
            exit()

    return iteration, skipped_iters
コード例 #5
0
def train(model,
          optimizer,
          lr_scheduler,
          train_data_iterator,
          val_data_iterator,
          timers,
          args,
          summary_writer=None):
    """Train the model."""

    # Turn on training mode which enables dropout.
    model.train()

    # Tracking loss.
    total_lm_loss = 0.0

    # Iterations.
    skipped_iters = 0

    timers('interval time').start()
    report_memory_flag = True
    mems = []
    while args.iteration < args.train_iters:

        lm_loss, skipped_iter, mems = train_step(train_data_iterator, model,
                                                 optimizer, lr_scheduler, args,
                                                 timers, mems)
        skipped_iters += skipped_iter
        args.iteration += 1

        # Update losses.
        total_lm_loss += lm_loss.data.detach().float()

        # Logging.
        if args.iteration % args.log_interval == 0:
            learning_rate = optimizer.param_groups[0]['lr']
            avg_lm_loss = total_lm_loss.item() / args.log_interval
            elapsed_time = timers('interval time').elapsed()
            report_iteration_metrics(summary_writer, optimizer, learning_rate,
                                     avg_lm_loss,
                                     elapsed_time * 1000.0 / args.log_interval,
                                     args.iteration, args.train_iters, args)
            total_lm_loss = 0.0
            if report_memory_flag:
                report_memory('after {} iterations'.format(args.iteration))
                report_memory_flag = False
            if USE_TORCH_DDP:
                timers.log([
                    'forward', 'backward', 'optimizer', 'batch generator',
                    'data loader'
                ],
                           normalizer=args.log_interval)
            else:
                timers.log([
                    'forward', 'backward', 'allreduce', 'optimizer',
                    'batch generator', 'data loader'
                ],
                           normalizer=args.log_interval)
        # Checkpointing
        if args.save and args.save_interval and args.iteration % args.save_interval == 0:
            save_checkpoint(args.iteration, model, optimizer, lr_scheduler,
                            args)

        # Evaluation
        if args.eval_interval and args.iteration % args.eval_interval == 0 and args.do_valid:
            prefix = 'iteration {}'.format(args.iteration)
            evaluate_and_print_results(prefix,
                                       val_data_iterator,
                                       model,
                                       args,
                                       timers,
                                       False,
                                       step=args.iteration,
                                       summary_writer=summary_writer)

        if args.exit_interval and args.iteration % args.exit_interval == 0:
            torch.distributed.barrier()
            time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
            rank = torch.distributed.get_rank()
            print('rank: {} | time: {} | exiting the program at iteration {}'.
                  format(rank, time_str, args.iteration),
                  flush=True)
            exit()

    return args.iteration, skipped_iters