def run(cur_gpu, hparams):
    if hparams.distributed_mode == 'gpus':
        dist.init_process_group(backend=hparams.dist_backend, init_method=hparams.dist_url,
                                world_size=hparams.world_size, rank=cur_gpu)

    if cur_gpu >= 0:
        torch.cuda.set_device(cur_gpu)
        model = getattr(models, hparams.model_name)(hparams, use_cuda=True, use_fp16=hparams.fp16)
        model.cuda()
    else:
        model = getattr(models, hparams.model_name)(hparams)

    if hparams.fp16:
        model = convert_to_half(model)

    if hparams.distributed_mode == 'gpus':
        model = nn.parallel.DistributedDataParallel(model, device_ids=[cur_gpu], output_device=cur_gpu,
                                                    find_unused_parameters=True)

    criterion = cross_entropy

    params, params_clone = get_parameters(model, clone=hparams.fp16)
    optimizer = optim.SGD([
        {'params': params_clone if hparams.fp16 else params, 'weight_decay': hparams.weight_decay},
    ], lr=hparams.initial_learning_rate, momentum=hparams.momentum)

    lr_scheduler = get_lr_scheduler(hparams.lr_scheduler, optimizer, hparams)

    best_acc1 = 0
    best_acc5 = 0
    start_epoch = hparams.start_epoch
    if hparams.checkpoint and os.path.isfile(hparams.checkpoint):
        start_epoch, model, optimizer, lr_scheduler, best_acc1, best_acc5 = load_checkpoint(
            hparams.checkpoint, cur_gpu, model, optimizer, lr_scheduler)

    torch.backends.cudnn.benchmark = True

    train_loader, train_sampler = get_train_loader(hparams.data_dir, hparams.image_size,
                                                   hparams.per_replica_batch_size,
                                                   hparams.n_data_loading_workers,
                                                   hparams.distributed_mode,
                                                   hparams.world_size, cur_gpu)
    val_loader = get_val_loader(hparams.data_dir, hparams.image_size, hparams.per_replica_batch_size,
                                hparams.n_data_loading_workers, hparams.distributed_mode,
                                hparams.world_size, cur_gpu)

    if hparams.evaluate:
        return validate(cur_gpu, val_loader, model, criterion, 0, hparams)

    monitor = get_monitor()
    for epoch in range(start_epoch, hparams.epochs):
        if cur_gpu == -1 or cur_gpu == 0:
            print('Epoch %d\n' % (epoch + 1))
        monitor and monitor.before_epoch()

        if train_sampler:
            train_sampler.set_epoch(epoch)
        train(cur_gpu, train_loader, model, criterion, optimizer, lr_scheduler,
              params, params_clone, epoch, hparams)

        loss, acc1, acc5 = validate(cur_gpu, val_loader, model, criterion, epoch, hparams)
        monitor and monitor.after_epoch(loss, acc1, acc5)

        if hparams.save_model and cur_gpu in (-1, 0):
            is_best = acc1 > best_acc1
            best_acc1 = acc1 if is_best else best_acc1
            save_checkpoint(hparams.model_dir, epoch, model, optimizer, lr_scheduler,
                            best_acc1, best_acc5, is_best)

    if hparams.distributed_mode == 'gpus':
        dist.destroy_process_group()
def run(cur_gpu, hparams):
    if hparams.distributed_mode == 'gpus':
        dist.init_process_group(backend=hparams.dist_backend,
                                init_method=hparams.dist_url,
                                world_size=hparams.world_size,
                                rank=cur_gpu)

    model = getattr(models,
                    hparams.model_name)(hparams.n_classes, hparams.n_channels,
                                        hparams.model_version)

    if cur_gpu >= 0:
        torch.cuda.set_device(cur_gpu)
        model.cuda()

    if hparams.fp16:
        model = convert_to_half(model)

    if hparams.distributed_mode == 'gpus':
        model = nn.parallel.DistributedDataParallel(model,
                                                    device_ids=[cur_gpu],
                                                    output_device=cur_gpu)

    criterion = cross_entropy

    params_no_bn, params_no_bn_clone = get_parameters(
        model,
        exclude=(nn.BatchNorm2d, nn.SyncBatchNorm, nn.GroupNorm),
        clone=hparams.fp16)
    params_bn, params_bn_clone = get_parameters(model,
                                                include=(nn.BatchNorm2d,
                                                         nn.SyncBatchNorm,
                                                         nn.GroupNorm),
                                                clone=hparams.fp16)
    optimizer = optim.SGD([{
        'params': params_no_bn_clone if hparams.fp16 else params_no_bn,
        'weight_decay': hparams.weight_decay
    }, {
        'params': params_bn_clone if hparams.fp16 else params_bn,
        'weight_decay': 0.0
    }],
                          lr=hparams.initial_learning_rate,
                          momentum=hparams.momentum)

    lr_scheduler = MultiStepLRWithWarmup(optimizer,
                                         hparams.lr_milestones,
                                         hparams.lr_warmup_epochs,
                                         factor_min=hparams.lr_factor_min,
                                         gamma=hparams.lr_decay_rate)

    best_acc1 = 0
    best_acc5 = 0
    start_epoch = hparams.start_epoch
    if hparams.checkpoint and os.path.isfile(hparams.checkpoint):
        start_epoch, model, optimizer, lr_scheduler, best_acc1, best_acc5 = load_checkpoint(
            hparams.checkpoint, cur_gpu, model, optimizer, lr_scheduler)

    torch.backends.cudnn.benchmark = True

    train_loader, train_sampler = get_train_loader(
        hparams.data_dir, hparams.image_size, hparams.per_replica_batch_size,
        hparams.n_data_loading_workers, hparams.distributed_mode,
        hparams.world_size, cur_gpu)
    val_loader = get_val_loader(hparams.data_dir, hparams.image_size,
                                hparams.per_replica_batch_size,
                                hparams.n_data_loading_workers,
                                hparams.distributed_mode, hparams.world_size,
                                cur_gpu)

    if hparams.evaluate:
        return validate(cur_gpu, val_loader, model, criterion, 0, hparams)

    monitor = get_progress_monitor(cur_gpu, hparams.log_dir,
                                   hparams.steps_per_epoch, hparams.epochs,
                                   hparams.print_freq, start_epoch)

    for epoch in range(start_epoch, hparams.epochs):
        monitor and monitor.before_epoch()

        if train_sampler:
            train_sampler.set_epoch(epoch)
        train(cur_gpu, train_loader, model, criterion, optimizer, lr_scheduler,
              params_no_bn + params_bn, params_no_bn_clone + params_bn_clone,
              epoch, hparams, monitor)

        loss, acc1, acc5 = validate(cur_gpu, val_loader, model, criterion,
                                    epoch, hparams)

        monitor and monitor.after_epoch(loss, acc1, acc5)

        if hparams.save_model and cur_gpu in (-1, 0):
            is_best = acc1 > best_acc1
            best_acc1 = acc1 if is_best else best_acc1
            save_checkpoint(hparams.model_dir, epoch, model, optimizer,
                            lr_scheduler, best_acc1, best_acc5, is_best)

    if hparams.distributed_mode == 'gpus':
        dist.destroy_process_group()

    monitor and monitor.end()