예제 #1
0
def get_optimizer(model, args):
    """Set up the optimizer."""

    # Build parameter groups (weight decay and non-decay).
    while isinstance(model, (DDP, FP16_Module)):
        model = model.module
    param_groups = gpt2_get_params_for_weight_decay_optimization(model)

    # Add model parallel attribute if it is not set.
    for param_group in param_groups:
        for param in param_group['params']:
            if not hasattr(param, 'model_parallel'):
                param.model_parallel = False

    # Use FusedAdam.
    optimizer = Adam(param_groups,
                         lr=args.lr, weight_decay=args.weight_decay)

    print(f'Optimizer = {optimizer.__class__.__name__}')

    if args.deepspeed:
        return optimizer, param_groups

    # Wrap into fp16 optimizer.
    if args.fp16:
        
        optimizer = FP16_Optimizer(optimizer,
                                   static_loss_scale=args.loss_scale,
                                   dynamic_loss_scale=args.dynamic_loss_scale,
                                   dynamic_loss_args={
                                       'scale_window': args.loss_scale_window,
                                       'min_scale': args.min_scale,
                                       'delayed_shift': args.hysteresis})

    return optimizer, param_groups
예제 #2
0
def get_optimizer(model, args):
    """Set up the optimizer."""

    # Build parameter groups (weight decay and non-decay).
    while isinstance(model, (DDP, FP16_Module)):
        model = model.module
    layers = model.model.bert.encoder.layer
    pooler = model.model.bert.pooler
    lmheads = model.model.cls.predictions
    nspheads = model.model.cls.seq_relationship
    embeddings = model.model.bert.embeddings
    param_groups = []
    param_groups += list(get_params_for_weight_decay_optimization(layers))
    param_groups += list(get_params_for_weight_decay_optimization(pooler))
    param_groups += list(get_params_for_weight_decay_optimization(nspheads))
    param_groups += list(get_params_for_weight_decay_optimization(embeddings))
    param_groups += list(
        get_params_for_weight_decay_optimization(lmheads.transform))
    param_groups[1]['params'].append(lmheads.bias)

    # Use Adam.
    optimizer = Adam(param_groups, lr=args.lr, weight_decay=args.weight_decay)

    # Wrap into fp16 optimizer.
    if args.fp16:
        optimizer = FP16_Optimizer(optimizer,
                                   static_loss_scale=args.loss_scale,
                                   dynamic_loss_scale=args.dynamic_loss_scale,
                                   dynamic_loss_args={
                                       'scale_window': args.loss_scale_window,
                                       'min_scale': args.min_scale,
                                       'delayed_shift': args.hysteresis
                                   })

    return optimizer
def get_optimizer(model, args):
    """Set up the optimizer."""

    # Build parameter groups (weight decay and non-decay).
    while isinstance(model, (DDP, FP16_Module)):
        model = model.module
    param_groups = gpt2_get_params_for_weight_decay_optimization(model)

    # Add model parallel attribute if it is not set.
    for param_group in param_groups:
        for param in param_group['params']:
            if not hasattr(param, 'model_parallel'):
                param.model_parallel = False

    if args.cpu_optimizer:
        #Apex FusedAdam uses decoupled weight decay so use the same here
        if args.cpu_torch_adam:
            cpu_adam_optimizer = torch.optim.AdamW
        else:
            #TODO add option for decoupled weight decay in DeepCPUAdam
            from deepspeed.ops.adam import DeepSpeedCPUAdam
            cpu_adam_optimizer = DeepSpeedCPUAdam
        optimizer = cpu_adam_optimizer(param_groups,
                                       lr=args.lr,
                                       weight_decay=args.weight_decay)
    else:
        # Use FusedAdam.
        optimizer = Adam(param_groups,
                         lr=args.lr,
                         weight_decay=args.weight_decay)

    print(f'Optimizer = {optimizer.__class__.__name__}')
    if args.deepspeed:
        # fp16 wrapper is not required for DeepSpeed.
        return optimizer

    # Wrap into fp16 optimizer.
    if args.fp16:
        optimizer = FP16_Optimizer(optimizer,
                                   static_loss_scale=args.loss_scale,
                                   dynamic_loss_scale=args.dynamic_loss_scale,
                                   dynamic_loss_args={
                                       'scale_window': args.loss_scale_window,
                                       'min_scale': args.min_scale,
                                       'delayed_shift': args.hysteresis
                                   })

    return optimizer
예제 #4
0
파일: train_utils.py 프로젝트: spatil6/GLM
def get_optimizer(param_groups, args):
    """Set up the optimizer."""
    if args.cpu_optimizer:
        # Apex FusedAdam uses decoupled weight decay so use the same here
        if args.cpu_torch_adam:
            cpu_adam_optimizer = torch.optim.AdamW
        else:
            from deepspeed.ops.adam import DeepSpeedCPUAdam
            cpu_adam_optimizer = DeepSpeedCPUAdam
        optimizer = cpu_adam_optimizer(param_groups,
                                       lr=args.lr,
                                       weight_decay=args.weight_decay)
    else:
        # Use FusedAdam.
        if args.optimizer == 'adam':
            optimizer = Adam(param_groups,
                             lr=args.lr,
                             weight_decay=args.weight_decay,
                             betas=(args.adam_beta1, args.adam_beta2),
                             eps=args.adam_eps)
        elif args.optimizer == 'adafactor':
            from transformers import Adafactor
            optimizer = Adafactor(param_groups,
                                  lr=args.lr,
                                  relative_step=False,
                                  warmup_init=False)
        else:
            raise NotImplementedError

    print(f'Optimizer = {optimizer.__class__.__name__}')
    if hasattr(args, "deepspeed") and args.deepspeed:
        raise NotImplementedError
        # fp16 wrapper is not required for DeepSpeed.
        # return optimizer

    # Wrap into fp16 optimizer.
    if args.fp16:
        optimizer = FP16_Optimizer(optimizer,
                                   static_loss_scale=args.loss_scale,
                                   dynamic_loss_scale=args.dynamic_loss_scale,
                                   dynamic_loss_args={
                                       'scale_window': args.loss_scale_window,
                                       'min_scale': args.min_scale,
                                       'delayed_shift': args.hysteresis
                                   })

    return optimizer
예제 #5
0
def get_optimizer(model, args):
    """Set up the optimizer."""

    # Build parameter groups (weight decay and non-decay).
    while isinstance(model, (args.DDP_type, FP16_Module)):
        model = model.module
    layers = model.model.bert.encoder.layer
    pooler = model.model.bert.pooler
    lmheads = model.model.cls.predictions
    nspheads = model.model.cls.seq_relationship
    embeddings = model.model.bert.embeddings
    param_groups = []
    param_groups += list(get_params_for_weight_decay_optimization(layers))
    param_groups += list(get_params_for_weight_decay_optimization(pooler))
    param_groups += list(get_params_for_weight_decay_optimization(nspheads))
    param_groups += list(get_params_for_weight_decay_optimization(embeddings))
    param_groups += list(
        get_params_for_weight_decay_optimization(lmheads.transform))
    param_groups[1]['params'].append(lmheads.bias)

    # Add model parallel attribute if it is not set.
    for param_group in param_groups:
        for param in param_group['params']:
            if not hasattr(param, 'model_parallel'):
                param.model_parallel = False

    # Use Adam.
    betas = (0.9, 0.999)
    optimizer = Adam(param_groups,
                     betas=betas,
                     lr=args.lr,
                     weight_decay=args.weight_decay)

    # Wrap into fp16 optimizer.
    if args.fp16:
        optimizer = FP16_Optimizer(optimizer,
                                   static_loss_scale=args.loss_scale,
                                   dynamic_loss_scale=args.dynamic_loss_scale,
                                   dynamic_loss_args={
                                       'scale_window': args.loss_scale_window,
                                       'min_scale': args.min_scale,
                                       'delayed_shift': args.hysteresis
                                   })

    return optimizer
def get_optimizer(param_groups, args):
    """Set up the optimizer."""
    if args.cpu_optimizer:
        #Apex FusedAdam uses decoupled weight decay so use the same here
        if args.cpu_torch_adam:
            cpu_adam_optimizer = torch.optim.AdamW
        else:
            #TODO add option for decoupled weight decay in DeepCPUAdam
            from deepspeed.ops.adam import DeepSpeedCPUAdam
            cpu_adam_optimizer = DeepSpeedCPUAdam
        optimizer = cpu_adam_optimizer(param_groups,
                                       lr=args.lr,
                                       weight_decay=args.weight_decay)
    else:
        # Use FusedAdam.
        optimizer = Adam(param_groups,
                         lr=args.lr,
                         weight_decay=args.weight_decay)

    print(f'Optimizer = {optimizer.__class__.__name__}')
    if hasattr(args, "deepspeed") and args.deepspeed:
        raise NotImplementedError
        # fp16 wrapper is not required for DeepSpeed.
        # return optimizer

    # Wrap into fp16 optimizer.
    if args.fp16:
        optimizer = FP16_Optimizer(optimizer,
                                   static_loss_scale=args.loss_scale,
                                   dynamic_loss_scale=args.dynamic_loss_scale,
                                   dynamic_loss_args={
                                       'scale_window': args.loss_scale_window,
                                       'min_scale': args.min_scale,
                                       'delayed_shift': args.hysteresis
                                   })

    return optimizer
예제 #7
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    N_img = 224
    N_img_scale = 256

    cudnn.benchmark = True

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(N_img),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=None)

    val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
        valdir,
        transforms.Compose([
            transforms.Resize(N_img_scale),
            transforms.CenterCrop(N_img),
            transforms.ToTensor(),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    #### #### To simplify data parallelism we make an nn module with multiple outs
    print("=> creating model '{}'".format(args.arch))
    model = models.__dict__[args.arch]()

    with open(name_log_txt, "a") as text_file:
        print(model, file=text_file)
    model = nn.DataParallel(model)
    model = model.cuda()
    if args.half:
        model = model.half()
        model = BN_convert_float(model)
    ############### Initialize all
    num_ep = args.nepochs

    ############## Resume if we need to resume
    if (args.resume):
        name = args.resume
        model_dict = torch.load(name)
        model.load_state_dict(model_dict)
        print('model loaded')


######################### Lets do the training
    criterion = nn.CrossEntropyLoss().cuda()

    lr = args.lr
    to_train = itertools.chain(model.parameters())
    optim = optimizer.SGD(to_train,
                          lr=lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    if args.half:
        optim = FP16_Optimizer(optim,
                               static_loss_scale=args.static_loss_scale,
                               dynamic_loss_scale=args.dynamic_loss_scale,
                               dynamic_loss_args={'scale_window': 1000})

    for epoch in range(args.start_epoch, num_ep + 1):
        # Make sure we set the bn right
        model.train()

        #For each epoch let's store each layer individually
        batch_time_total = AverageMeter()
        data_time = AverageMeter()
        lossm = AverageMeter()
        top1 = AverageMeter()
        top5 = AverageMeter()

        if epoch % args.epochdecay == 0:
            lr = lr / 10.0
            to_train = itertools.chain(model.parameters())
            optim = optimizer.SGD(to_train,
                                  lr=lr,
                                  momentum=args.momentum,
                                  weight_decay=args.weight_decay)
            if args.half:
                optim = FP16_Optimizer(
                    optim,
                    static_loss_scale=args.static_loss_scale,
                    dynamic_loss_scale=args.dynamic_loss_scale,
                    dynamic_loss_args={'scale_window': 1000})
        end = time.time()

        for i, (inputs, targets) in enumerate(train_loader):
            # measure data loading time
            data_time.update(time.time() - end)

            targets = targets.cuda(non_blocking=True)
            inputs = inputs.cuda(non_blocking=True)
            inputs = torch.autograd.Variable(inputs)
            targets = torch.autograd.Variable(targets)
            if args.half:
                inputs = inputs.half()

            end = time.time()

            # Forward
            optim.zero_grad()
            outputs = model(inputs)

            loss = criterion(outputs, targets)
            # update
            if args.half:
                optim.backward(loss)
            else:
                loss.backward()

            optim.step()

            # measure accuracy and record loss
            # measure elapsed time
            batch_time_total.update(time.time() - end)
            end = time.time()
            prec1, prec5 = accuracy(outputs.data, targets, topk=(1, 5))
            lossm.update(float(loss.data[0]), float(inputs.size(0)))
            top1.update(float(prec1[0]), float(inputs.size(0)))
            top5.update(float(prec5[0]), float(inputs.size(0)))

            if i % args.print_freq == 0:
                print('Epoch: [{0}][{1}/{2}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                      'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                          epoch,
                          i,
                          len(train_loader),
                          batch_time=batch_time_total,
                          data_time=data_time,
                          loss=lossm,
                          top1=top1,
                          top5=top5))

            if args.debug and i > 50:
                break

        top1test, top5test = validate(val_loader, model, criterion, epoch)
        with open(name_log_txt, "a") as text_file:
            print("lr: {}, epoch {}, train top1:{}(top5:{}), "
                  "test top1:{} (top5:{})".format(lr, epoch, top1.avg,
                                                  top5.avg, top1test,
                                                  top5test),
                  file=text_file)

    #####Checkpoint
        if not args.debug:
            torch.save(model.state_dict(), args.save_folder + '/' + \
                   name_log_txt + '_current_model.t7')

    ############Save the final model
    torch.save(model.state_dict(),
               args.save_folder + '/' + name_log_txt + '_model.t7')
예제 #8
0
def setup_model_and_optim(args, train_data, tokenizer):
    ntokens = args.data_size
    if args.model.lower() == 'transformer':
        embed_tokens = m.Embedding(
            ntokens,
            args.decoder_embed_dim,
            padding_idx=tokenizer.command_name_map['pad'].Id)
        model = m.TransformerModel(m.DecoderPreprocessor(args, embed_tokens),
                                   m.TransformerDecoder(args, embed_tokens))
    else:
        model = m.RNNModel(args.model, ntokens, args.emsize, args.nhid,
                           args.nlayers, args.dropout, args.tied)
        global rnn_model
        rnn_model = model
    LR_Warmer = None
    print('* number of parameters: %d' %
          sum([p.nelement() for p in model.parameters()]))
    if args.cuda:
        model.cuda()

    optim = None
    if args.load is not None and args.load != '':
        sd = torch.load(args.load, map_location='cpu')
        if args.load_optim:
            #optim_sd = torch.load(os.path.join(os.path.dirname(args.load), 'optim.pt'), map_location='cpu')
            rng = torch.load(os.path.join(os.path.dirname(args.load),
                                          'rng.pt'))
            torch.cuda.set_rng_state(rng[0])
            torch.set_rng_state(rng[1])
        try:
            model.load_state_dict(sd)
        except:
            if hasattr(model, 'rnn'):
                apply_weight_norm(model.rnn, hook_child=False)
            else:
                apply_weight_norm(model, hook_child=False)
            model.load_state_dict(sd)
            remove_weight_norm(model)

    if not args.no_weight_norm:
        if hasattr(model, 'rnn'):
            apply_weight_norm(model.rnn, hook_child=False)
        else:
            apply_weight_norm(model, hook_child=False)

    if optim is None:
        optim_choice = 'Adam' if args.stlr_cut_frac else args.optim
        if args.fp16:
            model = FP16_Module(model)
            optim = eval('torch.optim.' + args.optim)(model.parameters(),
                                                      lr=args.lr)
            optim = FP16_Optimizer(optim,
                                   static_loss_scale=args.loss_scale,
                                   dynamic_loss_scale=args.dynamic_loss_scale)
        else:
            optim = eval('torch.optim.' + args.optim)(model.parameters(),
                                                      lr=args.lr)

    if args.load_optim:
        optim.load_state_dict(optim_sd)

    # add linear learning rate scheduler
    if train_data is not None:
        if args.constant_decay:
            num_iters = args.constant_decay
        else:
            num_iters = args.train_iters * args.epochs

        init_step = -1
        if args.load_optim:
            #TODO: this no longer makes sense given the new data loaders
            init_step = optim_sd['iter'] - optim_sd['skipped_iter']
            train_data.batch_sampler.start_iter = (optim_sd['iter'] %
                                                   len(train_data)) + 1

        warmup_iter = args.warmup * num_iters

        if args.stlr_cut_frac is not None:
            LR = SlantedTriangularLR(optim,
                                     cut_frac=args.stlr_cut_frac,
                                     num_iters=num_iters)
        else:
            LR = AnnealingLR(optim,
                             start_lr=args.lr,
                             warmup_iter=warmup_iter,
                             num_iters=num_iters,
                             decay_style=args.decay_style)

        if args.warmup != 0:
            LR_Warmer = WarmupLR(optim, warmup_iter, last_iter=init_step)

    # wrap model for distributed training
    if args.world_size > 1:
        model = DDP(model)

    criterion = nn.CrossEntropyLoss(reduce=False)
    return model, optim, LR, LR_Warmer, criterion
예제 #9
0
        raise ImportError(
            "Please install apex from https://www.github.com/nvidia/apex "
            "to use distributed and fp16 training.")

    # Add model parallel attribute if it is not set.
    for param_group in optimizer_grouped_parameters:
        for param in param_group['params']:
            if not hasattr(param, 'model_parallel'):
                param.model_parallel = False

    optimizer = FusedAdam(optimizer_grouped_parameters,
                          lr=args.learning_rate,
                          bias_correction=False)
    if args.loss_scale == 0:
        optimizer = FP16_Optimizer(optimizer,
                                   dynamic_loss_scale=True,
                                   verbose=False)
    else:
        optimizer = FP16_Optimizer(optimizer,
                                   static_loss_scale=args.loss_scale,
                                   verbose=False)
else:
    optimizer = Adam(optimizer_grouped_parameters,
                     args.learning_rate,
                     max_grad_norm=1.0)

#########################################################################
# Training !
##########################################################################

if args.local_rank == -1 or get_rank() == 0:
예제 #10
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    #### setup sizes and dataloaders
    if args.large_size_images == 0:
        N_img = 112
        N_img_scale = 128
        print('using 112')
    elif args.large_size_images == 1:
        N_img = 224
        N_img_scale = 256

    cudnn.benchmark = True

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(N_img),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=None)

    val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
        valdir,
        transforms.Compose([
            transforms.Resize(N_img_scale),
            transforms.CenterCrop(N_img),
            transforms.ToTensor(),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    #### #### To simplify data parallelism we make an nn module with multiple outs
    model = models.__dict__[args.arch](nlin=args.nlin).cuda()

    args.ncnn = len(model.main_cnn.blocks)
    n_cnn = len(model.main_cnn.blocks)
    with open(name_log_txt, "a") as text_file:
        print(model, file=text_file)
    if len(device_ids) == 1:
        model = nn.DataParallel(
            model
        )  #single gpu mode, we do the DataParallle so we can still do .module later
    else:
        model = DataParallelSpecial(model)

    if args.half:
        model = model.half()
        model = BN_convert_float(model)
    ############### Initialize all
    num_ep = args.nepochs
    layer_epoch = [0] * n_cnn
    layer_lr = [args.lr] * n_cnn
    layer_optim = [None] * n_cnn

    ############## Resume if we need to resume
    if (args.resume):
        name = args.resume
        model_dict = torch.load(name)
        model.load_state_dict(model_dict)
        print('model loaded')
    for n in range(args.ncnn):
        to_train = itertools.chain(
            model.module.main_cnn.blocks[n].parameters(),
            model.module.auxillary_nets[n].parameters())
        layer_optim[n] = optim.SGD(to_train,
                                   lr=layer_lr[n],
                                   momentum=args.momentum,
                                   weight_decay=args.weight_decay)
        if args.half:
            layer_optim[n] = FP16_Optimizer(
                layer_optim[n],
                static_loss_scale=args.static_loss_scale,
                dynamic_loss_scale=args.dynamic_loss_scale,
                dynamic_loss_args={'scale_window': 1000})


######################### Lets do the training
    criterion = nn.CrossEntropyLoss().cuda()
    for n in range(args.ncnn):
        for epoch in range(args.start_epoch, num_ep):

            # Make sure we set the batchnorm right
            model.train()
            for k in range(n):
                model.module.main_cnn.blocks[k].eval()

            #For each epoch let's store each layer individually
            batch_time = AverageMeter()
            batch_time_total = AverageMeter()
            data_time = AverageMeter()
            losses = AverageMeter()
            top1 = AverageMeter()
            top5 = AverageMeter()

            if epoch % args.epochdecay == 0:
                layer_lr[n] = layer_lr[n] / 10.0
                to_train = itertools.chain(
                    model.module.main_cnn.blocks[n].parameters(),
                    model.module.auxillary_nets[n].parameters())
                layer_optim[n] = optim.SGD(to_train,
                                           lr=layer_lr[n],
                                           momentum=args.momentum,
                                           weight_decay=args.weight_decay)
                if args.half:
                    layer_optim[n] = FP16_Optimizer(
                        layer_optim[n],
                        static_loss_scale=args.static_loss_scale,
                        dynamic_loss_scale=args.dynamic_loss_scale,
                        dynamic_loss_args={'scale_window': 1000})
            end = time.time()

            for i, (inputs, targets) in enumerate(train_loader):
                # measure data loading time
                data_time.update(time.time() - end)

                targets = targets.cuda(non_blocking=True)
                inputs = inputs.cuda(non_blocking=True)
                inputs = torch.autograd.Variable(inputs)
                targets = torch.autograd.Variable(targets)
                if args.half:
                    inputs = inputs.half()

                #Main loop
                if torch.cuda.device_count() > 1:
                    _, representation = model(
                        inputs,
                        init=True)  #This only initializes the multi-gpu
                else:
                    representation = inputs

                for k in range(n):
                    #forward only
                    outputs, representation = model(representation, n=k)

                if n > 0:
                    if torch.cuda.device_count() > 1:
                        representation = [
                            rep.detach() for rep in representation
                        ]
                    else:
                        representation = representation.detach()

                #update current layer
                layer_optim[n].zero_grad()
                outputs, representation = model(representation, n=n)
                loss = criterion(outputs, targets)

                # update
                if args.half:
                    layer_optim[n].backward(loss)
                else:
                    loss.backward()

                layer_optim[n].step()

                # measure accuracy and record loss
                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()
                prec1, prec5 = accuracy(outputs.data, targets, topk=(1, 5))
                losses.update(float(loss.data[0]), float(inputs.size(0)))
                top1.update(float(prec1[0]), float(inputs.size(0)))
                top5.update(float(prec5[0]), float(inputs.size(0)))

                if i % args.print_freq == 0:
                    print('n:{0} Epoch: [{1}][{2}/{3}]\t'
                          'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                          'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                          'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                          'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                          'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                              n,
                              epoch,
                              i,
                              len(train_loader),
                              batch_time=batch_time,
                              data_time=data_time,
                              loss=losses,
                              top1=top1,
                              top5=top5))

                if args.debug and i > 50:
                    break

            ##### evaluate on validation set
            top1test, top5test, top1ens, top5ens = validate(
                val_loader, model, criterion, epoch, n)
            with open(name_log_txt, "a") as text_file:
                print("n: {}, epoch {}, train top1:{}(top5:{}), "
                      "test top1:{} (top5:{}), top1ens:{} top5ens:{}".format(
                          n, epoch, top1.avg, top5.avg, top1test, top5test,
                          top1ens, top5ens),
                      file=text_file)

    #####Checkpoint
        if not args.debug:
            torch.save(model.state_dict(), args.save_folder + '/' + \
                   name_log_txt + '_current_model.t7')

    ############Save the final model
    torch.save(model.state_dict(),
               args.save_folder + '/' + name_log_txt + '_model.t7')
예제 #11
0
def main():
    args = parse_args()

    if args.shard_optimizer_state > 0 and not args.skip_full_optimizer:
        raise ValueError(
            "If shard_optimizer_state is enabled, skip_full_optimizer must also be enabled. Full optimizer saving is currently not supported under optimizer state sharding."
        )

    if args.partition_assignment != "" and args.manual_partition == 0:
        print("[Warning] partition_assignment is set, enable manual_partition")
        args.manual_partition = 1

    # any value here is overriden by the config set in notebook when launching the sagemaker job
    smp_config = {
        "ddp": True,
        "tensor_parallel_degree": args.tensor_parallel_degree,
        "pipeline_parallel_degree": args.pipeline_parallel_degree,
        "microbatches": args.microbatches,
        # if activation_checkpointing true checkpoints transformer layers below
        "checkpoint_attentions":
        False if args.activation_checkpointing else True,
        "shard_optimizer_state": args.shard_optimizer_state > 0,
        "prescaled_batch": args.prescaled_batch > 0,
        "offload_activations": args.offload_activations > 0,
        "optimize": args.optimize,
        "auto_partition": False if args.manual_partition else True,
        "default_partition": 0,
        "static_mode": args.static_mode > 0,
        "fast_mode": args.fast_mode > 0,
    }

    if args.smp_version < 110:
        smp_config["fp16_params"] = args.fp16 > 0
    else:
        smp_config["fp16"] = args.fp16 > 0
        smp_config["delayed_parameter_initialization"] = args.delayed_param > 0
        smp_config["placement_strategy"] = args.placement_strategy
        smp_config[
            "activation_loading_horizon"] = args.activation_loading_horizon
        smp_config["skip_tracing"] = args.skip_tracing > 0

    if args.active_microbatches is not None:
        smp_config["active_microbatches"] = args.active_microbatches

    smp.init(smp_config)

    if smp.rank() == 0:
        print("Arguments:", args.__dict__)
        print(f"Transformers version: {transformers.__version__}")
        print(
            f"smdistributed.modelparallel version: {smdistributed.modelparallel.__version__}"
        )
        print(f"smdistributed config: {smp_config}")

    if args.save_final_full_model and smp.rank() == 0:
        print(
            f"[Warning] Note that save_final_full_model only saves the final model at the end of all steps. It does not save optimizer state. Optimizer state is only saved with partial models which are saved at checkpointing_freq during training. If you want to restart training you need partial checkpoints."
        )

    if args.partition_assignment != "":
        partition_assignment = args.partition_assignment.split(",")
        assert (
            len(partition_assignment) == smp.pp_size()
        ), f"partition_assignment must have the same size as pipeline parallel degree, but getting {len(partition_assignment)} vs {smp.pp_size()}"

    if smp.rank() == 0 or (smp.local_rank() == 0 and args.use_fsx == 0):
        for path in [args.model_dir, args.checkpoint_dir]:
            if not os.path.exists(path):
                os.makedirs(path, exist_ok=True)

    model_config = GPT2Config(
        vocab_size=args.vocab_size,
        n_positions=args.max_context_width,
        n_embd=args.hidden_width,
        n_layer=args.num_layers,
        n_head=args.num_heads,
        n_inner=None,
        activation_function="gelu_new",
        resid_pdrop=args.resid_pdrop,
        embd_pdrop=args.embd_pdrop,
        attn_pdrop=args.attn_pdrop,
        layer_norm_epsilon=1e-05,
        initializer_range=0.02,
        summary_type="cls_index",
        summary_use_proj=True,
        summary_activation=None,
        summary_proj_to_labels=True,
        summary_first_dropout=args.summary_first_pdrop,
        # gradient_checkpointing=args.gradient_checkpointing > 0,
        use_cache=False,
        bos_token_id=50256,
        eos_token_id=50256,
        return_dict=True,
    )

    # the following improves start-up time by skipping proper initialization
    # of weights in the original model. this is not a problem because DistributedModel
    # will override those weights anyway when tensor_parallel_degree > 1.
    if smp.tp_size() > 1:
        from transformers.modeling_utils import PreTrainedModel

        PreTrainedModel.init_weights = lambda x: None

    set_seed(args.seed)

    if args.enable_memory_profiling > 0:
        memory_status_cpu(msg="before model creation")

    if args.smp_version < 110:
        if args.fp16:
            torch.set_default_dtype(torch.float16)
        with smp.tensor_parallelism(
                enabled=smp.tp_size() > 1,
                attention_in_fp32=args.attention_in_fp32 > 0):
            with smp.delay_param_initialization(
                    enabled=(smp.tp_size() > 1 and args.delayed_param > 0)):
                model = AutoModelForCausalLM.from_config(model_config)
    else:
        with smp.model_creation(
                tensor_parallelism=smp.tp_size() > 1,
                attention_in_fp32=args.attention_in_fp32 > 0,
                query_key_layer_scaling=args.query_key_layer_scaling > 0,
                fused_softmax=args.fused_softmax > 0,
                fused_bias_gelu=args.fused_bias_gelu > 0,
                dtype=torch.float16
                if args.fp16 else torch.get_default_dtype(),
        ):
            model = AutoModelForCausalLM.from_config(model_config)

    if args.smp_version < 110 and args.fp16:
        model = FP16_Module(model)

    if args.enable_memory_profiling > 0:
        memory_status_cpu(msg="after model creation")

    num_params = sum([np.prod(p.size()) for p in model.parameters()])
    if smp.rank() == 0:
        print(f"# total parameters: {num_params}")

    # smdistributed: Set the device to the GPU ID used by the current process.
    # Input tensors should be transferred to this device.
    torch.cuda.set_device(smp.local_rank())
    device = torch.device("cuda")

    if not args.same_seed:
        # Set seed by tp_rank to prevent weights from being the same on different tp_ranks
        set_seed(args.seed + smp.tp_rank())

    # smdistributed: Use the DistributedModel container to provide the model
    # to be partitioned across different ranks. For the rest of the script,
    # the returned DistributedModel object should be used in place of
    # the model provided for DistributedModel class instantiation.
    if args.smp_version < 110 and args.fp16:
        torch.set_default_dtype(torch.float16)
    if args.enable_memory_profiling > 0:
        memory_status_cpu(msg="before dist model creation")
    model = smp.DistributedModel(model, trace_device="gpu")
    if args.enable_memory_profiling > 0:
        memory_status_cpu(msg="after dist model creation")

    if args.smp_version < 110:
        if smp.tp_size() > 1:
            transformer_layers = model.module.module.module.transformer.seq_layers
        else:
            transformer_layers = model.module.module.module.transformer.h
    else:
        m = model.get_module()
        if smp.tp_size() > 1:
            transformer_layers = m.transformer.seq_layers
        else:
            transformer_layers = m.transformer.h

    if args.manual_partition:
        print(f"Manual partition enabled")
        if args.partition_assignment != "":
            get_num_layers = lambda x: int(partition_assignment[x])
            total_layers = sum(
                [get_num_layers(pp_rank) for pp_rank in range(smp.pp_size())])
            assert (
                total_layers == args.num_layers
            ), f"partition_assignment must have the same total transformer layers as model, but getting {total_layers} vs {args.num_layers}"
        else:
            # evenly distribute layers across all partitions
            div, rem = divmod(args.num_layers, smp.pp_size())
            get_num_layers = lambda x: (div + 1
                                        if x >= smp.pp_size() - rem else div)
        assignments = []
        # (TODO) This is required for 175B otherwise a hang for partition "8,17,17,18,18,18"
        # Need further investigation
        # for pp_rank in reversed(range(smp.pp_size())):
        for pp_rank in range(smp.pp_size()):
            nl = get_num_layers(pp_rank)
            print(f"{nl} layers assigned to partition {pp_rank}")
            assignments += [pp_rank for _ in range(nl)]

        for i, c in enumerate(transformer_layers.children()):
            smp.set_partition(c, assignments[i])
    if args.smp_version < 110:
        iter_model = model
        # Build parameter groups (weight decay and non-decay).
        while isinstance(iter_model, (DistributedDataParallel, FP16_Module)):
            iter_model = iter_model.module
    else:
        iter_model = m
    param_groups = get_param_groups_by_weight_decay(iter_model)

    if args.use_adamw > 0:
        optimizer = optim.AdamW(param_groups,
                                betas=(args.beta1, args.beta2),
                                lr=args.lr,
                                weight_decay=args.weight_decay)
    else:
        optimizer = optim.Adam(param_groups,
                               betas=(args.beta1, args.beta2),
                               lr=args.lr,
                               weight_decay=args.weight_decay)

    if args.activation_checkpointing:
        kwargs = {}
        if isinstance(transformer_layers, nn.Sequential):
            kwargs["pack_args_as_tuple"] = True
            kwargs["strategy"] = args.activation_strategy
        smp.set_activation_checkpointing(transformer_layers, **kwargs)

    if args.smp_version < 110:
        optimizer = FP16_Optimizer(
            model,
            optimizer,
            static_loss_scale=None,
            dynamic_loss_scale=True,
            use_smp=True,
            dynamic_loss_args={
                "scale_window": 1000,
                "min_scale": 1,
                "delayed_shift": 2
            },
            params_have_main_grad=False,
            shard_optimizer_state=args.shard_optimizer_state > 0,
        )

        optimizer = smp.DistributedOptimizer(optimizer)
        model.register_post_step_hook(
            lambda model, optimizer: optimizer.init_master_params())
    else:
        optimizer = smp.DistributedOptimizer(
            optimizer,
            static_loss_scale=None,
            dynamic_loss_scale=True,
            dynamic_loss_args={
                "scale_window": 1000,
                "min_scale": 1,
                "delayed_shift": 2
            },
        )
    lr_scheduler = get_learning_rate_scheduler(optimizer, args)

    if args.enable_memory_profiling > 0:
        model.register_post_partition_hook(
            lambda model, optimizer: memory_status(msg="After_partition"))

    # load after wrapping model and optimizer with smp Distributed...
    if args.load_full or args.load_partial:
        if args.load_partial and args.load_full:
            print(
                "Since both --load_partial and --load_full set, will try to load from full checkpoint."
                "If the intention is to load from partial checkpoint, please don't set --load_full"
            )
        partial = not args.load_full
        path = args.checkpoint_dir if partial else args.model_dir
        translate_from_hf = not partial
        model, optimizer, total_steps, start_train_path_index, start_batch_index = load_model_and_optimizer(
            path,
            model,
            optimizer,
            lr_scheduler,
            partial,
            args,
            translate_from_hf=translate_from_hf,
            seq_length=args.max_context_width,
            load_model=True,
            load_optimizer=args.load_partial > 0,
            num_params=num_params,
        )
    else:
        total_steps = 0
        start_train_path_index = 0
        start_batch_index = 0

    start = time.time()
    total_steps, throughput, loss = train(
        model,
        optimizer,
        lr_scheduler,
        model_config,
        start_train_path_index,
        start_batch_index,
        num_params,
        total_steps,
        args,
    )
    time_to_train = time.time() - start
    if args.ci:
        print(f"[SMP_METRIC]__GPT2__Time_to_train__{time_to_train}")
        print(f"[SMP_METRIC]__GPT2__samples/second__{throughput}")
        print(f"[SMP_METRIC]__GPT2__Loss__{loss}")
        if not args.load_partial and not args.load_full:
            assert time_to_train < args.time_to_train
            assert throughput > args.throughput
            if args.loss:
                assert loss < args.loss

    if args.save_final_full_model:
        # saves full model at the end

        base_path = f"trained_gpt_nparams-{num_params}_steps-{total_steps}.pt"
        out_path = os.path.join(args.model_dir, base_path)

        if smp.rdp_rank() == 0:
            save(
                out_path,
                model,
                optimizer,
                lr_scheduler,
                model_config,
                num_params,
                total_steps,
                -1,
                args,
                partial=False,
                translate_to_hf=smp.tp_size() > 1,
                seq_length=args.max_context_width,
            )

    smp.barrier()
    if smp.rank() == 0:
        print("SMP training finished successfully")
예제 #12
0
    try:
        model.load_state_dict(sd)
    except:
        apply_weight_norm(model.rnn, hook_child=False)
        model.load_state_dict(sd)
        remove_weight_norm(model.rnn)

if not args.no_weight_norm:
    apply_weight_norm(model, 'rnn', hook_child=False)

# create optimizer and fp16 models
if args.fp16:
    model = FP16_Module(model)
    optim = eval('torch.optim.' + args.optim)(model.parameters(), lr=args.lr)
    optim = FP16_Optimizer(optim,
                           static_loss_scale=args.loss_scale,
                           dynamic_loss_scale=args.dynamic_loss_scale)
else:
    optim = eval('torch.optim.' + args.optim)(model.parameters(), lr=args.lr)

if args.load_optim:
    pass
    optim.load_state_dict(optim_sd)

# add linear learning rate scheduler
if train_data is not None:
    if args.constant_decay:
        num_iters = args.constant_decay
    else:
        num_iters = args.train_iters * args.epochs
예제 #13
0
def train(task_ids, model):
    tasks = [args.tasks[task_id] for task_id in task_ids]

    logger.info("start to train { task: %s, seq train type: %s }" %
                (tasks, args.seq_train_type))
    model_dir = get_model_dir(tasks)
    make_dir(model_dir)

    #train_dataset = [(TASK_DICT[t]["train"] if not args.seq_distil else TASK_DICT[t]["train"].replace("train", "distil")) for t in tasks]
    train_dataset = [
        swap_name(TASK_DICT[t]["train"], args.seq_distil, args.ref1)
        for t in tasks
    ]
    train_extra_data = []
    if "lll" in args.seq_train_type and task_ids[0] > 0 and not args.skip_tasks:
        prev_task = args.tasks[task_ids[0] - 1]
        with torch.no_grad():
            create_extra_data(tasks[0], prev_task, model, train_extra_data)
    elif "gem" in args.seq_train_type and task_ids[0] > 0:
        get_real_data(tasks[0], train_extra_data, accum=False, encode=True)
        args.memory_data.append(train_extra_data)
        train_extra_data = []
    logger.info('extra training data size: {}'.format(len(train_extra_data)))

    if not model:
        # which_model_to_load = model_dir if os.path.isfile(os.path.join(model_dir, FINAL_SAVE_NAME)) else args.model_name
        model = MODEL_CLASS.from_pretrained(args.model_name).cuda()
        model.resize_token_embeddings(len(TOKENIZER))
        if not args.fp32:
            model = FP16_Module(model)

    gen_token = get_gen_token(tasks[0])
    TOKENIZER.add_tokens([gen_token])
    TOKENIZER.save_pretrained(model_dir)
    SPECIAL_TOKENS[tasks[0]] = gen_token
    SPECIAL_TOKEN_IDS[tasks[0]] = TOKENIZER.convert_tokens_to_ids(gen_token)
    logger.info('gen token = {} , gen token id = {}'.format(
        gen_token, SPECIAL_TOKEN_IDS[tasks[0]]))
    MODEL_CONFIG.vocab_size = len(TOKENIZER)
    MODEL_CONFIG.to_json_file(os.path.join(model_dir, CONFIG_NAME))
    global TOKENS_WEIGHT
    if len(TOKENIZER) != TOKENS_WEIGHT.shape[0]:
        TOKENS_WEIGHT = torch.cat((TOKENS_WEIGHT, torch.ones([1]).cuda()))

    if args.skip_tasks and len(tasks) == 1:
        logger.info("*********** skip task: {} ***********".format(tasks[0]))
        if tasks[0] in args.skip_tasks:
            if len(args.skip_tasks) == 1:
                model_dir = get_model_dir(tasks)
                model_path = os.path.join(model_dir, FINAL_SAVE_NAME)
                config_path = os.path.join(model_dir, CONFIG_NAME)
                model_config = CONFIG_CLASS.from_json_file(config_path)
                model = MODEL_CLASS(model_config).cuda()
                state_dict = torch.load(model_path)
                model.load_state_dict(state_dict)
                if not args.fp32:
                    model = FP16_Module(model)
                if args.seq_train_type in REG_TYPE_KEYS:
                    logger.info("calulating reg_params ...")
                    train_qadata = QADataset(train_dataset, "train",
                                             SPECIAL_TOKEN_IDS[tasks[0]],
                                             train_extra_data)
                    max_train_batch_size = max(
                        len(train_qadata) // args.min_n_steps,
                        args.min_batch_size)
                    train_dataloader = create_dataloader(
                        train_qadata, "train", max_train_batch_size)
                    parallel_model = DataParallelModel(WrapModel(model),
                                                       args.device_ids)
                    regularizer = REG_TYPES[args.seq_train_type](
                        model, parallel_model, [train_dataloader], tasks[0])
                    regularizer.task_start_do()
                    regularizer.task_end_do()
                    torch.save(model.state_dict(),
                               os.path.join(model_dir, FINAL_SAVE_NAME))
                    logger.info("done reg_params!")
            args.skip_tasks.remove(tasks[0])
            return model

    model.resize_token_embeddings(
        len(TOKENIZER) if not args.multitask_specific else len(TOKENIZER) + 4)
    if args.multitask_specific:
        for i in range(4):
            TOKENS_WEIGHT = torch.cat((TOKENS_WEIGHT, torch.ones([1]).cuda()))
    if args.distil:
        teacher_model = MODEL_CLASS.from_pretrained(args.model_name).cuda()
        teacher_vocab_size = json.load(
            open("models/gpt2/lll/{task}_0.2/{task}/config.json".format(
                task=tasks[0])))['vocab_size']
        teacher_model.resize_token_embeddings(teacher_vocab_size)
        print("load teacher model from {}".format(
            "models/gpt2/lll/{task}_0.2/{task}/model-finish".format(
                task=tasks[0])))
        teacher_model.load_state_dict(
            torch.load("models/gpt2/lll/{task}_0.2/{task}/model-finish".format(
                task=tasks[0])))
        if not args.fp32:
            teacher_model = FP16_Module(teacher_model)
        teacher_model.eval()
        teacher_model = DataParallelModel(WrapModel(teacher_model),
                                          args.device_ids)

    if not args.fp32:  # again because resize_token_embeddings makes embedding layer fp32
        model = FP16_Module(model)

    parallel_model = DataParallelModel(WrapModel(model), args.device_ids)

    train_qadata = QADataset(train_dataset, "train",
                             SPECIAL_TOKEN_IDS[tasks[0]], train_extra_data)
    max_train_batch_size = max(
        len(train_qadata) // args.min_n_steps, args.min_batch_size)
    train_dataloader = create_dataloader(train_qadata, "train",
                                         max_train_batch_size)
    if not args.unbound and args.seq_train_type not in [
            "multitask", "multilm"
    ]:
        #n_train_epochs = TASK_DICT[tasks[0]]["n_train_epochs"]
        n_train_epochs = args.n_train_epochs[tasks[0]]
    else:
        n_train_epochs = args.n_train_epochs['_'.join(tasks)]
    n_train_optimization_steps = len(train_qadata) * n_train_epochs
    logger.info(
        'len of train dataset: {} , max train batch size {} , num of opt steps: {}'
        .format(len(train_qadata), max_train_batch_size,
                n_train_optimization_steps))

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        args.weight_decay
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    if "gem" in args.seq_train_type:
        model.task_id = task_ids[0]
        if not hasattr(model, "grad_dims"):
            model.grad_dims = []
            for param in model.parameters():
                model.grad_dims.append(param.data.numel())
        if not hasattr(model, "grads"):
            model.grads = torch.zeros(sum(model.grad_dims), len(args.tasks))
            model.grads = model.grads.cuda()

    if args.seq_train_type in REG_TYPE_KEYS:
        optimizer = Weight_Regularized_AdamW(optimizer_grouped_parameters,
                                             lr=args.learning_rate,
                                             eps=args.adam_epsilon)
    else:
        optimizer = AdamW(optimizer_grouped_parameters,
                          lr=args.learning_rate,
                          eps=args.adam_epsilon)
    if not args.fp32:
        optimizer = FP16_Optimizer(optimizer,
                                   static_loss_scale=None,
                                   dynamic_loss_scale=True,
                                   dynamic_loss_args={
                                       'scale_window': 100,
                                       'min_scale': 1,
                                       'delayed_shift': 2
                                   })

    scheduler = AnnealingLR(optimizer,
                            start_lr=args.learning_rate,
                            warmup_iter=int(args.n_warmup_ratio *
                                            len(train_qadata)),
                            num_iters=int(n_train_optimization_steps),
                            decay_style=args.decay_style)
    train_loss_fct = DataParallelCriterion(
        CrossEntropyLoss(ignore_index=FILL_VAL, weight=TOKENS_WEIGHT),
        args.device_ids)
    if args.distil:
        kd_loss_fct = DataParallelCriterion(
            nn.KLDivLoss(reduction="batchmean"), args.device_ids)

    if args.seq_train_type in REG_TYPE_KEYS:
        copy_train_dataloader = create_dataloader(train_qadata, "train",
                                                  max_train_batch_size)
        prev_task = args.tasks[task_ids[0] - 1]
        regularizer = REG_TYPES[args.seq_train_type](model, parallel_model,
                                                     [copy_train_dataloader],
                                                     tasks[0], prev_task)
        regularizer.task_start_do()

    tot_n_steps = 0
    train_once = TrainStep(model, optimizer, scheduler)
    if "gem" in args.seq_train_type and task_ids[0] != 0:
        gem_step = GEMStep(model, parallel_model, train_loss_fct, optimizer)
    model.train()
    for ep in range(n_train_epochs):
        cum_loss, cum_qa_loss, cum_lm_loss, cur_n_inputs = 0, 0, 0, 0
        for n_steps, (_, _, cqa, _, Y, gen_X, gen_Y,
                      is_extra) in enumerate(train_dataloader):

            n_inputs = sum(_cqa.shape[0] for _cqa in cqa)
            if args.multitask_specific:
                for i in range(len(is_extra)):
                    gen_X[i][:, 0] += is_extra[i]
                    is_extra[i] = is_extra[i] * 0

            for i in range(len(cqa)):
                cqa[i] = (cqa[i].to(args.device_ids[i]), )
                Y[i] = Y[i].to(args.device_ids[i])
                gen_X[i] = (gen_X[i].to(args.device_ids[i]), )
                gen_Y[i] = gen_Y[i].to(args.device_ids[i])
                is_extra[i] = is_extra[i].to(args.device_ids[i])

            if args.distil:
                losses = get_distil_losses(teacher_model,
                                           parallel_model,
                                           cqa,
                                           Y,
                                           gen_X,
                                           gen_Y,
                                           is_extra,
                                           kd_loss_fct,
                                           train_loss_fct,
                                           args.temperature_kd,
                                           pad_idx=FILL_VAL)
            else:
                losses = get_losses(parallel_model, cqa, Y, gen_X, gen_Y,
                                    train_loss_fct)
            loss = sum(losses)
            if "gem" in args.seq_train_type and task_ids[0] != 0:
                gem_step(task_ids[0])
            train_once(loss, n_inputs)

            qa_loss = losses[0].item() * n_inputs
            lm_loss = losses[1].item() * n_inputs
            cum_loss += (qa_loss + lm_loss)
            cum_qa_loss += qa_loss
            cum_lm_loss += lm_loss
            cur_n_inputs += n_inputs

            if (n_steps + 1) % args.logging_steps == 0:
                logger.info(
                    'progress {:.3f} , lr {:.1E} , loss {:.3f} , qa loss {:.3f} , lm loss {:.3f} , avg batch size {:.1f}'
                    .format(ep + cur_n_inputs / len(train_qadata),
                            scheduler.get_lr(), cum_loss / cur_n_inputs,
                            cum_qa_loss / cur_n_inputs,
                            cum_lm_loss / cur_n_inputs,
                            cur_n_inputs / (n_steps + 1)))

        torch.save(model.state_dict(),
                   os.path.join(model_dir, SAVE_NAME + str(ep + 1)))
        tot_n_steps += (n_steps + 1)
        logger.info(
            'epoch {}/{} done , tot steps {} , lr {:.1E} , loss {:.2f} , qa loss {:.2f} , lm loss {:.2f} , avg batch size {:.1f}'
            .format(ep + 1, n_train_epochs, tot_n_steps, scheduler.get_lr(),
                    cum_loss / cur_n_inputs, cum_qa_loss / cur_n_inputs,
                    cum_lm_loss / cur_n_inputs, cur_n_inputs / (n_steps + 1)))

    # task end do for reg
    if args.seq_train_type in REG_TYPE_KEYS:
        regularizer.task_end_do()
    torch.save(model.state_dict(), os.path.join(model_dir, FINAL_SAVE_NAME))

    return model