Example #1
0
     swa_n += 1
     bn_update(cifar100_training_loader, swa_model)
     swa_acc1, swa_acc5 = eval_training(swa_model, epoch, suffix='SWA')
     if not cfg.DIST or (cfg.DIST and dist.get_rank() == 0):
         if best_swa_acc1 < swa_acc1:
             best_swa_acc1 = swa_acc1
             best_swa_acc5 = swa_acc5
         print('best swa acc1: {:.4f}, best swa acc5: {:.4f}'.format(
             best_swa_acc1, best_swa_acc5))
         print()
 #start to save best performance model after learning rate decay to 0.01
 if not cfg.DIST or (cfg.DIST and dist.get_rank() == 0):
     if epoch > cfg.TRAIN.STEPS[1] and best_acc1 < acc1:
         if not cfg.DIST:
             torch.save(
                 net.state_dict(),
                 checkpoint_path.format(net=cfg.NET,
                                        epoch=epoch,
                                        type='best'))
         else:
             torch.save(
                 net.module.state_dict(),
                 checkpoint_path.format(net=cfg.NET,
                                        epoch=epoch,
                                        type='best'))
         best_acc1 = acc1
         best_acc5 = acc5
         best_epoch = epoch
         print('best epoch: {}, best acc1: {:.4f}, acc5: {:.4f}'.format(
             best_epoch, best_acc1, best_acc5))
         print()
def main():
    global best_prec1, args

    args = parse()

    cudnn.benchmark = True
    best_prec1 = 0
    if args.deterministic:
        cudnn.benchmark = False
        cudnn.deterministic = True
        torch.manual_seed(args.local_rank)
        torch.set_printoptions(precision=10)

    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1

    args.log_dir = args.log_dir + '_' + time.asctime(
        time.localtime(time.time())).replace(" ", "-")
    os.makedirs('results/{}'.format(args.log_dir), exist_ok=True)
    global logger
    logger = create_logger('global_logger',
                           "results/{}/log.txt".format(args.log_dir))
    args.gpu = 0
    args.world_size = 1

    if args.distributed:
        logger.info(args.local_rank)
        args.gpu = args.local_rank
        torch.cuda.set_device(args.gpu)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()
    logger.info(args.world_size)
    if args.local_rank == 0:

        wandb.init(
            project="tinyimagenet",
            dir="results/{}".format(args.log_dir),
            name=args.log_dir,
        )
        wandb.config.update(args)

        logger.info("\nCUDNN VERSION: {}\n".format(
            torch.backends.cudnn.version()))

    args.batch_size = int(args.batch_size / args.world_size)
    logger.info(args.batch_size)

    assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled."

    if args.channels_last:
        memory_format = torch.channels_last
    else:
        memory_format = torch.contiguous_format

    # create model
    global norm_layer
    print(args.norm_layer)
    if args.norm_layer is not None and args.norm_layer != 'False':
        if args.norm_layer == 'bn':
            norm_layer = nn.BatchNorm2d
        elif args.norm_layer == 'mybn':
            norm_layer = models.__dict__['BatchNorm2d']
        else:
            norm_layer = None

    if args.pretrained:
        logger.info("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True,
                                           norm_layer=norm_layer)
    else:
        logger.info("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch](norm_layer=norm_layer)

    if args.sync_bn:
        import apex
        logger.info("using apex synced BN")
        model = apex.parallel.convert_syncbn_model(model)

    model = model.cuda()

    # Scale learning rate based on global batch size
    args.lr = args.lr * float(args.batch_size * args.world_size) / 256.

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # Initialize Amp.  Amp accepts either values or strings for the optional override arguments,
    # for convenient interoperation with argparse.
    if args.mixed_precision:
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.opt_level,
                                          keep_batchnorm_fp32=None,
                                          loss_scale=args.loss_scale)

    # For distributed training, wrap the model with apex.parallel.DistributedDataParallel.
    # This must be done AFTER the call to amp.initialize.  If model = DDP(model) is called
    # before model, ... = amp.initialize(model, ...), the call to amp.initialize may alter
    # the types of model's parameters in a way that disrupts or destroys DDP's allreduce hooks.
    if args.distributed:
        # By default, apex.parallel.DistributedDataParallel overlaps communication with
        # computation in the backward pass.
        # model = DDP(model)
        # delay_allreduce delays all communication to the end of the backward pass.
        model = DDP(model, delay_allreduce=True)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    # Optionally resume from a checkpoint
    print(args.resume)
    if args.resume != '':
        # Use a local scope to avoid dangling references
        def resume():
            if os.path.isfile(args.resume):
                logger.info("=> loading checkpoint '{}'".format(args.resume))
                checkpoint = torch.load(
                    args.resume,
                    map_location=lambda storage, loc: storage.cuda(args.gpu))
                args.start_epoch = checkpoint['epoch']
                best_prec1 = checkpoint['best_prec1']
                model.load_state_dict(checkpoint['state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer'])
                logger.info("=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint['epoch']))
            else:
                logger.info("=> no checkpoint found at '{}'".format(
                    args.resume))

        resume()

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')

    if (args.arch == "inception_v3"):
        raise RuntimeError(
            "Currently, inception_v3 is not supported by this example.")
        # crop_size = 299
        # val_size = 320 # I chose this value arbitrarily, we can adjust.
    else:
        crop_size = 224
        val_size = 256

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(crop_size),
            transforms.RandomHorizontalFlip(),
            # transforms.ToTensor(), Too slow
            # normalize,
        ]))
    val_dataset = datasets.ImageFolder(
        valdir,
        transforms.Compose([
            transforms.Resize(val_size),
            transforms.CenterCrop(crop_size),
        ]))

    train_sampler = None
    val_sampler = None
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
        val_sampler = torch.utils.data.distributed.DistributedSampler(
            val_dataset)

    collate_fn = lambda b: fast_collate(b, memory_format)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=False,
                                               sampler=train_sampler,
                                               collate_fn=collate_fn)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=False,
                                             sampler=val_sampler,
                                             collate_fn=collate_fn)

    if args.evaluate:
        validate(val_loader, model, criterion)
        return
    try:
        from models import SyncBatchNorm
    except:
        pass
    device = torch.device("cuda")

    from models.batchrenorm import BatchRenorm2d
    from models.batchnorm import BatchNorm2d
    if args.sample_noise:
        for m in model.modules():
            if isinstance(m, (BatchRenorm2d, BatchNorm2d, norm_layer)):
                m.sample_noise = args.sample_noise
                m.sample_mean = torch.ones(m.num_features).to(device)
                m.noise_std_mean = torch.sqrt(
                    torch.Tensor([args.noise_std_mean]))[0].to(device)
                m.noise_std_var = torch.sqrt(torch.Tensor([args.noise_std_var
                                                           ]))[0].to(device)

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        if args.warmup_noise is not None:
            if epoch in args.warmup_noise:

                for m in model.modules():
                    if isinstance(m, norm_layer):
                        m.sample_mean_std *= math.sqrt(args.warmup_scale)
                        m.sample_var_std *= math.sqrt(args.warmup_scale)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        prec1 = validate(epoch, val_loader, model, criterion)

        # remember best prec@1 and save checkpoint
        if args.local_rank == 0:
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                    'optimizer': optimizer.state_dict(),
                },
                is_best,
                filename=os.path.join("results/" + args.log_dir,
                                      "{}_checkpoint.pth.tar".format(epoch)))
Example #3
0
def main():
    global best_prec1, args

    args.distributed = args.world_size > 1
    args.gpu = 0
    if args.distributed:
        args.gpu = args.rank % torch.cuda.device_count()

    if args.distributed:
        torch.cuda.set_device(args.gpu)
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)

    if args.fp16:
        assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled."

    if args.static_loss_scale != 1.0:
        if not args.fp16:
            print(
                "Warning:  if --fp16 is not used, static_loss_scale will be ignored."
            )

    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True)
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch]()

    model = model.cuda()
    if args.fp16:
        model = network_to_half(model)
    if args.distributed:
        model = DDP(model)

    global model_params, master_params
    if args.fp16:
        model_params, master_params = prep_param_lists(model)
    else:
        master_params = list(model.parameters())

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(master_params,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(
                args.resume,
                map_location=lambda storage, loc: storage.cuda(args.gpu))
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    if (args.arch == "inception_v3"):
        crop_size = 299
        val_size = 320  # Arbitrarily chosen, adjustable.
    else:
        crop_size = 224
        val_size = 256

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(crop_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
        valdir,
        transforms.Compose([
            transforms.Resize(val_size),
            transforms.CenterCrop(crop_size),
            transforms.ToTensor(),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)
        if args.prof:
            break
        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion)

        # remember best prec@1 and save checkpoint
        if args.rank == 0:
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                    'optimizer': optimizer.state_dict(),
                }, is_best)
Example #4
0
def train300_mlperf_coco(args):
    global torch
    from coco import COCO
    # Check that GPUs are actually available
    use_cuda = not args.no_cuda and torch.cuda.is_available()
    args.distributed = False
    if use_cuda:
        try:
            from apex.parallel import DistributedDataParallel as DDP
            if 'WORLD_SIZE' in os.environ:
                args.distributed = int(os.environ['WORLD_SIZE']) > 1
        except:
            raise ImportError("Please install APEX from https://github.com/nvidia/apex")

    local_seed = args.seed
    if args.distributed:
        # necessary pytorch imports
        import torch.utils.data.distributed
        import torch.distributed as dist
        if args.no_cuda:
            device = torch.device('cpu')
        else:
            torch.cuda.set_device(args.local_rank)
            device = torch.device('cuda')
            dist.init_process_group(backend='nccl',
                                    init_method='env://')
            # set seeds properly
            args.seed = broadcast_seeds(args.seed, device)
            local_seed = (args.seed + dist.get_rank()) % 2**32
    mllogger.event(key=mllog_const.SEED, value=local_seed)
    torch.manual_seed(local_seed)
    np.random.seed(seed=local_seed)

    args.rank = dist.get_rank() if args.distributed else args.local_rank
    print("args.rank = {}".format(args.rank))
    print("local rank = {}".format(args.local_rank))
    print("distributed={}".format(args.distributed))

    dboxes = dboxes300_coco()
    encoder = Encoder(dboxes)

    input_size = 300
    train_trans = SSDTransformer(dboxes, (input_size, input_size), val=False,
                                 num_cropping_iterations=args.num_cropping_iterations)
    val_trans = SSDTransformer(dboxes, (input_size, input_size), val=True)

    val_annotate = os.path.join(args.data, "annotations/instances_val2017.json")
    val_coco_root = os.path.join(args.data, "val2017")
    train_annotate = os.path.join(args.data, "annotations/instances_train2017.json")
    train_coco_root = os.path.join(args.data, "train2017")

    cocoGt = COCO(annotation_file=val_annotate)
    train_coco = COCODetection(train_coco_root, train_annotate, train_trans)
    val_coco = COCODetection(val_coco_root, val_annotate, val_trans)
    mllogger.event(key=mllog_const.TRAIN_SAMPLES, value=len(train_coco))
    mllogger.event(key=mllog_const.EVAL_SAMPLES, value=len(val_coco))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_coco)
    else:
        train_sampler = None
    train_dataloader = DataLoader(train_coco,
                                  batch_size=args.batch_size,
                                  shuffle=(train_sampler is None),
                                  sampler=train_sampler,
                                  num_workers=4)
    # set shuffle=True in DataLoader
    if args.rank==0:
        val_dataloader = DataLoader(val_coco,
                                    batch_size=args.val_batch_size or args.batch_size,
                                    shuffle=False,
                                    sampler=None,
                                    num_workers=4)
    else:
        val_dataloader = None

    ssd300 = SSD300(train_coco.labelnum, model_path=args.pretrained_backbone)
    if args.checkpoint is not None:
        print("loading model checkpoint", args.checkpoint)
        od = torch.load(args.checkpoint)
        ssd300.load_state_dict(od["model"])
    ssd300.train()
    if use_cuda:
        ssd300.cuda()
    loss_func = Loss(dboxes)
    if use_cuda:
        loss_func.cuda()
    if args.distributed:
        N_gpu = torch.distributed.get_world_size()
    else:
        N_gpu = 1

	# parallelize
    if args.distributed:
        ssd300 = DDP(ssd300)

    global_batch_size = N_gpu * args.batch_size
    mllogger.event(key=mllog_const.GLOBAL_BATCH_SIZE, value=global_batch_size)
    # Reference doesn't support group batch norm, so bn_span==local_batch_size
    mllogger.event(key=mllog_const.MODEL_BN_SPAN, value=args.batch_size)
    current_lr = args.lr * (global_batch_size / 32)

    assert args.batch_size % args.batch_splits == 0, "--batch-size must be divisible by --batch-splits"
    fragment_size = args.batch_size // args.batch_splits
    if args.batch_splits != 1:
        print("using gradient accumulation with fragments of size {}".format(fragment_size))

    current_momentum = 0.9
    optim = torch.optim.SGD(ssd300.parameters(), lr=current_lr,
                            momentum=current_momentum,
                            weight_decay=args.weight_decay)
    ssd_print(key=mllog_const.OPT_BASE_LR, value=current_lr)
    ssd_print(key=mllog_const.OPT_WEIGHT_DECAY, value=args.weight_decay)

    iter_num = args.iteration
    avg_loss = 0.0
    inv_map = {v:k for k,v in val_coco.label_map.items()}
    success = torch.zeros(1)
    if use_cuda:
        success = success.cuda()


    if args.warmup:
        nonempty_imgs = len(train_coco)
        wb = int(args.warmup * nonempty_imgs / (N_gpu*args.batch_size))
        ssd_print(key=mllog_const.OPT_LR_WARMUP_STEPS, value=wb)
        warmup_step = lambda iter_num, current_lr: lr_warmup(optim, wb, iter_num, current_lr, args)
    else:
        warmup_step = lambda iter_num, current_lr: None

    ssd_print(key=mllog_const.OPT_LR_WARMUP_FACTOR, value=args.warmup_factor)
    ssd_print(key=mllog_const.OPT_LR_DECAY_BOUNDARY_EPOCHS, value=args.lr_decay_schedule)
    mllogger.start(
        key=mllog_const.BLOCK_START,
        metadata={mllog_const.FIRST_EPOCH_NUM: 1,
                  mllog_const.EPOCH_COUNT: args.epochs})

    optim.zero_grad()
    for epoch in range(args.epochs):
        mllogger.start(
            key=mllog_const.EPOCH_START,
            metadata={mllog_const.EPOCH_NUM: epoch})
        # set the epoch for the sampler
        if args.distributed:
            train_sampler.set_epoch(epoch)

        if epoch in args.lr_decay_schedule:
            current_lr *= 0.1
            print("")
            print("lr decay step #{num}".format(num=args.lr_decay_schedule.index(epoch) + 1))
            for param_group in optim.param_groups:
                param_group['lr'] = current_lr

        for nbatch, (img, img_id, img_size, bbox, label) in enumerate(train_dataloader):
            current_batch_size = img.shape[0]
            # Split batch for gradient accumulation
            img = torch.split(img, fragment_size)
            bbox = torch.split(bbox, fragment_size)
            label = torch.split(label, fragment_size)

            for (fimg, fbbox, flabel) in zip(img, bbox, label):
                current_fragment_size = fimg.shape[0]
                trans_bbox = fbbox.transpose(1,2).contiguous()
                if use_cuda:
                    fimg = fimg.cuda()
                    trans_bbox = trans_bbox.cuda()
                    flabel = flabel.cuda()
                fimg = Variable(fimg, requires_grad=True)
                ploc, plabel = ssd300(fimg)
                gloc, glabel = Variable(trans_bbox, requires_grad=False), \
                               Variable(flabel, requires_grad=False)
                loss = loss_func(ploc, plabel, gloc, glabel)
                loss = loss * (current_fragment_size / current_batch_size) # weighted mean
                loss.backward()

            warmup_step(iter_num, current_lr)
            optim.step()
            optim.zero_grad()
            if not np.isinf(loss.item()): avg_loss = 0.999*avg_loss + 0.001*loss.item()
            if args.rank == 0 and args.log_interval and not iter_num % args.log_interval:
                print("Iteration: {:6d}, Loss function: {:5.3f}, Average Loss: {:.3f}"\
                    .format(iter_num, loss.item(), avg_loss))
            iter_num += 1


        if (args.val_epochs and (epoch+1) in args.val_epochs) or \
           (args.val_interval and not (epoch+1) % args.val_interval):
            if args.distributed:
                world_size = float(dist.get_world_size())
                for bn_name, bn_buf in ssd300.module.named_buffers(recurse=True):
                    if ('running_mean' in bn_name) or ('running_var' in bn_name):
                        dist.all_reduce(bn_buf, op=dist.ReduceOp.SUM)
                        bn_buf /= world_size
                        ssd_print(key=mllog_const.MODEL_BN_SPAN,
                            value=bn_buf)
            if args.rank == 0:
                if not args.no_save:
                    print("")
                    print("saving model...")
                    torch.save({"model" : ssd300.state_dict(), "label_map": train_coco.label_info},
                               "./models/iter_{}.pt".format(iter_num))

                if coco_eval(ssd300, val_dataloader, cocoGt, encoder, inv_map,
                             args.threshold, epoch + 1, iter_num,
                             log_interval=args.log_interval,
                             nms_valid_thresh=args.nms_valid_thresh):
                    success = torch.ones(1)
                    if use_cuda:
                        success = success.cuda()
            if args.distributed:
                dist.broadcast(success, 0)
            if success[0]:
                    return True
            mllogger.end(
                key=mllog_const.EPOCH_STOP,
                metadata={mllog_const.EPOCH_NUM: epoch})
    mllogger.end(
        key=mllog_const.BLOCK_STOP,
        metadata={mllog_const.FIRST_EPOCH_NUM: 1,
                  mllog_const.EPOCH_COUNT: args.epochs})

    return False
Example #5
0
def main():
    parser = argparse.ArgumentParser("PyTorch Xview Pipeline")
    arg = parser.add_argument
    arg('--config', metavar='CONFIG_FILE', help='path to configuration file')
    arg('--workers', type=int, default=6, help='number of cpu threads to use')
    arg('--gpu',
        type=str,
        default='0',
        help='List of GPUs for parallel training, e.g. 0,1,2,3')
    arg('--output-dir', type=str, default='weights/')
    arg('--resume', type=str, default='')
    arg('--fold', type=int, default=0)
    arg('--prefix', type=str, default='classifier_')
    arg('--data-dir', type=str, default="/mnt/sota/datasets/deepfake")
    arg('--folds-csv', type=str, default='folds.csv')
    arg('--crops-dir', type=str, default='crops')
    arg('--label-smoothing', type=float, default=0.01)
    arg('--logdir', type=str, default='logs')
    arg('--zero-score', action='store_true', default=False)
    arg('--from-zero', action='store_true', default=False)
    arg('--distributed', action='store_true', default=False)
    arg('--freeze-epochs', type=int, default=0)
    arg("--local_rank", default=0, type=int)
    arg("--seed", default=777, type=int)
    arg("--padding-part", default=3, type=int)
    arg("--opt-level", default='O1', type=str)
    arg("--test_every", type=int, default=1)
    arg("--no-oversample", action="store_true")
    arg("--no-hardcore", action="store_true")
    arg("--only-changed-frames", action="store_true")

    args = parser.parse_args()
    os.makedirs(args.output_dir, exist_ok=True)
    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
    else:
        os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    cudnn.benchmark = True

    conf = load_config(args.config)
    model = classifiers.__dict__[conf['network']](encoder=conf['encoder'])

    model = model.cuda()
    if args.distributed:
        model = convert_syncbn_model(model)
    ohem = conf.get("ohem_samples", None)
    reduction = "mean"
    if ohem:
        reduction = "none"
    loss_fn = []
    weights = []
    for loss_name, weight in conf["losses"].items():
        loss_fn.append(losses.__dict__[loss_name](reduction=reduction).cuda())
        weights.append(weight)
    loss = WeightedLosses(loss_fn, weights)
    loss_functions = {"classifier_loss": loss}
    optimizer, scheduler = create_optimizer(conf['optimizer'], model)
    bce_best = 100
    start_epoch = 0
    batch_size = conf['optimizer']['batch_size']

    data_train = DeepFakeClassifierDataset(
        mode="train",
        oversample_real=not args.no_oversample,
        fold=args.fold,
        padding_part=args.padding_part,
        hardcore=not args.no_hardcore,
        crops_dir=args.crops_dir,
        data_path=args.data_dir,
        label_smoothing=args.label_smoothing,
        folds_csv=args.folds_csv,
        transforms=create_train_transforms(conf["size"]),
        normalize=conf.get("normalize", None))
    data_val = DeepFakeClassifierDataset(mode="val",
                                         fold=args.fold,
                                         padding_part=args.padding_part,
                                         crops_dir=args.crops_dir,
                                         data_path=args.data_dir,
                                         folds_csv=args.folds_csv,
                                         transforms=create_val_transforms(
                                             conf["size"]),
                                         normalize=conf.get("normalize", None))
    val_data_loader = DataLoader(data_val,
                                 batch_size=batch_size * 2,
                                 num_workers=args.workers,
                                 shuffle=False,
                                 pin_memory=False)
    os.makedirs(args.logdir, exist_ok=True)
    summary_writer = SummaryWriter(args.logdir + '/' +
                                   conf.get("prefix", args.prefix) +
                                   conf['encoder'] + "_" + str(args.fold))
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location='cpu')
            state_dict = checkpoint['state_dict']
            state_dict = {k[7:]: w for k, w in state_dict.items()}
            model.load_state_dict(state_dict, strict=False)
            if not args.from_zero:
                start_epoch = checkpoint['epoch']
                if not args.zero_score:
                    bce_best = checkpoint.get('bce_best', 0)
            print("=> loaded checkpoint '{}' (epoch {}, bce_best {})".format(
                args.resume, checkpoint['epoch'], checkpoint['bce_best']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    if args.from_zero:
        start_epoch = 0
    current_epoch = start_epoch

    if conf['fp16']:
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.opt_level,
                                          loss_scale='dynamic')

    snapshot_name = "{}{}_{}_{}".format(conf.get("prefix",
                                                 args.prefix), conf['network'],
                                        conf['encoder'], args.fold)

    if args.distributed:
        model = DistributedDataParallel(model, delay_allreduce=True)
    else:
        model = DataParallel(model).cuda()

    # register each block, in order to extract the blocks' feature maps
    for name, block in model.encoder.blocks.named_children():
        block.register_forward_hook(hook_function)

    data_val.reset(1, args.seed)
    max_epochs = conf['optimizer']['schedule']['epochs']

    for epoch in range(start_epoch, max_epochs):
        data_train.reset(epoch, args.seed)
        train_sampler = None
        if args.distributed:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                data_train)
            train_sampler.set_epoch(epoch)
        if epoch < args.freeze_epochs:
            print("Freezing encoder!!!")
            model.module.encoder.eval()
            for p in model.module.encoder.parameters():
                p.requires_grad = False
        else:
            model.module.encoder.train()
            for p in model.module.encoder.parameters():
                p.requires_grad = True

        train_data_loader = DataLoader(data_train,
                                       batch_size=batch_size,
                                       num_workers=args.workers,
                                       shuffle=train_sampler is None,
                                       sampler=train_sampler,
                                       pin_memory=False,
                                       drop_last=True)

        train_epoch(current_epoch, loss_functions, model, optimizer, scheduler,
                    train_data_loader, summary_writer, conf, args.local_rank,
                    args.only_changed_frames)
        model = model.eval()

        if args.local_rank == 0:
            torch.save(
                {
                    'epoch': current_epoch + 1,
                    'state_dict': model.state_dict(),
                    'bce_best': bce_best,
                }, args.output_dir + '/' + snapshot_name + "_last")
            torch.save(
                {
                    'epoch': current_epoch + 1,
                    'state_dict': model.state_dict(),
                    'bce_best': bce_best,
                },
                args.output_dir + snapshot_name + "_{}".format(current_epoch))
            if (epoch + 1) % args.test_every == 0:
                bce_best = evaluate_val(args,
                                        val_data_loader,
                                        bce_best,
                                        model,
                                        snapshot_name=snapshot_name,
                                        current_epoch=current_epoch,
                                        summary_writer=summary_writer)
        current_epoch += 1
Example #6
0
class BaseTrainer(object):
    def __init__(self, cfg, model, train_dl, val_dl, loss_func, num_query,
                 num_gpus):
        self.cfg = cfg
        self.model = model
        self.train_dl = train_dl
        self.val_dl = val_dl
        self.loss_func = loss_func
        self.num_query = num_query

        self.loss_avg = AvgerageMeter()
        self.acc_avg = AvgerageMeter()
        self.train_epoch = 1
        self.batch_cnt = 0

        self.logger = logging.getLogger('reid_baseline.train')
        self.log_period = cfg.SOLVER.LOG_PERIOD
        self.checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
        self.eval_period = cfg.SOLVER.EVAL_PERIOD
        self.output_dir = cfg.OUTPUT_DIR
        self.device = cfg.MODEL.DEVICE
        self.epochs = cfg.SOLVER.MAX_EPOCHS

        if num_gpus > 1:
            # convert to use sync_bn
            self.logger.info(
                'More than one gpu used, convert model to use SyncBN.')
            if cfg.SOLVER.FP16:
                # TODO: Multi-GPU model with FP16
                raise NotImplementedError
                self.logger.info(
                    'Using apex to perform SyncBN and FP16 training')
                torch.distributed.init_process_group(backend='nccl',
                                                     init_method='env://')
                self.model = apex.parallel.convert_syncbn_model(self.model)
            else:
                # Multi-GPU model without FP16
                self.model = nn.DataParallel(self.model)
                self.model = convert_model(self.model)
                self.model.cuda()
                self.logger.info('Using pytorch SyncBN implementation')

                self.optim = make_optimizer(cfg, self.model, num_gpus)
                self.scheduler = WarmupMultiStepLR(self.optim,
                                                   cfg.SOLVER.STEPS,
                                                   cfg.SOLVER.GAMMA,
                                                   cfg.SOLVER.WARMUP_FACTOR,
                                                   cfg.SOLVER.WARMUP_ITERS,
                                                   cfg.SOLVER.WARMUP_METHOD)
                self.scheduler.step()
                self.mix_precision = False
                self.logger.info('Trainer Built')
                return
        else:
            # Single GPU model
            self.model.cuda()
            self.optim = make_optimizer(cfg, self.model, num_gpus)
            self.scheduler = WarmupMultiStepLR(self.optim, cfg.SOLVER.STEPS,
                                               cfg.SOLVER.GAMMA,
                                               cfg.SOLVER.WARMUP_FACTOR,
                                               cfg.SOLVER.WARMUP_ITERS,
                                               cfg.SOLVER.WARMUP_METHOD)
            self.scheduler.step()
            self.mix_precision = False
            if cfg.SOLVER.FP16:
                # Single model using FP16
                self.model, self.optim = amp.initialize(self.model,
                                                        self.optim,
                                                        opt_level='O1')
                self.mix_precision = True
                self.logger.info('Using fp16 training')
            self.logger.info('Trainer Built')
            return

        # TODO: Multi-GPU model with FP16
        raise NotImplementedError
        self.model.to(self.device)
        self.optim = make_optimizer(cfg, self.model, num_gpus)
        self.scheduler = WarmupMultiStepLR(self.optim, cfg.SOLVER.STEPS,
                                           cfg.SOLVER.GAMMA,
                                           cfg.SOLVER.WARMUP_FACTOR,
                                           cfg.SOLVER.WARMUP_ITERS,
                                           cfg.SOLVER.WARMUP_METHOD)
        self.scheduler.step()

        self.model, self.optim = amp.initialize(self.model,
                                                self.optim,
                                                opt_level='O1')
        self.mix_precision = True
        self.logger.info('Using fp16 training')

        self.model = DDP(self.model, delay_allreduce=True)
        self.logger.info('Convert model using apex')
        self.logger.info('Trainer Built')

    def handle_new_batch(self):
        self.batch_cnt += 1
        if self.batch_cnt % self.cfg.SOLVER.LOG_PERIOD == 0:
            self.logger.info('Epoch[{}] Iteration[{}/{}] Loss: {:.3f},'
                             'Acc: {:.3f}, Base Lr: {:.2e}'.format(
                                 self.train_epoch, self.batch_cnt,
                                 len(self.train_dl), self.loss_avg.avg,
                                 self.acc_avg.avg,
                                 self.scheduler.get_lr()[0]))

    def handle_new_epoch(self):
        self.batch_cnt = 1
        self.scheduler.step()
        self.logger.info('Epoch {} done'.format(self.train_epoch))
        self.logger.info('-' * 20)
        if self.train_epoch % self.checkpoint_period == 0:
            self.save()
        if self.train_epoch % self.eval_period == 0:
            self.evaluate()
        self.train_epoch += 1

    def step(self, batch):
        self.model.train()
        self.optim.zero_grad()
        img, target = batch
        img, target = img.cuda(), target.cuda()
        score, feat = self.model(img)
        loss = self.loss_func(score, feat, target)
        if self.mix_precision:
            with amp.scale_loss(loss, self.optim) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        self.optim.step()

        acc = (score.max(1)[1] == target).float().mean()

        self.loss_avg.update(loss.cpu().item())
        self.acc_avg.update(acc.cpu().item())

        return self.loss_avg.avg, self.acc_avg.avg

    def evaluate(self):
        self.model.eval()
        num_query = self.num_query
        feats, pids, camids = [], [], []
        with torch.no_grad():
            for batch in tqdm(self.val_dl, total=len(self.val_dl),
                              leave=False):
                data, pid, camid, _ = batch
                data = data.cuda()
                feat = self.model(data).detach().cpu()
                feats.append(feat)
                pids.append(pid)
                camids.append(camid)
        feats = torch.cat(feats, dim=0)
        pids = torch.cat(pids, dim=0)
        camids = torch.cat(camids, dim=0)

        query_feat = feats[:num_query]
        query_pid = pids[:num_query]
        query_camid = camids[:num_query]

        gallery_feat = feats[num_query:]
        gallery_pid = pids[num_query:]
        gallery_camid = camids[num_query:]

        distmat = euclidean_dist(query_feat, gallery_feat)

        cmc, mAP, _ = eval_func(distmat.numpy(),
                                query_pid.numpy(),
                                gallery_pid.numpy(),
                                query_camid.numpy(),
                                gallery_camid.numpy(),
                                use_cython=self.cfg.SOLVER.CYTHON)
        self.logger.info('Validation Result:')
        for r in self.cfg.TEST.CMC:
            self.logger.info('CMC Rank-{}: {:.2%}'.format(r, cmc[r - 1]))
        self.logger.info('mAP: {:.2%}'.format(mAP))
        self.logger.info('-' * 20)

    def save(self):
        torch.save(
            self.model.state_dict(),
            osp.join(
                self.output_dir, self.cfg.MODEL.NAME + '_epoch' +
                str(self.train_epoch) + '.pth'))
        torch.save(
            self.optim.state_dict(),
            osp.join(
                self.output_dir, self.cfg.MODEL.NAME + '_epoch' +
                str(self.train_epoch) + '_optim.pth'))
Example #7
0
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--data_dir",
                        default=None,
                        type=str,
                        required=True,
                        help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
    parser.add_argument("--bert_model", default=None, type=str, required=True,
                        help="Bert pre-trained model selected in the list: bert-base-uncased, "
                        "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
                        "bert-base-multilingual-cased, bert-base-chinese.")
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        help="The name of the task to train.")
    parser.add_argument("--output_dir",
                        default=None,
                        type=str,
                        required=True,
                        help="The output directory where the model predictions and checkpoints will be written.")

    ## Other parameters
    parser.add_argument("--bert_original",
                        action='store_true',
                        help="To run for original BERT")
    parser.add_argument("--cache_dir",
                        default="",
                        type=str,
                        help="Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument("--max_seq_length",
                        default=128,
                        type=int,
                        help="The maximum total input sequence length after WordPiece tokenization. \n"
                             "Sequences longer than this will be truncated, and sequences shorter \n"
                             "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=8,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--warmup_proportion",
                        default=0.1,
                        type=float,
                        help="Proportion of training to perform linear learning rate warmup for. "
                             "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--gradient_accumulation_steps',
                        type=int,
                        default=1,
                        help="Number of updates steps to accumulate before performing a backward/update pass.")
    parser.add_argument('--fp16',
                        action='store_true',
                        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--loss_scale',
                        type=float, default=0,
                        help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
                             "0 (default value): dynamic loss scaling.\n"
                             "Positive power of 2: static loss scaling value.\n")
    args = parser.parse_args()

    processors = {
        "nsp": NSPProcessor,
    }

    num_labels_task = {
        "nsp": 2,
    }

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
        device, n_gpu, bool(args.local_rank != -1), args.fp16))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError("At least one of `do_train` or `do_eval` must be True.")

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    task_name = args.task_name.lower()

    processor = processors[task_name]()
    num_labels = num_labels_task[task_name]
    label_list = processor.get_labels()

    tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=True)

    train_examples = None
    num_train_optimization_steps = None
    if args.do_train:
        train_examples = processor.get_train_examples(args.data_dir)
        num_train_optimization_steps = int(
            len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs
        if args.local_rank != -1:
            num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()


    cache_dir = args.cache_dir if args.cache_dir else os.path.join(PYTORCH_PRETRAINED_BERT_CACHE, 'distributed_{}'.format(args.local_rank))
    model = BertForSequenceClassification.from_pretrained(args.bert_model,
            cache_dir=cache_dir,
            num_labels = num_labels)
    print('BERT original model loaded')
        
    
    if args.fp16:
        model.half()
    model.to(device)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")

        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
    if args.fp16:
        try:
            from apex.optimizers import FP16_Optimizer
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale)

    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=num_train_optimization_steps)

    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0
    if args.do_train:
        train_features = convert_examples_to_features(
            train_examples, label_list, args.max_seq_length, tokenizer)
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_examples))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)
        all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long)
        train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)

        model.train()
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids = batch
                loss = model(input_ids, segment_ids, input_mask, label_ids)
                if n_gpu > 1:
                    loss = loss.mean() # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()

                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        # if args.fp16 is False, BertAdam is used that handles this automatically
                        lr_this_step = args.learning_rate * warmup_linear(global_step/num_train_optimization_steps, args.warmup_proportion)
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1
    
    # save model
    torch.save(model.state_dict(), os.path.join(args.output_dir, 'nsp_model.pt'))

    if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
        eval_examples = processor.get_dev_examples(args.data_dir)
        eval_features = convert_examples_to_features(
            eval_examples, label_list, args.max_seq_length, tokenizer)
        logger.info("***** Running evaluation *****")
        logger.info("  Num examples = %d", len(eval_examples))
        logger.info("  Batch size = %d", args.eval_batch_size)
        all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
        eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
        # Run prediction for full data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)

        model.eval()
        eval_loss, eval_accuracy = 0, 0
        nb_eval_steps, nb_eval_examples = 0, 0
 
        for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"):
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            label_ids = label_ids.to(device)

            with torch.no_grad():
                tmp_eval_loss = model(input_ids, segment_ids, input_mask, label_ids)
                logits = model(input_ids, segment_ids, input_mask)

            logits = logits.detach().cpu().numpy()
            label_ids = label_ids.to('cpu').numpy()
            tmp_eval_accuracy = accuracy(logits, label_ids)

            eval_loss += tmp_eval_loss.mean().item()
            eval_accuracy += tmp_eval_accuracy

            nb_eval_examples += input_ids.size(0)
            nb_eval_steps += 1

        eval_loss = eval_loss / nb_eval_steps
        eval_accuracy = eval_accuracy / nb_eval_examples
        loss = tr_loss/nb_tr_steps if args.do_train else None
        result = {'eval_loss': eval_loss,
                  'eval_accuracy': eval_accuracy,
                  'global_step': global_step,
                  'loss': loss}

        output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
        with open(output_eval_file, "w") as writer:
            logger.info("***** Eval results *****")
            for key in sorted(result.keys()):
                logger.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))
Example #8
0
def train(train_loop_func, logger, args):
    # Check that GPUs are actually available
    use_cuda = not args.no_cuda

    # Setup multi-GPU if necessary
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1

    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.N_gpu = torch.distributed.get_world_size()
    else:
        args.N_gpu = 1

    if args.seed is None:
        args.seed = np.random.randint(1e4)

    if args.distributed:
        args.seed = (args.seed + torch.distributed.get_rank()) % 2**32
    print("Using seed = {}".format(args.seed))
    torch.manual_seed(args.seed)
    np.random.seed(seed=args.seed)

    # Setup data, defaults
    dboxes = dboxes300_coco()
    encoder = Encoder(dboxes)
    cocoGt = get_coco_ground_truth(args)

    train_loader = get_train_loader(args, args.seed - 2**31)

    val_dataset = get_val_dataset(args)
    val_dataloader = get_val_dataloader(val_dataset, args)

    ssd300 = SSD300(backbone=ResNet(args.backbone, args.backbone_path))
    args.learning_rate = args.learning_rate * args.N_gpu * (args.batch_size /
                                                            32)
    start_epoch = 0
    iteration = 0
    loss_func = Loss(dboxes)

    if use_cuda:
        ssd300.cuda()
        loss_func.cuda()

    optimizer = torch.optim.SGD(tencent_trick(ssd300),
                                lr=args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    scheduler = MultiStepLR(optimizer=optimizer,
                            milestones=args.multistep,
                            gamma=0.1)
    if args.amp:
        ssd300, optimizer = amp.initialize(ssd300, optimizer, opt_level='O2')

    if args.distributed:
        ssd300 = DDP(ssd300)

    if args.checkpoint is not None:
        if os.path.isfile(args.checkpoint):
            load_checkpoint(ssd300.module if args.distributed else ssd300,
                            args.checkpoint)
            checkpoint = torch.load(args.checkpoint,
                                    map_location=lambda storage, loc: storage.
                                    cuda(torch.cuda.current_device()))
            start_epoch = checkpoint['epoch']
            iteration = checkpoint['iteration']
            scheduler.load_state_dict(checkpoint['scheduler'])
            optimizer.load_state_dict(checkpoint['optimizer'])
        else:
            print('Provided checkpoint is not path to a file')
            return

    inv_map = {v: k for k, v in val_dataset.label_map.items()}

    total_time = 0

    if args.mode == 'evaluation':
        acc = evaluate(ssd300, val_dataloader, cocoGt, encoder, inv_map, args)
        if args.local_rank == 0:
            print('Model precision {} mAP'.format(acc))
        return

    mean, std = generate_mean_std(args)

    for epoch in range(start_epoch, args.epochs):
        start_epoch_time = time.time()
        scheduler.step()
        iteration = train_loop_func(ssd300, loss_func, epoch, optimizer,
                                    train_loader, val_dataloader, encoder,
                                    iteration, logger, args, mean, std)
        end_epoch_time = time.time() - start_epoch_time
        total_time += end_epoch_time

        if args.local_rank == 0:
            logger.update_epoch_time(epoch, end_epoch_time)

        if epoch in args.evaluation:
            acc = evaluate(ssd300, val_dataloader, cocoGt, encoder, inv_map,
                           args)

            if args.local_rank == 0:
                logger.update_epoch(epoch, acc)

        if args.save and args.local_rank == 0:
            print("saving model...")
            obj = {
                'epoch': epoch + 1,
                'iteration': iteration,
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
                'label_map': val_dataset.label_info
            }
            if args.distributed:
                obj['model'] = ssd300.module.state_dict()
            else:
                obj['model'] = ssd300.state_dict()
            save_path = os.path.join(args.save, f'epoch_{epoch}.pt')
            torch.save(obj, save_path)
            logger.log('model path', save_path)
        train_loader.reset()
    DLLogger.log((), {'total time': total_time})
    logger.log_summary()
Example #9
0
def main():
    parser = argparse.ArgumentParser("PyTorch Xview Pipeline")
    arg = parser.add_argument
    arg('--config', metavar='CONFIG_FILE', help='path to configuration file')
    arg('--workers', type=int, default=6, help='number of cpu threads to use')
    arg('--gpu',
        type=str,
        default='0',
        help='List of GPUs for parallel training, e.g. 0,1,2,3')
    arg('--output-dir', type=str, default='weights/')
    arg('--resume', type=str, default='')
    arg('--fold', type=int, default=0)
    arg('--prefix', type=str, default='damage_')
    arg('--data-dir', type=str, default="/home/selim/datasets/xview/train")
    arg('--folds-csv', type=str, default='folds.csv')
    arg('--logdir', type=str, default='logs')
    arg('--zero-score', action='store_true', default=False)
    arg('--from-zero', action='store_true', default=False)
    arg('--distributed', action='store_true', default=False)
    arg('--freeze-epochs', type=int, default=1)
    arg("--local_rank", default=0, type=int)
    arg("--opt-level", default='O1', type=str)
    arg("--predictions", default="../oof_preds", type=str)
    arg("--test_every", type=int, default=1)

    args = parser.parse_args()

    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
    else:
        os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    cudnn.benchmark = True

    conf = load_config(args.config)
    model = models.__dict__[conf['network']](seg_classes=conf['num_classes'],
                                             backbone_arch=conf['encoder'])

    model = model.cuda()
    if args.distributed:
        model = convert_syncbn_model(model)
    damage_loss_function = losses.__dict__[conf["damage_loss"]["type"]](
        **conf["damage_loss"]["params"]).cuda()
    mask_loss_function = losses.__dict__[conf["mask_loss"]["type"]](
        **conf["mask_loss"]["params"]).cuda()
    loss_functions = {
        "damage_loss": damage_loss_function,
        "mask_loss": mask_loss_function
    }
    optimizer, scheduler = create_optimizer(conf['optimizer'], model)

    dice_best = 0
    xview_best = 0
    start_epoch = 0
    batch_size = conf['optimizer']['batch_size']

    data_train = XviewSingleDataset(
        mode="train",
        fold=args.fold,
        data_path=args.data_dir,
        folds_csv=args.folds_csv,
        transforms=create_train_transforms(conf['input']),
        multiplier=conf["data_multiplier"],
        normalize=conf["input"].get("normalize", None))
    data_val = XviewSingleDataset(
        mode="val",
        fold=args.fold,
        data_path=args.data_dir,
        folds_csv=args.folds_csv,
        transforms=create_val_transforms(conf['input']),
        normalize=conf["input"].get("normalize", None))
    train_sampler = None
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            data_train)

    train_data_loader = DataLoader(data_train,
                                   batch_size=batch_size,
                                   num_workers=args.workers,
                                   shuffle=train_sampler is None,
                                   sampler=train_sampler,
                                   pin_memory=False,
                                   drop_last=True)
    val_batch_size = 1
    val_data_loader = DataLoader(data_val,
                                 batch_size=val_batch_size,
                                 num_workers=args.workers,
                                 shuffle=False,
                                 pin_memory=False)

    os.makedirs(args.logdir, exist_ok=True)
    summary_writer = SummaryWriter(args.logdir + '/' + args.prefix +
                                   conf['encoder'])
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location='cpu')
            state_dict = checkpoint['state_dict']
            if conf['optimizer'].get('zero_decoder', False):
                for key in state_dict.copy().keys():
                    if key.startswith("module.final"):
                        del state_dict[key]
            state_dict = {k[7:]: w for k, w in state_dict.items()}
            model.load_state_dict(state_dict, strict=False)
            if not args.from_zero:
                start_epoch = checkpoint['epoch']
                if not args.zero_score:
                    dice_best = checkpoint.get('dice_best', 0)
                    xview_best = checkpoint.get('xview_best', 0)
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    if args.from_zero:
        start_epoch = 0
    current_epoch = start_epoch

    if conf['fp16']:
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.opt_level,
                                          loss_scale='dynamic')

    snapshot_name = "{}{}_{}_{}".format(args.prefix, conf['network'],
                                        conf['encoder'], args.fold)

    if args.distributed:
        model = DistributedDataParallel(model, delay_allreduce=True)
    else:
        model = DataParallel(model).cuda()
    for epoch in range(start_epoch, conf['optimizer']['schedule']['epochs']):
        if epoch < args.freeze_epochs:
            print("Freezing encoder!!!")
            if hasattr(model.module, 'encoder_stages1'):
                model.module.encoder_stages1.eval()
                model.module.encoder_stages2.eval()
                for p in model.module.encoder_stages1.parameters():
                    p.requires_grad = False
                for p in model.module.encoder_stages2.parameters():
                    p.requires_grad = False
            else:
                model.module.encoder_stages.eval()
                for p in model.module.encoder_stages.parameters():
                    p.requires_grad = False
        else:
            if hasattr(model.module, 'encoder_stages1'):
                print("Unfreezing encoder!!!")
                model.module.encoder_stages1.train()
                model.module.encoder_stages2.train()
                for p in model.module.encoder_stages1.parameters():
                    p.requires_grad = True
                for p in model.module.encoder_stages2.parameters():
                    p.requires_grad = True
            else:
                model.module.encoder_stages.train()
                for p in model.module.encoder_stages.parameters():
                    p.requires_grad = True
        train_epoch(current_epoch, loss_functions, model, optimizer, scheduler,
                    train_data_loader, summary_writer, conf, args.local_rank)

        model = model.eval()
        if args.local_rank == 0:
            torch.save(
                {
                    'epoch': current_epoch + 1,
                    'state_dict': model.state_dict(),
                    'dice_best': dice_best,
                    'xview_best': xview_best,
                }, args.output_dir + '/' + snapshot_name + "_last")
            if epoch % args.test_every == 0:
                preds_dir = os.path.join(args.predictions, snapshot_name)
                dice_best, xview_best = evaluate_val(
                    args,
                    val_data_loader,
                    xview_best,
                    dice_best,
                    model,
                    snapshot_name=snapshot_name,
                    current_epoch=current_epoch,
                    optimizer=optimizer,
                    summary_writer=summary_writer,
                    predictions_dir=preds_dir)
        current_epoch += 1
# main
best_prec1 = 0
for epoch in range(args.epochs):
    if args.distributed:
        train_sampler.set_epoch(epoch)

    lr_scheduler.step()
    # train for one epoch
    train(train_loader, model, criterion, optimizer, epoch)

    # evaluate on validation set
    prec1 = validate(test_loader, model, criterion, epoch)

    # remember best prec@1 and save checkpoint
    if args.local_rank == 0:
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        torch.save(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict(),
            }, os.path.join(args.logdir, 'checkpoint.pth.tar'))

        if is_best:
            shutil.copyfile(os.path.join(args.logdir, 'checkpoint.pth.tar'),
                            os.path.join(args.logdir, 'model_best.pth.tar'))
            print(' * Save best model @ Epoch {}\n'.format(epoch))
Example #11
0
def main():
    global best_prec1, args

    args.gpu = 0
    args.world_size = 1

    if args.distributed:
        args.gpu = args.local_rank % torch.cuda.device_count()
        torch.cuda.set_device(args.gpu)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()

    args.total_batch_size = args.world_size * args.batch_size

    if not os.path.isdir(args.checkpoint) and args.local_rank == 0:
        mkdir_p(args.checkpoint)

    assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled."

    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True)
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch]()

    model = model.cuda()

    # args.lr = float(args.lr * float(args.batch_size * args.world_size) / 256.)  # default args.lr = 0.1 -> 256
    optimizer = set_optimizer(model)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()


    model, optimizer = amp.initialize(model, optimizer,
                                      opt_level=args.opt_level,
                                      keep_batchnorm_fp32=args.keep_batchnorm_fp32,
                                      loss_scale=args.loss_scale,
                                      verbosity = 0)

    model = DDP(model, delay_allreduce=True)


    # optionally resume from a checkpoint
    title = 'ImageNet-' + args.arch
    args.lastepoch =-1
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location=lambda storage, loc: storage.cuda(args.gpu))
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
            args.lastepoch = checkpoint['epoch']
            if args.local_rank == 0:
                logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title, resume=True)
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    else:
        if args.local_rank == 0:
            logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title)
            logger.set_names(['Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.', 'Valid Acc.', 'Valid Top5.'])

    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')

    if(args.arch == "inception_v3"):
        crop_size = 299
        val_size = 320 # I chose this value arbitrarily, we can adjust.
    else:
        crop_size = 224
        val_size = 256

    pipe = HybridTrainPipe(batch_size=args.batch_size, num_threads=args.workers, device_id=args.local_rank, data_dir=traindir, crop=crop_size, dali_cpu=args.dali_cpu)
    pipe.build()
    train_loader = DALIClassificationIterator(pipe, size=int(pipe.epoch_size("Reader") / args.world_size))

    pipe = HybridValPipe(batch_size=args.test_batch, num_threads=4, device_id=args.local_rank, data_dir=valdir, crop=crop_size, size=val_size)
    pipe.build()
    val_loader = DALIClassificationIterator(pipe, size=int(pipe.epoch_size("Reader") / args.world_size))

    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    train_loader_len = int(train_loader._size / args.batch_size)
    if args.resume:
        scheduler = CosineAnnealingLR(optimizer, args.epochs, train_loader_len,
                                      eta_min=0., last_epoch=args.lastepoch, warmup=args.warmup)
    else:
        scheduler = CosineAnnealingLR(optimizer,
                                      args.epochs, train_loader_len, eta_min=0., warmup=args.warmup)
    total_time = AverageMeter()
    for epoch in range(args.start_epoch, args.epochs):
        # train for one epoch

        if args.local_rank == 0:
            print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, args.epochs, optimizer.param_groups[0]['lr']))

        [train_loss, train_acc, avg_train_time] = train(train_loader, model, criterion, optimizer, epoch,scheduler)
        total_time.update(avg_train_time)
        # evaluate on validation set
        [test_loss, prec1, prec5] = validate(val_loader, model, criterion)

        # remember best prec@1 and save checkpoint
        if args.local_rank == 0:
            # append logger file
            logger.append([optimizer.param_groups[0]['lr'], train_loss, test_loss, train_acc, prec1, prec5])

            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint({
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict(),
            }, is_best,checkpoint=args.checkpoint)
            if epoch == args.epochs - 1:
                print('##Top-1 {0}\n'
                      '##Top-5 {1}\n'
                      '##Perf  {2}'.format(prec1, prec5, args.total_batch_size / total_time.avg))

        # reset DALI iterators
        train_loader.reset()
        val_loader.reset()

    if args.local_rank == 0:
        logger.close()
Example #12
0
def main():
    global best_prec1, args

    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1

    args.gpu = 0
    args.world_size = 1

    if args.distributed:
        args.gpu = args.local_rank
        torch.cuda.set_device(args.gpu)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()

    assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled."

    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True)
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch]()

    if args.sync_bn:
        import apex
        print("using apex synced BN")
        model = apex.parallel.convert_syncbn_model(model)

    model = model.cuda()

    # Scale learning rate based on global batch size
    args.lr = args.lr * float(args.batch_size * args.world_size) / 256.
    print('learning rate: ', args.lr)
    param = model.parameters()
    optimizer = torch.optim.SGD(param,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    if args.lr_adjust_type == 'step':
        scheduler = MultiStepLR(optimizer,
                                milestones=args.lr_adjust_step,
                                gamma=0.1)
    elif args.lr_adjust_type == 'cosine':
        scheduler = CosineAnnealingLR(optimizer, args.epochs)
    elif args.lr_adjust_type == 'exp':
        scheduler = ExponentialLR(optimizer, args.gamma)

    # Initialize Amp.  Amp accepts either values or strings for the optional override arguments,
    # for convenient interoperation with argparse.
    model, optimizer = amp.initialize(
        model,
        optimizer,
        opt_level=args.opt_level,
        keep_batchnorm_fp32=args.keep_batchnorm_fp32,
        loss_scale=args.loss_scale)

    # For distributed training, wrap the model with apex.parallel.DistributedDataParallel.
    # This must be done AFTER the call to amp.initialize.  If model = DDP(model) is called
    # before model, ... = amp.initialize(model, ...), the call to amp.initialize may alter
    # the types of model's parameters in a way that disrupts or destroys DDP's allreduce hooks.
    if args.distributed:
        # By default, apex.parallel.DistributedDataParallel overlaps communication with
        # computation in the backward pass.
        # model = DDP(model)
        # delay_allreduce delays all communication to the end of the backward pass.
        model = DDP(model, delay_allreduce=True)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    # Optionally resume from a checkpoint
    if args.resume:
        # Use a local scope to avoid dangling references
        def resume():
            if os.path.isfile(args.resume):
                print("=> loading checkpoint '{}'".format(args.resume))
                checkpoint = torch.load(
                    args.resume,
                    map_location=lambda storage, loc: storage.cuda(args.gpu))
                args.start_epoch = checkpoint['epoch']
                best_prec1 = checkpoint['best_prec1']
                model.load_state_dict(checkpoint['state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer'])
                print("=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint['epoch']))
            else:
                print("=> no checkpoint found at '{}'".format(args.resume))

        resume()

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')

    if (args.arch == "inception_v3"):
        raise RuntimeError(
            "Currently, inception_v3 is not supported by this example.")
        # crop_size = 299
        # val_size = 320 # I chose this value arbitrarily, we can adjust.
    else:
        crop_size = 224
        val_size = 256

    trans = transforms.Compose([
        transforms.RandomResizedCrop(crop_size),
        transforms.RandomHorizontalFlip(),
        # transforms.ToTensor(), Too slow
        # normalize,
    ])
    train_dataset = datasets.ImageFolder(traindir, trans)
    val_dataset = datasets.ImageFolder(
        valdir,
        transforms.Compose([
            transforms.Resize(val_size),
            transforms.CenterCrop(crop_size),
        ]))

    train_sampler = None
    val_sampler = None
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
        val_sampler = torch.utils.data.distributed.DistributedSampler(
            val_dataset)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler,
                                               collate_fn=fast_collate)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True,
                                             sampler=val_sampler,
                                             collate_fn=fast_collate)

    if args.evaluate:
        validate(val_loader, model, criterion)
        return
    st_time = time.time()
    prec1 = 0.
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)

        if args.grid:
            grid.set_prob(epoch, args.st_epochs)

        # train for one epoch
        #adjust_learning_rate(scheduler, optimizer, epoch, 1, 1)
        train(train_loader, model, criterion, optimizer, epoch, scheduler)

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion)

        # remember best prec@1 and save checkpoint
        if args.local_rank == 0:
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            print(epoch)
            print('Learning rate:', optimizer.param_groups[0]['lr'])
            print('Total Time: ' + format_time(time.time() - st_time))
            print('Remaining Time: ' +
                  format_time((time.time() - st_time) /
                              (epoch - args.start_epoch + 1) *
                              (args.epochs - epoch - 1)))
            print('Best Acc: ' + str(best_prec1))
            save_checkpoint(
                args, {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                    'optimizer': optimizer.state_dict(),
                }, is_best)
Example #13
0
def main():
    global best_prec1, args

    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1

    args.gpu = 0
    args.world_size = 1

    if args.distributed:
        args.gpu = args.local_rank % torch.cuda.device_count()
        torch.cuda.set_device(args.gpu)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()

    assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled."

    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True)
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch]()

    if args.sync_bn:
        import apex
        print("using apex synced BN")
        model = apex.parallel.convert_syncbn_model(model)

    model = model.cuda()

    # Scale learning rate based on global batch size
    args.lr = args.lr * float(args.batch_size * args.world_size) / 256.
    if args.fused_adam:
        optimizer = optimizers.FusedAdam(model.parameters())
    else:
        optimizer = torch.optim.SGD(model.parameters(),
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)

    model, optimizer = amp.initialize(
        model,
        optimizer,
        # enabled=False,
        opt_level=args.opt_level,
        keep_batchnorm_fp32=args.keep_batchnorm_fp32,
        loss_scale=args.loss_scale)

    if args.distributed:
        # By default, apex.parallel.DistributedDataParallel overlaps communication with
        # computation in the backward pass.
        # model = DDP(model)
        # delay_allreduce delays all communication to the end of the backward pass.
        model = DDP(model, delay_allreduce=True)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    # Optionally resume from a checkpoint
    if args.resume:
        # Use a local scope to avoid dangling references
        def resume():
            if os.path.isfile(args.resume):
                print("=> loading checkpoint '{}'".format(args.resume))
                checkpoint = torch.load(
                    args.resume,
                    map_location=lambda storage, loc: storage.cuda(args.gpu))
                args.start_epoch = checkpoint['epoch']
                best_prec1 = checkpoint['best_prec1']
                model.load_state_dict(checkpoint['state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer'])
                print("=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint['epoch']))
            else:
                print("=> no checkpoint found at '{}'".format(args.resume))

        resume()

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')

    if (args.arch == "inception_v3"):
        crop_size = 299
        val_size = 320  # I chose this value arbitrarily, we can adjust.
    else:
        crop_size = 224
        val_size = 256

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(crop_size),
            transforms.RandomHorizontalFlip(),
            # transforms.ToTensor(), Too slow
            # normalize,
        ]))
    val_dataset = datasets.ImageFolder(
        valdir,
        transforms.Compose([
            transforms.Resize(val_size),
            transforms.CenterCrop(crop_size),
        ]))

    train_sampler = None
    val_sampler = None
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
        val_sampler = torch.utils.data.distributed.DistributedSampler(
            val_dataset)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler,
                                               collate_fn=fast_collate)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True,
                                             sampler=val_sampler,
                                             collate_fn=fast_collate)

    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)
        if args.prof:
            break
        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion)

        # remember best prec@1 and save checkpoint
        if args.local_rank == 0:
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                    'optimizer': optimizer.state_dict(),
                }, is_best)
Example #14
0
def main():
    global best_acc, mean, std, scale

    args = parse_args()
    args.mean, args.std, args.scale = mean, std, scale
    args.is_master = args.local_rank == 0

    if args.deterministic:
        cudnn.deterministic = True
        torch.manual_seed(0)
        random.seed(0)
        np.random.seed(0)

    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1

    if args.is_master:
        print("opt_level = {}".format(args.opt_level))
        print("keep_batchnorm_fp32 = {}".format(args.keep_batchnorm_fp32),
              type(args.keep_batchnorm_fp32))
        print("loss_scale = {}".format(args.loss_scale), type(args.loss_scale))
        print("\nCUDNN VERSION: {}\n".format(torch.backends.cudnn.version()))
        print(f"Distributed Training Enabled: {args.distributed}")

    args.gpu = 0
    args.world_size = 1

    if args.distributed:
        args.gpu = args.local_rank
        torch.cuda.set_device(args.gpu)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()
        # Scale learning rate based on global batch size
        # args.lr *= args.batch_size * args.world_size / 256

    assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled."

    # create model
    model = models.ResNet18(args.num_patches, args.num_angles)

    if args.sync_bn:
        import apex
        print("using apex synced BN")
        model = apex.parallel.convert_syncbn_model(model)

    model = model.cuda()
    optimiser = Ranger(model.parameters(), lr=args.lr)
    criterion = nn.CrossEntropyLoss().cuda()

    # Initialize Amp.  Amp accepts either values or strings for the optional override arguments,
    # for convenient interoperation with argparse.
    model, optimiser = amp.initialize(
        model,
        optimiser,
        opt_level=args.opt_level,
        keep_batchnorm_fp32=args.keep_batchnorm_fp32,
        loss_scale=args.loss_scale)

    # For distributed training, wrap the model with apex.parallel.DistributedDataParallel.
    # This must be done AFTER the call to amp.initialize.  If model = DDP(model) is called
    # before model, ... = amp.initialize(model, ...), the call to amp.initialize may alter
    # the types of model's parameters in a way that disrupts or destroys DDP's allreduce hooks.
    if args.distributed:
        model = DDP(model, delay_allreduce=True)

    # Optionally resume from a checkpoint
    if args.resume:
        # Use a local scope to avoid dangling references
        def resume():
            global best_acc
            if os.path.isfile(args.resume):
                print("=> loading checkpoint '{}'".format(args.resume))
                checkpoint = torch.load(
                    args.resume,
                    map_location=lambda storage, loc: storage.cuda(args.gpu))
                args.start_epoch = checkpoint['epoch']
                best_acc = checkpoint['best_acc']
                args.poisson_rate = checkpoint["poisson_rate"]
                model.load_state_dict(checkpoint['state_dict'])
                optimiser.load_state_dict(checkpoint['optimiser'])
                print("=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint['epoch']))
            else:
                print("=> no checkpoint found at '{}'".format(args.resume))

        resume()

    # Data loading code
    train_dir = os.path.join(args.data, 'train')
    val_dir = os.path.join(args.data, 'val')

    crop_size = 225
    val_size = 256

    imagenet_train = datasets.ImageFolder(
        root=train_dir,
        transform=transforms.Compose([
            transforms.RandomResizedCrop(crop_size),
        ]))
    train_dataset = SSLTrainDataset(imagenet_train, args.num_patches,
                                    args.num_angles, args.poisson_rate)
    imagenet_val = datasets.ImageFolder(root=val_dir,
                                        transform=transforms.Compose([
                                            transforms.Resize(val_size),
                                            transforms.CenterCrop(crop_size),
                                        ]))
    val_dataset = SSLValDataset(imagenet_val, args.num_patches,
                                args.num_angles)

    train_sampler = None
    val_sampler = None
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
        val_sampler = torch.utils.data.distributed.DistributedSampler(
            val_dataset)

    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=(train_sampler is None),
                              num_workers=args.workers,
                              pin_memory=True,
                              sampler=train_sampler,
                              collate_fn=fast_collate)

    val_loader = DataLoader(val_dataset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.workers,
                            pin_memory=True,
                            sampler=val_sampler,
                            collate_fn=fast_collate)

    if args.evaluate:
        val_loss, val_acc = apex_validate(val_loader, model, criterion, args)
        utils.logger.info(f"Val Loss = {val_loss}, Val Accuracy = {val_acc}")
        return

    # Create dir to save model and command-line args
    if args.is_master:
        model_dir = time.ctime().replace(" ", "_").replace(":", "_")
        model_dir = os.path.join("models", model_dir)
        os.makedirs(model_dir, exist_ok=True)
        with open(os.path.join(model_dir, "args.json"), "w") as f:
            json.dump(args.__dict__, f, indent=2)
        writer = SummaryWriter()

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)

        # train for one epoch
        train_loss, train_acc = apex_train(train_loader, model, criterion,
                                           optimiser, args, epoch)

        # evaluate on validation set
        val_loss, val_acc = apex_validate(val_loader, model, criterion, args)

        if (epoch + 1) % args.learn_prd == 0:
            utils.adj_poisson_rate(train_loader, args)

        # remember best Acc and save checkpoint
        if args.is_master:
            is_best = val_acc > best_acc
            best_acc = max(val_acc, best_acc)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'best_acc': best_acc,
                    'optimiser': optimiser.state_dict(),
                    "poisson_rate": args.poisson_rate
                }, is_best, model_dir)

            writer.add_scalars("Loss", {
                "train_loss": train_loss,
                "val_loss": val_loss
            }, epoch)
            writer.add_scalars("Accuracy", {
                "train_acc": train_acc,
                "val_acc": val_acc
            }, epoch)
            writer.add_scalar("Poisson_Rate", train_loader.dataset.pdist.rate,
                              epoch)
Example #15
0
def main():
    args = parser.parse_args()

    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1
        if args.distributed and args.num_gpu > 1:
            print(
                'Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.'
            )
            args.num_gpu = 1

    args.device = 'cuda:0'
    args.world_size = 1
    r = -1
    if args.distributed:
        args.device = 'cuda:%d' % args.local_rank
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()
        r = torch.distributed.get_rank()

    if args.distributed:
        print(
            'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
            % (r, args.world_size))
    else:
        print('Training with a single process on %d GPUs.' % args.num_gpu)

    # FIXME seed handling for multi-process distributed?
    torch.manual_seed(args.seed)

    output_dir = ''
    if args.local_rank == 0:
        if args.output:
            output_base = args.output
        else:
            output_base = './output'
        exp_name = '-'.join([
            datetime.now().strftime("%Y%m%d-%H%M%S"), args.model,
            str(args.img_size)
        ])
        output_dir = get_outdir(output_base, 'train', exp_name)

    model = create_model(args.model,
                         pretrained=args.pretrained,
                         num_classes=args.num_classes,
                         drop_rate=args.drop,
                         global_pool=args.gp,
                         bn_tf=args.bn_tf,
                         bn_momentum=args.bn_momentum,
                         bn_eps=args.bn_eps,
                         checkpoint_path=args.initial_checkpoint)

    print('Model %s created, param count: %d' %
          (args.model, sum([m.numel() for m in model.parameters()])))

    data_config = resolve_data_config(model,
                                      args,
                                      verbose=args.local_rank == 0)

    # optionally resume from a checkpoint
    start_epoch = 0
    optimizer_state = None
    if args.resume:
        optimizer_state, start_epoch = resume_checkpoint(
            model, args.resume, args.start_epoch)

    if args.num_gpu > 1:
        if args.amp:
            print(
                'Warning: AMP does not work well with nn.DataParallel, disabling. '
                'Use distributed mode for multi-GPU AMP.')
            args.amp = False
        model = nn.DataParallel(model,
                                device_ids=list(range(args.num_gpu))).cuda()
    else:
        model.cuda()

    optimizer = create_optimizer(args, model)
    if optimizer_state is not None:
        optimizer.load_state_dict(optimizer_state)

    if has_apex and args.amp:
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
        use_amp = True
        print('AMP enabled')
    else:
        use_amp = False
        print('AMP disabled')

    if args.distributed:
        model = DDP(model, delay_allreduce=True)

    lr_scheduler, num_epochs = create_scheduler(args, optimizer)
    if start_epoch > 0:
        lr_scheduler.step(start_epoch)
    if args.local_rank == 0:
        print('Scheduled epochs: ', num_epochs)

    train_dir = os.path.join(args.data, 'train')
    if not os.path.exists(train_dir):
        print('Error: training folder does not exist at: %s' % train_dir)
        exit(1)
    dataset_train = Dataset(train_dir)

    loader_train = create_loader(
        dataset_train,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        is_training=True,
        use_prefetcher=True,
        rand_erase_prob=args.reprob,
        rand_erase_pp=args.repp,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
    )

    eval_dir = os.path.join(args.data, 'validation')
    if not os.path.isdir(eval_dir):
        print('Error: validation folder does not exist at: %s' % eval_dir)
        exit(1)
    dataset_eval = Dataset(eval_dir)

    loader_eval = create_loader(
        dataset_eval,
        input_size=data_config['input_size'],
        batch_size=4 * args.batch_size,
        is_training=False,
        use_prefetcher=True,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
    )

    if args.smoothing:
        train_loss_fn = LabelSmoothingCrossEntropy(
            smoothing=args.smoothing).cuda()
        validate_loss_fn = nn.CrossEntropyLoss().cuda()
    else:
        train_loss_fn = nn.CrossEntropyLoss().cuda()
        validate_loss_fn = train_loss_fn

    eval_metric = args.eval_metric
    saver = None
    if output_dir:
        decreasing = True if eval_metric == 'loss' else False
        saver = CheckpointSaver(checkpoint_dir=output_dir,
                                decreasing=decreasing)
    best_metric = None
    best_epoch = None
    try:
        for epoch in range(start_epoch, num_epochs):
            if args.distributed:
                loader_train.sampler.set_epoch(epoch)

            train_metrics = train_epoch(epoch,
                                        model,
                                        loader_train,
                                        optimizer,
                                        train_loss_fn,
                                        args,
                                        lr_scheduler=lr_scheduler,
                                        saver=saver,
                                        output_dir=output_dir,
                                        use_amp=use_amp)

            eval_metrics = validate(model, loader_eval, validate_loss_fn, args)

            if lr_scheduler is not None:
                lr_scheduler.step(epoch, eval_metrics[eval_metric])

            update_summary(epoch,
                           train_metrics,
                           eval_metrics,
                           os.path.join(output_dir, 'summary.csv'),
                           write_header=best_metric is None)

            if saver is not None:
                # save proper checkpoint with eval metric
                best_metric, best_epoch = saver.save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'arch': args.model,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'args': args,
                    },
                    epoch=epoch + 1,
                    metric=eval_metrics[eval_metric])

    except KeyboardInterrupt:
        pass
    if best_metric is not None:
        print('*** Best metric: {0} (epoch {1})'.format(
            best_metric, best_epoch))
Example #16
0
def main():
    global best_prec1, args

    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1

    args.gpu = 0
    if args.distributed:
        args.gpu = args.local_rank % torch.cuda.device_count()

    if args.distributed:
        torch.cuda.set_device(args.gpu)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')

    if args.fp16:
        assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled."

    if args.static_loss_scale != 1.0:
        if not args.fp16:
            print(
                "Warning:  if --fp16 is not used, static_loss_scale will be ignored."
            )

    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True)
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch]()

    model = model.cuda()
    if args.fp16:
        model = network_to_half(model)
    if args.distributed:
        # shared param turns off bucketing in DDP, for lower latency runs this can improve perf
        model = DDP(model, shared_param=True)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    if args.fp16:
        optimizer = FP16_Optimizer(optimizer,
                                   static_loss_scale=args.static_loss_scale,
                                   dynamic_loss_scale=args.dynamic_loss_scale)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(
                args.resume,
                map_location=lambda storage, loc: storage.cuda(args.gpu))
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')

    if (args.arch == "inception_v3"):
        crop_size = 299
        val_size = 320  # I chose this value arbitrarily, we can adjust.
    else:
        crop_size = 224
        val_size = 256

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(crop_size),
            transforms.RandomHorizontalFlip(),
            # transforms.ToTensor(), Too slow
            # normalize,
        ]))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler,
                                               collate_fn=fast_collate)

    val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
        valdir,
        transforms.Compose([
            transforms.Resize(val_size),
            transforms.CenterCrop(crop_size),
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True,
                                             collate_fn=fast_collate)

    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)
        if args.prof:
            break
        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion)

        # remember best prec@1 and save checkpoint
        if args.local_rank == 0:
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                    'optimizer': optimizer.state_dict(),
                }, is_best)
Example #17
0
def main():
    global best_top1, best_top5

    start_epoch = args.start_epoch  # start from epoch 0 or last checkpoint epoch

    if not os.path.isdir(args.checkpoint):
        mkdir_p(args.checkpoint)

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    validdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ])

    val_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ])

    train_data = imagenet_lmdb_dataset(traindir, transform=train_transform)
    valid_data = imagenet_lmdb_dataset(validdir, transform=val_transform)

    train_sampler = torch.utils.data.distributed.DistributedSampler(train_data)

    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.train_batch,
                                               shuffle=(train_sampler is None),
                                               pin_memory=True,
                                               num_workers=8,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(valid_data,
                                             batch_size=args.test_batch,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=8)

    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True)
    elif args.arch.startswith('resnext'):
        model = models.__dict__[args.arch](
            baseWidth=args.base_width,
            cardinality=args.cardinality,
        )
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch]()

    if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
        model.features = DDP(model.features)
        model.cuda()
    else:
        model = model.cuda()
        model = DDP(model, delay_allreduce=True)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()
    if args.optimizer.lower() == 'sgd':
        optimizer = optim.SGD(model.parameters(),
                              lr=args.lr,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)
    elif args.optimizer.lower() == 'adamw':
        optimizer = AdamW(model.parameters(),
                          lr=args.lr,
                          betas=(args.beta1, args.beta2),
                          weight_decay=args.weight_decay,
                          warmup=0)
    elif args.optimizer.lower() == 'radam':
        optimizer = RAdam(model.parameters(),
                          lr=args.lr,
                          betas=(args.beta1, args.beta2),
                          weight_decay=args.weight_decay)
    elif args.optimizer.lower() == 'lsadam':
        optimizer = LSAdamW(model.parameters(),
                            lr=args.lr * ((1. + 4. * args.sigma)**(0.25)),
                            betas=(args.beta1, args.beta2),
                            weight_decay=args.weight_decay,
                            sigma=args.sigma)
    elif args.optimizer.lower() == 'lsradam':
        sigma = 0.1
        optimizer = LSRAdam(model.parameters(),
                            lr=args.lr * ((1. + 4. * args.sigma)**(0.25)),
                            betas=(args.beta1, args.beta2),
                            weight_decay=args.weight_decay,
                            sigma=args.sigma)
    elif args.optimizer.lower() == 'srsgd':
        iter_count = 1
        optimizer = SGD_Adaptive(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay,
                                 iter_count=iter_count,
                                 restarting_iter=args.restart_schedule[0])
    elif args.optimizer.lower() == 'sradam':
        iter_count = 1
        optimizer = SRNAdam(model.parameters(),
                            lr=args.lr,
                            betas=(args.beta1, args.beta2),
                            iter_count=iter_count,
                            weight_decay=args.weight_decay,
                            restarting_iter=args.restart_schedule[0])
    elif args.optimizer.lower() == 'sradamw':
        iter_count = 1
        optimizer = SRAdamW(model.parameters(),
                            lr=args.lr,
                            betas=(args.beta1, args.beta2),
                            iter_count=iter_count,
                            weight_decay=args.weight_decay,
                            warmup=0,
                            restarting_iter=args.restart_schedule[0])
    elif args.optimizer.lower() == 'srradam':
        #NOTE: need to double-check this
        iter_count = 1
        optimizer = SRRAdam(model.parameters(),
                            lr=args.lr,
                            betas=(args.beta1, args.beta2),
                            iter_count=iter_count,
                            weight_decay=args.weight_decay,
                            warmup=0,
                            restarting_iter=args.restart_schedule[0])

    schedule_index = 1
    # Resume
    title = 'ImageNet-' + args.arch
    if args.resume:
        # Load checkpoint.
        print('==> Resuming from checkpoint..')
        assert os.path.isfile(
            args.resume), 'Error: no checkpoint directory found!'
        # args.checkpoint = os.path.dirname(args.resume)
        # checkpoint = torch.load(args.resume, map_location = lambda storage, loc: storage.cuda(args.local_rank))
        checkpoint = torch.load(args.resume, map_location=torch.device('cpu'))
        best_top1 = checkpoint['best_top1']
        best_top5 = checkpoint['best_top5']
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        if args.optimizer.lower() == 'srsgd' or args.optimizer.lower(
        ) == 'sradam' or args.optimizer.lower(
        ) == 'sradamw' or args.optimizer.lower() == 'srradam':
            iter_count = optimizer.param_groups[0]['iter_count']
        schedule_index = checkpoint['schedule_index']
        state['lr'] = optimizer.param_groups[0]['lr']
        if args.checkpoint == args.resume:
            logger = LoggerDistributed(os.path.join(args.checkpoint,
                                                    'log.txt'),
                                       rank=args.local_rank,
                                       title=title,
                                       resume=True)
        else:
            logger = LoggerDistributed(os.path.join(args.checkpoint,
                                                    'log.txt'),
                                       rank=args.local_rank,
                                       title=title)
            if args.local_rank == 0:
                logger.set_names([
                    'Learning Rate', 'Train Loss', 'Valid Loss', 'Train Top1',
                    'Valid Top1', 'Train Top5', 'Valid Top5'
                ])
    else:
        logger = LoggerDistributed(os.path.join(args.checkpoint, 'log.txt'),
                                   rank=args.local_rank,
                                   title=title)
        if args.local_rank == 0:
            logger.set_names([
                'Learning Rate', 'Train Loss', 'Valid Loss', 'Train Top1',
                'Valid Top1', 'Train Top5', 'Valid Top5'
            ])

    if args.local_rank == 0:
        logger.file.write('    Total params: %.2fM' %
                          (sum(p.numel()
                               for p in model.parameters()) / 1000000.0))

    if args.evaluate:
        if args.local_rank == 0:
            logger.file.write('\nEvaluation only')
        test_loss, test_top1, test_top5 = test(val_loader, model, criterion,
                                               start_epoch, use_cuda, logger)
        if args.local_rank == 0:
            logger.file.write(
                ' Test Loss:  %.8f, Test Top1:  %.2f, Test Top5: %.2f' %
                (test_loss, test_top1, test_top5))
        return

    # Train and val
    for epoch in range(start_epoch, args.epochs):
        # Shuffle the sampler.
        train_loader.sampler.set_epoch(epoch + args.manualSeed)

        if args.optimizer.lower() == 'srsgd':
            if epoch in args.schedule:
                optimizer = SGD_Adaptive(
                    model.parameters(),
                    lr=args.lr * (args.gamma**schedule_index),
                    weight_decay=args.weight_decay,
                    iter_count=iter_count,
                    restarting_iter=args.restart_schedule[schedule_index])
                schedule_index += 1

        elif args.optimizer.lower() == 'sradam':
            if epoch in args.schedule:
                optimizer = SRNAdam(
                    model.parameters(),
                    lr=args.lr * (args.gamma**schedule_index),
                    betas=(args.beta1, args.beta2),
                    iter_count=iter_count,
                    weight_decay=args.weight_decay,
                    restarting_iter=args.restart_schedule[schedule_index])
                schedule_index += 1

        elif args.optimizer.lower() == 'sradamw':
            if epoch in args.schedule:
                optimizer = SRAdamW(
                    model.parameters(),
                    lr=args.lr * (args.gamma**schedule_index),
                    betas=(args.beta1, args.beta2),
                    iter_count=iter_count,
                    weight_decay=args.weight_decay,
                    warmup=0,
                    restarting_iter=args.restart_schedule[schedule_index])
                schedule_index += 1

        elif args.optimizer.lower() == 'srradam':
            if epoch in args.schedule:
                optimizer = SRRAdam(
                    model.parameters(),
                    lr=args.lr * (args.gamma**schedule_index),
                    betas=(args.beta1, args.beta2),
                    iter_count=iter_count,
                    weight_decay=args.weight_decay,
                    warmup=0,
                    restarting_iter=args.restart_schedule[schedule_index])
                schedule_index += 1

        else:
            adjust_learning_rate(optimizer, epoch)

        if args.local_rank == 0:
            logger.file.write('\nEpoch: [%d | %d] LR: %f' %
                              (epoch + 1, args.epochs, state['lr']))

        if args.optimizer.lower() == 'srsgd' or args.optimizer.lower(
        ) == 'sradam' or args.optimizer.lower(
        ) == 'sradamw' or args.optimizer.lower() == 'srradam':
            train_loss, train_top1, train_top5, iter_count = train(
                train_loader, model, criterion, optimizer, epoch, use_cuda,
                logger)
        else:
            train_loss, train_top1, train_top5 = train(train_loader, model,
                                                       criterion, optimizer,
                                                       epoch, use_cuda, logger)

        test_loss, test_top1, test_top5 = test(val_loader, model, criterion,
                                               epoch, use_cuda, logger)

        # append logger file
        if args.local_rank == 0:
            logger.append([
                state['lr'], train_loss, test_loss, train_top1, test_top1,
                train_top5, test_top5
            ])
            writer.add_scalars('train_loss', {args.model_name: train_loss},
                               epoch)
            writer.add_scalars('test_loss', {args.model_name: test_loss},
                               epoch)
            writer.add_scalars('train_top1', {args.model_name: train_top1},
                               epoch)
            writer.add_scalars('test_top1', {args.model_name: test_top1},
                               epoch)
            writer.add_scalars('train_top5', {args.model_name: train_top5},
                               epoch)
            writer.add_scalars('test_top5', {args.model_name: test_top5},
                               epoch)

        # save model
        is_best = test_top1 > best_top1
        best_top1 = max(test_top1, best_top1)
        best_top5 = max(test_top5, best_top5)
        if args.local_rank == 0:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'schedule_index': schedule_index,
                    'state_dict': model.state_dict(),
                    'top1': test_top1,
                    'top5': test_top5,
                    'best_top1': best_top1,
                    'best_top5': best_top5,
                    'optimizer': optimizer.state_dict(),
                },
                is_best,
                epoch,
                checkpoint=args.checkpoint)

    if args.local_rank == 0:
        logger.file.write('Best top1: %f' % best_top1)
        logger.file.write('Best top5: %f' % best_top5)
        logger.close()
        logger.plot()
        savefig(os.path.join(args.checkpoint, 'log.eps'))
        print('Best top1: %f' % best_top1)
        print('Best top5: %f' % best_top5)
        with open("./all_results_imagenet.txt", "a") as f:
            fcntl.flock(f, fcntl.LOCK_EX)
            f.write("%s\n" % args.checkpoint)
            f.write("best_top1 %f, best_top5 %f\n\n" % (best_top1, best_top5))
            fcntl.flock(f, fcntl.LOCK_UN)
Example #18
0
def train():
    if args.local_rank == 0:
        logger.info('Initializing....')
    cudnn.enable = True
    cudnn.benchmark = True
    # torch.manual_seed(1)
    # torch.cuda.manual_seed(1)
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1
    args.gpu = 0
    args.world_size = 1
    if args.distributed:
        args.gpu = args.local_rank
        torch.cuda.set_device(args.gpu)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()

    if args.local_rank == 0:
        write_config_into_log(cfg)

    if args.local_rank == 0:
        logger.info('Building model......')
    if cfg.pretrained:
        model = make_model(cfg)
        model.load_param(cfg)
        if args.local_rank == 0:
            logger.info('Loaded pretrained model from {0}'.format(
                cfg.pretrained))
    else:
        model = make_model(cfg)

    if args.sync_bn:
        if args.local_rank == 0: logging.info("using apex synced BN")
        model = convert_syncbn_model(model)
    model.cuda()
    if args.distributed:
        # By default, apex.parallel.DistributedDataParallel overlaps communication with
        # computation in the backward pass.
        # delay_allreduce delays all communication to the end of the backward pass.
        model = DistributedDataParallel(model, delay_allreduce=True)
    else:
        model = torch.nn.DataParallel(model)

    optimizer = torch.optim.Adam(
        [{
            'params': model.module.base.parameters(),
            'lr': cfg.get_lr(0)[0]
        }, {
            'params': model.module.classifiers.parameters(),
            'lr': cfg.get_lr(0)[1]
        }],
        weight_decay=cfg.weight_decay)

    celoss = nn.CrossEntropyLoss().cuda()
    softloss = SoftLoss()
    sp_kd_loss = SP_KD_Loss()
    criterions = [celoss, softloss, sp_kd_loss]

    cfg.batch_size = cfg.batch_size // args.world_size
    cfg.num_workers = cfg.num_workers // args.world_size
    train_loader, val_loader = make_dataloader(cfg)

    if args.local_rank == 0:
        logger.info('Begin training......')
    for epoch in range(cfg.start_epoch, cfg.max_epoch):
        train_one_epoch(train_loader, val_loader, model, criterions, optimizer,
                        epoch, cfg)

        total_acc = test(cfg, val_loader, model, epoch)
        if args.local_rank == 0:
            with open(cfg.test_log, 'a+') as f:
                f.write('Epoch {0}: Acc is {1:.4f}\n'.format(epoch, total_acc))
            torch.save(obj=model.state_dict(),
                       f=os.path.join(
                           cfg.snapshot_dir,
                           'ep{}_acc{:.4f}.pth'.format(epoch, total_acc)))
            logger.info('Model saved')
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )
    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
        "bert-base-multilingual-cased, bert-base-chinese.")
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        help="The name of the task to train.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )

    ## Other parameters
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=100,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=8,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3,
                        type=int,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    # parser.add_argument('--gpuid', type=int, default=-1,help='The gpu id to use')
    args = parser.parse_args()

    processors = {
        "cola": ColaProcessor,
        "mnli": MnliProcessor,
        "snli": SnliProcessor,
        "mrpc": MrpcProcessor,
        "sst": SstProcessor,
        "twitter": TwitterProcessor,
    }

    num_labels_task = {
        "cola": 2,
        "mnli": 3,
        "snli": 3,
        "mrpc": 2,
        "sst": 2,
        "twitter": 2,
    }

    if args.local_rank == -1 or args.no_cuda:
        if not args.no_cuda:
            # device = torch.device("cuda",args.gpuid)
            # torch.cuda.set_device(args.gpuid)
            dummy = torch.cuda.FloatTensor(1)
        else:
            device = torch.device("cpu")
        n_gpu = 1
        # n_gpu = torch.cuda.device_count()
    else:
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
        if args.fp16:
            # logger.info("16-bits training currently not supported in distributed training")
            args.fp16 = False  # (see https://github.com/pytorch/pytorch/pull/13496)
    # logger.info("device %s n_gpu %d distributed training %r", device, n_gpu, bool(args.local_rank != -1))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = int(args.train_batch_size /
                                args.gradient_accumulation_steps)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    # if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train:
    #     raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
    os.makedirs(args.output_dir, exist_ok=True)

    task_name = args.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()
    num_labels = num_labels_task[task_name]
    label_list = processor.get_labels()

    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)

    train_examples = None
    num_train_steps = None
    if args.do_train:
        train_examples = processor.get_train_examples(args.data_dir)
        num_train_steps = int(
            len(train_examples) / args.train_batch_size /
            args.gradient_accumulation_steps * args.num_train_epochs)

    if args.do_eval and (args.local_rank == -1
                         or torch.distributed.get_rank() == 0):
        eval_examples = processor.get_dev_examples(args.data_dir)
        eval_features = convert_examples_to_features(eval_examples, label_list,
                                                     args.max_seq_length,
                                                     tokenizer)
        # logger.info("***** Running evaluation *****")
        # logger.info("  Num examples = %d", len(eval_examples))
        # logger.info("  Batch size = %d", args.eval_batch_size)
        all_input_ids = torch.tensor([f.input_ids for f in eval_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in eval_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features],
                                       dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in eval_features],
                                     dtype=torch.long)
        eval_data = TensorDataset(all_input_ids, all_input_mask,
                                  all_segment_ids, all_label_ids)
        # Run prediction for full data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data,
                                     sampler=eval_sampler,
                                     batch_size=args.train_batch_size)

    # Prepare model
    model = BertForSequenceClassification.from_pretrained(
        args.bert_model,
        cache_dir=PYTORCH_PRETRAINED_BERT_CACHE /
        'distributed_{}'.format(args.local_rank),
        num_labels=num_labels)

    if args.fp16:
        model.half()
    model.cuda()
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    t_total = num_train_steps
    if args.local_rank != -1:
        t_total = t_total // torch.distributed.get_world_size()
    if args.fp16:
        try:
            from apex.optimizers import FP16_Optimizer
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer(optimizer,
                                       static_loss_scale=args.loss_scale)

    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=t_total)

    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0
    if args.do_train:
        train_features = convert_examples_to_features(train_examples,
                                                      label_list,
                                                      args.max_seq_length,
                                                      tokenizer)
        # logger.info("***** Running training *****")
        # logger.info("  Num examples = %d", len(train_examples))
        # logger.info("  Batch size = %d", args.train_batch_size)
        # logger.info("  Num steps = %d", num_train_steps)
        all_input_ids = torch.tensor([f.input_ids for f in train_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in train_features],
                                       dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in train_features],
                                     dtype=torch.long)
        train_data = TensorDataset(all_input_ids, all_input_mask,
                                   all_segment_ids, all_label_ids)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

        best_eval_acc = 0.0
        for epoch in trange(int(args.num_train_epochs), desc="Epoch"):
            # for epoch in range(int(args.num_train_epochs)):
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration")):
                # for step, batch in enumerate(train_dataloader):
                model.train()
                batch = tuple(t.cuda() for t in batch)
                input_ids, input_mask, segment_ids, label_ids = batch
                loss = model(input_ids, segment_ids, input_mask, label_ids)
                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()

                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    # modify learning rate with special warm up BERT uses
                    lr_this_step = args.learning_rate * warmup_linear(
                        global_step / t_total, args.warmup_proportion)
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1
                    # if epoch>0:
                    # logits_all,eval_accuracy=do_evaluation(model,eval_dataloader,args,is_training=False)
                    # print(eval_accuracy)
            # logits_all,eval_accuracy=do_evaluation(model,eval_dataloader,args,is_training=False)
            # if best_eval_acc<eval_accuracy:
            #     best_eval_acc=eval_accuracy
            #     print(eval_accuracy)
            model_save_dir = os.path.join(args.output_dir, f'model{epoch}')
            os.makedirs(model_save_dir, exist_ok=True)
            torch.save(model.state_dict(),
                       os.path.join(model_save_dir, f"pytorch_model.bin"))
        print('Best eval acc:', best_eval_acc)
Example #20
0
def main():
    parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
    parser.add_argument(
        'data',
        metavar='DIR',
        nargs='*',
        help='path(s) to dataset (if one path is provided, it is assumed\n' +
        'to have subdirectories named "train" and "val"; alternatively,\n' +
        'train and val paths can be specified directly by providing both paths as arguments)'
    )
    parser.add_argument('-a',
                        '--arch',
                        metavar='ARCH',
                        default='resnet18',
                        choices=model_names,
                        help='model architecture: ' + ' | '.join(model_names) +
                        ' (default: resnet18)')
    parser.add_argument('-j',
                        '--workers',
                        default=4,
                        type=int,
                        metavar='N',
                        help='number of data loading workers (default: 4)')
    parser.add_argument('--epochs',
                        default=15,
                        type=int,
                        metavar='N',
                        help='number of total epochs to run')
    parser.add_argument('--start-epoch',
                        default=0,
                        type=int,
                        metavar='N',
                        help='manual epoch number (useful on restarts)')
    parser.add_argument(
        '-bs',
        '--batch-size',
        default=128,
        type=int,
        metavar='N',
        help='batch size for descriptor generation (default: 128)')
    parser.add_argument('-lr',
                        '--learning-rate',
                        default=0.1,
                        type=float,
                        metavar='LR',
                        help='initial learning rate',
                        dest='lr')
    parser.add_argument('--momentum',
                        default=0.9,
                        type=float,
                        metavar='M',
                        help='momentum')
    parser.add_argument('--wd',
                        '--weight-decay',
                        default=1e-4,
                        type=float,
                        metavar='W',
                        help='weight decay (default: 1e-4)',
                        dest='weight_decay')
    parser.add_argument('-p',
                        '--print-freq',
                        default=10,
                        type=int,
                        metavar='N',
                        help='print frequency (default: 10)')
    parser.add_argument('--fp16',
                        action='store_true',
                        help='Run model fp16 mode.')
    parser.add_argument('--dali_cpu',
                        action='store_true',
                        help='Runs CPU based version of DALI pipeline.')
    parser.add_argument(
        '--static-loss-scale',
        type=float,
        default=1,
        help=
        'Static loss scale, positive power of 2 values can improve fp16 convergence.'
    )
    parser.add_argument(
        '--dynamic-loss-scale',
        action='store_true',
        help='Use dynamic loss scaling.  If supplied, this argument supersedes '
        + '--static-loss-scale.')
    parser.add_argument('--prof',
                        dest='prof',
                        action='store_true',
                        help='Only run 10 iterations for profiling.')
    parser.add_argument('-t',
                        '--test',
                        action='store_true',
                        help='Launch test mode with preset arguments')

    parser.add_argument("--local_rank", default=0, type=int)
    # added
    parser.add_argument(
        '-ir',
        '--imbalance-ratio',
        type=int,
        default=1,
        metavar='N',
        help=
        'ratio of 0..499 to 500..999 labels in the training dataset drawn from uniform distribution'
    )
    parser.add_argument(
        '-nr',
        '--noisy-ratio',
        type=float,
        default=0.0,
        metavar='N',
        help=
        'ratio of noisy(random) labels in the training dataset drawn from uniform distribution'
    )
    parser.add_argument(
        '-ens',
        '--ensemble-size',
        type=int,
        default=1,
        metavar='E',
        help='defines size of ensemble or, by default, no ensemble if = 1')
    parser.add_argument('-e',
                        '--ensemble-index',
                        type=int,
                        default=0,
                        metavar='E',
                        help='defines index of ensemble')
    parser.add_argument('--save-folder',
                        default='../local_data/ImageNet',
                        type=str,
                        help='dir to save data')
    parser.add_argument('-r',
                        '--run-folder',
                        default='run99',
                        type=str,
                        help='dir to save run')

    args = parser.parse_args()
    cudnn.benchmark = True

    # test mode, use default args for sanity test
    if args.test:
        args.fp16 = False
        args.epochs = 1
        args.start_epoch = 0
        args.arch = 'resnet18'
        args.batch_size = 256
        args.data = []
        args.prof = True
        args.data.append('/data/imagenet/train-jpeg/')
        args.data.append('/data/imagenet/val-jpeg/')

    if not len(args.data):
        raise Exception("error: too few data arguments")

    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1

    args.gpu = 0
    args.world_size = 1
    if args.distributed:
        args.gpu = args.local_rank % torch.cuda.device_count()
        torch.cuda.set_device(args.gpu)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()

    if args.fp16:
        assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled."

    if args.static_loss_scale != 1.0:
        if not args.fp16:
            print(
                "Warning:  if --fp16 is not used, static_loss_scale will be ignored."
            )

    # Data loading code
    if len(args.data) == 1:
        train_dir = os.path.join(args.data[0], 'train')
        val_dir = os.path.join(args.data[0], 'val')
    else:
        train_dir = args.data[0]
        val_dir = args.data[1]

    if (args.arch == "inception_v3"):
        crop_size = 299
        val_size = 320  # I chose this value arbitrarily, we can adjust.
    else:
        crop_size = 224
        val_size = 256

    if not os.path.isdir(args.save_folder):
        os.mkdir(args.save_folder)
    # make a separate folder for experiment
    run_folder = '{}/{}'.format(args.save_folder, args.run_folder)
    if not os.path.isdir(run_folder):
        os.mkdir(run_folder)
        os.mkdir(run_folder + '/data')
        os.mkdir(run_folder + '/checkpoint')
        os.mkdir(run_folder + '/descr')

    # lists for full datasets
    orig_train_list_file = '{}/{}'.format(args.save_folder,
                                          'processed/train_list.txt')
    val_list_file = '{}/{}'.format(args.save_folder, 'processed/val_list.txt')
    if (args.imbalance_ratio
            == 1) and (args.noisy_ratio
                       == 0.0):  # use original training dataset
        full_train_list_file = orig_train_list_file
    else:
        R = 1000  # number of classes
        distorted_train_list_file = '{}/{}/full_train_list_ir_{}_nr_{}.txt'.format(
            args.save_folder, args.run_folder, args.imbalance_ratio,
            args.noisy_ratio)
        full_train_list_file = distorted_train_list_file
        if not os.path.isfile(distorted_train_list_file):
            with open(orig_train_list_file) as f:
                lines = f.readlines()
            full_train_list = [x.strip().split() for x in lines]
            R = 1000  # number of classes
            class_index = random.sample(
                range(R),
                R >> 1)  # randomly sample half of classes which we will modify
            # class imbalance
            if args.imbalance_ratio != 1:
                distorted_list = list()
                for c in range(R):
                    c_list = [
                        x for i, x in enumerate(full_train_list)
                        if int(x[1]) == c
                    ]
                    A = len(c_list)
                    # select indices we will evict from the list to distort dataset
                    selected_index = list()
                    if c in class_index:
                        selected_index = random.sample(
                            range(A),
                            round(A * (args.imbalance_ratio - 1) /
                                  args.imbalance_ratio))
                    #
                    distorted_list.extend([
                        i for j, i in enumerate(c_list)
                        if j not in selected_index
                    ])
                    print(c, A, len(selected_index), len(distorted_list))
            else:
                distorted_list = full_train_list
            #
            print('Imbalance =', len(distorted_list), 'selected from original',
                  len(full_train_list))
            # noisy labels
            if args.noisy_ratio != 0.0:
                P = len(distorted_list)
                K = int(P * args.noisy_ratio)
                print('Noisy =', K, ' out of', P)
                noisy_index = random.sample(range(P), K)
                for j, i in enumerate(distorted_list):  # SHOULD BE SLOW!!!
                    if j in noisy_index:
                        distorted_list[j][1] = random.randint(0, R - 1)
            #
            with open(distorted_train_list_file, "w") as f:
                for item in distorted_list:
                    f.write("%s %s\n" % (item[0], item[1]))

    # initially we use unsupervised pretraining
    unsup_prefix = 'unsup_'
    refer_prefix = ''
    unsup_postfix = '{}batch_0_ir_{}_nr_{}_sub_{}_aug_{}'.format(
        unsup_prefix, args.imbalance_ratio, args.noisy_ratio, 'none', 'none')
    refer_postfix = '{}batch_0_ir_{}_nr_{}_sub_{}_aug_{}'.format(
        refer_prefix, args.imbalance_ratio, args.noisy_ratio, 'none', 'none')
    train_list_file = '{}/{}/train_list_{}.txt'.format(args.save_folder,
                                                       args.run_folder,
                                                       unsup_postfix)
    index_list_file = '{}/{}/index_list_{}.npy'.format(args.save_folder,
                                                       args.run_folder,
                                                       unsup_postfix)
    if os.path.isfile(train_list_file) and os.path.isfile(index_list_file):
        print('Train list exists =', train_list_file)
        with open(train_list_file) as f:
            train_list = f.readlines()
    else:
        with open(full_train_list_file) as f:
            lines = f.readlines()
        lines = [l.strip() for l in lines]
        index_list = range(len(lines))
        train_list = lines
        #
        np.save(index_list_file, index_list)
        with open(train_list_file, "w") as f:
            f.write("\n".join(train_list))
        print('Train list files created =', index_list_file, train_list_file)

    pipe = HybridTrainPipe(batch_size=args.batch_size,
                           num_threads=args.workers,
                           device_id=args.local_rank,
                           data_dir=train_dir,
                           file_list=train_list_file,
                           crop=crop_size,
                           local_rank=args.local_rank,
                           world_size=args.world_size,
                           dali_cpu=args.dali_cpu)
    pipe.build()
    train_loader = DALIClassificationIterator(
        pipe, size=int(pipe.epoch_size("Reader") / args.world_size))

    pipe = HybridValPipe(batch_size=args.batch_size,
                         num_threads=args.workers,
                         device_id=args.local_rank,
                         data_dir=val_dir,
                         file_list=val_list_file,
                         crop=crop_size,
                         local_rank=args.local_rank,
                         world_size=args.world_size,
                         size=val_size)
    pipe.build()
    val_loader = DALIClassificationIterator(
        pipe, size=int(pipe.epoch_size("Reader") / args.world_size))

    model_folder = '{}/{}/checkpoint'.format(args.save_folder, args.run_folder)
    if args.ensemble_size > 1:
        checkpoint_refer_file = '{}/init_{}_E_{}.pt'.format(
            model_folder, refer_postfix, args.ensemble_index)
        checkpoint_unsup_file = '{}/init_{}_E_{}.pt'.format(
            model_folder, unsup_postfix, args.ensemble_index)
    else:
        checkpoint_refer_file = '{}/init_{}.pt'.format(model_folder,
                                                       refer_postfix)
        checkpoint_unsup_file = '{}/init_{}.pt'.format(model_folder,
                                                       unsup_postfix)
    # save reference checkpoint (randomly initialized)
    if os.path.isfile(checkpoint_refer_file):
        print('Model {} is already trained!'.format(checkpoint_refer_file))
    else:
        print("=> creating reference model '{}'".format(args.arch))
        modelRefer = models.__dict__[args.arch](UNSUP=False)
        modelRefer = modelRefer.cuda()
        if args.fp16:
            modelRefer = network_to_half(modelRefer)
        if args.distributed:
            # shared param/delay all reduce turns off bucketing in DDP, for lower latency runs this can improve perf
            # for the older version of APEX please use shared_param, for newer one it is delay_allreduce
            modelRefer = DDP(modelRefer, delay_allreduce=True)
        # evaluate on validation set
        criterion = nn.CrossEntropyLoss().cuda()
        [refer_prec1, refer_prec5] = validate(args,
                                              val_loader,
                                              modelRefer,
                                              criterion,
                                              unsup=False)
        val_loader.reset()
        #
        print(
            'Saving reference checkpoint at epoch {} with accuracy {}'.format(
                0, refer_prec1))
        save_checkpoint(
            {
                'epoch': 0,
                'arch': args.arch,
                'state_dict': modelRefer.state_dict(),
                'acc': refer_prec1,
            }, checkpoint_refer_file)
        del modelRefer, criterion
    # train unsupervised model
    if os.path.isfile(checkpoint_unsup_file):
        print('Model {} is already trained!'.format(checkpoint_unsup_file))
    else:
        print("=> creating unsupervised model '{}'".format(args.arch))
        modelUnsup = models.__dict__[args.arch](UNSUP=True)
        modelUnsup = modelUnsup.cuda()
        if args.fp16:
            modelUnsup = network_to_half(modelUnsup)
        if args.distributed:
            # shared param/delay all reduce turns off bucketing in DDP, for lower latency runs this can improve perf
            # for the older version of APEX please use shared_param, for newer one it is delay_allreduce
            modelUnsup = DDP(modelUnsup, delay_allreduce=True)

        # define loss function (criterion) and optimizer
        criterion = nn.CrossEntropyLoss().cuda()
        optimizer = torch.optim.SGD(modelUnsup.parameters(),
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
        if args.fp16:
            optimizer = FP16_Optimizer(
                optimizer,
                static_loss_scale=args.static_loss_scale,
                dynamic_loss_scale=args.dynamic_loss_scale)

        # evaluate on validation set
        [best_prec1, best_prec5] = validate(args,
                                            val_loader,
                                            modelUnsup,
                                            criterion,
                                            unsup=True)
        val_loader.reset()
        for epoch in range(args.start_epoch, args.epochs):
            # train for one epoch
            train(args,
                  train_loader,
                  modelUnsup,
                  criterion,
                  optimizer,
                  epoch,
                  unsup=True)
            # evaluate on validation set
            [prec1, prec5] = validate(args,
                                      val_loader,
                                      modelUnsup,
                                      criterion,
                                      unsup=True)

            # remember best prec@1 and save checkpoint
            if args.local_rank == 0:
                if prec1 > best_prec1:
                    best_prec1 = prec1
                    print(
                        'Saving best unsupervised checkpoint at epoch {} with accuracy {}'
                        .format(epoch + 1, best_prec1))
                    save_checkpoint(
                        {
                            'epoch': epoch + 1,
                            'arch': args.arch,
                            'state_dict': modelUnsup.state_dict(),
                            'acc': best_prec1,
                            #'optimizer': optimizer.state_dict(),
                        },
                        checkpoint_unsup_file)
            else:
                print('Local rank is not zero')
            # reset DALI iterators
            train_loader.reset()
            val_loader.reset()
            if args.epochs == args.start_epoch - 1:
                print('##Top-1 {0}\n' '##Top-5 {1}').format(prec1, prec5)
class AugmentStageTrainer():
    def __init__(self, config):
        self.config = config
        
        """device parameters"""
        self.world_size = self.config.world_size
        self.rank = self.config.rank
        self.gpu = self.config.local_rank
        self.distributed = self.config.dist

        """get the train parameters"""
        self.total_epochs = self.config.epochs
        self.train_batch_size = self.config.batch_size
        self.val_batch_size = self.config.batch_size
        self.global_batch_size = self.world_size * self.train_batch_size

        self.max_lr = self.config.lr * self.world_size

        """construct the whole network"""
        self.resume_path = self.config.resume_path
        if torch.cuda.is_available():
            self.device = torch.device(f'cuda:{self.gpu}')
            torch.cuda.set_device(self.device)
            cudnn.benchmark = True
        else:
            self.device = torch.device('cpu')
        self.construct_model()

        """save checkpoint path"""
        self.save_epoch = 1
        self.ckpt_path = self.config.path

        """log tools in the running phase"""
        self.steps = 0
        self.log_step = 10
        self.logger = self.config.logger
        if self.rank == 0:
            self.writer = SummaryWriter(log_dir=os.path.join(self.config.path, "tb"))
            self.writer.add_text('config', config.as_markdown(), 0)

    def construct_model(self):
        """get data loader"""
        input_size, input_channels, n_classes, train_data, valid_data = get_data(
            self.config.dataset, self.config.data_path, self.config.cutout_length, validation=True
        )

        if self.distributed:
            self.train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_data, num_replicas=self.world_size, rank=self.rank
            )
        else:
            self.train_sampler = None

        self.train_loader = torch.utils.data.DataLoader(train_data,
                                                        batch_size=self.config.batch_size,
                                                        shuffle=(self.train_sampler is None),
                                                        num_workers=self.config.workers,
                                                        pin_memory=True,
                                                        sampler=self.train_sampler)
        self.valid_loader = torch.utils.data.DataLoader(valid_data,
                                                        batch_size=self.config.batch_size,
                                                        shuffle=False,
                                                        num_workers=self.config.workers,
                                                        pin_memory=True)
        self.sync_bn = self.config.amp_sync_bn
        self.opt_level = self.config.amp_opt_level
        print(f"sync_bn: {self.sync_bn}")

        """build model"""
        print("init model")
        self.criterion = nn.CrossEntropyLoss().to(self.device)
        self.use_aux = self.config.aux_weight > 0.
        model = AugmentStage(input_size, input_channels, self.config.init_channels, n_classes, self.config.layers, self.use_aux, self.config.genotype, self.config.DAG)
        if self.sync_bn:
            model = apex.parallel.convert_syncbn_model(model)
        self.model = model.to(self.device)
        print("init model end!")

        """ build optimizer """
        print("get optimizer")
        momentum = self.config.momentum
        weight_decay = self.config.weight_decay
        # LARSSGD
        # exclude_bias_and_bn = self.config.exclude_bias_and_bn
        # params = collect_params([self.model], exclude_bias_and_bn=exclude_bias_and_bn)
        # self.optimizer = LARS(params, lr=self.max_lr, momentum=momentum, weight_decay=weight_decay)
        # SGD
        self.optimizer = torch.optim.SGD(model.parameters(), lr=self.max_lr, momentum=momentum, weight_decay=weight_decay)

        self.lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, self.total_epochs)

        """init amp"""
        print("amp init!")
        self.model, self.optimizer = amp.initialize(
            self.model, self.optimizer, opt_level=self.opt_level
        )
        if self.distributed:
            self.model = DDP(self.model, delay_allreduce=True)
        print("amp init end!")
    
    def resume_model(self, model_path=None):
        if model_path is None and not self.resume_path:
            self.start_epoch = 0
            self.logger.info("--> No loaded checkpoint!")
        else:
            model_path = model_path or self.resume_path
            checkpoint = torch.load(model_path, map_location=self.device)

            self.start_epoch = checkpoint['epoch']
            self.steps = checkpoint['steps']
            self.model.load_state_dict(checkpoint['model'], strict=True)
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            amp.load_state_dict(checkpoint['amp'])
            self.logger.info(f"--> Loaded checkpoint '{model_path}' (epoch {self.start_epoch})")
    
    def save_checkpoint(self, epoch, is_best=False):
        if epoch % self.save_epoch == 0 and self.rank == 0:
            state = {'config': self.config,
                     'epoch': epoch,
                     'steps': self.steps,
                     'model': self.model.state_dict(),
                     'optimizer': self.optimizer.state_dict(),
                     'amp': amp.state_dict()
                    }
            if is_best:
                best_filename = os.path.join(self.ckpt_path, 'best.pth.tar')
                torch.save(state, best_filename)

    @torch.no_grad()
    def concat_all_gather(self, tensor):
        """
        Performs all_gather operation on the provided tensors.
        *** Warning ***: torch.distributed.all_gather has no gradient
        """
        tensor_gather = [
            torch.ones_like(tensor) for _ in range(self.world_size)
        ]
        torch.distributed.all_gather(tensor_gather, tensor, async_op=False)

        output = torch.cat(tensor_gather, dim=0)
        return output

    def train_epoch(self, epoch, printer=print):
        top1 = AverageMeter()
        top5 = AverageMeter()
        losses = AverageMeter()

        # cur_step = epoch * len(self.train_loader)
        cur_lr = self.optimizer.param_groups[0]['lr']
        
        self.model.train()

        # for step, (X, y) in enumerate(self.train_loader):
        prefetcher = data_prefetcher(self.train_loader)
        X, y = prefetcher.next()
        i = 0
        while X is not None:
            i += 1
            N = X.size(0)
            self.steps += 1

            logits, aux_logits = self.model(X)

            loss = self.criterion(logits, y)
            if self.use_aux:
                loss += self.config.aux_weight * self.criterion(aux_logits, y)
            
            self.optimizer.zero_grad()
            if self.opt_level == 'O0':
                loss.backward()
            else:
                with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                    scaled_loss.backward()
            
            nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_clip)
            self.optimizer.step()

            prec1, prec5 = accuracy(logits, y, topk=(1, 5))
            losses.update(loss.item(), N)
            top1.update(prec1.item(), N)
            top5.update(prec5.item(), N)

            if self.steps % self.log_step == 0 and self.rank == 0:
                self.writer.add_scalar('train/lr', round(cur_lr, 5), self.steps)
                self.writer.add_scalar('train/loss', loss.item(), self.steps)
                self.writer.add_scalar('train/top1', prec1.item(), self.steps)
                self.writer.add_scalar('train/top5', prec5.item(), self.steps)

            if self.gpu == 0 and (i % self.config.print_freq == 0 or i == len(self.train_loader) - 1):
                printer(f'Train: Epoch: [{epoch}][{i}/{len(self.train_loader) - 1}]\t'
                        f'Step {self.steps}\t'
                        f'lr {round(cur_lr, 5)}\t'
                        f'Loss {losses.val:.4f} ({losses.avg:.4f})\t'
                        f'Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})\t'
                    )

            X, y = prefetcher.next()

        if self.gpu == 0:
            printer("Train: [{:3d}/{}] Final Prec@1 {:.4%}".format(epoch, self.total_epochs - 1, top1.avg))
    
    def val_epoch(self, epoch, printer):
        top1 = AverageMeter()
        top5 = AverageMeter()
        losses = AverageMeter()

        self.model.eval()

        prefetcher = data_prefetcher(self.valid_loader)
        X, y = prefetcher.next()
        i = 0

        with torch.no_grad():
            while X is not None:
                N = X.size(0)
                i += 1

                logits, _ = self.model(X)

                loss = self.criterion(logits, y)

                prec1, prec5 = accuracy(logits, y, topk=(1, 5))
                losses.update(loss.item(), N)
                top1.update(prec1.item(), N)
                top5.update(prec5.item(), N)
                
                if self.rank == 0 and (i % self.config.print_freq == 0 or i == len(self.valid_loader) - 1):
                    printer(f'Valid: Epoch: [{epoch}][{i}/{len(self.valid_loader)}]\t'
                            f'Step {self.steps}\t'
                            f'Loss {losses.avg:.4f}\t'
                            f'Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})')
                
                X, y = prefetcher.next()
                
        if self.rank == 0:
            self.writer.add_scalar('val/loss', losses.avg, self.steps)
            self.writer.add_scalar('val/top1', top1.avg, self.steps)
            self.writer.add_scalar('val/top5', top5.avg, self.steps)

            printer("Valid: [{:3d}/{}] Final Prec@1 {:.4%}".format(epoch, self.total_epochs - 1, top1.avg))
        
        return top1.avg
def main():
    global best_prec1, args

    args.gpu = 0
    args.world_size = 1

    if args.distributed:
        args.gpu = args.local_rank % torch.cuda.device_count()
        torch.cuda.set_device(args.gpu)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()

    args.total_batch_size = args.world_size * args.batch_size

    if not os.path.isdir(args.checkpoint) and args.local_rank == 0:
        mkdir_p(args.checkpoint)

    if args.fp16:
        assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled."

    if args.static_loss_scale != 1.0:
        if not args.fp16:
            print("Warning:  if --fp16 is not used, static_loss_scale will be ignored.")

    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True)
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch]()

    model = model.cuda()
    if args.fp16:
        model = network_to_half(model)
    if args.distributed:
        # shared param/delay all reduce turns off bucketing in DDP, for lower latency runs this can improve perf
        # for the older version of APEX please use shared_param, for newer one it is delay_allreduce
        model = DDP(model, delay_allreduce=True)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    if args.fp16:
        optimizer = FP16_Optimizer(optimizer,
                                   static_loss_scale=args.static_loss_scale,
                                   dynamic_loss_scale=args.dynamic_loss_scale,
                                   verbose=False)

    # optionally resume from a checkpoint
    title = 'ImageNet-' + args.arch
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location=lambda storage, loc: storage.cuda(args.gpu))
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
            if args.local_rank == 0:
                logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title, resume=True)
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    else:
        if args.local_rank == 0:
            logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title)
            logger.set_names(['Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.', 'Valid Acc.', 'Valid Top5.'])

    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')

    if(args.arch == "inception_v3"):
        crop_size = 299
        val_size = 320 # I chose this value arbitrarily, we can adjust.
    else:
        crop_size = 224
        val_size = 256

    pipe = HybridTrainPipe(batch_size=args.batch_size, num_threads=args.workers, device_id=args.local_rank, data_dir=traindir, crop=crop_size, dali_cpu=args.dali_cpu)
    pipe.build()
    train_loader = DALIClassificationIterator(pipe, size=int(pipe.epoch_size("Reader") / args.world_size))

    pipe = HybridValPipe(batch_size=args.batch_size, num_threads=args.workers, device_id=args.local_rank, data_dir=valdir, crop=crop_size, size=val_size)
    pipe.build()
    val_loader = DALIClassificationIterator(pipe, size=int(pipe.epoch_size("Reader") / args.world_size))

    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    total_time = AverageMeter()
    for epoch in range(args.start_epoch, args.epochs):
        # train for one epoch
        adjust_learning_rate(optimizer, epoch,args)

        if args.local_rank == 0:
            print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, args.epochs, optimizer.param_groups[0]['lr']))

        [train_loss, train_acc, avg_train_time] = train(train_loader, model, criterion, optimizer, epoch)
        total_time.update(avg_train_time)
        # evaluate on validation set
        [test_loss, prec1, prec5] = validate(val_loader, model, criterion)

        # remember best prec@1 and save checkpoint
        if args.local_rank == 0:
            # append logger file
            logger.append([optimizer.param_groups[0]['lr'], train_loss, test_loss, train_acc, prec1, prec5])

            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint({
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict(),
            }, is_best,checkpoint=args.checkpoint)
            # if need to save each epoch checkpoint, add: filename="epoch"+str(epoch+1)+"checkpoint.pth.tar"
            if epoch == args.epochs - 1:
                print('##Top-1 {0}\n'
                      '##Top-5 {1}\n'
                      '##Perf  {2}'.format(prec1, prec5, args.total_batch_size / total_time.avg))

        # reset DALI iterators
        train_loader.reset()
        val_loader.reset()

    if args.local_rank == 0:
        logger.close()
Example #23
0
def main():
    global best_prec1, args

    args = parse()
    print("opt_level = {}".format(args.opt_level))
    print("keep_batchnorm_fp32 = {}".format(args.keep_batchnorm_fp32),
          type(args.keep_batchnorm_fp32))
    print("loss_scale = {}".format(args.loss_scale), type(args.loss_scale))
    print("\nCUDNN VERSION: {}\n".format(torch.backends.cudnn.version()))

    cudnn.benchmark = True
    best_prec1 = 0
    if args.deterministic:
        cudnn.benchmark = False
        cudnn.deterministic = True
        torch.manual_seed(args.local_rank)
        torch.set_printoptions(precision=10)

    args.distributed = False
    args.gpu = 0
    args.world_size = 1

    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1
    print(f"world_size {int(os.environ['WORLD_SIZE'])}")

    if args.distributed:
        args.gpu = args.local_rank
        torch.cuda.set_device(args.gpu)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()

    assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled."

    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True)
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch]()

    if args.sync_bn:
        import apex
        print("using apex synced BN")
        model = apex.parallel.convert_syncbn_model(model)

    model = model.cuda()

    # Scale learning rate based on global batch size
    args.lr = args.lr * float(args.batch_size * args.world_size) / 256.
    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # Initialize Amp.  Amp accepts either values or strings for the optional override arguments,
    # for convenient interoperation with argparse.
    model, optimizer = amp.initialize(
        model,
        optimizer,
        opt_level=args.opt_level,
        keep_batchnorm_fp32=args.keep_batchnorm_fp32,
        loss_scale=args.loss_scale)

    # For distributed training, wrap the model with apex.parallel.DistributedDataParallel.
    # This must be done AFTER the call to amp.initialize.  If model = DDP(model) is called
    # before model, ... = amp.initialize(model, ...), the call to amp.initialize may alter
    # the types of model's parameters in a way that disrupts or destroys DDP's allreduce hooks.
    if args.distributed:
        # By default, apex.parallel.DistributedDataParallel overlaps communication with
        # computation in the backward pass.
        # model = DDP(model)
        # delay_allreduce delays all communication to the end of the backward pass.
        model = DDP(model, delay_allreduce=True)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    # Optionally resume from a checkpoint
    if args.resume:
        # Use a local scope to avoid dangling references
        def resume():
            if os.path.isfile(args.resume):
                print("=> loading checkpoint '{}'".format(args.resume))
                checkpoint = torch.load(
                    args.resume,
                    map_location=lambda storage, loc: storage.cuda(args.gpu))
                args.start_epoch = checkpoint['epoch']
                best_prec1 = checkpoint['best_prec1']
                model.load_state_dict(checkpoint['state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer'])
                print("=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint['epoch']))
            else:
                print("=> no checkpoint found at '{}'".format(args.resume))

        resume()

    if (args.arch == "inception_v3"):
        raise RuntimeError(
            "Currently, inception_v3 is not supported by this example.")
        # crop_size = 299
        # val_size = 320 # I chose this value arbitrarily, we can adjust.
    else:
        crop_size = 224
        val_size = 256

    # Data loading code
    # traindir = os.path.join(args.data, 'train')
    # valdir = os.path.join(args.data, 'val')
    #
    # train_dataset = datasets.ImageFolder(
    #     traindir,
    #     transforms.Compose([
    #         transforms.RandomResizedCrop(crop_size),
    #         transforms.RandomHorizontalFlip(),
    #         # transforms.ToTensor(), Too slow
    #         # normalize,
    #     ]))
    # val_dataset = datasets.ImageFolder(valdir, transforms.Compose([
    #     transforms.Resize(val_size),
    #     transforms.CenterCrop(crop_size),
    # ]))

    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        # transforms.RandomAffine(5),
        # transforms.RandomRotation(30),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    train_set = torchvision.datasets.CIFAR10(root='~/dataSet/cifar10',
                                             train=True,
                                             download=True,
                                             transform=transform_train)

    test_set = torchvision.datasets.CIFAR10(root='~/dataSet/cifar10',
                                            train=False,
                                            download=True,
                                            transform=transform_test)

    train_sampler = None
    val_sampler = None
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_set)
        val_sampler = torch.utils.data.distributed.DistributedSampler(test_set)

    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(test_set,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True,
                                             sampler=val_sampler)

    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion)

        # remember best prec@1 and save checkpoint
        if args.local_rank == 0:
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                    'optimizer': optimizer.state_dict(),
                }, is_best)
Example #24
0
def main():
    global best_prec1, args
    best_prec1 = 0
    args = parse()

    # test mode, use default args for sanity test
    if args.test:
        args.opt_level = None
        args.epochs = 1
        args.start_epoch = 0
        args.arch = 'resnet50'
        args.batch_size = 64
        args.data = []
        args.sync_bn = False
        args.data.append('/data/imagenet/train-jpeg/')
        args.data.append('/data/imagenet/val-jpeg/')
        print("Test mode - no DDP, no apex, RN50, 10 iterations")

    if not len(args.data):
        raise Exception("error: No data set provided")

    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1

    # make apex optional
    if args.opt_level is not None or args.distributed or args.sync_bn:
        try:
            global DDP, amp, optimizers, parallel
            from apex.parallel import DistributedDataParallel as DDP
            from apex import amp, optimizers, parallel
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to run this example."
            )

    print("opt_level = {}".format(args.opt_level))
    print("keep_batchnorm_fp32 = {}".format(args.keep_batchnorm_fp32),
          type(args.keep_batchnorm_fp32))
    print("loss_scale = {}".format(args.loss_scale), type(args.loss_scale))

    print("\nCUDNN VERSION: {}\n".format(torch.backends.cudnn.version()))

    cudnn.benchmark = True
    best_prec1 = 0
    if args.deterministic:
        cudnn.benchmark = False
        cudnn.deterministic = True
        torch.manual_seed(args.local_rank)
        torch.set_printoptions(precision=10)

    args.gpu = 0
    args.world_size = 1

    if args.distributed:
        args.gpu = args.local_rank
        torch.cuda.set_device(args.gpu)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()

    args.total_batch_size = args.world_size * args.batch_size
    assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled."

    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True)
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch]()

    if args.sync_bn:
        print("using apex synced BN")
        model = parallel.convert_syncbn_model(model)

    if hasattr(torch, 'channels_last') and hasattr(torch, 'contiguous_format'):
        if args.channels_last:
            memory_format = torch.channels_last
        else:
            memory_format = torch.contiguous_format
        model = model.cuda().to(memory_format=memory_format)
    else:
        model = model.cuda()

    # Scale learning rate based on global batch size
    args.lr = args.lr * float(args.batch_size * args.world_size) / 256.
    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # Initialize Amp.  Amp accepts either values or strings for the optional override arguments,
    # for convenient interoperation with argparse.
    if args.opt_level is not None:
        model, optimizer = amp.initialize(
            model,
            optimizer,
            opt_level=args.opt_level,
            keep_batchnorm_fp32=args.keep_batchnorm_fp32,
            loss_scale=args.loss_scale)

    # For distributed training, wrap the model with apex.parallel.DistributedDataParallel.
    # This must be done AFTER the call to amp.initialize.  If model = DDP(model) is called
    # before model, ... = amp.initialize(model, ...), the call to amp.initialize may alter
    # the types of model's parameters in a way that disrupts or destroys DDP's allreduce hooks.
    if args.distributed:
        # By default, apex.parallel.DistributedDataParallel overlaps communication with
        # computation in the backward pass.
        # model = DDP(model)
        # delay_allreduce delays all communication to the end of the backward pass.
        model = DDP(model, delay_allreduce=True)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    # Optionally resume from a checkpoint
    if args.resume:
        # Use a local scope to avoid dangling references
        def resume():
            if os.path.isfile(args.resume):
                print("=> loading checkpoint '{}'".format(args.resume))
                checkpoint = torch.load(
                    args.resume,
                    map_location=lambda storage, loc: storage.cuda(args.gpu))
                args.start_epoch = checkpoint['epoch']
                global best_prec1
                best_prec1 = checkpoint['best_prec1']
                model.load_state_dict(checkpoint['state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer'])
                print("=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint['epoch']))
            else:
                print("=> no checkpoint found at '{}'".format(args.resume))

        resume()

    # Data loading code
    if len(args.data) == 1:
        traindir = os.path.join(args.data[0], 'train')
        valdir = os.path.join(args.data[0], 'val')
    else:
        traindir = args.data[0]
        valdir = args.data[1]

    if (args.arch == "inception_v3"):
        raise RuntimeError(
            "Currently, inception_v3 is not supported by this example.")
        # crop_size = 299
        # val_size = 320 # I chose this value arbitrarily, we can adjust.
    else:
        crop_size = 224
        val_size = 256

    pipe = create_dali_pipeline(batch_size=args.batch_size,
                                num_threads=args.workers,
                                device_id=args.local_rank,
                                seed=12 + args.local_rank,
                                data_dir=traindir,
                                crop=crop_size,
                                size=val_size,
                                dali_cpu=args.dali_cpu,
                                shard_id=args.local_rank,
                                num_shards=args.world_size,
                                is_training=True)
    pipe.build()
    train_loader = DALIClassificationIterator(
        pipe, reader_name="Reader", last_batch_policy=LastBatchPolicy.PARTIAL)

    pipe = create_dali_pipeline(batch_size=args.batch_size,
                                num_threads=args.workers,
                                device_id=args.local_rank,
                                seed=12 + args.local_rank,
                                data_dir=valdir,
                                crop=crop_size,
                                size=val_size,
                                dali_cpu=args.dali_cpu,
                                shard_id=args.local_rank,
                                num_shards=args.world_size,
                                is_training=False)
    pipe.build()
    val_loader = DALIClassificationIterator(
        pipe, reader_name="Reader", last_batch_policy=LastBatchPolicy.PARTIAL)

    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    total_time = AverageMeter()
    for epoch in range(args.start_epoch, args.epochs):
        # train for one epoch
        avg_train_time = train(train_loader, model, criterion, optimizer,
                               epoch)
        total_time.update(avg_train_time)
        if args.test:
            break

        # evaluate on validation set
        [prec1, prec5] = validate(val_loader, model, criterion)

        # remember best prec@1 and save checkpoint
        if args.local_rank == 0:
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                    'optimizer': optimizer.state_dict(),
                }, is_best)
            if epoch == args.epochs - 1:
                print('##Top-1 {0}\n'
                      '##Top-5 {1}\n'
                      '##Perf  {2}'.format(
                          prec1, prec5,
                          args.total_batch_size / total_time.avg))

        train_loader.reset()
        val_loader.reset()
Example #25
0
def main():
    global best_prec1, args

    args.gpu = 0
    args.world_size = 1

    if args.distributed:
        args.gpu = args.local_rank % torch.cuda.device_count()
        torch.cuda.set_device(args.gpu)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()

    if args.fp16:
        assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled."

    if args.static_loss_scale != 1.0:
        if not args.fp16:
            print("Warning:  if --fp16 is not used, static_loss_scale will be ignored.")

    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True)
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch]()

    model = model.cuda()
    if args.fp16:
        model = network_to_half(model)
    if args.distributed:
        # shared param/delay all reduce turns off bucketing in DDP, for lower latency runs this can improve perf
        # for the older version of APEX please use shared_param, for newer one it is delay_allreduce
        model = DDP(model, delay_allreduce=True)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    if args.fp16:
        optimizer = FP16_Optimizer(optimizer,
                                   static_loss_scale=args.static_loss_scale,
                                   dynamic_loss_scale=args.dynamic_loss_scale)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location=lambda storage, loc: storage.cuda(args.gpu))
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # Data loading code
    if len(args.data) == 1:
        traindir = os.path.join(args.data[0], 'train')
        valdir = os.path.join(args.data[0], 'val')
    else:
        traindir = args.data[0]
        valdir= args.data[1]

    if(args.arch == "inception_v3"):
        crop_size = 299
        val_size = 320 # I chose this value arbitrarily, we can adjust.
    else:
        crop_size = 224
        val_size = 256

    pipe = HybridTrainPipe(batch_size=args.batch_size, num_threads=args.workers, device_id=args.local_rank, data_dir=traindir, crop=crop_size, dali_cpu=args.dali_cpu)
    pipe.build()
    train_loader = DALIClassificationIterator(pipe, size=int(pipe.epoch_size("Reader") / args.world_size))

    pipe = HybridValPipe(batch_size=args.batch_size, num_threads=args.workers, device_id=args.local_rank, data_dir=valdir, crop=crop_size, size=val_size)
    pipe.build()
    val_loader = DALIClassificationIterator(pipe, size=int(pipe.epoch_size("Reader") / args.world_size))

    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    for epoch in range(args.start_epoch, args.epochs):
        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)
        if args.prof:
            break
        # evaluate on validation set
        [prec1, prec5] = validate(val_loader, model, criterion)

        # remember best prec@1 and save checkpoint
        if args.local_rank == 0:
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint({
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict(),
            }, is_best)
            if epoch == args.epochs - 1:
                print('##Top-1 {0}\n'
                      '##Top-5 {1}'.format(prec1, prec5))

        # reset DALI iterators
        train_loader.reset()
        val_loader.reset()
Example #26
0
def train(args, train_dataset, model):
    """ Train the model """
    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    t_total = len(train_dataloader
                  ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        args.weight_decay
    }, {
        'params': [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = WarmupLinearSchedule(optimizer,
                                     warmup_steps=t_total * 0.1,
                                     t_total=t_total)
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        if args.fp16_opt_level == "O2":
            keep_batchnorm_fp32 = False
        else:
            keep_batchnorm_fp32 = True
        model, optimizer = amp.initialize(
            model,
            optimizer,
            opt_level=args.fp16_opt_level,
            keep_batchnorm_fp32=keep_batchnorm_fp32)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )
        model = DDP(
            model,
            message_size=250000000,
            gradient_predivide_factor=torch.distributed.get_world_size())

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs = 0
    model.zero_grad()
    model.train()
    train_iterator = trange(int(args.num_train_epochs),
                            desc="Epoch",
                            disable=args.local_rank not in [-1, 0])
    set_seed(
        args)  # Added here for reproductibility (even between python 2 and 3)
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader,
                              desc="Train(XX Epoch) Step(X/X) (loss=X.X)",
                              disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):
            batch = tuple(t.to(args.device)
                          for t in batch)  # multi-gpu does scattering it-self
            input_ids, input_mask, segment_ids, start_positions, end_positions = batch
            outputs = model(input_ids, segment_ids, input_mask,
                            start_positions, end_positions)
            loss = outputs  # model outputs are always tuple in transformers (see doc)

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel (not distributed) training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)

                scheduler.step()  # Update learning rate schedule\
                optimizer.step()
                optimizer.zero_grad()
                global_step += 1
                epoch_iterator.set_description(
                    "Train(%d Epoch) Step(%d / %d) (loss=%5.5f)" %
                    (_, global_step, t_total, loss.item()))

        if args.local_rank in [-1, 0]:
            model_checkpoint = 'korquad_{0}_{1}_{2}_{3}.bin'.format(
                args.learning_rate, args.train_batch_size, epochs,
                int(args.num_train_epochs))
            logger.info(model_checkpoint)
            output_model_file = os.path.join(args.output_dir, model_checkpoint)
            if args.n_gpu > 1 or args.local_rank != -1:
                logger.info("** ** * Saving file * ** **(module)")
                torch.save(model.module.state_dict(), output_model_file)
            else:
                logger.info("** ** * Saving file * ** **")
                torch.save(model.state_dict(), output_model_file)
        epochs += 1
    logger.info("Training End!!!")
Example #27
0
def train300_mlperf_coco(args):
    from pycocotools.coco import COCO

    # Check that GPUs are actually available
    use_cuda = not args.no_cuda

    # Setup multi-GPU if necessary
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1

    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
    local_seed = set_seeds(args)
    # start timing here
    if args.distributed:
        N_gpu = torch.distributed.get_world_size()
    else:
        N_gpu = 1

    validate_group_bn(args.bn_group)
    # Setup data, defaults
    dboxes = dboxes300_coco()
    encoder = Encoder(dboxes)
    input_size = 300
    val_trans = SSDTransformer(dboxes, (input_size, input_size), val=True)

    val_annotate = os.path.join(args.data, "annotations/instances_val2017.json")
    val_coco_root = os.path.join(args.data, "val2017")
    train_annotate = os.path.join(args.data, "annotations/instances_train2017.json")
    train_coco_root = os.path.join(args.data, "train2017")

    # Build the model
    model_options = {
        'backbone' : args.backbone,
        'use_nhwc' : args.nhwc,
        'pad_input' : args.pad_input,
        'bn_group' : args.bn_group,
    }

    ssd300 = SSD300(args.num_classes, **model_options)
    if args.checkpoint is not None:
        load_checkpoint(ssd300, args.checkpoint)

    ssd300.train()
    ssd300.cuda()
    if args.opt_loss:
        loss_func = OptLoss(dboxes)
    else:
        loss_func = Loss(dboxes)
    loss_func.cuda()

    if args.distributed:
        N_gpu = torch.distributed.get_world_size()
    else:
        N_gpu = 1

    if args.use_fp16:
        ssd300 = network_to_half(ssd300)

    # Parallelize.  Need to do this after network_to_half.
    if args.distributed:
        if args.delay_allreduce:
            print_message(args.local_rank, "Delaying allreduces to the end of backward()")
        ssd300 = DDP(ssd300,
                     gradient_predivide_factor=N_gpu/8.0,
                     delay_allreduce=args.delay_allreduce,
                     retain_allreduce_buffers=args.use_fp16)

    # Create optimizer.  This must also be done after network_to_half.
    global_batch_size = (N_gpu * args.batch_size)
    mlperf_print(key=mlperf_compliance.constants.MODEL_BN_SPAN, value=args.bn_group*args.batch_size)
    mlperf_print(key=mlperf_compliance.constants.GLOBAL_BATCH_SIZE, value=global_batch_size)

    # mlperf only allows base_lr scaled by an integer
    base_lr = 2.5e-3
    requested_lr_multiplier = args.lr / base_lr
    adjusted_multiplier = max(1, round(requested_lr_multiplier * global_batch_size / 32))

    current_lr = base_lr * adjusted_multiplier
    current_momentum = 0.9
    current_weight_decay = args.wd
    static_loss_scale = 128.
    if args.use_fp16:
        if args.distributed and not args.delay_allreduce:
            # We can't create the flat master params yet, because we need to
            # imitate the flattened bucket structure that DDP produces.
            optimizer_created = False
        else:
            model_buckets = [[p for p in ssd300.parameters() if p.requires_grad
                              and p.type() == "torch.cuda.HalfTensor"],
                              [p for p in ssd300.parameters() if p.requires_grad
                               and p.type() == "torch.cuda.FloatTensor"]]
            flat_master_buckets = create_flat_master(model_buckets)
            optim = torch.optim.SGD(flat_master_buckets, lr=current_lr, momentum=current_momentum,
                                    weight_decay=current_weight_decay)
            optimizer_created = True
    else:
        optim = torch.optim.SGD(ssd300.parameters(), lr=current_lr, momentum=current_momentum,
                                weight_decay=current_weight_decay)
        optimizer_created = True

    mlperf_print(key=mlperf_compliance.constants.OPT_BASE_LR, value=current_lr)
    mlperf_print(key=mlperf_compliance.constants.OPT_WEIGHT_DECAY,
                         value=current_weight_decay)
    if args.warmup is not None:
        mlperf_print(key=mlperf_compliance.constants.OPT_LR_WARMUP_STEPS,
                  value=args.warmup)
        mlperf_print(key=mlperf_compliance.constants.OPT_LR_WARMUP_FACTOR,
                  value=args.warmup_factor)

    # Model is completely finished -- need to create separate copies, preserve parameters across
    # them, and jit
    ssd300_eval = SSD300(args.num_classes, backbone=args.backbone, use_nhwc=args.nhwc, pad_input=args.pad_input).cuda()
    if args.use_fp16:
        ssd300_eval = network_to_half(ssd300_eval)

    # Get the existant state from the train model
    # * if we use distributed, then we want .module
    train_model = ssd300.module if args.distributed else ssd300

    ssd300_eval.load_state_dict(train_model.state_dict())

    ssd300_eval.eval()


    print_message(args.local_rank, "epoch", "nbatch", "loss")
    eval_points = np.array(args.evaluation) * 32 / global_batch_size
    eval_points = list(map(int, list(eval_points)))

    iter_num = args.iteration
    avg_loss = 0.0

    start_elapsed_time = time.time()
    last_printed_iter = args.iteration
    num_elapsed_samples = 0

    # Generate normalization tensors
    mean, std = generate_mean_std(args)

    dummy_overflow_buf = torch.cuda.IntTensor([0])
    def step_maybe_fp16_maybe_distributed(optim):
        if args.use_fp16:
            if args.distributed:
                for flat_master, allreduce_buffer in zip(flat_master_buckets, ssd300.allreduce_buffers):
                    if allreduce_buffer is None:
                        raise RuntimeError("allreduce_buffer is None")
                    flat_master.grad = allreduce_buffer.float()
                    flat_master.grad.data.mul_(1./static_loss_scale)
            else:
                for flat_master, model_bucket in zip(flat_master_buckets, model_buckets):
                    flat_grad = apex_C.flatten([m.grad.data for m in model_bucket])
                    flat_master.grad = flat_grad.float()
                    flat_master.grad.data.mul_(1./static_loss_scale)
        optim.step()
        if args.use_fp16:
            # Use multi-tensor scale instead of loop & individual parameter copies
            for model_bucket, flat_master in zip(model_buckets, flat_master_buckets):
                multi_tensor_applier(
                    amp_C.multi_tensor_scale,
                    dummy_overflow_buf,
                    [apex_C.unflatten(flat_master.data, model_bucket), model_bucket],
                    1.0)

    input_c = 4 if args.pad_input else 3
    example_shape = [args.batch_size, 300, 300, input_c] if args.nhwc else [args.batch_size, input_c, 300, 300]
    example_input = torch.randn(*example_shape).cuda()

    if args.use_fp16:
        example_input = example_input.half()
    if args.jit:
        # DDP has some Python-side control flow.  If we JIT the entire DDP-wrapped module,
        # the resulting ScriptModule will elide this control flow, resulting in allreduce
        # hooks not being called.  If we're running distributed, we need to extract and JIT
        # the wrapped .module.
        # Replacing a DDP-ed ssd300 with a script_module might also cause the AccumulateGrad hooks
        # to go out of scope, and therefore silently disappear.
        module_to_jit = ssd300.module if args.distributed else ssd300
        if args.distributed:
            ssd300.module = torch.jit.trace(module_to_jit, example_input)
        else:
            ssd300 = torch.jit.trace(module_to_jit, example_input)
        # JIT the eval model too
        ssd300_eval = torch.jit.trace(ssd300_eval, example_input)

    # do a dummy fprop & bprop to make sure cudnnFind etc. are timed here
    ploc, plabel = ssd300(example_input)

    # produce a single dummy "loss" to make things easier
    loss = ploc[0,0,0] + plabel[0,0,0]
    dloss = torch.randn_like(loss)
    # Cause cudnnFind for dgrad, wgrad to run
    loss.backward(dloss)

    mlperf_print(key=mlperf_compliance.constants.INIT_STOP,
                 sync=True)
    ##### END INIT

    # This is the first place we touch anything related to data
    ##### START DATA TOUCHING
    mlperf_print(key=mlperf_compliance.constants.RUN_START,
                 sync=True)
    barrier()
    cocoGt = COCO(annotation_file=val_annotate, use_ext=True)
    val_coco = COCODetection(val_coco_root, val_annotate, val_trans)

    if args.distributed:
        val_sampler = GeneralDistributedSampler(val_coco, pad=False)
    else:
        val_sampler = None

    if args.no_dali:
        train_trans = SSDTransformer(dboxes, (input_size, input_size), val=False)
        train_coco = COCODetection(train_coco_root, train_annotate, train_trans)

        if args.distributed:
            train_sampler = GeneralDistributedSampler(train_coco, pad=False)
        else:
            train_sampler = None

        train_loader = DataLoader(train_coco,
                                  batch_size=args.batch_size*args.input_batch_multiplier,
                                  shuffle=(train_sampler is None),
                                  sampler=train_sampler,
                                  num_workers=args.num_workers,
                                  collate_fn=partial(my_collate, is_training=True))
    else:
        train_pipe = COCOPipeline(args.batch_size*args.input_batch_multiplier, args.local_rank, train_coco_root,
                                  train_annotate, N_gpu, num_threads=args.num_workers,
                                  output_fp16=args.use_fp16, output_nhwc=args.nhwc,
                                  pad_output=args.pad_input, seed=local_seed - 2**31,
                                  use_nvjpeg=args.use_nvjpeg, use_roi=args.use_roi_decode,
                                  dali_cache=args.dali_cache,
                                  dali_async=(not args.dali_sync))
        print_message(args.local_rank, "time_check a: {secs:.9f}".format(secs=time.time()))
        train_pipe.build()
        print_message(args.local_rank, "time_check b: {secs:.9f}".format(secs=time.time()))
        test_run = train_pipe.run()
        train_loader = SingleDaliIterator(train_pipe, ['images', DALIOutput('bboxes', False, True), DALIOutput('labels', True, True)],
                                          train_pipe.epoch_size()['train_reader'], ngpu=N_gpu)

    train_loader = EncodingInputIterator(train_loader, dboxes=encoder.dboxes.cuda(), nhwc=args.nhwc,
                                         fake_input=args.fake_input, no_dali=args.no_dali)
    if args.input_batch_multiplier > 1:
        train_loader = RateMatcher(input_it=train_loader, output_size=args.batch_size)

    val_dataloader   = DataLoader(val_coco,
                                  batch_size=args.eval_batch_size,
                                  shuffle=False, # Note: distributed sampler is shuffled :(
                                  sampler=val_sampler,
                                  num_workers=args.num_workers)

    inv_map = {v:k for k,v in val_coco.label_map.items()}

    ##### END DATA TOUCHING
    i_eval = 0
    first_epoch = 1
    mlperf_print(key=mlperf_compliance.constants.BLOCK_START,
                 metadata={'first_epoch_num': first_epoch,
                           'epoch_count': args.evaluation[i_eval]*32/train_pipe.epoch_size()['train_reader'] },
                 sync=True)
    for epoch in range(args.epochs):
        mlperf_print(key=mlperf_compliance.constants.EPOCH_START,
                     metadata={'epoch_num': epoch + 1},
                     sync=True)
        for p in ssd300.parameters():
            p.grad = None

        for i, (img, bbox, label) in enumerate(train_loader):

            if args.profile_start is not None and iter_num == args.profile_start:
                torch.cuda.profiler.start()
                torch.cuda.synchronize()
                if args.profile_nvtx:
                    torch.autograd._enable_profiler(torch.autograd.ProfilerState.NVTX)

            if args.profile is not None and iter_num == args.profile:
                if args.profile_start is not None and iter_num >=args.profile_start:
                    # we turned cuda and nvtx profiling on, better turn it off too
                    if args.profile_nvtx:
                        torch.autograd._disable_profiler()
                    torch.cuda.profiler.stop()
                return

            if args.warmup is not None and optimizer_created:
                lr_warmup(optim, args.warmup, iter_num, epoch, current_lr, args)
            if iter_num == ((args.decay1 * 1000 * 32) // global_batch_size):
                print_message(args.local_rank, "lr decay step #1")
                current_lr *= 0.1
                for param_group in optim.param_groups:
                    param_group['lr'] = current_lr

            if iter_num == ((args.decay2 * 1000 * 32) // global_batch_size):
                print_message(args.local_rank, "lr decay step #2")
                current_lr *= 0.1
                for param_group in optim.param_groups:
                    param_group['lr'] = current_lr

            if (img is None) or (bbox is None) or (label is None):
                print("No labels in batch")
                continue

            ploc, plabel = ssd300(img)
            ploc, plabel = ploc.float(), plabel.float()

            N = img.shape[0]
            gloc, glabel = Variable(bbox, requires_grad=False), \
                           Variable(label, requires_grad=False)
            loss = loss_func(ploc, plabel, gloc, glabel)

            if np.isfinite(loss.item()):
                avg_loss = 0.999*avg_loss + 0.001*loss.item()
            else:
                print("model exploded (corrupted by Inf or Nan)")
                sys.exit()

            num_elapsed_samples += N
            if args.local_rank == 0 and iter_num % args.print_interval == 0:
                end_elapsed_time = time.time()
                elapsed_time = end_elapsed_time - start_elapsed_time

                avg_samples_per_sec = num_elapsed_samples * N_gpu / elapsed_time

                print("Iteration: {:6d}, Loss function: {:5.3f}, Average Loss: {:.3f}, avg. samples / sec: {:.2f}"\
                            .format(iter_num, loss.item(), avg_loss, avg_samples_per_sec), end="\n")

                last_printed_iter = iter_num
                start_elapsed_time = time.time()
                num_elapsed_samples = 0

            # loss scaling
            if args.use_fp16:
                loss = loss*static_loss_scale
            loss.backward()

            if not optimizer_created:
                # Imitate the model bucket structure created by DDP.
                # These will already be split by type (float or half).
                model_buckets = []
                for bucket in ssd300.active_i_buckets:
                    model_buckets.append([])
                    for active_i in bucket:
                        model_buckets[-1].append(ssd300.active_params[active_i])
                flat_master_buckets = create_flat_master(model_buckets)
                optim = torch.optim.SGD(flat_master_buckets, lr=current_lr, momentum=current_momentum,
                                        weight_decay=current_weight_decay)
                optimizer_created = True
                # Skip this first iteration because flattened allreduce buffers are not yet created.
                # step_maybe_fp16_maybe_distributed(optim)
            else:
                step_maybe_fp16_maybe_distributed(optim)

            # Likely a decent skew here, let's take this opportunity to set the gradients to None.
            # After DALI integration, playing with the placement of this is worth trying.
            for p in ssd300.parameters():
                p.grad = None

            if iter_num in eval_points:
		# Get the existant state from the train model
		# * if we use distributed, then we want .module
                train_model = ssd300.module if args.distributed else ssd300

                if args.distributed and args.allreduce_running_stats:
                    if get_rank() == 0: print("averaging bn running means and vars")
                    # make sure every node has the same running bn stats before
                    # using them to evaluate, or saving the model for inference
                    world_size = float(torch.distributed.get_world_size())
                    for bn_name, bn_buf in train_model.named_buffers(recurse=True):
                        if ('running_mean' in bn_name) or ('running_var' in bn_name):
                            torch.distributed.all_reduce(bn_buf, op=dist.ReduceOp.SUM)
                            bn_buf /= world_size

                if get_rank() == 0:
                    if not args.no_save:
                        print("saving model...")
                        torch.save({"model" : ssd300.state_dict(), "label_map": val_coco.label_info},
                                    "./models/iter_{}.pt".format(iter_num))

                ssd300_eval.load_state_dict(train_model.state_dict())
                succ = coco_eval(ssd300_eval,
                             val_dataloader,
                             cocoGt,
                             encoder,
                             inv_map,
                             args.threshold,
                             epoch,
                             iter_num,
                             args.eval_batch_size,
                             use_fp16=args.use_fp16,
                             local_rank=args.local_rank if args.distributed else -1,
                             N_gpu=N_gpu,
                             use_nhwc=args.nhwc,
                             pad_input=args.pad_input)
                mlperf_print(key=mlperf_compliance.constants.BLOCK_STOP,
                             metadata={'first_epoch_num': first_epoch},
                             sync=True)
                if succ:
                    return True
                if iter_num != max(eval_points):
                    i_eval += 1
                    first_epoch = epoch+1
                    mlperf_print(key=mlperf_compliance.constants.BLOCK_START,
                                 metadata={'first_epoch_num': first_epoch,
                                           'epoch_count': (args.evaluation[i_eval]-args.evaluation[i_eval-1])*32/train_pipe.epoch_size()['train_reader']},
                                 sync=True)
            iter_num += 1

        train_loader.reset()
        mlperf_print(key=mlperf_compliance.constants.EPOCH_STOP,
                     metadata={'epoch_num': epoch + 1},
                     sync=True)
    return False
Example #28
0
def train(cudaid, args, model):

    dist.init_process_group(backend='nccl',
                            init_method='env://',
                            world_size=args.size,
                            rank=cudaid)

    random.seed(1)
    np.random.seed(1)
    torch.manual_seed(1)
    torch.cuda.manual_seed(1)

    print('params: ', " T_warm: ", T_warm, " all_iteration: ", all_iteration,
          " lr: ", lr)
    #cuda_list=range(args.size)
    print('rank: ', cudaid)
    torch.cuda.set_device(cudaid)
    model.cuda(cudaid)

    accumulation_steps = int(args.batch_size / args.size / args.gpu_size)
    optimizer = apex.optimizers.FusedLAMB(model.parameters(),
                                          lr=lr,
                                          betas=(0.9, 0.98),
                                          eps=1e-6,
                                          weight_decay=0.0,
                                          max_grad_norm=1.0)
    model, optimizer = amp.initialize(model, optimizer, opt_level='O2')
    model = DDP(model)

    #model = nn.DataParallel(model, device_ids=cuda_list)
    # torch.distributed.init_process_group(backend='nccl', init_method='tcp://localhost:23456', rank=0, world_size=1)
    # torch.cuda.set_device(cudaid)

    #model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
    #model=torch.nn.parallel.DistributedDataParallel(model, device_ids=cuda_list)
    #model = torch.nn.DataParallel(model)
    #model=apex.parallel.DistributedDataParallel(model)

    accum_batch_loss = 0
    iterator = NewsIterator(batch_size=args.gpu_size,
                            npratio=4,
                            feature_file=os.path.join(args.data_dir,
                                                      args.feature_file),
                            field=args.field)
    train_file = os.path.join(args.data_dir, args.data_file)
    #for epoch in range(0,100):
    batch_t = 0
    iteration = 0
    print('train...', args.field)
    #w=open(os.path.join(args.data_dir,args.log_file),'w')
    if cudaid == 0:
        writer = SummaryWriter(os.path.join(args.data_dir, args.log_file))
    epoch = 0
    model.train()
    # batch_t=52880-1
    # iteration=3305-1
    batch_t = 0
    iteration = 0
    step = 0
    best_score = -1
    #w=open(os.path.join(args.data_dir,args.log_file),'w')

    # model.eval()
    # auc=test(model,args)

    for epoch in range(0, 10):
        #while True:
        all_loss = 0
        all_batch = 0
        data_batch = iterator.load_data_from_file(train_file, cudaid,
                                                  args.size)
        for imp_index, user_index, his_id, candidate_id, label in data_batch:
            batch_t += 1
            assert candidate_id.shape[1] == 2
            his_id = his_id.cuda(cudaid)
            candidate_id = candidate_id.cuda(cudaid)
            label = label.cuda(cudaid)
            loss = model(his_id, candidate_id, label)

            sample_size = candidate_id.shape[0]
            loss = loss.sum() / sample_size / math.log(2)

            accum_batch_loss += float(loss)

            all_loss += float(loss)
            all_batch += 1

            loss = loss / accumulation_steps
            #loss.backward()
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()

            # num=0
            # if cudaid==0:
            #     for p in model.parameters():
            #         if p.grad==None:
            #             print('error: ',p.size())
            #         else:
            #             print('ok: ',p.size())
            #             o=1
            #             for item in p.size():
            #                 o=o*item
            #             num+=o
            #     print(num)
            #     assert 1==0

            if (batch_t) % accumulation_steps == 0:

                iteration += 1
                adjust_learning_rate(optimizer, iteration)
                optimizer.step()
                optimizer.zero_grad()
                if cudaid == 0:
                    print(' batch_t: ', batch_t, ' iteration: ', iteration,
                          ' epoch: ', epoch, ' accum_batch_loss: ',
                          accum_batch_loss / accumulation_steps, ' lr: ',
                          optimizer.param_groups[0]['lr'])
                    writer.add_scalar('Loss/train',
                                      accum_batch_loss / accumulation_steps,
                                      iteration)
                    writer.add_scalar('Ltr/train',
                                      optimizer.param_groups[0]['lr'],
                                      iteration)
                accum_batch_loss = 0
                if iteration % 500 == 0 and cudaid == 0:
                    torch.cuda.empty_cache()
                    model.eval()
                    if cudaid == 0:
                        auc = test(model, args)
                        print(auc)
                        if auc > best_score:
                            torch.save(
                                model.state_dict(),
                                os.path.join(args.save_dir,
                                             'Plain_robert_dot_best.pkl'))
                            best_score = auc
                            print('best score: ', best_score)
                            writer.add_scalar('auc/valid', auc, step)
                            step += 1
                    torch.cuda.empty_cache()
                    model.train()

        if cudaid == 0:
            torch.save(
                model.state_dict(),
                os.path.join(args.save_dir,
                             'Plain_robert_dot' + str(epoch) + '.pkl'))
Example #29
0
def train(cudaid, args, model, roberta_dict, rerank, data, data_valid, label,
          label_valid):

    #pynvml.nvmlInit()
    dist.init_process_group(backend='nccl',
                            init_method='env://',
                            world_size=args.world_size,
                            rank=cudaid)

    random.seed(1)
    np.random.seed(1)
    torch.manual_seed(1)
    torch.cuda.manual_seed(1)

    print('params: ', " T_warm: ", args.T_warm, " all_iteration: ",
          args.all_iteration, " lr: ", args.lr)
    #cuda_list=range(args.world_size)
    print('rank: ', cudaid)
    torch.cuda.set_device(cudaid)
    model.cuda(cudaid)

    accumulation_steps = int(args.batch_size / args.world_size / args.gpu_size)
    optimizer = apex.optimizers.FusedLAMB(
        model.parameters(),
        lr=args.lr,
        betas=(0.9, 0.98),
        eps=1e-6,
        weight_decay=0.1,
        max_grad_norm=1.0)  #clip-norm = 0.0???

    model, optimizer = amp.initialize(model, optimizer, opt_level='O2')
    model = DDP(model)

    accum_batch_loss = 0
    accum_batch_acc = 0
    all_batch_loss = 0
    accum_acc = 0
    accum_num = 0

    batch_t = 0
    iteration = 0
    print('train...', args.field)
    #w=open(os.path.join(args.data_dir,args.log_file),'w')
    #if cudaid==0:
    if not os.path.exists(
            os.path.join(args.data_dir, args.log_file, 'cuda_' + str(cudaid))):
        os.mkdir(
            os.path.join(args.data_dir, args.log_file, 'cuda_' + str(cudaid)))
    writer = SummaryWriter(
        os.path.join(args.data_dir, args.log_file, 'cuda_' + str(cudaid)))
    epoch = 0
    epoch_o = 0
    model.train()
    # batch_t=52880-1
    # iteration=3305-1
    batch_t = 0
    iteration = 0
    step = 0
    best_score = -1
    step_t = 0
    start_pos = None
    batch_t_arg = 0
    #w=open(os.path.join(args.data_dir,args.log_file),'w')

    # model.eval()
    # auc=test(model,args)
    if args.model_file != None:
        epoch_o = args.epoch
        iteration = args.iteration
        #batch_t=args.batch_t
        step = int(iteration / 10000) + 1
        if args.use_start_pos:
            #start_pos=args.gpu_size*batch_t*2%(int((32255176-int(0.002*32255176))/args.world_size)+1)
            start_pos = args.gpu_size * batch_t * 2 % (
                int(len(data) / args.world_size) + 1)
            batch_t_arg = args.batch_t
            batch_t = args.batch_t
        elif args.batch_one_epoch != None:
            batch_t_arg = args.batch_t % args.batch_one_epoch
        else:
            batch_t_arg = args.batch_t

    #print('???',batch_t_arg,args.batch_t)
    for epoch in range(epoch_o, 10):
        data_batch = utils.get_batch_glue(data,
                                          label,
                                          roberta_dict,
                                          args.gpu_size,
                                          rerank=rerank,
                                          dist=True,
                                          cudaid=cudaid,
                                          size=args.world_size,
                                          start_pos=start_pos)
        start_pos = None  #下次还是从开头开始
        for token_list, label_list in data_batch:
            if epoch == epoch_o and batch_t < batch_t_arg:
                batch_t += 1
                continue
            batch_t += 1
            #assert candidate_id.shape[1]==2
            # his_id=his_id.cuda(cudaid)
            # candidate_id= candidate_id.cuda(cudaid)
            # label = label.cuda(cudaid)
            # loss=model(his_id,candidate_id, label)

            token_list = token_list.cuda(cudaid)
            label_list = label_list.cuda(cudaid)

            loss, sample_size, acc = model(token_list, label=label_list)

            #print('????decode: ',sample_size_decode)
            #print('output: ',loss_mask,sample_size_mask,loss_decode,sample_size_decode)

            if sample_size != 0:
                if args.num_classes != 1:
                    loss = loss / sample_size / math.log(2)
                else:
                    loss = loss / sample_size

            accum_batch_loss += float(loss)
            all_batch_loss += float(loss)

            accum_batch_acc += float(acc) / sample_size

            accum_acc += float(acc)
            accum_num += sample_size

            loss = loss / accumulation_steps

            # loss.backward()

            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()

            if (batch_t) % accumulation_steps == 0:

                # handle = pynvml.nvmlDeviceGetHandleByIndex(cudaid)
                # meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
                # #print(int(meminfo.used)/1024/1024)
                # print('memory: ',int(meminfo.used)/1024/1024,' cudaid: ',cudaid)
                iteration += 1
                adjust_learning_rate(optimizer,
                                     iteration,
                                     lr=args.lr,
                                     T_warm=args.T_warm,
                                     all_iteration=args.all_iteration)
                optimizer.step()
                optimizer.zero_grad()
                if cudaid == 0:
                    print(' batch_t: ',batch_t, ' iteration: ', iteration, ' epoch: ',epoch,' accum_batch_loss: ',accum_batch_loss/accumulation_steps,\
                        ' accum_batch_acc: ', accum_batch_acc/accumulation_steps,' accum_loss: ',all_batch_loss/batch_t, ' accum_acc: ', accum_acc/accum_num , \
                        ' lr: ', optimizer.param_groups[0]['lr'])
                    writer.add_scalar('Loss/train',
                                      accum_batch_loss / accumulation_steps,
                                      iteration)
                    writer.add_scalar('Accuracy/train',
                                      accum_batch_acc / accumulation_steps,
                                      iteration)
                    writer.add_scalar('Loss_all/train',
                                      all_batch_loss / batch_t, iteration)
                    writer.add_scalar('Ltr/train',
                                      optimizer.param_groups[0]['lr'],
                                      iteration)
                    writer.add_scalar('Accuracy_all/train',
                                      accum_acc / accum_num, iteration)

                accum_batch_loss = 0
                accum_batch_acc = 0
                break

                # if iteration%5000==0 and cudaid==0:
                #     torch.save(model.state_dict(), os.path.join(args.save_dir,'finetune_iteration'+str(iteration)+'.pkl'))

                #if iteration%5000==0:

        data_batch_valid = utils.get_batch_glue(data_valid,
                                                label_valid,
                                                roberta_dict,
                                                args.valid_size,
                                                rerank=None,
                                                dist=True,
                                                cudaid=cudaid,
                                                size=args.world_size,
                                                start_pos=start_pos)
        accum_batch_loss_valid = 0
        accumulation_steps_valid = 0
        batch_t_valid = 0
        accum_num_valid = 0
        accum_acc_valid = 0

        torch.cuda.empty_cache()
        model.eval()
        with torch.no_grad():
            for token_list_valid, label_list_valid in data_batch_valid:
                #batch_t+=1
                #assert candidate_id.shape[1]==2
                # his_id=his_id.cuda(cudaid)
                # candidate_id= candidate_id.cuda(cudaid)
                # label = label.cuda(cudaid)
                # loss=model(his_id,candidate_id, label)
                batch_t_valid += token_list_valid.shape[0]

                token_list_valid = token_list_valid.cuda(cudaid)
                label_list_valid = label_list_valid.cuda(cudaid)

                loss_valid, sample_size_valid, acc_valid = model(
                    token_list_valid, label=label_list_valid)

                if args.num_classes != 1:
                    loss_valid = loss_valid / sample_size_valid / math.log(2)
                else:
                    loss_valid = loss_valid / sample_size_valid
                # print('loss: ',loss,' sample_size: ',sample_size)
                # assert 1==0
                accum_batch_loss_valid += float(loss_valid)
                accum_acc_valid += float(acc_valid)
                accum_num_valid += sample_size_valid

                accumulation_steps_valid += 1

                if accumulation_steps_valid % 100 == 0:
                    print('batch_t: ', batch_t_valid, cudaid)

        accum_batch_loss_t = accum_batch_loss_valid / accumulation_steps_valid
        accum_acc_valid = accum_acc_valid / accum_num_valid

        if cudaid == 0:
            print(
                'valid loss: ',
                accum_batch_loss_t,
                'valid acc: ',
            )

        writer.add_scalar('Loss/valid' + str(cudaid), accum_batch_loss_t, step)
        writer.add_scalar('Accuracy/valid' + str(cudaid), accum_acc_valid,
                          step)
        step += 1
        torch.cuda.empty_cache()
        model.train()
        if cudaid == 0:
            torch.save(
                model.state_dict(),
                os.path.join(args.save_dir,
                             'glue_roberta' + str(epoch) + '.pkl'))
Example #30
0
class Seq2SeqTrainer:
    """
    Seq2SeqTrainer
    """
    def __init__(self,
                 model,
                 criterion,
                 opt_config,
                 scheduler_config,
                 print_freq=10,
                 save_freq=1000,
                 grad_clip=float('inf'),
                 save_info={},
                 save_dir='.',
                 train_iterations=0,
                 checkpoint_filename='checkpoint%s.pth',
                 keep_checkpoints=5,
                 math='fp32',
                 loss_scaling={},
                 intra_epoch_eval=0,
                 prealloc_mode='always',
                 iter_size=1,
                 translator=None,
                 verbose=False):
        """
        Constructor for the Seq2SeqTrainer.

        :param model: model to train
        :param criterion: criterion (loss function)
        :param opt_config: dictionary with options for the optimizer
        :param scheduler_config: dictionary with options for the learning rate
            scheduler
        :param print_freq: prints short summary every 'print_freq' iterations
        :param save_freq: saves checkpoint every 'save_freq' iterations
        :param grad_clip: coefficient for gradient clipping
        :param save_info: dict with additional state stored in each checkpoint
        :param save_dir: path to the directiory for checkpoints
        :param train_iterations: total number of training iterations to execute
        :param checkpoint_filename: name of files with checkpoints
        :param keep_checkpoints: max number of checkpoints to keep
        :param math: arithmetic type
        :param loss_scaling: options for dynamic loss scaling
        :param intra_epoch_eval: number of additional eval runs within each
            training epoch
        :param prealloc_mode: controls preallocation,
            choices=['off', 'once', 'always']
        :param iter_size: number of iterations between weight updates
        :param translator: instance of Translator, runs inference on test set
        :param verbose: enables verbose logging
        """
        super(Seq2SeqTrainer, self).__init__()
        self.model = model
        self.criterion = criterion
        self.epoch = 0
        self.save_info = save_info
        self.save_dir = save_dir
        self.save_freq = save_freq
        self.save_counter = 0
        self.checkpoint_filename = checkpoint_filename
        self.checkpoint_counter = cycle(range(keep_checkpoints))
        self.opt_config = opt_config
        self.device = next(model.parameters()).device
        self.print_freq = print_freq
        self.verbose = verbose
        self.loss = None
        self.translator = translator
        self.intra_epoch_eval = intra_epoch_eval
        self.iter_size = iter_size
        self.prealloc_mode = prealloc_mode
        self.preallocated = False

        self.distributed = torch.distributed.is_initialized()
        self.batch_first = model.batch_first

        params = self.model.parameters()

        if math == 'manual_fp16':
            self.fp_optimizer = FP16Optimizer(
                self.model,
                grad_clip,
                loss_scale=loss_scaling['init_scale'],
                dls_upscale_interval=loss_scaling['upscale_interval'])
            params = self.fp_optimizer.fp32_params
        elif math == 'fp32':
            self.fp_optimizer = FP32Optimizer(self.model, grad_clip)

        opt_name = opt_config.pop('optimizer')
        self.optimizer = torch.optim.__dict__[opt_name](params, **opt_config)
        logging.info(f'Using optimizer: {self.optimizer}')

        self.scheduler = WarmupMultiStepLR(self.optimizer, train_iterations,
                                           **scheduler_config)

        if math == 'fp16':
            self.model, self.optimizer = amp.initialize(
                self.model,
                self.optimizer,
                cast_model_outputs=torch.float16,
                keep_batchnorm_fp32=False,
                opt_level='O2')

            self.fp_optimizer = AMPOptimizer(
                self.model,
                grad_clip,
                loss_scale=loss_scaling['init_scale'],
                dls_upscale_interval=loss_scaling['upscale_interval'])

        if self.distributed:
            self.model = DistributedDataParallel(self.model)

    def iterate(self, src, tgt, update=True, training=True):
        """
        Performs one iteration of the training/validation.

        :param src: batch of examples from the source language
        :param tgt: batch of examples from the target language
        :param update: if True: optimizer does update of the weights
        :param training: if True: executes optimizer
        """
        pyprof2.init()
        src, src_length = src
        tgt, tgt_length = tgt
        src = src.to(self.device)
        tgt = tgt.to(self.device)
        src_length = src_length.to(self.device)

        num_toks = {}
        num_toks['tgt'] = int(sum(tgt_length - 1))
        num_toks['src'] = int(sum(src_length))

        with torch.autograd.profiler.emit_nvtx():
            profiler.start()

            if self.batch_first:
                output = self.model(src, src_length, tgt[:, :-1])
                tgt_labels = tgt[:, 1:]
                T, B = output.size(1), output.size(0)
            else:
                output = self.model(src, src_length, tgt[:-1])
                tgt_labels = tgt[1:]
                T, B = output.size(0), output.size(1)

            loss = self.criterion(output.view(T * B, -1),
                                  tgt_labels.contiguous().view(-1))

            loss_per_batch = loss.item()
            loss /= (B * self.iter_size)

            if training:
                self.fp_optimizer.step(loss, self.optimizer, self.scheduler,
                                       update)

            loss_per_token = loss_per_batch / num_toks['tgt']
            loss_per_sentence = loss_per_batch / B

            profiler.stop()

        print('You can stop now')
        exit()

        return loss_per_token, loss_per_sentence, num_toks

    def feed_data(self, data_loader, training=True):
        """
        Runs training or validation on batches from data_loader.

        :param data_loader: data loader
        :param training: if True runs training else runs validation
        """
        if training:
            assert self.optimizer is not None
            eval_fractions = np.linspace(0, 1, self.intra_epoch_eval + 2)[1:-1]
            iters_with_update = len(data_loader) // self.iter_size
            eval_iters = (eval_fractions * iters_with_update).astype(int)
            eval_iters = eval_iters * self.iter_size
            eval_iters = set(eval_iters)

        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses_per_token = AverageMeter()
        losses_per_sentence = AverageMeter()

        tot_tok_time = AverageMeter()
        src_tok_time = AverageMeter()
        tgt_tok_time = AverageMeter()

        batch_size = data_loader.batch_size

        end = time.time()
        for i, (src, tgt) in enumerate(data_loader):
            self.save_counter += 1
            # measure data loading time
            data_time.update(time.time() - end)

            update = False
            if i % self.iter_size == self.iter_size - 1:
                update = True

            # do a train/evaluate iteration
            stats = self.iterate(src, tgt, update, training=training)
            loss_per_token, loss_per_sentence, num_toks = stats

            # measure accuracy and record loss
            losses_per_token.update(loss_per_token, num_toks['tgt'])
            losses_per_sentence.update(loss_per_sentence, batch_size)

            # measure elapsed time
            elapsed = time.time() - end
            batch_time.update(elapsed)
            src_tok_time.update(num_toks['src'] / elapsed)
            tgt_tok_time.update(num_toks['tgt'] / elapsed)
            tot_num_toks = num_toks['tgt'] + num_toks['src']
            tot_tok_time.update(tot_num_toks / elapsed)
            self.loss = losses_per_token.avg

            if training and i in eval_iters:
                eval_fname = f'eval_epoch_{self.epoch}_iter_{i}'
                eval_path = os.path.join(self.save_dir, eval_fname)
                _, eval_stats = self.translator.run(
                    calc_bleu=True,
                    epoch=self.epoch,
                    iteration=i,
                    eval_path=eval_path,
                )
                test_bleu = eval_stats['bleu']

                log = []
                log += [f'TRAIN [{self.epoch}][{i}/{len(data_loader)}]']
                log += [f'BLEU: {test_bleu:.2f}']
                log = '\t'.join(log)
                logging.info(log)

                self.model.train()
                self.preallocate(data_loader.batch_size,
                                 data_loader.dataset.max_len,
                                 training=True)

            if i % self.print_freq == 0:
                phase = 'TRAIN' if training else 'VALIDATION'
                log = []
                log += [f'{phase} [{self.epoch}][{i}/{len(data_loader)}]']
                log += [f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})']
                log += [f'Data {data_time.val:.2e} ({data_time.avg:.2e})']
                log += [
                    f'Tok/s {tot_tok_time.val:.0f} ({tot_tok_time.avg:.0f})'
                ]
                if self.verbose:
                    log += [
                        f'Src tok/s {src_tok_time.val:.0f} ({src_tok_time.avg:.0f})'
                    ]
                    log += [
                        f'Tgt tok/s {tgt_tok_time.val:.0f} ({tgt_tok_time.avg:.0f})'
                    ]
                    log += [
                        f'Loss/sentence {losses_per_sentence.val:.1f} ({losses_per_sentence.avg:.1f})'
                    ]
                log += [
                    f'Loss/tok {losses_per_token.val:.4f} ({losses_per_token.avg:.4f})'
                ]
                if training:
                    lr = self.optimizer.param_groups[0]['lr']
                    log += [f'LR {lr:.3e}']
                log = '\t'.join(log)
                logging.info(log)

            save_chkpt = (self.save_counter %
                          self.save_freq) == (self.save_freq - 1)
            if training and save_chkpt:
                self.save_counter = 0
                self.save_info['iteration'] = i
                identifier = next(self.checkpoint_counter, -1)
                if identifier != -1:
                    with sync_workers() as rank:
                        if rank == 0:
                            self.save(identifier=identifier)

            end = time.time()

        tot_tok_time.reduce('sum')
        losses_per_token.reduce('mean')

        return losses_per_token.avg, tot_tok_time.avg

    def preallocate(self, batch_size, max_length, training):
        """
        Generates maximum sequence length batch and runs forward and backward
        pass without updating model parameters.

        :param batch_size: batch size for preallocation
        :param max_length: max sequence length for preallocation
        :param training: if True preallocates memory for backward pass
        """
        if self.prealloc_mode == 'always' or (self.prealloc_mode == 'once'
                                              and not self.preallocated):
            logging.info('Executing preallocation')
            torch.cuda.empty_cache()

            src_length = torch.full((batch_size, ),
                                    max_length,
                                    dtype=torch.int64)
            tgt_length = torch.full((batch_size, ),
                                    max_length,
                                    dtype=torch.int64)

            if self.batch_first:
                shape = (batch_size, max_length)
            else:
                shape = (max_length, batch_size)

            src = torch.full(shape, 4, dtype=torch.int64)
            tgt = torch.full(shape, 4, dtype=torch.int64)
            src = src, src_length
            tgt = tgt, tgt_length
            self.iterate(src, tgt, update=False, training=training)
            self.model.zero_grad()
            self.preallocated = True

    def optimize(self, data_loader):
        """
        Sets model in training mode, preallocates memory and runs training on
        data provided by data_loader.

        :param data_loader: data loader
        """
        torch.set_grad_enabled(True)
        self.model.train()
        self.preallocate(data_loader.batch_size,
                         data_loader.dataset.max_len,
                         training=True)

        output = self.feed_data(data_loader, training=True)

        self.model.zero_grad()
        return output

    def evaluate(self, data_loader):
        """
        Sets model in eval mode, disables gradients, preallocates memory and
        runs validation on data provided by data_loader.

        :param data_loader: data loader
        """
        torch.set_grad_enabled(False)
        self.model.eval()
        self.preallocate(data_loader.batch_size,
                         data_loader.dataset.max_len,
                         training=False)

        output = self.feed_data(data_loader, training=False)

        self.model.zero_grad()
        return output

    def load(self, filename):
        """
        Loads checkpoint from filename.

        :param filename: path to the checkpoint file
        """
        if os.path.isfile(filename):
            checkpoint = torch.load(filename, map_location={'cuda:0': 'cpu'})
            if self.distributed:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            self.fp_optimizer.initialize_model(self.model)
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.scheduler.load_state_dict(checkpoint['scheduler'])
            self.epoch = checkpoint['epoch']
            self.loss = checkpoint['loss']
            logging.info(f'Loaded checkpoint {filename} (epoch {self.epoch})')
        else:
            logging.error(f'Invalid checkpoint: {filename}')

    def save(self, identifier=None, is_best=False, save_all=False):
        """
        Stores checkpoint to a file.

        :param identifier: identifier for periodic checkpoint
        :param is_best: if True stores checkpoint to 'model_best.pth'
        :param save_all: if True stores checkpoint after completed training
            epoch
        """
        def write_checkpoint(state, filename):
            filename = os.path.join(self.save_dir, filename)
            logging.info(f'Saving model to {filename}')
            torch.save(state, filename)

        if self.distributed:
            model_state = self.model.module.state_dict()
        else:
            model_state = self.model.state_dict()

        state = {
            'epoch': self.epoch,
            'state_dict': model_state,
            'optimizer': self.optimizer.state_dict(),
            'scheduler': self.scheduler.state_dict(),
            'loss': getattr(self, 'loss', None),
        }
        state = dict(list(state.items()) + list(self.save_info.items()))

        if identifier is not None:
            filename = self.checkpoint_filename % identifier
            write_checkpoint(state, filename)

        if is_best:
            filename = 'model_best.pth'
            write_checkpoint(state, filename)

        if save_all:
            filename = f'checkpoint_epoch_{self.epoch:03d}.pth'
            write_checkpoint(state, filename)
Example #31
0
File: main.py Project: gongzg/DALI
def main():
    global best_prec1, args

    args.distributed = args.world_size > 1
    args.gpu = 0
    if args.distributed:
        args.gpu = args.rank % torch.cuda.device_count()


    if args.distributed:
        torch.cuda.set_device(args.gpu)
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                world_size=args.world_size, rank=args.rank)

    if args.fp16:
        assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled."

    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True, num_classes=args.num_classes)
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch](num_classes=args.num_classes)

    model = model.cuda()
    if args.fp16:
        model = network_to_half(model)
    if args.distributed:
        model = DDP(model)

    global model_params, master_params
    if args.fp16:
        model_params, master_params = prep_param_lists(model)
    else:
        master_params = list(model.parameters())

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(master_params, args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location = lambda storage, loc: storage.cuda(args.gpu))
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')

    pipe = HybridPipe(batch_size=args.batch_size, num_threads=args.workers, device_id = args.rank, data_dir = traindir)
    pipe.build()
    test_run = pipe.run()
    from nvidia.dali.plugin.pytorch import DALIClassificationIterator
    train_loader = DALIClassificationIterator(pipe, size = int(1281167 / args.world_size) )


    pipe = HybridPipe(batch_size=args.batch_size, num_threads=args.workers, device_id = args.rank, data_dir = valdir)
    pipe.build()
    test_run = pipe.run()
    from nvidia.dali.plugin.pytorch import DALIClassificationIterator
    val_loader = DALIClassificationIterator(pipe, size = int(50000 / args.world_size) )

    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)
        if args.prof:
            break
        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion)

        # remember best prec@1 and save checkpoint
        if args.rank == 0:
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint({
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer' : optimizer.state_dict(),
            }, is_best)