def get_model(args):
    model = models.RNNModel(args.model, args.ntokens, args.emsize, args.nhid, 
                            args.nlayers, args.dropout, args.tied).to(args.device)

    # Horovod: scale learning rate by the number of GPUs.
    args.base_lr = args.base_lr * hvd.size()
    optimizer = optim.SGD(model.parameters(), lr=args.base_lr,
                          momentum=args.momentum, weight_decay=args.wd)

    if args.kfac_update_freq > 0:
        preconditioner = kfac.KFAC(
                model, lr=args.base_lr, stat_decay=args.stat_decay,
                damping=args.damping, kl_clip=args.kl_clip,
                TCov=args.kfac_cov_update_freq,
                TInv=args.kfac_update_freq,
                diag_blocks=args.diag_blocks,
                diag_warmup=args.diag_warmup)
    else:
         preconditioner = None

    optimizer = hvd.DistributedOptimizer(
            optimizer, named_parameters=model.named_parameters(),
            compression=hvd.Compression.none, op=hvd.Average)

    # Horovod: broadcast parameters & optimizer state.
    hvd.broadcast_parameters(model.state_dict(), root_rank=0)
    hvd.broadcast_optimizer_state(optimizer, root_rank=0)

    lrs = create_lr_schedule(hvd.size(), args.warmup_epochs, args.lr_decay, alpha=0.25)
    lr_schedules = [LambdaLR(optimizer, lrs)]
    if preconditioner is not None:
        lr_schedules.append(LambdaLR(preconditioner, lrs))

    criterion = nn.NLLLoss()

    return model, optimizer, preconditioner, lr_schedules, lrs, criterion
Esempio n. 2
0
def prepare_optimizers(args, model, checkpoint, global_steps):
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'gamma', 'beta', 'LayerNorm']

    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    if args.lr_decay == 'poly':
        Scheduler = PolyWarmUpScheduler
    elif args.lr_decay == 'linear':
        Scheduler = LinearWarmUpScheduler
    else:
        raise ValueError('Unknown lr decay "{}"'.format(args.lr_decay))

    optimizer = FusedLAMB(optimizer_grouped_parameters, lr=args.learning_rate)

    if checkpoint is not None:
        if args.resume_step >= args.previous_phase_end_step:
            keys = list(checkpoint['optimizer']['state'].keys())
            # Override hyperparameters from previous checkpoint
            for key in keys:
                checkpoint['optimizer']['state'][key]['step'] = global_steps
            for i, item in enumerate(checkpoint['optimizer']['param_groups']):
                checkpoint['optimizer']['param_groups'][i][
                    'step'] = global_steps
                checkpoint['optimizer']['param_groups'][i][
                    't_total'] = args.max_steps
                checkpoint['optimizer']['param_groups'][i][
                    'warmup'] = args.warmup_proportion
                checkpoint['optimizer']['param_groups'][i][
                    'lr'] = args.learning_rate
        optimizer.load_state_dict(checkpoint['optimizer'])

    lr_schedulers = [
        Scheduler(optimizer,
                  warmup=args.warmup_proportion,
                  total_steps=args.max_steps)
    ]

    scaler = None
    if args.fp16:
        scaler = GradScaler()
        if checkpoint is not None and 'scaler' in checkpoint:
            scaler.load_state_dict(checkpoint['scaler'])

    preconditioner = None
    if args.kfac:
        preconditioner = kfac.KFAC(
            model,
            lr=args.learning_rate,
            factor_decay=args.kfac_stat_decay,
            damping=args.kfac_damping,
            kl_clip=args.kfac_kl_clip,
            factor_update_freq=args.kfac_factor_interval,
            inv_update_freq=args.kfac_inv_interval,
            # Skip TrainingHeads which contains the decoder, a Linear module
            # with shape (seq_len, vocab_size), such that it is too large to invert
            skip_layers=args.kfac_skip_layers,
            # BERT calls KFAC very infrequently so no need to optimize for
            # communication. Optimize for memory instead.
            comm_method=kfac.CommMethod.HYBRID_OPT,
            grad_worker_fraction=0.5,
            inv_dtype=torch.float16,
            # Compute the factors and update the running averages during the
            # forward backward pass b/c we are using grad accumulation but
            # not accumulating the input/output data
            accumulate_data=False,
            compute_factor_in_hook=True,
            distribute_layer_factors=False,
            grad_scaler=scaler,
        )

        lrs = Scheduler(preconditioner,
                        warmup=args.warmup_proportion,
                        total_steps=args.max_steps)
        lr_schedulers.append(lrs)

        if checkpoint is not None and 'preconditioner' in checkpoint:
            preconditioner.load_state_dict(checkpoint['preconditioner'])

        if is_main_process():
            logger.info(preconditioner)

    return optimizer, preconditioner, lr_schedulers, scaler
Esempio n. 3
0
        optim.lr_scheduler.CosineAnnealingLR(optimizer, 3 * len(train_loader),
                                             1e-4)
    ]
else:
    optimizer = optim.SGD(model.parameters(),
                          lr=args.base_lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)

if use_kfac:
    preconditioner = kfac.KFAC(
        model,
        lr=args.base_lr,
        factor_decay=args.stat_decay,
        damping=args.damping,
        kl_clip=args.kl_clip,
        fac_update_freq=args.kfac_cov_update_freq,
        kfac_update_freq=args.kfac_update_freq,
        diag_blocks=args.diag_blocks,
        diag_warmup=args.diag_warmup,
        distribute_layer_factors=args.distribute_layer_factors)
    kfac_param_scheduler = kfac.KFACParamScheduler(
        preconditioner,
        damping_alpha=args.damping_alpha,
        damping_schedule=args.damping_schedule,
        update_freq_alpha=args.kfac_update_freq_alpha,
        update_freq_schedule=args.kfac_update_freq_schedule)

# KFAC guarentees grads are equal across ranks before opt.step() is called
# so if we do not use kfac we need to wrap the optimizer with horovod
compression = hvd.Compression.fp16 if args.fp16_allreduce else hvd.Compression.none
def get_model(args):
    if args.model.lower() == 'resnet34':
        model = models.resnet34()
    elif args.model.lower() == 'resnet50':
        model = models.resnet50()
    elif args.model.lower() == 'resnet101':
        model = models.resnet101()
    elif args.model.lower() == 'resnet152':
        model = models.resnet152()
    elif args.model.lower() == 'resnext50':
        model = models.resnext50_32x4d()
    elif args.model.lower() == 'resnext101':
        model = models.resnext101_32x8d()
    else:
        raise ValueError('Unknown model \'{}\''.format(args.model))

    if args.cuda:
        model.cuda()

    # Horovod: scale learning rate by the number of GPUs.
    args.base_lr = args.base_lr * hvd.size() * args.batches_per_allreduce
    optimizer = optim.SGD(model.parameters(),
                          lr=args.base_lr,
                          momentum=args.momentum,
                          weight_decay=args.wd)

    if args.kfac_update_freq > 0:
        preconditioner = kfac.KFAC(
            model,
            lr=args.base_lr,
            factor_decay=args.stat_decay,
            damping=args.damping,
            kl_clip=args.kl_clip,
            fac_update_freq=args.kfac_cov_update_freq,
            kfac_update_freq=args.kfac_update_freq,
            diag_blocks=args.diag_blocks,
            diag_warmup=args.diag_warmup,
            distribute_layer_factors=args.distribute_layer_factors)
        kfac_param_scheduler = kfac.KFACParamScheduler(
            preconditioner,
            damping_alpha=args.damping_alpha,
            damping_schedule=args.damping_decay,
            update_freq_alpha=args.kfac_update_freq_alpha,
            update_freq_schedule=args.kfac_update_freq_decay,
            start_epoch=args.resume_from_epoch)
    else:
        preconditioner = None

    compression = hvd.Compression.fp16 if args.fp16_allreduce \
                                       else hvd.Compression.none
    optimizer = hvd.DistributedOptimizer(
        optimizer,
        named_parameters=model.named_parameters(),
        compression=compression,
        op=hvd.Average,
        backward_passes_per_step=args.batches_per_allreduce)

    # Restore from a previous checkpoint, if initial_epoch is specified.
    # Horovod: restore on the first worker which will broadcast weights
    # to other workers.
    if args.resume_from_epoch > 0 and hvd.rank() == 0:
        filepath = args.checkpoint_format.format(epoch=args.resume_from_epoch)
        checkpoint = torch.load(filepath)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])

    # Horovod: broadcast parameters & optimizer state.
    hvd.broadcast_parameters(model.state_dict(), root_rank=0)
    hvd.broadcast_optimizer_state(optimizer, root_rank=0)

    lrs = create_lr_schedule(hvd.size(), args.warmup_epochs, args.lr_decay)
    lr_scheduler = [LambdaLR(optimizer, lrs)]
    if preconditioner is not None:
        lr_scheduler.append(LambdaLR(preconditioner, lrs))
        lr_scheduler.append(kfac_param_scheduler)

    loss_func = LabelSmoothLoss(args.label_smoothing)

    return model, optimizer, preconditioner, lr_scheduler, lrs, loss_func