Beispiel #1
0
    def __init__(self, args, training=True):
        super(RunnerWrapper, self).__init__()
        # Store general options
        self.args = args
        self.training = training

        # Read names lists from the args
        self.load_names(args)

        # Initialize classes for the networks
        nets_names = self.nets_names_test
        if self.training:
            nets_names += self.nets_names_train
        nets_names = list(set(nets_names))

        self.nets = nn.ModuleDict()

        for net_name in sorted(nets_names):
            self.nets[net_name] = importlib.import_module(f'networks.{net_name}').NetworkWrapper(args)

            if args.num_gpus > 1:
                # Apex is only needed for multi-gpu training
                from apex import parallel

                self.nets[net_name] = parallel.convert_syncbn_model(self.nets[net_name])

        # Set nets that are not training into eval mode
        for net_name in self.nets.keys():
            if net_name not in self.nets_names_to_train:
                self.nets[net_name].eval()

        # Initialize classes for the losses
        if self.training:
            losses_names = list(set(self.losses_names_train + self.losses_names_test))
            self.losses = nn.ModuleDict()

            for loss_name in sorted(losses_names):
                self.losses[loss_name] = importlib.import_module(f'losses.{loss_name}').LossWrapper(args)

        # Spectral norm
        if args.spn_layers:
            spn_layers = utils.parse_str_to_list(args.spn_layers, sep=',')
            spn_nets_names = utils.parse_str_to_list(args.spn_networks, sep=',')

            for net_name in spn_nets_names:
                self.nets[net_name].apply(lambda module: utils.spectral_norm(module, apply_to=spn_layers, eps=args.eps))

            # Remove spectral norm in modules in exceptions
            spn_exceptions = utils.parse_str_to_list(args.spn_exceptions, sep=',')

            for full_module_name in spn_exceptions:
                if not full_module_name:
                    continue

                parts = full_module_name.split('.')

                # Get the module that needs to be changed
                module = self.nets[parts[0]]
                for part in parts[1:]:
                    module = getattr(module, part)

                module.apply(utils.remove_spectral_norm)

        # Weight averaging
        if args.wgv_mode != 'none':
            # Apply weight averaging only for networks that are being trained
            for net_name, _ in self.nets_names_to_train:
                self.nets[net_name].apply(lambda module: utils.weight_averaging(module, mode=args.wgv_mode, momentum=args.wgv_momentum))

        # Check which networks are being trained and put the rest into the eval mode
        for net_name in self.nets.keys():
            if net_name not in self.nets_names_to_train:
                self.nets[net_name].eval()

        # Set the same batchnorm momentum accross all modules
        if self.training:
            self.apply(lambda module: utils.set_batchnorm_momentum(module, args.bn_momentum))

        # Store a history of losses and images for visualization
        self.losses_history = {
            True: {}, # self.training = True
            False: {}}
def main():
    parser = argparse.ArgumentParser("PyTorch Xview Pipeline")
    arg = parser.add_argument
    arg('--config', metavar='CONFIG_FILE', help='path to configuration file')
    arg('--workers', type=int, default=8, help='number of cpu threads to use')
    arg('--gpu',
        type=str,
        default='0',
        help='List of GPUs for parallel training, e.g. 0,1,2,3')
    arg('--output-dir', type=str, default='weights/')
    arg('--resume', type=str, default='')
    arg('--fold', type=int, default=0)
    arg('--prefix', type=str, default='localization_')
    arg('--data-dir', type=str, default="/home/selim/datasets/xview/train")
    arg('--folds-csv', type=str, default='folds.csv')
    arg('--logdir', type=str, default='logs')
    arg('--zero-score', action='store_true', default=False)
    arg('--from-zero', action='store_true', default=False)
    arg('--distributed', action='store_true', default=False)
    arg('--freeze-epochs', type=int, default=1)
    arg("--local_rank", default=0, type=int)
    arg("--opt-level", default='O0', type=str)
    arg("--predictions", default="../oof_preds", type=str)
    arg("--test_every", type=int, default=1)

    args = parser.parse_args()

    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
    else:
        os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    cudnn.benchmark = True

    conf = load_config(args.config)
    model = models.__dict__[conf['network']](seg_classes=conf['num_classes'],
                                             backbone_arch=conf['encoder'])

    model = model.cuda()
    if args.distributed:
        model = convert_syncbn_model(model)
    mask_loss_function = losses.__dict__[conf["mask_loss"]["type"]](
        **conf["mask_loss"]["params"]).cuda()
    loss_functions = {"mask_loss": mask_loss_function}
    optimizer, scheduler = create_optimizer(conf['optimizer'], model)

    dice_best = 0
    start_epoch = 0
    batch_size = conf['optimizer']['batch_size']

    data_train = XviewSingleDataset(
        mode="train",
        fold=args.fold,
        data_path=args.data_dir,
        folds_csv=args.folds_csv,
        transforms=create_train_transforms(conf['input']),
        multiplier=conf["data_multiplier"],
        normalize=conf["input"].get("normalize", None))
    data_val = XviewSingleDataset(
        mode="val",
        fold=args.fold,
        data_path=args.data_dir,
        folds_csv=args.folds_csv,
        transforms=create_val_transforms(conf['input']),
        normalize=conf["input"].get("normalize", None))
    train_sampler = None
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            data_train)

    train_data_loader = DataLoader(data_train,
                                   batch_size=batch_size,
                                   num_workers=args.workers,
                                   shuffle=train_sampler is None,
                                   sampler=train_sampler,
                                   pin_memory=False,
                                   drop_last=True)
    val_batch_size = 1
    val_data_loader = DataLoader(data_val,
                                 batch_size=val_batch_size,
                                 num_workers=args.workers,
                                 shuffle=False,
                                 pin_memory=False)

    os.makedirs(args.logdir, exist_ok=True)
    summary_writer = SummaryWriter(args.logdir + '/' + args.prefix +
                                   conf['encoder'])
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location='cpu')
            state_dict = checkpoint['state_dict']
            if conf['optimizer'].get('zero_decoder', False):
                for key in state_dict.copy().keys():
                    if key.startswith("module.final"):
                        del state_dict[key]
            state_dict = {k[7:]: w for k, w in state_dict.items()}
            model.load_state_dict(state_dict, strict=False)
            if not args.from_zero:
                start_epoch = checkpoint['epoch']
                if not args.zero_score:
                    dice_best = checkpoint.get('dice_best', 0)
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    if args.from_zero:
        start_epoch = 0
    current_epoch = start_epoch

    if conf['fp16']:
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.opt_level,
                                          loss_scale='dynamic')

    snapshot_name = "{}{}_{}_{}".format(args.prefix, conf['network'],
                                        conf['encoder'], args.fold)

    if args.distributed:
        model = DistributedDataParallel(model, delay_allreduce=True)
    else:
        model = DataParallel(model).cuda()
    for epoch in range(start_epoch, conf['optimizer']['schedule']['epochs']):
        if train_sampler:
            train_sampler.set_epoch(epoch)
        if epoch < args.freeze_epochs:
            print("Freezing encoder!!!")
            model.module.encoder_stages.eval()
            for p in model.module.encoder_stages.parameters():
                p.requires_grad = False
        else:
            print("Unfreezing encoder!!!")
            model.module.encoder_stages.train()
            for p in model.module.encoder_stages.parameters():
                p.requires_grad = True
        train_epoch(current_epoch, loss_functions, model, optimizer, scheduler,
                    train_data_loader, summary_writer, conf, args.local_rank)

        model = model.eval()
        if args.local_rank == 0:
            torch.save(
                {
                    'epoch': current_epoch + 1,
                    'state_dict': model.state_dict(),
                    'dice_best': dice_best,
                }, args.output_dir + '/' + snapshot_name + "_last")
            if epoch % args.test_every == 0:
                preds_dir = os.path.join(args.predictions, snapshot_name)
                dice_best = evaluate_val(args,
                                         val_data_loader,
                                         dice_best,
                                         model,
                                         snapshot_name=snapshot_name,
                                         current_epoch=current_epoch,
                                         optimizer=optimizer,
                                         summary_writer=summary_writer,
                                         predictions_dir=preds_dir)
        current_epoch += 1
def main():
    parser = argparse.ArgumentParser("PyTorch Xview Pipeline")
    arg = parser.add_argument
    arg('--config', metavar='CONFIG_FILE', help='path to configuration file')
    arg('--workers', type=int, default=6, help='number of cpu threads to use')
    arg('--gpu',
        type=str,
        default='0',
        help='List of GPUs for parallel training, e.g. 0,1,2,3')
    arg('--output-dir', type=str, default='weights/')
    arg('--resume', type=str, default='')
    arg('--fold', type=int, default=0)
    arg('--prefix', type=str, default='classifier_')
    arg('--data-dir', type=str, default="/mnt/sota/datasets/deepfake")
    arg('--folds-csv', type=str, default='folds.csv')
    arg('--crops-dir', type=str, default='crops')
    arg('--label-smoothing', type=float, default=0.01)
    arg('--logdir', type=str, default='logs')
    arg('--zero-score', action='store_true', default=False)
    arg('--from-zero', default=True)
    arg('--distributed', action='store_true', default=False)
    arg('--freeze-epochs', type=int, default=0)
    arg("--local_rank", default=0, type=int)
    arg("--seed", default=777, type=int)
    arg("--padding-part", default=3, type=int)
    arg("--opt-level", default='O1', type=str)
    arg("--test_every", type=int, default=1)
    arg("--no-oversample", action="store_true")
    arg("--no-hardcore", action="store_true")
    arg("--only-changed-frames", action="store_true")

    args = parser.parse_args()
    os.makedirs(args.output_dir, exist_ok=True)
    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
    else:
        os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    cudnn.benchmark = True

    conf = load_config(args.config)
    model = classifiers.__dict__[conf['network']](encoder=conf['encoder'])

    model = model.cuda()

    teacher_model = classifiers.__dict__[conf['network']](
        encoder=conf['encoder'])

    teacher_model = teacher_model.cuda()

    if args.distributed:
        model = convert_syncbn_model(model)
    ohem = conf.get("ohem_samples", None)
    reduction = "mean"
    if ohem:
        reduction = "none"
    loss_fn = []
    weights = []
    for loss_name, weight in conf["losses"].items():
        loss_fn.append(losses.__dict__[loss_name](reduction=reduction).cuda())
        weights.append(weight)
    loss = WeightedLosses(loss_fn, weights)
    loss_functions = {"classifier_loss": loss}
    optimizer, scheduler = create_optimizer(conf['optimizer'], model)
    teacher_optimizer, _ = create_optimizer(conf['optimizer'], model)
    bce_best = 100
    start_epoch = 0
    batch_size = conf['optimizer']['batch_size']

    data_train = DeepFakeClassifierDataset_kn(
        mode="train",
        oversample_real=not args.no_oversample,
        fold=args.fold,
        padding_part=args.padding_part,
        hardcore=not args.no_hardcore,
        crops_dir=args.crops_dir,
        data_path=args.data_dir,
        label_smoothing=args.label_smoothing,
        folds_csv=args.folds_csv,
        #    transforms=create_train_transforms(conf["size"]),
        normalize=conf.get("normalize", None))
    # data_val = DeepFakeClassifierDataset_kn(mode="val",
    #                                      fold=args.fold,
    #                                      padding_part=args.padding_part,
    #                                      crops_dir=args.crops_dir,
    #                                      data_path=args.data_dir,
    #                                      folds_csv=args.folds_csv,
    #                                     #  transforms=create_val_transforms(conf["size"]),
    #                                      normalize=conf.get("normalize", None))

    data_train_student = DeepFakeClassifierDataset_kn(
        mode="train",
        oversample_real=not args.no_oversample,
        fold=args.fold,
        padding_part=args.padding_part,
        hardcore=not args.no_hardcore,
        crops_dir=args.crops_dir,
        data_path=args.data_dir,
        label_smoothing=args.label_smoothing,
        folds_csv=args.folds_csv,
        transforms=create_train_transforms(conf["size"]),
        normalize=conf.get("normalize", None))
    data_val_student = DeepFakeClassifierDataset_kn(
        mode="val",
        fold=args.fold,
        padding_part=args.padding_part,
        crops_dir=args.crops_dir,
        data_path=args.data_dir,
        folds_csv=args.folds_csv,
        transforms=create_val_transforms(conf["size"]),
        normalize=conf.get("normalize", None))

    # val_data_loader = DataLoader(data_val, batch_size=batch_size * 2, num_workers=args.workers, shuffle=False,
    #                              pin_memory=False)

    val_data_loader_student = DataLoader(data_val_student,
                                         batch_size=batch_size * 2,
                                         num_workers=args.workers,
                                         shuffle=False,
                                         pin_memory=False)
    os.makedirs(args.logdir, exist_ok=True)
    summary_writer = SummaryWriter(args.logdir + '/' +
                                   conf.get("prefix", args.prefix) +
                                   conf['encoder'] + "_" + str(args.fold))

    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location='cpu')
            state_dict = checkpoint['state_dict']
            state_dict = {k[7:]: w for k, w in state_dict.items()}
            teacher_model.load_state_dict(state_dict, strict=False)
            if not args.from_zero:
                start_epoch = checkpoint['epoch']
                if not args.zero_score:
                    bce_best = checkpoint.get('bce_best', 0)
            print("=> loaded checkpoint '{}' (epoch {}, bce_best {})".format(
                args.resume, checkpoint['epoch'], checkpoint['bce_best']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    if args.from_zero:
        start_epoch = 0
    current_epoch = start_epoch

    if conf['fp16']:
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.opt_level,
                                          loss_scale='dynamic')

        teacher_model, _ = amp.initialize(teacher_model,
                                          teacher_optimizer,
                                          opt_level=args.opt_level,
                                          loss_scale='dynamic')

    snapshot_name = "{}{}_{}_{}".format(conf.get("prefix",
                                                 args.prefix), conf['network'],
                                        conf['encoder'], args.fold)

    if args.distributed:
        model = DistributedDataParallel(model, delay_allreduce=True)
    else:
        model = DataParallel(model).cuda()
        teacher_model = DataParallel(teacher_model).cuda()

    teacher_model.eval()
    # register each block, in order to extract the blocks' feature maps
    for name, block in model.encoder.blocks.named_children():
        block.register_forward_hook(hook_function)

    for name, block in teacher_model.encoder.blocks.named_children():
        block.register_forward_hook(teacher_hook_function)

    data_val_student.reset(1, args.seed)
    max_epochs = conf['optimizer']['schedule']['epochs']

    for epoch in range(start_epoch, max_epochs):
        data_train.reset(epoch, args.seed)
        train_sampler = None
        if args.distributed:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                data_train)
            train_sampler.set_epoch(epoch)
        if epoch < args.freeze_epochs:
            print("Freezing encoder!!!")
            model.module.encoder.eval()
            for p in model.module.encoder.parameters():
                p.requires_grad = False
        else:
            model.module.encoder.train()
            for p in model.module.encoder.parameters():
                p.requires_grad = True

        train_data_loader = DataLoader(data_train,
                                       batch_size=batch_size,
                                       num_workers=args.workers,
                                       shuffle=train_sampler is None,
                                       sampler=train_sampler,
                                       pin_memory=False,
                                       drop_last=True)

        train_data_loader_student = DataLoader(data_train_student,
                                               batch_size=batch_size,
                                               num_workers=args.workers,
                                               shuffle=train_sampler is None,
                                               sampler=train_sampler,
                                               pin_memory=False,
                                               drop_last=True)

        train_epoch(current_epoch, loss_functions, model, teacher_model,
                    optimizer, scheduler, train_data_loader,
                    train_data_loader_student, summary_writer, conf,
                    args.local_rank, args.only_changed_frames)
        model = model.eval()

        if args.local_rank == 0:
            torch.save(
                {
                    'epoch': current_epoch + 1,
                    'state_dict': model.state_dict(),
                    'bce_best': bce_best,
                }, args.output_dir + '/' + snapshot_name + "_last")
            torch.save(
                {
                    'epoch': current_epoch + 1,
                    'state_dict': model.state_dict(),
                    'bce_best': bce_best,
                },
                args.output_dir + snapshot_name + "_{}".format(current_epoch))
            if (epoch + 1) % args.test_every == 0:
                bce_best = evaluate_val(args,
                                        val_data_loader_student,
                                        bce_best,
                                        model,
                                        snapshot_name=snapshot_name,
                                        current_epoch=current_epoch,
                                        summary_writer=summary_writer)
        current_epoch += 1
Beispiel #4
0
loss_fn = OHEMLoss(ignore_index=255, numel_frac=0.05)
loss_fn = loss_fn.cuda()


scheduler = CosineAnnealingScheduler(
    optimizer, 'lr',
    args.learning_rate, 1e-6,
    cycle_size=args.epochs * len(train_loader),
)
scheduler = create_lr_scheduler_with_warmup(
    scheduler, 0, args.learning_rate, 1000)


model, optimizer = amp.initialize(model, optimizer, opt_level="O2")
if args.distributed:
    model = convert_syncbn_model(model)
    model = DistributedDataParallel(model)


trainer = create_segmentation_trainer(
    model, optimizer, loss_fn,
    device=device,
    use_f16=True,
)
trainer.add_event_handler(Events.ITERATION_COMPLETED, scheduler)


evaluator = create_segmentation_evaluator(
    model,
    device=device,
    num_classes=19,
        train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
        dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, sampler=train_sampler)
    else:
        # load the dataset using structure DataLoader (part of torch.utils.data)
        dataloader = DataLoader(dataset=dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True)

    # instantiate Generator(nn.Module) and load in cpu/gpu
    generator = Generator().to(device)
    
    ## DEFINE CHECKPOINT
    # checkpoints are used during training to save a model (model parameters I suppose)
    # here we are only testing the pre-trained model, thus we load (torch.load) the model
    if args.distributed:
        g_checkpoint = torch.load(args.checkpoint_path, map_location = lambda storage, loc: storage.cuda(args.local_rank))
        generator = parallel.DistributedDataParallel(generator)
        generator = parallel.convert_syncbn_model(generator)
    else:
        g_checkpoint = torch.load(args.checkpoint_path)
    
    generator.load_state_dict(g_checkpoint['model_state_dict'], strict=False)
    step = g_checkpoint['step']
    alpha = g_checkpoint['alpha']
    iteration = g_checkpoint['iteration']
    print('pre-trained model is loaded step:%d, alpha:%d iteration:%d'%(step, alpha, iteration))
    MSE_Loss = nn.MSELoss()

    # notify all layers that you are in eval mode instead of training mode
    generator.eval()

    test(dataloader, generator, MSE_Loss, step, alpha)
Beispiel #6
0
def main(opt):
    """
    Trains SRVP and saved the resulting model.

    Parameters
    ----------
    opt : helper.DotDict
        Contains the training configuration.
    """
    ##################################################################################################################
    # Setup
    ##################################################################################################################
    # Device handling (CPU, GPU, multi GPU)
    if opt.device is None:
        device = torch.device('cpu')
        opt.n_gpu = 0
    else:
        opt.n_gpu = len(opt.device)
        os.environ['CUDA_VISIBLE_DEVICES'] = str(opt.device[opt.local_rank])
        device = torch.device('cuda:0')
        torch.cuda.set_device(0)
        # In the case of multi GPU: sets up distributed training
        if opt.n_gpu > 1 or opt.local_rank > 0:
            torch.distributed.init_process_group(backend='nccl')
            # Since we are in distributed mode, divide batch size by the number of GPUs
            assert opt.batch_size % opt.n_gpu == 0
            opt.batch_size = opt.batch_size // opt.n_gpu
    # Seed
    if opt.seed is None:
        opt.seed = random.randint(1, 10000)
    else:
        assert isinstance(opt.seed, int) and opt.seed > 0
    print(f'Learning on {opt.n_gpu} GPU(s) (seed: {opt.seed})')
    random.seed(opt.seed)
    np.random.seed(opt.seed + opt.local_rank)
    torch.manual_seed(opt.seed)
    # cuDNN
    if opt.n_gpu > 1 or opt.local_rank > 0:
        assert torch.backends.cudnn.enabled
        cudnn.deterministic = True
    # Mixed-precision training
    if opt.torch_amp and not torch_amp_imported:
        raise ImportError(
            'Mixed-precision not supported by this PyTorch version, upgrade PyTorch or use Apex'
        )
    if opt.apex_amp and not apex_amp_imported:
        raise ImportError(
            'Apex not installed (https://github.com/NVIDIA/apex)')

    ##################################################################################################################
    # Data
    ##################################################################################################################
    print('Loading data...')
    # Load data
    dataset = data.load_dataset(opt, True)
    trainset = dataset.get_fold('train')
    valset = dataset.get_fold('val')
    # Change validation sequence length, if specified
    if opt.seq_len_test is not None:
        valset.change_seq_len(opt.seq_len_test)

    # Handle random seed for dataloader workers
    def worker_init_fn(worker_id):
        np.random.seed(
            (opt.seed + itr + opt.local_rank * opt.n_workers + worker_id) %
            (2**32 - 1))

    # Dataloader
    sampler = None
    shuffle = True
    if opt.n_gpu > 1:
        # Let the distributed sampler shuffle for the distributed case
        sampler = torch.utils.data.distributed.DistributedSampler(trainset)
        shuffle = False
    train_loader = DataLoader(trainset,
                              batch_size=opt.batch_size,
                              collate_fn=data.collate_fn,
                              sampler=sampler,
                              num_workers=opt.n_workers,
                              shuffle=shuffle,
                              drop_last=True,
                              pin_memory=True,
                              worker_init_fn=worker_init_fn)
    val_loader = DataLoader(
        valset,
        batch_size=opt.batch_size_test,
        collate_fn=data.collate_fn,
        num_workers=opt.n_workers,
        shuffle=True,
        drop_last=True,
        pin_memory=True,
        worker_init_fn=worker_init_fn) if opt.local_rank == 0 else None

    ##################################################################################################################
    # Model
    ##################################################################################################################
    # Buid model
    print('Building model...')
    model = srvp.StochasticLatentResidualVideoPredictor(
        opt.nx, opt.nc, opt.nf, opt.nhx, opt.ny, opt.nz, opt.skipco,
        opt.nt_inf, opt.nh_inf, opt.nlayers_inf, opt.nh_res, opt.nlayers_res,
        opt.archi)
    model.init(res_gain=opt.res_gain)
    # Make the batch norms in the model synchronized in the distributed case
    if opt.n_gpu > 1:
        if opt.apex_amp:
            from apex.parallel import convert_syncbn_model
            model = convert_syncbn_model(model)
        else:
            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
    model.to(device)

    ##################################################################################################################
    # Optimizer
    ##################################################################################################################
    optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr)
    opt.n_iter = opt.lr_scheduling_burnin + opt.lr_scheduling_n_iter
    lr_sch_n_iter = opt.lr_scheduling_n_iter
    lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer,
        lr_lambda=lambda i: max(0, (lr_sch_n_iter - i) / lr_sch_n_iter))

    ##################################################################################################################
    # Automatic Mixed Precision
    ##################################################################################################################
    scaler = None
    if opt.torch_amp:
        scaler = torch_amp.GradScaler()
    if opt.apex_amp:
        model, optimizer = apex_amp.initialize(
            model,
            optimizer,
            opt_level=opt.amp_opt_lvl,
            keep_batchnorm_fp32=opt.keep_batchnorm_fp32,
            verbosity=opt.apex_verbose)

    ##################################################################################################################
    # Multi GPU
    ##################################################################################################################
    if opt.n_gpu > 1:
        if opt.apex_amp:
            from apex.parallel import DistributedDataParallel
            forward_fn = DistributedDataParallel(model)
        else:
            forward_fn = torch.nn.parallel.DistributedDataParallel(model)
    else:
        forward_fn = model

    ##################################################################################################################
    # Training
    ##################################################################################################################
    cudnn.benchmark = True  # Activate benchmarks to select the fastest algorithms
    assert opt.n_iter > 0
    itr = 0
    finished = False
    # Progress bar
    if opt.local_rank == 0:
        pb = tqdm(total=opt.n_iter, ncols=0)
    # Current and best model evaluation metric (lower is better)
    val_metric = None
    best_val_metric = None
    try:
        while not finished:
            if sampler is not None:
                sampler.set_epoch(opt.seed + itr)
            # -------- TRAIN --------
            for batch in train_loader:
                # Stop when the given number of optimization steps have been done
                if itr >= opt.n_iter:
                    finished = True
                    status_code = 0
                    break

                itr += 1
                model.train()
                # Optimization step on batch
                # Allow PyTorch's mixed-precision computations if required while ensuring retrocompatibilty
                with (torch_amp.autocast()
                      if opt.torch_amp else nullcontext()):
                    loss, nll, kl_y_0, kl_z = train(forward_fn, optimizer,
                                                    scaler, batch, device, opt)

                # Learning rate scheduling
                if itr >= opt.lr_scheduling_burnin:
                    lr_scheduler.step()

                # Evaluation and model saving are performed on the process with local rank zero
                if opt.local_rank == 0:
                    # Evaluation
                    if itr % opt.val_interval == 0:
                        model.eval()
                        val_metric = evaluate(forward_fn, val_loader, device,
                                              opt)
                        if best_val_metric is None or best_val_metric > val_metric:
                            best_val_metric = val_metric
                            torch.save(
                                model.state_dict(),
                                os.path.join(opt.save_path, 'model_best.pt'))

                    # Checkpointing
                    if opt.chkpt_interval is not None and itr % opt.chkpt_interval == 0:
                        torch.save(
                            model.state_dict(),
                            os.path.join(opt.save_path, f'model_{itr}.pt'))

                # Progress bar
                if opt.local_rank == 0:
                    pb.set_postfix(
                        {
                            'loss': loss,
                            'nll': nll,
                            'kl_y_0': kl_y_0,
                            'kl_z': kl_z,
                            'val_metric': val_metric,
                            'best_val_metric': best_val_metric
                        },
                        refresh=False)
                    pb.update()

    except KeyboardInterrupt:
        status_code = 130

    if opt.local_rank == 0:
        pb.close()
    # Save model
    print('Saving...')
    if opt.local_rank == 0:
        torch.save(model.state_dict(), os.path.join(opt.save_path, 'model.pt'))
    print('Done')
    return status_code
Beispiel #7
0
def main():
    setup_default_logging()
    args = parser.parse_args()
    args.prefetcher = not args.no_prefetcher
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1
        if args.distributed and args.num_gpu > 1:
            logging.warning(
                'Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.'
            )
            args.num_gpu = 1

    args.device = 'cuda:0'
    args.world_size = 1
    args.rank = 0  # global rank
    if args.distributed:
        args.num_gpu = 1
        args.device = 'cuda:%d' % args.local_rank
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()
        args.rank = torch.distributed.get_rank()
    assert args.rank >= 0

    if args.distributed:
        logging.info(
            'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
            % (args.rank, args.world_size))
    else:
        logging.info('Training with a single process on %d GPUs.' %
                     args.num_gpu)

    torch.manual_seed(args.seed + args.rank)

    model = create_model(args.model,
                         pretrained=args.pretrained,
                         num_classes=args.num_classes,
                         drop_rate=args.drop,
                         global_pool=args.gp,
                         bn_tf=args.bn_tf,
                         bn_momentum=args.bn_momentum,
                         bn_eps=args.bn_eps,
                         checkpoint_path=args.initial_checkpoint)

    if args.local_rank == 0:
        logging.info('Model %s created, param count: %d' %
                     (args.model, sum([m.numel()
                                       for m in model.parameters()])))

    data_config = resolve_data_config(model,
                                      args,
                                      verbose=args.local_rank == 0)

    # optionally resume from a checkpoint
    start_epoch = 0
    optimizer_state = None
    if args.resume:
        optimizer_state, start_epoch = resume_checkpoint(
            model, args.resume, args.start_epoch)

    if args.num_gpu > 1:
        if args.amp:
            logging.warning(
                'AMP does not work well with nn.DataParallel, disabling. Use distributed mode for multi-GPU AMP.'
            )
            args.amp = False
        model = nn.DataParallel(model,
                                device_ids=list(range(args.num_gpu))).cuda()
    else:
        model.cuda()

    optimizer = create_optimizer(args, model)
    if optimizer_state is not None:
        optimizer.load_state_dict(optimizer_state)

    use_amp = False
    if has_apex and args.amp:
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
        use_amp = True
    if args.local_rank == 0:
        logging.info('NVIDIA APEX {}. AMP {}.'.format(
            'installed' if has_apex else 'not installed',
            'on' if use_amp else 'off'))

    model_ema = None
    if args.model_ema:
        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
        model_ema = ModelEma(model,
                             decay=args.model_ema_decay,
                             device='cpu' if args.model_ema_force_cpu else '',
                             resume=args.resume)

    if args.distributed:
        if args.sync_bn:
            try:
                if has_apex:
                    model = convert_syncbn_model(model)
                else:
                    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                        model)
                if args.local_rank == 0:
                    logging.info(
                        'Converted model to use Synchronized BatchNorm.')
            except Exception as e:
                logging.error(
                    'Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1'
                )
        if has_apex:
            model = DDP(model, delay_allreduce=True)
        else:
            if args.local_rank == 0:
                logging.info(
                    "Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP."
                )
            model = DDP(model,
                        device_ids=[args.local_rank
                                    ])  # can use device str in Torch >= 1.1
        # NOTE: EMA model does not need to be wrapped by DDP

    lr_scheduler, num_epochs = create_scheduler(args, optimizer)
    if start_epoch > 0:
        lr_scheduler.step(start_epoch)
    if args.local_rank == 0:
        logging.info('Scheduled epochs: {}'.format(num_epochs))

    train_dir = os.path.join(args.data, 'train')
    if not os.path.exists(train_dir):
        logging.error(
            'Training folder does not exist at: {}'.format(train_dir))
        exit(1)
    dataset_train = Dataset(train_dir)

    collate_fn = None
    if args.prefetcher and args.mixup > 0:
        collate_fn = FastCollateMixup(args.mixup, args.smoothing,
                                      args.num_classes)

    loader_train = create_loader(
        dataset_train,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        is_training=True,
        use_prefetcher=args.prefetcher,
        rand_erase_prob=args.reprob,
        rand_erase_mode=args.remode,
        interpolation=
        'random',  # FIXME cleanly resolve this? data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        collate_fn=collate_fn,
    )

    eval_dir = os.path.join(args.data, 'validation')
    if not os.path.isdir(eval_dir):
        logging.error(
            'Validation folder does not exist at: {}'.format(eval_dir))
        exit(1)
    dataset_eval = Dataset(eval_dir)

    loader_eval = create_loader(
        dataset_eval,
        input_size=data_config['input_size'],
        batch_size=4 * args.batch_size,
        is_training=False,
        use_prefetcher=args.prefetcher,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
    )

    if args.mixup > 0.:
        # smoothing is handled with mixup label transform
        train_loss_fn = SoftTargetCrossEntropy().cuda()
        validate_loss_fn = nn.CrossEntropyLoss().cuda()
    elif args.smoothing:
        train_loss_fn = LabelSmoothingCrossEntropy(
            smoothing=args.smoothing).cuda()
        validate_loss_fn = nn.CrossEntropyLoss().cuda()
    else:
        train_loss_fn = nn.CrossEntropyLoss().cuda()
        validate_loss_fn = train_loss_fn

    eval_metric = args.eval_metric
    best_metric = None
    best_epoch = None
    saver = None
    output_dir = ''
    if args.local_rank == 0:
        output_base = args.output if args.output else './output'
        exp_name = '-'.join([
            datetime.now().strftime("%Y%m%d-%H%M%S"), args.model,
            str(data_config['input_size'][-1])
        ])
        output_dir = get_outdir(output_base, 'train', exp_name)
        decreasing = True if eval_metric == 'loss' else False
        saver = CheckpointSaver(checkpoint_dir=output_dir,
                                decreasing=decreasing)

    try:
        for epoch in range(start_epoch, num_epochs):
            if args.distributed:
                loader_train.sampler.set_epoch(epoch)

            train_metrics = train_epoch(epoch,
                                        model,
                                        loader_train,
                                        optimizer,
                                        train_loss_fn,
                                        args,
                                        lr_scheduler=lr_scheduler,
                                        saver=saver,
                                        output_dir=output_dir,
                                        use_amp=use_amp,
                                        model_ema=model_ema)

            eval_metrics = validate(model, loader_eval, validate_loss_fn, args)

            if model_ema is not None and not args.model_ema_force_cpu:
                ema_eval_metrics = validate(model_ema.ema,
                                            loader_eval,
                                            validate_loss_fn,
                                            args,
                                            log_suffix=' (EMA)')
                eval_metrics = ema_eval_metrics

            if lr_scheduler is not None:
                lr_scheduler.step(epoch, eval_metrics[eval_metric])

            update_summary(epoch,
                           train_metrics,
                           eval_metrics,
                           os.path.join(output_dir, 'summary.csv'),
                           write_header=best_metric is None)

            if saver is not None:
                # save proper checkpoint with eval metric
                save_metric = eval_metrics[eval_metric]
                best_metric, best_epoch = saver.save_checkpoint(
                    model,
                    optimizer,
                    args,
                    epoch=epoch + 1,
                    model_ema=model_ema,
                    metric=save_metric)

    except KeyboardInterrupt:
        pass
    if best_metric is not None:
        logging.info('*** Best metric: {0} (epoch {1})'.format(
            best_metric, best_epoch))
Beispiel #8
0
def train():
    if args.local_rank == 0:
        logger.info('Initializing....')
    cudnn.enable = True
    cudnn.benchmark = True
    # torch.manual_seed(1)
    # torch.cuda.manual_seed(1)
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1
    args.gpu = 0
    args.world_size = 1
    if args.distributed:
        args.gpu = args.local_rank
        torch.cuda.set_device(args.gpu)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()

    if args.local_rank == 0:
        write_config_into_log(cfg)

    if args.local_rank == 0:
        logger.info('Building model......')
    if cfg.pretrained:
        model = make_model(cfg)
        model.load_param(cfg)
        if args.local_rank == 0:
            logger.info('Loaded pretrained model from {0}'.format(
                cfg.pretrained))
    else:
        model = make_model(cfg)

    if args.sync_bn:
        if args.local_rank == 0: logging.info("using apex synced BN")
        model = convert_syncbn_model(model)
    model.cuda()
    if args.distributed:
        # By default, apex.parallel.DistributedDataParallel overlaps communication with
        # computation in the backward pass.
        # delay_allreduce delays all communication to the end of the backward pass.
        model = DistributedDataParallel(model, delay_allreduce=True)
    else:
        model = torch.nn.DataParallel(model)

    optimizer = torch.optim.Adam(
        [{
            'params': model.module.base.parameters(),
            'lr': cfg.get_lr(0)[0]
        }, {
            'params': model.module.classifiers.parameters(),
            'lr': cfg.get_lr(0)[1]
        }],
        weight_decay=cfg.weight_decay)

    celoss = nn.CrossEntropyLoss().cuda()
    softloss = SoftLoss()
    sp_kd_loss = SP_KD_Loss()
    criterions = [celoss, softloss, sp_kd_loss]

    cfg.batch_size = cfg.batch_size // args.world_size
    cfg.num_workers = cfg.num_workers // args.world_size
    train_loader, val_loader = make_dataloader(cfg)

    if args.local_rank == 0:
        logger.info('Begin training......')
    for epoch in range(cfg.start_epoch, cfg.max_epoch):
        train_one_epoch(train_loader, val_loader, model, criterions, optimizer,
                        epoch, cfg)

        total_acc = test(cfg, val_loader, model, epoch)
        if args.local_rank == 0:
            with open(cfg.test_log, 'a+') as f:
                f.write('Epoch {0}: Acc is {1:.4f}\n'.format(epoch, total_acc))
            torch.save(obj=model.state_dict(),
                       f=os.path.join(
                           cfg.snapshot_dir,
                           'ep{}_acc{:.4f}.pth'.format(epoch, total_acc)))
            logger.info('Model saved')
Beispiel #9
0
def main():
    global n_eval_epoch

    ## dataloader
    dataset_train = ImageNet(datapth, mode='train', cropsize=cropsize)
    sampler_train = torch.utils.data.distributed.DistributedSampler(
        dataset_train, shuffle=True)
    batch_sampler_train = torch.utils.data.sampler.BatchSampler(
        sampler_train, batchsize, drop_last=True
    )
    dl_train = DataLoader(
        dataset_train, batch_sampler=batch_sampler_train, num_workers=num_workers, pin_memory=True
    )
    dataset_eval = ImageNet(datapth, mode='val', cropsize=cropsize)
    sampler_val = torch.utils.data.distributed.DistributedSampler(
        dataset_eval, shuffle=False)
    batch_sampler_val = torch.utils.data.sampler.BatchSampler(
        sampler_val, batchsize * 2, drop_last=False
    )
    dl_eval = DataLoader(
        dataset_eval, batch_sampler=batch_sampler_val,
        num_workers=4, pin_memory=True
    )
    n_iters_per_epoch = len(dataset_train) // n_gpus // batchsize
    n_iters = n_epoches * n_iters_per_epoch


    ## model
    #  model = EfficientNet(model_type, n_classes)
    model = build_model(**model_args)


    ## sync bn
    #  if use_sync_bn: model = nn.SyncBatchNorm.convert_sync_batchnorm(model)

    init_model_weights(model)
    model.cuda()
    if use_sync_bn: model = parallel.convert_syncbn_model(model)
    crit = nn.CrossEntropyLoss()
    #  crit = LabelSmoothSoftmaxCEV3(lb_smooth)
    #  crit = SoftmaxCrossEntropyV2()

    ## optimizer
    optim = set_optimizer(model, lr, opt_wd, momentum, nesterov=nesterov)

    ## apex
    model, optim = amp.initialize(model, optim, opt_level=fp16_level)

    ## ema
    ema = EMA(model, ema_alpha)

    ## ddp training
    model = parallel.DistributedDataParallel(model, delay_allreduce=True)
    #  local_rank = dist.get_rank()
    #  model = nn.parallel.DistributedDataParallel(
    #      model, device_ids=[local_rank, ], output_device=local_rank
    #  )

    ## log meters
    time_meter = TimeMeter(n_iters)
    loss_meter = AvgMeter()
    logger = logging.getLogger()


    # for mixup
    label_encoder = OnehotEncoder(n_classes=model_args['n_classes'], lb_smooth=lb_smooth)
    mixuper = MixUper(mixup_alpha, mixup=mixup)

    ## train loop
    for e in range(n_epoches):
        sampler_train.set_epoch(e)
        model.train()
        for idx, (im, lb) in enumerate(dl_train):
            im, lb= im.cuda(), lb.cuda()
            #  lb = label_encoder(lb)
            #  im, lb = mixuper(im, lb)
            optim.zero_grad()
            logits = model(im)
            loss = crit(logits, lb) #+ cal_l2_loss(model, weight_decay)
            #  loss.backward()
            with amp.scale_loss(loss, optim) as scaled_loss:
                scaled_loss.backward()
            optim.step()
            torch.cuda.synchronize()
            ema.update_params()
            time_meter.update()
            loss_meter.update(loss.item())
            if (idx + 1) % 200 == 0:
                t_intv, eta = time_meter.get()
                lr_log = scheduler.get_lr_ratio() * lr
                msg = 'epoch: {}, iter: {}, lr: {:.4f}, loss: {:.4f}, time: {:.2f}, eta: {}'.format(
                    e + 1, idx + 1, lr_log, loss_meter.get()[0], t_intv, eta)
                logger.info(msg)
            scheduler.step()
        torch.cuda.empty_cache()
        if (e + 1) % n_eval_epoch == 0:
            if e > 50: n_eval_epoch = 5
            logger.info('evaluating...')
            acc_1, acc_5, acc_1_ema, acc_5_ema = evaluate(ema, dl_eval)
            msg = 'epoch: {}, naive_acc1: {:.4}, naive_acc5: {:.4}, ema_acc1: {:.4}, ema_acc5: {:.4}'.format(e + 1, acc_1, acc_5, acc_1_ema, acc_5_ema)
            logger.info(msg)
    if dist.is_initialized() and dist.get_rank() == 0:
        torch.save(model.module.state_dict(), './res/model_final.pth')
        torch.save(ema.ema_model.state_dict(), './res/model_final_ema.pth')
Beispiel #10
0
    def _init_network(self, **kwargs):
        load_only = kwargs.get('load_only', False)
        if not self.num_class and self._problem_type != REGRESSION:
            raise ValueError(
                'This is a classification problem and we are not able to create network when `num_class` is unknown. \
                It should be inferred from dataset or resumed from saved states.'
            )
        assert len(self.classes) == self.num_class

        # Disable syncBatchNorm as it's only supported on DDP
        if self._train_cfg.sync_bn:
            self._logger.info(
                'Disable Sync batch norm as it is not supported for now.')
            update_cfg(self._cfg, {'train': {'sync_bn': False}})

        # ctx
        self.found_gpu = False
        valid_gpus = []
        if self._cfg.gpus:
            valid_gpus = self._torch_validate_gpus(self._cfg.gpus)
            self.found_gpu = True
            if not valid_gpus:
                self.found_gpu = False
                self._logger.warning(
                    'No gpu detected, fallback to cpu. You can ignore this warning if this is intended.'
                )
            elif len(valid_gpus) != len(self._cfg.gpus):
                self._logger.warning(
                    f'Loaded on gpu({valid_gpus}), different from gpu({self._cfg.gpus}).'
                )
        self.ctx = [torch.device(f'cuda:{gid}') for gid in valid_gpus
                    ] if self.found_gpu else [torch.device('cpu')]
        self.valid_gpus = valid_gpus

        if not self.found_gpu and self.use_amp:
            self.use_amp = None
            self._logger.warning('Training on cpu. AMP disabled.')
            update_cfg(self._cfg, {
                'misc': {
                    'amp': False,
                    'apex_amp': False,
                    'native_amp': False
                }
            })

        if not self.found_gpu and self._misc_cfg.prefetcher:
            self._logger.warning('Training on cpu. Prefetcher disabled.')
            update_cfg(self._cfg, {'misc': {'prefetcher': False}})
            self._logger.warning('Training on cpu. SyncBatchNorm disabled.')
            update_cfg(self._cfg, {'train': {'sync_bn': False}})

        random_seed(self._misc_cfg.seed)

        if not self.net:
            self.net = create_model(
                self._img_cls_cfg.model,
                pretrained=self._img_cls_cfg.pretrained and not load_only,
                num_classes=max(self.num_class, 1),
                global_pool=self._img_cls_cfg.global_pool_type,
                drop_rate=self._augmentation_cfg.drop,
                drop_path_rate=self._augmentation_cfg.drop_path,
                drop_block_rate=self._augmentation_cfg.drop_block,
                bn_momentum=self._train_cfg.bn_momentum,
                bn_eps=self._train_cfg.bn_eps,
                scriptable=self._misc_cfg.torchscript)

            self._logger.info(
                f'Model {safe_model_name(self._img_cls_cfg.model)} created, param count: \
                                        {sum([m.numel() for m in self.net.parameters()])}'
            )
        else:
            self._logger.info(
                f'Use user provided model. Neglect model in config.')
            out_features = list(self.net.children())[-1].out_features
            if self._problem_type != REGRESSION:
                assert out_features == self.num_class, f'Custom model out_feature {out_features} != num_class {self.num_class}.'
            else:
                assert out_features == 1, f'Regression problem expects num_out_feature == 1, got {out_features} instead.'

        resolve_data_config(self._cfg, model=self.net)

        self.net = self.net.to(self.ctx[0])

        # setup synchronized BatchNorm
        if self._train_cfg.sync_bn:
            if has_apex and self.use_amp != 'native':
                # Apex SyncBN preferred unless native amp is activated
                self.net = convert_syncbn_model(self.net)
            else:
                self.net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                    self.net)
            self._logger.info(
                'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
                'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.'
            )

        if self._misc_cfg.torchscript:
            assert not self.use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
            assert not self._train_cfg.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
            self.net = torch.jit.script(self.net)
Beispiel #11
0
def main():
    args, cfg = parse_config_args('nni.cream.supernet')

    # resolve logging
    output_dir = os.path.join(
        cfg.SAVE_PATH, "{}-{}".format(datetime.date.today().strftime('%m%d'),
                                      cfg.MODEL))
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)

    if args.local_rank == 0:
        logger = get_logger(os.path.join(output_dir, "train.log"))
    else:
        logger = None

    # initialize distributed parameters
    torch.cuda.set_device(args.local_rank)
    torch.distributed.init_process_group(backend='nccl', init_method='env://')
    if args.local_rank == 0:
        logger.info('Training on Process %d with %d GPUs.', args.local_rank,
                    cfg.NUM_GPU)

    # fix random seeds
    torch.manual_seed(cfg.SEED)
    torch.cuda.manual_seed_all(cfg.SEED)
    np.random.seed(cfg.SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # generate supernet
    model, sta_num, resolution = gen_supernet(
        flops_minimum=cfg.SUPERNET.FLOPS_MINIMUM,
        flops_maximum=cfg.SUPERNET.FLOPS_MAXIMUM,
        num_classes=cfg.DATASET.NUM_CLASSES,
        drop_rate=cfg.NET.DROPOUT_RATE,
        global_pool=cfg.NET.GP,
        resunit=cfg.SUPERNET.RESUNIT,
        dil_conv=cfg.SUPERNET.DIL_CONV,
        slice=cfg.SUPERNET.SLICE,
        verbose=cfg.VERBOSE,
        logger=logger)

    # number of choice blocks in supernet
    choice_num = len(model.blocks[7])
    if args.local_rank == 0:
        logger.info('Supernet created, param count: %d',
                    (sum([m.numel() for m in model.parameters()])))
        logger.info('resolution: %d', (resolution))
        logger.info('choice number: %d', (choice_num))

    # initialize flops look-up table
    model_est = FlopsEst(model)
    flops_dict, flops_fixed = model_est.flops_dict, model_est.flops_fixed

    # optionally resume from a checkpoint
    optimizer_state = None
    resume_epoch = None
    if cfg.AUTO_RESUME:
        optimizer_state, resume_epoch = resume_checkpoint(
            model, cfg.RESUME_PATH)

    # create optimizer and resume from checkpoint
    optimizer = create_optimizer_supernet(cfg, model, USE_APEX)
    if optimizer_state is not None:
        optimizer.load_state_dict(optimizer_state['optimizer'])
    model = model.cuda()

    # convert model to distributed mode
    if cfg.BATCHNORM.SYNC_BN:
        try:
            if USE_APEX:
                model = convert_syncbn_model(model)
            else:
                model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
            if args.local_rank == 0:
                logger.info('Converted model to use Synchronized BatchNorm.')
        except Exception as exception:
            logger.info(
                'Failed to enable Synchronized BatchNorm. '
                'Install Apex or Torch >= 1.1 with Exception %s', exception)
    if USE_APEX:
        model = DDP(model, delay_allreduce=True)
    else:
        if args.local_rank == 0:
            logger.info(
                "Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP."
            )
        # can use device str in Torch >= 1.1
        model = DDP(model, device_ids=[args.local_rank])

    # create learning rate scheduler
    lr_scheduler, num_epochs = create_supernet_scheduler(cfg, optimizer)

    start_epoch = resume_epoch if resume_epoch is not None else 0
    if start_epoch > 0:
        lr_scheduler.step(start_epoch)

    if args.local_rank == 0:
        logger.info('Scheduled epochs: %d', num_epochs)

    # imagenet train dataset
    train_dir = os.path.join(cfg.DATA_DIR, 'train')
    if not os.path.exists(train_dir):
        logger.info('Training folder does not exist at: %s', train_dir)
        sys.exit()

    dataset_train = Dataset(train_dir)
    loader_train = create_loader(dataset_train,
                                 input_size=(3, cfg.DATASET.IMAGE_SIZE,
                                             cfg.DATASET.IMAGE_SIZE),
                                 batch_size=cfg.DATASET.BATCH_SIZE,
                                 is_training=True,
                                 use_prefetcher=True,
                                 re_prob=cfg.AUGMENTATION.RE_PROB,
                                 re_mode=cfg.AUGMENTATION.RE_MODE,
                                 color_jitter=cfg.AUGMENTATION.COLOR_JITTER,
                                 interpolation='random',
                                 num_workers=cfg.WORKERS,
                                 distributed=True,
                                 collate_fn=None,
                                 crop_pct=DEFAULT_CROP_PCT,
                                 mean=IMAGENET_DEFAULT_MEAN,
                                 std=IMAGENET_DEFAULT_STD)

    # imagenet validation dataset
    eval_dir = os.path.join(cfg.DATA_DIR, 'val')
    if not os.path.isdir(eval_dir):
        logger.info('Validation folder does not exist at: %s', eval_dir)
        sys.exit()
    dataset_eval = Dataset(eval_dir)
    loader_eval = create_loader(dataset_eval,
                                input_size=(3, cfg.DATASET.IMAGE_SIZE,
                                            cfg.DATASET.IMAGE_SIZE),
                                batch_size=4 * cfg.DATASET.BATCH_SIZE,
                                is_training=False,
                                use_prefetcher=True,
                                num_workers=cfg.WORKERS,
                                distributed=True,
                                crop_pct=DEFAULT_CROP_PCT,
                                mean=IMAGENET_DEFAULT_MEAN,
                                std=IMAGENET_DEFAULT_STD,
                                interpolation=cfg.DATASET.INTERPOLATION)

    # whether to use label smoothing
    if cfg.AUGMENTATION.SMOOTHING > 0.:
        train_loss_fn = LabelSmoothingCrossEntropy(
            smoothing=cfg.AUGMENTATION.SMOOTHING).cuda()
        validate_loss_fn = nn.CrossEntropyLoss().cuda()
    else:
        train_loss_fn = nn.CrossEntropyLoss().cuda()
        validate_loss_fn = train_loss_fn

    mutator = RandomMutator(model)

    trainer = CreamSupernetTrainer(model,
                                   train_loss_fn,
                                   validate_loss_fn,
                                   optimizer,
                                   num_epochs,
                                   loader_train,
                                   loader_eval,
                                   mutator=mutator,
                                   batch_size=cfg.DATASET.BATCH_SIZE,
                                   log_frequency=cfg.LOG_INTERVAL,
                                   meta_sta_epoch=cfg.SUPERNET.META_STA_EPOCH,
                                   update_iter=cfg.SUPERNET.UPDATE_ITER,
                                   slices=cfg.SUPERNET.SLICE,
                                   pool_size=cfg.SUPERNET.POOL_SIZE,
                                   pick_method=cfg.SUPERNET.PICK_METHOD,
                                   choice_num=choice_num,
                                   sta_num=sta_num,
                                   acc_gap=cfg.ACC_GAP,
                                   flops_dict=flops_dict,
                                   flops_fixed=flops_fixed,
                                   local_rank=args.local_rank,
                                   callbacks=[
                                       LRSchedulerCallback(lr_scheduler),
                                       ModelCheckpoint(output_dir)
                                   ])

    trainer.train()
Beispiel #12
0
def main():
    logger = None
    output_dir = ''
    setup_default_logging()
    args = parser.parse_args()
    args.prefetcher = not args.no_prefetcher
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1
        if args.distributed and args.num_gpu > 1:
            logging.warning(
                'Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.'
            )
            args.num_gpu = 1

    args.device = 'cuda:0'
    args.world_size = 1
    args.rank = 0  # global rank
    if args.distributed:
        args.num_gpu = 1
        args.device = 'cuda:%d' % args.local_rank
        torch.cuda.set_device(args.local_rank)
        import random
        port = random.randint(0, 50000)
        torch.distributed.init_process_group(
            backend='nccl', init_method='env://'
        )  # tcp://127.0.0.1:{}'.format(port), rank=args.local_rank, world_size=8)
        args.world_size = torch.distributed.get_world_size()
        args.rank = torch.distributed.get_rank()
    assert args.rank >= 0

    if args.distributed:
        logging.info(
            'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
            % (args.rank, args.world_size))
    else:
        logging.info('Training with a single process on %d GPUs.' %
                     args.num_gpu)

    seed = args.seed
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    model, sta_num, size_factor = _gen_supernet_large(
        num_classes=args.num_classes,
        drop_rate=args.drop,
        global_pool=args.gp,
        resunit=args.resunit,
        dil_conv=args.dil_conv,
        slice=args.slice)

    if args.local_rank == 0:
        print("Model Searched Using FLOPs {}".format(size_factor * 32))

    data_config = resolve_data_config(vars(args),
                                      model=model,
                                      verbose=args.local_rank == 0)
    if args.local_rank == 0:
        '''output dir'''
        output_base = args.output if args.output else './experiments'
        exp_name = args.model
        output_dir = get_outdir(output_base, 'search', exp_name)
        log_file = os.path.join(output_dir, "search.log")
        logger = get_logger(log_file)
    if args.local_rank == 0:
        logger.info(args)

    choice_num = 6
    if args.resunit:
        choice_num += 1
    if args.dil_conv:
        choice_num += 2

    if args.local_rank == 0:
        logger.info("Choice_num: {}".format(choice_num))

    model_est = LatencyEst(model)

    if os.path.exists(args.initial_checkpoint):
        load_checkpoint(model, args.initial_checkpoint)

    if args.local_rank == 0:
        logger.info('Model %s created, param count: %d' %
                    (args.model, sum([m.numel() for m in model.parameters()])))

    # data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)

    # optionally resume from a checkpoint
    optimizer_state = None
    resume_epoch = None
    if args.resume:
        optimizer_state, resume_epoch = resume_checkpoint(model, args.resume)

    if args.num_gpu > 1:
        if args.amp:
            logging.warning(
                'AMP does not work well with nn.DataParallel, disabling. Use distributed mode for multi-GPU AMP.'
            )
            args.amp = False
        model = nn.DataParallel(model,
                                device_ids=list(range(args.num_gpu))).cuda()
    else:
        model.cuda()

    optimizer = create_optimizer_supernet(args, model)
    if optimizer_state is not None:
        optimizer.load_state_dict(optimizer_state['optimizer'])

    use_amp = False
    if has_apex and args.amp:
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
        use_amp = True
    if args.local_rank == 0:
        logger.info('NVIDIA APEX {}. AMP {}.'.format(
            'installed' if has_apex else 'not installed',
            'on' if use_amp else 'off'))

    model_ema = None
    if args.model_ema:
        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
        model_ema = ModelEma(model,
                             decay=args.model_ema_decay,
                             device='cpu' if args.model_ema_force_cpu else '',
                             resume=args.resume)

    if args.distributed:
        if args.sync_bn:
            try:
                if has_apex:
                    model = convert_syncbn_model(model)
                else:
                    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                        model)
                if args.local_rank == 0:
                    logger.info(
                        'Converted model to use Synchronized BatchNorm.')
            except Exception as e:
                logging.error(
                    'Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1'
                )
        if has_apex:
            model = DDP(model, delay_allreduce=True)
        else:
            if args.local_rank == 0:
                logger.info(
                    "Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP."
                )
            model = DDP(model,
                        device_ids=[args.local_rank],
                        find_unused_parameters=True
                        )  # can use device str in Torch >= 1.1
        # NOTE: EMA model does not need to be wrapped by DDP

    lr_scheduler, num_epochs = create_scheduler(args, optimizer)

    start_epoch = 0
    if args.start_epoch is not None:
        # a specified start_epoch will always override the resume epoch
        start_epoch = args.start_epoch
    elif resume_epoch is not None:
        start_epoch = resume_epoch
    if start_epoch > 0:
        lr_scheduler.step(start_epoch)

    if args.local_rank == 0:
        logger.info('Scheduled epochs: {}'.format(num_epochs))

    if args.tiny:
        from lib.dataset.tiny_imagenet import get_newimagenet
        [loader_train,
         loader_eval], [train_sampler, test_sampler
                        ] = get_newimagenet(args.data, args.batch_size)
    else:
        train_dir = os.path.join(args.data, 'train')
        if not os.path.exists(train_dir):
            logger.error(
                'Training folder does not exist at: {}'.format(train_dir))
            exit(1)
        dataset_train = Dataset(train_dir)

        collate_fn = None

        loader_train = create_loader(
            dataset_train,
            input_size=data_config['input_size'],
            batch_size=args.batch_size,
            is_training=True,
            use_prefetcher=args.prefetcher,
            re_prob=args.reprob,
            re_mode=args.remode,
            color_jitter=args.color_jitter,
            interpolation=
            'random',  # FIXME cleanly resolve this? data_config['interpolation'],
            mean=data_config['mean'],
            std=data_config['std'],
            num_workers=args.workers,
            distributed=args.distributed,
            collate_fn=collate_fn,
        )

        eval_dir = os.path.join(args.data, 'val')
        if not os.path.isdir(eval_dir):
            logger.error(
                'Validation folder does not exist at: {}'.format(eval_dir))
            exit(1)
        dataset_eval = Dataset(eval_dir)

        loader_eval = create_loader(
            dataset_eval,
            input_size=data_config['input_size'],
            batch_size=4 * args.batch_size,
            is_training=False,
            use_prefetcher=args.prefetcher,
            interpolation=data_config['interpolation'],
            mean=data_config['mean'],
            std=data_config['std'],
            num_workers=args.workers,
            distributed=args.distributed,
        )

    if args.smoothing:
        train_loss_fn = LabelSmoothingCrossEntropy(
            smoothing=args.smoothing).cuda()
        validate_loss_fn = nn.CrossEntropyLoss().cuda()
    else:
        train_loss_fn = nn.CrossEntropyLoss().cuda()
        validate_loss_fn = train_loss_fn

    eval_metric = args.eval_metric
    best_metric = None
    best_epoch = None
    saver = None
    best_children_pool = []
    if args.local_rank == 0:
        decreasing = True if eval_metric == 'loss' else False
        saver = CheckpointSaver(checkpoint_dir=output_dir,
                                decreasing=decreasing)
    try:
        for epoch in range(start_epoch, num_epochs):
            if args.distributed:
                if args.tiny:
                    train_sampler.set_epoch(epoch)
                else:
                    loader_train.sampler.set_epoch(epoch)
            '''2020.10.19 large_model=True !'''
            train_metrics = train_epoch(epoch,
                                        model,
                                        loader_train,
                                        optimizer,
                                        train_loss_fn,
                                        args,
                                        CHOICE_NUM=choice_num,
                                        lr_scheduler=lr_scheduler,
                                        saver=saver,
                                        output_dir=output_dir,
                                        logger=logger,
                                        val_loader=loader_eval,
                                        use_amp=use_amp,
                                        model_ema=model_ema,
                                        est=model_est,
                                        sta_num=sta_num,
                                        large_model=True)

            # eval_metrics = OrderedDict([('loss', 0.0), ('prec1', 0.0), ('prec5', 0.0)])
            eval_metrics = validate(model,
                                    loader_eval,
                                    validate_loss_fn,
                                    args,
                                    CHOICE_NUM=choice_num,
                                    sta_num=sta_num)

            update_summary(epoch,
                           train_metrics,
                           eval_metrics,
                           os.path.join(output_dir, 'summary.csv'),
                           write_header=best_metric is None)

            if saver is not None:
                # save proper checkpoint with eval metric
                save_metric = eval_metrics[eval_metric]
                best_metric, best_epoch = saver.save_checkpoint(
                    model,
                    optimizer,
                    args,
                    epoch=epoch,
                    model_ema=model_ema,
                    metric=save_metric)

    except KeyboardInterrupt:
        pass
    if best_metric is not None:
        logging.info('*** Best metric: {0} (epoch {1})'.format(
            best_metric, best_epoch))
Beispiel #13
0
    _, loader = get_train_loader(args.data_root,
                                 args.num_src,
                                 total_steps,
                                 args.batch_size, {
                                     'interval_scale': args.interval_scale,
                                     'resize_width': resize_width,
                                     'resize_height': resize_height,
                                     'crop_width': crop_width,
                                     'crop_height': crop_height
                                 },
                                 num_workers=args.num_workers)

    model = Model()
    model.cuda()
    model = apex_parallel.convert_syncbn_model(model)
    print('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()
             if p.requires_grad])))
    compute_loss = Loss()

    model = nn.DataParallel(model)

    if args.load_path is None:
        for m in model.modules():
            if any([
                    isinstance(m, T) for T in
                [nn.Conv2d, nn.Conv3d, nn.ConvTranspose2d, nn.ConvTranspose3d]
            ]):
                if m.weight.requires_grad:
                    nn.init.xavier_uniform_(m.weight)
def main(fold_i=0, data_=None, train_index=None, val_index=None):
    setup_default_logging()
    args, args_text = _parse_args()

    args.prefetcher = not args.no_prefetcher
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1
    args.device = 'cuda:0'
    args.world_size = 1
    args.rank = 0  # global rank
    best_score = 0.0
    args.output = args.output + 'fold_' + str(fold_i)
    if args.distributed:
        args.device = 'cuda:%d' % args.local_rank
        torch.cuda.set_device(args.local_rank)
        if fold_i == 0:
            torch.distributed.init_process_group(backend='nccl',
                                                 init_method='env://')
        args.world_size = torch.distributed.get_world_size()
        args.rank = torch.distributed.get_rank()
        _logger.info(
            'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
            % (args.rank, args.world_size))
    else:
        _logger.info('Training with a single process on 1 GPUs.')
    assert args.rank >= 0

    # resolve AMP arguments based on PyTorch / Apex availability
    use_amp = None
    if args.amp:
        # for backwards compat, `--amp` arg tries apex before native amp
        if has_apex:
            args.apex_amp = True
        elif has_native_amp:
            args.native_amp = True
    if args.apex_amp and has_apex:
        use_amp = 'apex'
    elif args.native_amp and has_native_amp:
        use_amp = 'native'
    elif args.apex_amp or args.native_amp:
        _logger.warning(
            "Neither APEX or native Torch AMP is available, using float32. "
            "Install NVIDA apex or upgrade to PyTorch 1.6")

    torch.manual_seed(args.seed + args.rank)

    model = create_model(
        args.model,
        pretrained=args.pretrained,
        num_classes=args.num_classes,
        drop_rate=args.drop,
        drop_connect_rate=args.drop_connect,  # DEPRECATED, use drop_path
        drop_path_rate=args.drop_path,
        drop_block_rate=args.drop_block,
        global_pool=args.gp,
        bn_tf=args.bn_tf,
        bn_momentum=args.bn_momentum,
        bn_eps=args.bn_eps,
        scriptable=args.torchscript,
        checkpoint_path=args.initial_checkpoint)

    if args.local_rank == 0:
        _logger.info('Model %s created, param count: %d' %
                     (args.model, sum([m.numel()
                                       for m in model.parameters()])))

    data_config = resolve_data_config(vars(args),
                                      model=model,
                                      verbose=args.local_rank == 0)

    # setup augmentation batch splits for contrastive loss or split bn
    num_aug_splits = 0
    if args.aug_splits > 0:
        assert args.aug_splits > 1, 'A split of 1 makes no sense'
        num_aug_splits = args.aug_splits

    # enable split bn (separate bn stats per batch-portion)
    if args.split_bn:
        assert num_aug_splits > 1 or args.resplit
        model = convert_splitbn_model(model, max(num_aug_splits, 2))

    # move model to GPU, enable channels last layout if set
    model.cuda()
    if args.channels_last:
        model = model.to(memory_format=torch.channels_last)

    # setup synchronized BatchNorm for distributed training
    if args.distributed and args.sync_bn:
        assert not args.split_bn
        if has_apex and use_amp != 'native':
            # Apex SyncBN preferred unless native amp is activated
            model = convert_syncbn_model(model)
        else:
            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
        if args.local_rank == 0:
            _logger.info(
                'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
                'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.'
            )

    if args.torchscript:
        assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
        assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
        model = torch.jit.script(model)

    optimizer = create_optimizer(args, model)

    #optimizer = torch.optim.SGD(model.parameters(), lr=0.1, weight_decay=1e-6)
    # setup automatic mixed-precision (AMP) loss scaling and op casting

    amp_autocast = suppress  # do nothing
    loss_scaler = None
    if use_amp == 'apex':
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
        loss_scaler = ApexScaler()
        if args.local_rank == 0:
            _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
    elif use_amp == 'native':
        amp_autocast = torch.cuda.amp.autocast
        loss_scaler = NativeScaler()
        if args.local_rank == 0:
            _logger.info(
                'Using native Torch AMP. Training in mixed precision.')
    else:
        if args.local_rank == 0:
            _logger.info('AMP not enabled. Training in float32.')

    # optionally resume from a checkpoint
    resume_epoch = None
    if args.resume:
        resume_epoch = resume_checkpoint(
            model,
            args.resume,
            optimizer=None if args.no_resume_opt else optimizer,
            loss_scaler=None if args.no_resume_opt else loss_scaler,
            log_info=args.local_rank == 0)

    # setup exponential moving average of model weights, SWA could be used here too
    model_ema = None
    if args.model_ema:
        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
        model_ema = ModelEmaV2(
            model,
            decay=args.model_ema_decay,
            device='cpu' if args.model_ema_force_cpu else None)
        if args.resume:
            load_checkpoint(model_ema.module, args.resume, use_ema=True)

    # setup distributed training
    if args.distributed:
        if has_apex and use_amp != 'native':
            # Apex DDP preferred unless native amp is activated
            if args.local_rank == 0:
                _logger.info("Using NVIDIA APEX DistributedDataParallel.")
            model = ApexDDP(model, delay_allreduce=True)
        else:
            if args.local_rank == 0:
                _logger.info("Using native Torch DistributedDataParallel.")
            model = NativeDDP(model, device_ids=[
                args.local_rank
            ])  # can use device str in Torch >= 1.1
        # NOTE: EMA model does not need to be wrapped by DDP
    lr_scheduler, num_epochs = create_scheduler(args, optimizer)
    # lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=1, eta_min=1e-6, last_epoch=-1)

    if args.local_rank == 0:
        _logger.info('Scheduled epochs: {}'.format(20))

    ##create DataLoader
    train_data = data_.iloc[train_index, :].reset_index(drop=True)
    dataset_train = RiaddDataSet(image_ids=train_data, baseImgPath=args.data)

    val_data = data_.iloc[val_index, :].reset_index(drop=True)
    dataset_eval = RiaddDataSet(image_ids=val_data, baseImgPath=args.data)

    # setup mixup / cutmix
    collate_fn = None
    mixup_fn = None
    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
    if mixup_active:
        mixup_args = dict(mixup_alpha=args.mixup,
                          cutmix_alpha=args.cutmix,
                          cutmix_minmax=args.cutmix_minmax,
                          prob=args.mixup_prob,
                          switch_prob=args.mixup_switch_prob,
                          mode=args.mixup_mode,
                          label_smoothing=args.smoothing,
                          num_classes=args.num_classes)
        if args.prefetcher:
            assert not num_aug_splits  # collate conflict (need to support deinterleaving in collate mixup)
            collate_fn = FastCollateMixup(**mixup_args)
        else:
            mixup_fn = Mixup(**mixup_args)

    # wrap dataset in AugMix helper
    if num_aug_splits > 1:
        dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)

    # create data loaders w/ augmentation pipeiine
    train_interpolation = args.train_interpolation
    if args.no_aug or not train_interpolation:
        train_interpolation = data_config['interpolation']
    train_trans = get_riadd_train_transforms(args)
    loader_train = create_loader(
        dataset_train,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        is_training=True,
        use_prefetcher=args.prefetcher,
        no_aug=args.no_aug,
        re_prob=args.reprob,
        re_mode=args.remode,
        re_count=args.recount,
        re_split=args.resplit,
        scale=args.scale,
        ratio=args.ratio,
        hflip=args.hflip,
        vflip=args.vflip,
        color_jitter=args.color_jitter,
        auto_augment=args.aa,
        num_aug_splits=num_aug_splits,
        interpolation=train_interpolation,
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        collate_fn=collate_fn,
        pin_memory=args.pin_mem,
        use_multi_epochs_loader=args.use_multi_epochs_loader,
        transform=train_trans)

    valid_trans = get_riadd_valid_transforms(args)
    loader_eval = create_loader(
        dataset_eval,
        input_size=data_config['input_size'],
        batch_size=args.validation_batch_size_multiplier * args.batch_size,
        is_training=False,
        use_prefetcher=args.prefetcher,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        crop_pct=data_config['crop_pct'],
        pin_memory=args.pin_mem,
        transform=valid_trans)

    # # setup loss function
    # if args.jsd:
    #     assert num_aug_splits > 1  # JSD only valid with aug splits set
    #     train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda()
    # elif mixup_active:
    #     # smoothing is handled with mixup target transform
    #     train_loss_fn = SoftTargetCrossEntropy().cuda()
    # elif args.smoothing:
    #     train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda()
    # else:
    #     train_loss_fn = nn.CrossEntropyLoss().cuda()

    validate_loss_fn = nn.BCEWithLogitsLoss().cuda()
    train_loss_fn = nn.BCEWithLogitsLoss().cuda()

    # setup checkpoint saver and eval metric tracking
    eval_metric = args.eval_metric
    best_metric = None
    best_epoch = None
    saver = None
    vis = None
    output_dir = ''
    if args.local_rank == 0:
        output_base = args.output if args.output else './output'
        exp_name = '-'.join([
            datetime.now().strftime("%Y%m%d-%H%M%S"), args.model,
            str(data_config['input_size'][-1])
        ])
        output_dir = get_outdir(output_base, 'train', exp_name)
        decreasing = True if eval_metric == 'loss' else False
        saver = CheckpointSaver(model=model,
                                optimizer=optimizer,
                                args=args,
                                model_ema=model_ema,
                                amp_scaler=loss_scaler,
                                checkpoint_dir=output_dir,
                                recovery_dir=output_dir,
                                decreasing=decreasing)
        with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
            f.write(args_text)
        vis = Visualizer(env=args.output)

    try:
        for epoch in range(0, args.epochs):
            if args.distributed:
                loader_train.sampler.set_epoch(epoch)

            train_metrics = train_epoch(epoch,
                                        model,
                                        loader_train,
                                        optimizer,
                                        train_loss_fn,
                                        args,
                                        lr_scheduler=lr_scheduler,
                                        saver=saver,
                                        output_dir=output_dir,
                                        amp_autocast=amp_autocast,
                                        loss_scaler=loss_scaler,
                                        model_ema=model_ema,
                                        mixup_fn=mixup_fn)

            if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
                if args.local_rank == 0:
                    _logger.info(
                        "Distributing BatchNorm running means and vars")
                distribute_bn(model, args.world_size, args.dist_bn == 'reduce')

            eval_metrics = validate(model,
                                    loader_eval,
                                    validate_loss_fn,
                                    args,
                                    amp_autocast=amp_autocast)
            # predictions = gather_tensor(eval_metrics)
            predictions, valid_label = gather_predict_label(eval_metrics, args)
            valid_label = valid_label.detach().cpu().numpy()
            predictions = predictions.detach().cpu().numpy()
            score, scores = get_score(valid_label, predictions)
            ##visdom
            if vis is not None:
                vis.plot_curves({'None': epoch},
                                iters=epoch,
                                title='None',
                                xlabel='iters',
                                ylabel='None')
                vis.plot_curves(
                    {'learing rate': optimizer.param_groups[0]['lr']},
                    iters=epoch,
                    title='lr',
                    xlabel='iters',
                    ylabel='learing rate')
                vis.plot_curves({'train loss': float(train_metrics['loss'])},
                                iters=epoch,
                                title='train loss',
                                xlabel='iters',
                                ylabel='train loss')
                vis.plot_curves({'val loss': float(eval_metrics['loss'])},
                                iters=epoch,
                                title='val loss',
                                xlabel='iters',
                                ylabel='val loss')
                vis.plot_curves({'val score': float(score)},
                                iters=epoch,
                                title='val score',
                                xlabel='iters',
                                ylabel='val score')

            if model_ema is not None and not args.model_ema_force_cpu:
                if args.distributed and args.dist_bn in ('broadcast',
                                                         'reduce'):
                    distribute_bn(model_ema, args.world_size,
                                  args.dist_bn == 'reduce')
                ema_eval_metrics = validate(model_ema.module,
                                            loader_eval,
                                            validate_loss_fn,
                                            args,
                                            amp_autocast=amp_autocast,
                                            log_suffix=' (EMA)')
                eval_metrics = ema_eval_metrics

            if lr_scheduler is not None:
                # step LR for next epoch
                # lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
                lr_scheduler.step(epoch + 1, score)

            update_summary(epoch,
                           train_metrics,
                           eval_metrics,
                           os.path.join(output_dir, 'summary.csv'),
                           write_header=best_metric is None)

            if saver is not None and score > best_score:
                # save proper checkpoint with eval metric
                best_score = score
                save_metric = best_score
                best_metric, best_epoch = saver.save_checkpoint(
                    epoch, metric=save_metric)
        del model
        del optimizer
        torch.cuda.empty_cache()
    except KeyboardInterrupt:
        pass
Beispiel #15
0
def main():
    args, args_text = _parse_args()

    if args.local_rank == 0:
        args.local_rank = local_rank
        print("rank:{0},word_size:{1},dist_url:{2}".format(
            local_rank, word_size, dist_url))

    if args.model_selection == 470:
        arch_list = [[0], [3, 4, 3, 1], [3, 2, 3, 0], [3, 3, 3, 1],
                     [3, 3, 3, 3], [3, 3, 3, 3], [0]]
        arch_def = [
            # stage 0, 112x112 in
            ['ds_r1_k3_s1_e1_c16_se0.25'],
            # stage 1, 112x112 in
            [
                'ir_r1_k3_s2_e4_c24_se0.25', 'ir_r1_k3_s1_e4_c24_se0.25',
                'ir_r1_k3_s1_e4_c24_se0.25', 'ir_r1_k3_s1_e4_c24_se0.25'
            ],
            # stage 2, 56x56 in
            [
                'ir_r1_k5_s2_e4_c40_se0.25', 'ir_r1_k5_s1_e4_c40_se0.25',
                'ir_r1_k5_s2_e4_c40_se0.25', 'ir_r1_k5_s2_e4_c40_se0.25'
            ],
            # stage 3, 28x28 in
            [
                'ir_r1_k3_s2_e6_c80_se0.25', 'ir_r1_k3_s1_e4_c80_se0.25',
                'ir_r1_k3_s1_e4_c80_se0.25', 'ir_r2_k3_s1_e4_c80_se0.25'
            ],
            # stage 4, 14x14in
            [
                'ir_r1_k3_s1_e6_c96_se0.25', 'ir_r1_k3_s1_e6_c96_se0.25',
                'ir_r1_k3_s1_e6_c96_se0.25', 'ir_r1_k3_s1_e6_c96_se0.25'
            ],
            # stage 5, 14x14in
            [
                'ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25',
                'ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s2_e6_c192_se0.25'
            ],
            # stage 6, 7x7 in
            ['cn_r1_k1_s1_c320_se0.25'],
        ]
        args.img_size = 224
    elif args.model_selection == 42:
        arch_list = [[0], [3], [3, 1], [3, 1], [3, 3, 3], [3, 3], [0]]
        arch_def = [
            # stage 0, 112x112 in
            ['ds_r1_k3_s1_e1_c16_se0.25'],
            # stage 1, 112x112 in
            ['ir_r1_k3_s2_e4_c24_se0.25'],
            # stage 2, 56x56 in
            ['ir_r1_k5_s2_e4_c40_se0.25', 'ir_r1_k5_s2_e4_c40_se0.25'],
            # stage 3, 28x28 in
            ['ir_r1_k3_s2_e6_c80_se0.25', 'ir_r1_k3_s2_e6_c80_se0.25'],
            # stage 4, 14x14in
            [
                'ir_r1_k3_s1_e6_c96_se0.25', 'ir_r1_k3_s1_e6_c96_se0.25',
                'ir_r1_k3_s1_e6_c96_se0.25'
            ],
            # stage 5, 14x14in
            ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s2_e6_c192_se0.25'],
            # stage 6, 7x7 in
            ['cn_r1_k1_s1_c320_se0.25'],
        ]
        args.img_size = 96
    elif args.model_selection == 14:
        arch_list = [[0], [3], [3, 3], [3, 3], [3], [3], [0]]
        arch_def = [
            # stage 0, 112x112 in
            ['ds_r1_k3_s1_e1_c16_se0.25'],
            # stage 1, 112x112 in
            ['ir_r1_k3_s2_e4_c24_se0.25'],
            # stage 2, 56x56 in
            ['ir_r1_k5_s2_e4_c40_se0.25', 'ir_r1_k3_s2_e4_c40_se0.25'],
            # stage 3, 28x28 in
            ['ir_r1_k3_s2_e6_c80_se0.25', 'ir_r1_k3_s2_e4_c80_se0.25'],
            # stage 4, 14x14in
            ['ir_r1_k3_s1_e6_c96_se0.25'],
            # stage 5, 14x14in
            ['ir_r1_k5_s2_e6_c192_se0.25'],
            # stage 6, 7x7 in
            ['cn_r1_k1_s1_c320_se0.25'],
        ]
        args.img_size = 64
    elif args.model_selection == 112:
        arch_list = [[0], [3], [3, 3], [3, 3], [3, 3, 3], [3, 3], [0]]
        arch_def = [
            # stage 0, 112x112 in
            ['ds_r1_k3_s1_e1_c16_se0.25'],
            # stage 1, 112x112 in
            ['ir_r1_k3_s2_e4_c24_se0.25'],
            # stage 2, 56x56 in
            ['ir_r1_k5_s2_e4_c40_se0.25', 'ir_r1_k3_s2_e4_c40_se0.25'],
            # stage 3, 28x28 in
            ['ir_r1_k3_s2_e6_c80_se0.25', 'ir_r1_k3_s2_e6_c80_se0.25'],
            # stage 4, 14x14in
            [
                'ir_r1_k3_s1_e6_c96_se0.25', 'ir_r1_k3_s1_e6_c96_se0.25',
                'ir_r1_k3_s1_e6_c96_se0.25'
            ],
            # stage 5, 14x14in
            ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s2_e6_c192_se0.25'],
            # stage 6, 7x7 in
            ['cn_r1_k1_s1_c320_se0.25'],
        ]
        args.img_size = 160
    elif args.model_selection == 285:
        arch_list = [[0], [3], [3, 3], [3, 1, 3], [3, 3, 3, 3], [3, 3, 3], [0]]
        arch_def = [
            # stage 0, 112x112 in
            ['ds_r1_k3_s1_e1_c16_se0.25'],
            # stage 1, 112x112 in
            ['ir_r1_k3_s2_e4_c24_se0.25'],
            # stage 2, 56x56 in
            ['ir_r1_k5_s2_e4_c40_se0.25', 'ir_r1_k5_s2_e4_c40_se0.25'],
            # stage 3, 28x28 in
            [
                'ir_r1_k3_s2_e6_c80_se0.25', 'ir_r1_k3_s2_e6_c80_se0.25',
                'ir_r1_k3_s2_e6_c80_se0.25'
            ],
            # stage 4, 14x14in
            [
                'ir_r1_k3_s1_e6_c96_se0.25', 'ir_r1_k3_s1_e6_c96_se0.25',
                'ir_r1_k3_s1_e6_c96_se0.25', 'ir_r1_k3_s1_e6_c96_se0.25'
            ],
            # stage 5, 14x14in
            [
                'ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s2_e6_c192_se0.25',
                'ir_r1_k5_s2_e6_c192_se0.25'
            ],
            # stage 6, 7x7 in
            ['cn_r1_k1_s1_c320_se0.25'],
        ]
        args.img_size = 224
    elif args.model_selection == 600:
        arch_list = [[0], [3, 3, 2, 3, 3], [3, 2, 3, 2, 3], [3, 2, 3, 2, 3],
                     [3, 3, 2, 2, 3, 3], [3, 3, 2, 3, 3, 3], [0]]
        arch_def = [
            # stage 0, 112x112 in
            ['ds_r1_k3_s1_e1_c16_se0.25'],
            # stage 1, 112x112 in
            [
                'ir_r1_k3_s2_e4_c24_se0.25', 'ir_r1_k3_s2_e4_c24_se0.25',
                'ir_r1_k3_s2_e4_c24_se0.25', 'ir_r1_k3_s2_e4_c24_se0.25',
                'ir_r1_k3_s2_e4_c24_se0.25'
            ],
            # stage 2, 56x56 in
            [
                'ir_r1_k5_s2_e4_c40_se0.25', 'ir_r1_k5_s2_e4_c40_se0.25',
                'ir_r1_k5_s2_e4_c40_se0.25', 'ir_r1_k5_s2_e4_c40_se0.25',
                'ir_r1_k5_s2_e4_c40_se0.25'
            ],
            # stage 3, 28x28 in
            [
                'ir_r1_k3_s2_e6_c80_se0.25', 'ir_r1_k3_s1_e4_c80_se0.25',
                'ir_r1_k3_s1_e4_c80_se0.25', 'ir_r1_k3_s1_e4_c80_se0.25',
                'ir_r1_k3_s1_e4_c80_se0.25'
            ],
            # stage 4, 14x14in
            [
                'ir_r1_k3_s1_e6_c96_se0.25', 'ir_r1_k3_s1_e6_c96_se0.25',
                'ir_r1_k3_s1_e6_c96_se0.25', 'ir_r1_k3_s1_e6_c96_se0.25',
                'ir_r1_k3_s1_e6_c96_se0.25', 'ir_r1_k3_s1_e6_c96_se0.25'
            ],
            # stage 5, 14x14in
            [
                'ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25',
                'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25',
                'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25'
            ],
            # stage 6, 7x7 in
            ['cn_r1_k1_s1_c320_se0.25'],
        ]
        args.img_size = 224
    else:
        raise ValueError("Not Supported!")

    model = _gen_childnet(arch_list,
                          arch_def,
                          num_classes=args.num_classes,
                          drop_rate=args.drop,
                          drop_path_rate=args.drop_path,
                          global_pool=args.gp,
                          bn_momentum=args.bn_momentum,
                          bn_eps=args.bn_eps,
                          pool_bn=args.pool_bn,
                          zero_gamma=args.zero_gamma)

    data_config = resolve_data_config(vars(args),
                                      model=model,
                                      verbose=args.local_rank == 0)

    eval_metric = args.eval_metric
    best_metric = None
    best_epoch = None
    saver = None
    output_dir = ''
    if args.local_rank == 0:
        output_base = args.output if args.output else './experiments'
        exp_name = '-'.join([
            args.name,
            datetime.now().strftime("%Y%m%d-%H%M%S"), args.model,
            str(data_config['input_size'][-1])
        ])
        output_dir = get_outdir(output_base, 'retrain', exp_name)
        logger = get_logger(os.path.join(output_dir, 'retrain.log'))
        writer = SummaryWriter(os.path.join(output_dir, 'runs'))
        decreasing = True if eval_metric == 'loss' else False
        saver = CheckpointSaver(checkpoint_dir=output_dir,
                                decreasing=decreasing)
        with open(os.path.join(output_dir, 'config.yaml'), 'w') as f:
            f.write(args_text)
    else:
        writer = None
        logger = None

    args.prefetcher = not args.no_prefetcher
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1
        if args.distributed and args.num_gpu > 1 and args.local_rank == 0:
            logger.warning(
                'Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.'
            )
            args.num_gpu = 1

    args.device = 'cuda:0'
    args.world_size = 1
    args.rank = 0  # global rank
    args.distributed = True
    if args.distributed:
        args.num_gpu = 1
        args.device = 'cuda:%d' % args.local_rank
        torch.cuda.set_device(args.local_rank)
        # torch.distributed.init_process_group(backend='nccl', init_method='env://')
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()
        args.rank = torch.distributed.get_rank()
    assert args.rank >= 0

    if args.local_rank == 0:
        if args.distributed:
            logger.info(
                'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
                % (args.rank, args.world_size))
        else:
            logger.info('Training with a single process on %d GPUs.' %
                        args.num_gpu)

    seed = args.seed
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    if args.local_rank == 0:
        scope(model, input_size=(3, 224, 224))

    if os.path.exists(args.initial_checkpoint):
        load_checkpoint(model, args.initial_checkpoint)

    if args.local_rank == 0:
        logger.info('Model %s created, param count: %d' %
                    (args.model, sum([m.numel() for m in model.parameters()])))

    if args.num_gpu > 1:
        if args.amp:
            if args.local_rank == 0:
                logger.warning(
                    'AMP does not work well with nn.DataParallel, disabling. Use distributed mode for multi-GPU AMP.'
                )
            args.amp = False
        model = nn.DataParallel(model,
                                device_ids=list(range(args.num_gpu))).cuda()
    else:
        model.cuda()

    optimizer = create_optimizer(args, model)

    use_amp = False
    if has_apex and args.amp:
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
        use_amp = True
    if args.local_rank == 0:
        logger.info('NVIDIA APEX {}. '
                    ' {}.'.format('installed' if has_apex else 'not installed',
                                  'on' if use_amp else 'off'))

    # optionally resume from a checkpoint
    resume_state = {}
    resume_epoch = None
    if args.resume:
        resume_state, resume_epoch = resume_checkpoint(model, args.resume)
    if resume_state and not args.no_resume_opt:
        if 'optimizer' in resume_state:
            if args.local_rank == 0:
                logging.info('Restoring Optimizer state from checkpoint')
            optimizer.load_state_dict(resume_state['optimizer'])
        if use_amp and 'amp' in resume_state and 'load_state_dict' in amp.__dict__:
            if args.local_rank == 0:
                logging.info('Restoring NVIDIA AMP state from checkpoint')
            amp.load_state_dict(resume_state['amp'])
    del resume_state

    model_ema = None
    if args.model_ema:
        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
        model_ema = ModelEma(model,
                             decay=args.model_ema_decay,
                             device='cpu' if args.model_ema_force_cpu else '',
                             resume=args.resume)

    if args.distributed:
        if args.sync_bn:
            assert not args.split_bn
            try:
                if has_apex:
                    model = convert_syncbn_model(model)
                else:
                    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                        model)
                if args.local_rank == 0:
                    logger.info(
                        'Converted model to use Synchronized BatchNorm.')
            except Exception as e:
                if args.local_rank == 0:
                    logger.error(
                        'Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1'
                    )
        if has_apex:
            model = DDP(model, delay_allreduce=True)
        else:
            if args.local_rank == 0:
                logger.info(
                    "Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP."
                )
            model = DDP(model,
                        device_ids=[args.local_rank
                                    ])  # can use device str in Torch >= 1.1
        # NOTE: EMA model does not need to be wrapped by DDP

    train_dir = os.path.join(args.data, 'train')
    if not os.path.exists(train_dir) and args.local_rank == 0:
        logger.error('Training folder does not exist at: {}'.format(train_dir))
        exit(1)
    dataset_train = Dataset(train_dir)

    eval_dir = os.path.join(args.data, 'val')
    if not os.path.exists(eval_dir) and args.local_rank == 0:
        logger.error(
            'Validation folder does not exist at: {}'.format(eval_dir))
        exit(1)
    dataset_eval = Dataset(eval_dir)

    loader_train = create_loader(
        dataset_train,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        is_training=True,
        use_prefetcher=args.prefetcher,
        re_prob=args.reprob,
        re_mode=args.remode,
        re_count=args.recount,
        re_split=args.resplit,
        color_jitter=args.color_jitter,
        auto_augment=args.aa,
        num_aug_splits=0,
        interpolation=args.train_interpolation,
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        collate_fn=None,
        pin_memory=args.pin_mem,
    )

    loader_eval = create_loader(
        dataset_eval,
        input_size=data_config['input_size'],
        batch_size=args.validation_batch_size_multiplier * args.batch_size,
        is_training=False,
        use_prefetcher=args.prefetcher,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        crop_pct=data_config['crop_pct'],
        pin_memory=args.pin_mem,
    )

    if args.smoothing:
        train_loss_fn = LabelSmoothingCrossEntropy(
            smoothing=args.smoothing).cuda()
        validate_loss_fn = nn.CrossEntropyLoss().cuda()
    else:
        train_loss_fn = nn.CrossEntropyLoss().cuda()
        validate_loss_fn = train_loss_fn

    lr_scheduler, num_epochs = create_scheduler(args, optimizer)

    start_epoch = 0
    if args.start_epoch is not None:
        start_epoch = args.start_epoch
    elif resume_epoch is not None:
        start_epoch = resume_epoch
    if lr_scheduler is not None and start_epoch > 0:
        lr_scheduler.step(start_epoch)

    if args.local_rank == 0:
        logger.info('Scheduled epochs: {}'.format(num_epochs))

    try:
        best_record = 0
        best_ep = 0
        for epoch in range(start_epoch, num_epochs):
            if args.distributed:
                loader_train.sampler.set_epoch(epoch)

            train_metrics = train_epoch(epoch,
                                        model,
                                        loader_train,
                                        optimizer,
                                        train_loss_fn,
                                        args,
                                        lr_scheduler=lr_scheduler,
                                        saver=saver,
                                        output_dir=output_dir,
                                        use_amp=use_amp,
                                        model_ema=model_ema,
                                        logger=logger,
                                        writer=writer)

            if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
                if args.local_rank == 0:
                    logging.info(
                        "Distributing BatchNorm running means and vars")
                distribute_bn(model, args.world_size, args.dist_bn == 'reduce')

            eval_metrics = validate(epoch,
                                    model,
                                    loader_eval,
                                    validate_loss_fn,
                                    args,
                                    logger=logger,
                                    writer=writer)

            if model_ema is not None and not args.model_ema_force_cpu:
                if args.distributed and args.dist_bn in ('broadcast',
                                                         'reduce'):
                    distribute_bn(model_ema, args.world_size,
                                  args.dist_bn == 'reduce')

                ema_eval_metrics = validate(epoch,
                                            model_ema.ema,
                                            loader_eval,
                                            validate_loss_fn,
                                            args,
                                            log_suffix=' (EMA)',
                                            logger=logger,
                                            writer=writer)
                eval_metrics = ema_eval_metrics

            if lr_scheduler is not None:
                lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])

            update_summary(epoch,
                           train_metrics,
                           eval_metrics,
                           os.path.join(output_dir, 'summary.csv'),
                           write_header=best_metric is None)

            if saver is not None:
                # save proper checkpoint with eval metric
                save_metric = eval_metrics[eval_metric]
                best_metric, best_epoch = saver.save_checkpoint(
                    model,
                    optimizer,
                    args,
                    epoch=epoch,
                    model_ema=model_ema,
                    metric=save_metric,
                    use_amp=use_amp)

            if best_record < eval_metrics[eval_metric]:
                best_record = eval_metrics[eval_metric]
                best_ep = epoch

            if args.local_rank == 0:
                logger.info('*** Best metric: {0} (epoch {1})'.format(
                    best_record, best_ep))

    except KeyboardInterrupt:
        pass
    if best_metric is not None:
        logger.info('*** Best metric: {0} (epoch {1})'.format(
            best_metric, best_epoch))
Beispiel #16
0
def main():
    setup_default_logging()
    args, args_text = _parse_args()

    args.prefetcher = not args.no_prefetcher
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1
        if args.distributed and args.num_gpu > 1:
            logging.warning('Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.')
            args.num_gpu = 1

    args.device = 'cuda:0'
    args.world_size = 1
    args.rank = 0  # global rank
    if args.distributed:
        args.num_gpu = 1
        args.device = 'cuda:%d' % args.local_rank
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl', init_method='env://')
        args.world_size = torch.distributed.get_world_size()
        args.rank = torch.distributed.get_rank()
    assert args.rank >= 0

    if args.distributed:
        logging.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
                     % (args.rank, args.world_size))
    else:
        logging.info('Training with a single process on %d GPUs.' % args.num_gpu)

    torch.manual_seed(args.seed + args.rank)

    model = create_model(
        args.model,
        pretrained=args.pretrained,
        num_classes=args.num_classes,
        drop_rate=args.drop,
        global_pool=args.gp,
        bn_tf=args.bn_tf,
        bn_momentum=args.bn_momentum,
        bn_eps=args.bn_eps,
        checkpoint_path=args.initial_checkpoint)

    if args.binarizable:
        Model_binary_patch(model)


    if args.local_rank == 0:
        logging.info('Model %s created, param count: %d' %
                     (args.model, sum([m.numel() for m in model.parameters()])))

    data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)

    if args.num_gpu > 1:
        if args.amp:
            logging.warning(
                'AMP does not work well with nn.DataParallel, disabling. Use distributed mode for multi-GPU AMP.')
            args.amp = False
        model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()

    else:
        model.cuda()


    optimizer = create_optimizer(args, model)

    use_amp = False
    if has_apex and args.amp:
        print('Using amp.')
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
        use_amp = True
    else:
        print('Do NOT use amp.')
    if args.local_rank == 0:
        logging.info('NVIDIA APEX {}. AMP {}.'.format(
            'installed' if has_apex else 'not installed', 'on' if use_amp else 'off'))

    # optionally resume from a checkpoint
    resume_state = {}
    resume_epoch = None
    if args.resume:
        resume_state, resume_epoch = resume_checkpoint(model, args.resume)
    if resume_state and not args.no_resume_opt:
        if 'optimizer' in resume_state:
            if args.local_rank == 0:
                logging.info('Restoring Optimizer state from checkpoint')
            optimizer.load_state_dict(resume_state['optimizer'])
        if use_amp and 'amp' in resume_state and 'load_state_dict' in amp.__dict__:
            if args.local_rank == 0:
                logging.info('Restoring NVIDIA AMP state from checkpoint')
            amp.load_state_dict(resume_state['amp'])
    resume_state = None

    if args.freeze_binary:
        Model_freeze_binary(model)



    if args.distributed:
        if args.sync_bn:
            try:
                if has_apex:
                    model = convert_syncbn_model(model)
                else:
                    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
                if args.local_rank == 0:
                    logging.info('Converted model to use Synchronized BatchNorm.')
            except Exception as e:
                logging.error('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1')
        if has_apex:
            model = DDP(model, delay_allreduce=True)
        else:
            if args.local_rank == 0:
                logging.info("Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP.")
            model = DDP(model, device_ids=[args.local_rank])  # can use device str in Torch >= 1.1
        # NOTE: EMA model does not need to be wrapped by DDP

    lr_scheduler, num_epochs = create_scheduler(args, optimizer)
    # start_epoch = 0 #
    if args.start_epoch is not None:
        # a specified start_epoch will always override the resume epoch
        start_epoch = args.start_epoch
    elif resume_epoch is not None:
        start_epoch = resume_epoch
    if args.reset_lr_scheduler is not None:
        lr_scheduler.base_values = len(lr_scheduler.base_values)*[args.reset_lr_scheduler]
        lr_scheduler.step(start_epoch)

    if lr_scheduler is not None and start_epoch > 0:
        lr_scheduler.step(start_epoch)

    if args.local_rank == 0:
        logging.info('Scheduled epochs: {}'.format(num_epochs))

    # Using pruner to get sparse weights
    if args.prune:
        pruner = Pruner(model, 0, 100, 0.75)
    else:
        pruner = None

    dataset_train = torchvision.datasets.CIFAR100(root='~/Downloads/CIFAR100', train=True, download=True)

    collate_fn = None
    if args.prefetcher and args.mixup > 0:
        collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes)

    loader_train = create_loader_CIFAR100(
        dataset_train,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        is_training=True,
        use_prefetcher=args.prefetcher,
        rand_erase_prob=args.reprob,
        rand_erase_mode=args.remode,
        rand_erase_count=args.recount,
        color_jitter=args.color_jitter,
        auto_augment=args.aa,
        interpolation='random',
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        collate_fn=collate_fn,
        is_clean_data=args.clean_train,
    )


    dataset_eval = torchvision.datasets.CIFAR100(root='~/Downloads/CIFAR100', train=False, download=True)

    loader_eval = create_loader_CIFAR100(
        dataset_eval,
        input_size=data_config['input_size'],
        batch_size=4 * args.batch_size,
        is_training=False,
        use_prefetcher=args.prefetcher,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
    )

    if args.mixup > 0.:
        # smoothing is handled with mixup label transform
        train_loss_fn = SoftTargetCrossEntropy(multiplier=args.softmax_multiplier).cuda()
        validate_loss_fn = nn.CrossEntropyLoss().cuda()
    elif args.smoothing:
        train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda()
        validate_loss_fn = nn.CrossEntropyLoss().cuda()
    else:
        train_loss_fn = nn.CrossEntropyLoss().cuda()
        validate_loss_fn = train_loss_fn

    eval_metric = args.eval_metric
    best_metric = None
    best_epoch = None
    saver = None
    saver_last_10_epochs = None
    output_dir = ''
    if args.local_rank == 0:
        output_base = args.output if args.output else './output'
        exp_name = '-'.join([
            datetime.now().strftime("%Y%m%d-%H%M%S"),
            args.model,
            str(data_config['input_size'][-1])
        ])
        output_dir = get_outdir(output_base, 'train', exp_name)
        decreasing = True if eval_metric == 'loss' else False
        os.makedirs(output_dir+'/Top')
        os.makedirs(output_dir+'/Last')
        saver = CheckpointSaver(checkpoint_dir=output_dir + '/Top', decreasing=decreasing, max_history=10)  # Save the results of the top 10 epochs
        saver_last_10_epochs = CheckpointSaver(checkpoint_dir=output_dir + '/Last', decreasing=decreasing, max_history=10) # Save the results of the last 10 epochs
        with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
            f.write(args_text)
            f.write('==============================')
            f.write(model.__str__())

    tensorboard_writer = SummaryWriter(output_dir)

    try:
        for epoch in range(start_epoch, num_epochs):

            global alpha
            alpha = get_alpha(epoch, args)

            if args.distributed:
                loader_train.sampler.set_epoch(epoch)

            if pruner:
                pruner.on_epoch_begin(epoch)  # pruning

            train_metrics = train_epoch(
                epoch, model, loader_train, optimizer, train_loss_fn, args,
                lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
                use_amp=use_amp, tensorboard_writer=tensorboard_writer,
                pruner = pruner)

            if pruner:
                pruner.print_statistics()

            eval_metrics = validate(model, loader_eval, validate_loss_fn, args, tensorboard_writer=tensorboard_writer, epoch=epoch)

            if lr_scheduler is not None:
                # step LR for next epoch
                lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])

            update_summary(
                epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
                write_header=best_metric is None)

            if saver is not None:
                # save proper checkpoint with eval metric
                save_metric = eval_metrics[eval_metric]
                best_metric, best_epoch = saver.save_checkpoint(
                    model, optimizer, args,
                    epoch=epoch, metric=save_metric, use_amp=use_amp)
            if saver_last_10_epochs is not None:
                # save the checkpoint in the last 5 epochs
                _, _ = saver_last_10_epochs.save_checkpoint(
                    model, optimizer, args,
                    epoch=epoch, metric=epoch, use_amp=use_amp)



    except KeyboardInterrupt:
        pass
    if best_metric is not None:
        logging.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
    logging.info('The checkpoint of the last epoch is: \n')
    logging.info(saver_last_10_epochs.checkpoint_files[0][0])
Beispiel #17
0
def train():
    print("local_rank:", args.local_rank)
    cudnn.benchmark = True
    if args.deterministic:
        cudnn.benchmark = False
        cudnn.deterministic = True
        torch.manual_seed(args.local_rank)
        torch.set_printoptions(precision=10)
    torch.cuda.set_device(args.local_rank)
    torch.distributed.init_process_group(
        backend='nccl',
        init_method='env://',
    )
    torch.manual_seed(0)

    if not args.eval_net:
        train_ds = dataset_desc.Dataset('train')
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_ds)
        train_loader = torch.utils.data.DataLoader(
            train_ds,
            batch_size=config.mini_batch_size,
            shuffle=False,
            drop_last=True,
            num_workers=4,
            sampler=train_sampler,
            pin_memory=True)

        val_ds = dataset_desc.Dataset('test')
        val_sampler = torch.utils.data.distributed.DistributedSampler(val_ds)
        val_loader = torch.utils.data.DataLoader(
            val_ds,
            batch_size=config.val_mini_batch_size,
            shuffle=False,
            drop_last=False,
            num_workers=4,
            sampler=val_sampler)
    else:
        test_ds = dataset_desc.Dataset('test')
        test_loader = torch.utils.data.DataLoader(
            test_ds,
            batch_size=config.test_mini_batch_size,
            shuffle=False,
            num_workers=20)

    rndla_cfg = ConfigRandLA
    model = FFB6D(n_classes=config.n_objects,
                  n_pts=config.n_sample_points,
                  rndla_cfg=rndla_cfg,
                  n_kps=config.n_keypoints)
    model = convert_syncbn_model(model)
    device = torch.device('cuda:{}'.format(args.local_rank))
    print('local_rank:', args.local_rank)
    model.to(device)
    optimizer = optim.Adam(model.parameters(),
                           lr=args.lr,
                           weight_decay=args.weight_decay)
    opt_level = args.opt_level
    model, optimizer = amp.initialize(
        model,
        optimizer,
        opt_level=opt_level,
    )
    model = nn.DataParallel(model)

    # default value
    it = -1  # for the initialize value of `LambdaLR` and `BNMomentumScheduler`
    best_loss = 1e10
    start_epoch = 1

    # load status from checkpoint
    if args.checkpoint is not None:
        checkpoint_status = load_checkpoint(model,
                                            optimizer,
                                            filename=args.checkpoint[:-8])
        if checkpoint_status is not None:
            it, start_epoch, best_loss = checkpoint_status
        if args.eval_net:
            assert checkpoint_status is not None, "Failed loadding model."

    if not args.eval_net:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)
        clr_div = 6
        lr_scheduler = CyclicLR(
            optimizer,
            base_lr=1e-5,
            max_lr=1e-3,
            cycle_momentum=False,
            step_size_up=config.n_total_epoch * train_ds.minibatch_per_epoch //
            clr_div // args.gpus,
            step_size_down=config.n_total_epoch *
            train_ds.minibatch_per_epoch // clr_div // args.gpus,
            mode='triangular')
    else:
        lr_scheduler = None

    bnm_lmbd = lambda it: max(
        args.bn_momentum * args.bn_decay**
        (int(it * config.mini_batch_size / args.decay_step)),
        bnm_clip,
    )
    bnm_scheduler = pt_utils.BNMomentumScheduler(model,
                                                 bn_lambda=bnm_lmbd,
                                                 last_epoch=it)

    it = max(it, 0)  # for the initialize value of `trainer.train`

    if args.eval_net:
        model_fn = model_fn_decorator(
            FocalLoss(gamma=2),
            OFLoss(),
            args.test,
        )
    else:
        model_fn = model_fn_decorator(
            FocalLoss(gamma=2).to(device),
            OFLoss().to(device),
            args.test,
        )

    checkpoint_fd = config.log_model_dir

    trainer = Trainer(
        model,
        model_fn,
        optimizer,
        checkpoint_name=os.path.join(checkpoint_fd, "FFB6D"),
        best_name=os.path.join(checkpoint_fd, "FFB6D_best"),
        lr_scheduler=lr_scheduler,
        bnm_scheduler=bnm_scheduler,
    )

    if args.eval_net:
        start = time.time()
        val_loss, res = trainer.eval_epoch(test_loader,
                                           is_test=True,
                                           test_pose=args.test_pose)
        end = time.time()
        print("\nUse time: ", end - start, 's')
    else:
        trainer.train(it,
                      start_epoch,
                      config.n_total_epoch,
                      train_loader,
                      None,
                      val_loader,
                      best_loss=best_loss,
                      tot_iter=config.n_total_epoch *
                      train_ds.minibatch_per_epoch // args.gpus,
                      clr_div=clr_div)

        if start_epoch == config.n_total_epoch:
            _ = trainer.eval_epoch(val_loader)
Beispiel #18
0
def main():
    args.can_print = (args.distributed
                      and args.local_rank == 0) or (not args.distributed)

    log_out_dir = f'{RESULT_DIR}/logs/{args.out_dir}'
    os.makedirs(log_out_dir, exist_ok=True)
    if args.can_print:
        log = Logger()
        log.open(f'{log_out_dir}/log.train.txt', mode='a')
    else:
        log = None

    model_out_dir = f'{RESULT_DIR}/models/{args.out_dir}'
    if args.can_print:
        log.write(
            f'>> Creating directory if it does not exist:\n>> {model_out_dir}\n'
        )
    os.makedirs(model_out_dir, exist_ok=True)

    # set cuda visible device
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id
    if args.distributed:
        torch.cuda.set_device(args.local_rank)

    # set random seeds
    torch.manual_seed(0)
    torch.cuda.manual_seed_all(0)
    np.random.seed(0)

    model_params = {}
    model_params['architecture'] = args.arch
    model_params['num_classes'] = args.num_classes
    model_params['in_channels'] = args.in_channels
    model_params['can_print'] = args.can_print
    model = init_network(model_params)

    # move network to gpu
    if args.distributed:
        dist.init_process_group(backend='nccl', init_method='env://')
        model = convert_syncbn_model(model)
    model.cuda()
    if args.distributed:
        model = DistributedDataParallel(model, delay_allreduce=True)
    else:
        model = DataParallel(model)

    # define loss function (criterion)
    try:
        criterion = eval(args.loss)().cuda()
    except:
        raise RuntimeError(f'Loss {args.loss} not available!')

    start_epoch = 0
    best_score = 0
    best_epoch = 0

    # define scheduler
    try:
        scheduler = eval(args.scheduler)(model)
    except:
        raise RuntimeError(f'Scheduler {args.scheduler} not available!')

    # optionally resume from a checkpoint
    reset_epoch = True
    pretrained_file = None
    if args.model_file:
        reset_epoch = True
        pretrained_file = args.model_file
    if args.resume:
        reset_epoch = False
        pretrained_file = f'{model_out_dir}/{args.resume}'
    if pretrained_file and os.path.isfile(pretrained_file):
        # load checkpoint weights and update model and optimizer
        if args.can_print:
            log.write(f'>> Loading checkpoint:\n>> {pretrained_file}\n')

        checkpoint = torch.load(pretrained_file)
        if not reset_epoch:
            start_epoch = checkpoint['epoch']
            best_epoch = checkpoint['best_epoch']
            best_score = checkpoint['best_score']
        model.module.load_state_dict(checkpoint['state_dict'])
        if args.can_print:
            if reset_epoch:
                log.write(f'>>>> loaded checkpoint:\n>>>> {pretrained_file}\n')
            else:
                log.write(
                    f'>>>> loaded checkpoint:\n>>>> {pretrained_file} (epoch {checkpoint["epoch"]:.2f})\n'
                )
    else:
        if args.can_print:
            log.write(f'>> No checkpoint found at {pretrained_file}\n')

    # Data loading code
    train_transform = eval(f'train_multi_augment{args.aug_version}')
    train_split_file = f'{DATA_DIR}/split/{args.split_type}/random_train_cv0.csv'
    valid_split_file = f'{DATA_DIR}/split/{args.split_type}/random_valid_cv0.csv'
    train_dataset = RetrievalDataset(
        args,
        train_split_file,
        transform=train_transform,
        data_type='train',
    )
    valid_dataset = RetrievalDataset(
        args,
        valid_split_file,
        transform=None,
        data_type='valid',
    )
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
        valid_sampler = torch.utils.data.distributed.DistributedSampler(
            valid_dataset)
    else:
        train_sampler = RandomSampler(train_dataset)
        valid_sampler = SequentialSampler(valid_dataset)
    train_loader = DataLoader(
        train_dataset,
        sampler=train_sampler,
        batch_size=args.batch_size,
        drop_last=True,
        num_workers=args.workers,
        pin_memory=True,
        collate_fn=image_collate,
    )
    valid_loader = DataLoader(
        valid_dataset,
        sampler=valid_sampler,
        batch_size=args.batch_size,
        drop_last=False,
        num_workers=args.workers,
        pin_memory=True,
        collate_fn=image_collate,
    )

    train(args, train_loader, valid_loader, model, criterion, scheduler, log,
          best_epoch, best_score, start_epoch, model_out_dir)
Beispiel #19
0
def main():
    args, cfg = parse_config_args('super net training')

    # resolve logging
    output_dir = os.path.join(
        cfg.SAVE_PATH, "{}-{}".format(datetime.date.today().strftime('%m%d'),
                                      cfg.MODEL))

    if args.local_rank == 0:
        logger = get_logger(os.path.join(output_dir, "train.log"))
    else:
        logger = None

    # initialize distributed parameters
    torch.cuda.set_device(args.local_rank)
    torch.distributed.init_process_group(backend='nccl', init_method='env://')
    if args.local_rank == 0:
        logger.info('Training on Process %d with %d GPUs.', args.local_rank,
                    cfg.NUM_GPU)

    # fix random seeds
    torch.manual_seed(cfg.SEED)
    torch.cuda.manual_seed_all(cfg.SEED)
    np.random.seed(cfg.SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # generate supernet
    model, sta_num, resolution = gen_supernet(
        flops_minimum=cfg.SUPERNET.FLOPS_MINIMUM,
        flops_maximum=cfg.SUPERNET.FLOPS_MAXIMUM,
        num_classes=cfg.DATASET.NUM_CLASSES,
        drop_rate=cfg.NET.DROPOUT_RATE,
        global_pool=cfg.NET.GP,
        resunit=cfg.SUPERNET.RESUNIT,
        dil_conv=cfg.SUPERNET.DIL_CONV,
        slice=cfg.SUPERNET.SLICE,
        verbose=cfg.VERBOSE,
        logger=logger)

    # initialize meta matching networks
    MetaMN = MetaMatchingNetwork(cfg)

    # number of choice blocks in supernet
    choice_num = len(model.blocks[1][0])
    if args.local_rank == 0:
        logger.info('Supernet created, param count: %d',
                    (sum([m.numel() for m in model.parameters()])))
        logger.info('resolution: %d', (resolution))
        logger.info('choice number: %d', (choice_num))

    #initialize prioritized board
    prioritized_board = PrioritizedBoard(cfg,
                                         CHOICE_NUM=choice_num,
                                         sta_num=sta_num)

    # initialize flops look-up table
    model_est = FlopsEst(model)

    # optionally resume from a checkpoint
    optimizer_state = None
    resume_epoch = None
    if cfg.AUTO_RESUME:
        optimizer_state, resume_epoch = resume_checkpoint(
            model, cfg.RESUME_PATH)

    # create optimizer and resume from checkpoint
    optimizer = create_optimizer_supernet(cfg, model, USE_APEX)
    if optimizer_state is not None:
        optimizer.load_state_dict(optimizer_state['optimizer'])
    model = model.cuda()

    # convert model to distributed mode
    if cfg.BATCHNORM.SYNC_BN:
        try:
            if USE_APEX:
                model = convert_syncbn_model(model)
            else:
                model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
            if args.local_rank == 0:
                logger.info('Converted model to use Synchronized BatchNorm.')
        except Exception as exception:
            logger.info(
                'Failed to enable Synchronized BatchNorm. '
                'Install Apex or Torch >= 1.1 with Exception %s', exception)
    if USE_APEX:
        model = DDP(model, delay_allreduce=True)
    else:
        if args.local_rank == 0:
            logger.info(
                "Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP."
            )
        # can use device str in Torch >= 1.1
        model = DDP(model, device_ids=[args.local_rank])

    # create learning rate scheduler
    lr_scheduler, num_epochs = create_supernet_scheduler(cfg, optimizer)

    start_epoch = resume_epoch if resume_epoch is not None else 0
    if start_epoch > 0:
        lr_scheduler.step(start_epoch)

    if args.local_rank == 0:
        logger.info('Scheduled epochs: %d', num_epochs)

    # imagenet train dataset
    train_dir = os.path.join(cfg.DATA_DIR, 'train')
    if not os.path.exists(train_dir):
        logger.info('Training folder does not exist at: %s', train_dir)
        sys.exit()

    dataset_train = Dataset(train_dir)
    loader_train = create_loader(dataset_train,
                                 input_size=(3, cfg.DATASET.IMAGE_SIZE,
                                             cfg.DATASET.IMAGE_SIZE),
                                 batch_size=cfg.DATASET.BATCH_SIZE,
                                 is_training=True,
                                 use_prefetcher=True,
                                 re_prob=cfg.AUGMENTATION.RE_PROB,
                                 re_mode=cfg.AUGMENTATION.RE_MODE,
                                 color_jitter=cfg.AUGMENTATION.COLOR_JITTER,
                                 interpolation='random',
                                 num_workers=cfg.WORKERS,
                                 distributed=True,
                                 collate_fn=None,
                                 crop_pct=DEFAULT_CROP_PCT,
                                 mean=IMAGENET_DEFAULT_MEAN,
                                 std=IMAGENET_DEFAULT_STD)

    # imagenet validation dataset
    eval_dir = os.path.join(cfg.DATA_DIR, 'val')
    if not os.path.isdir(eval_dir):
        logger.info('Validation folder does not exist at: %s', eval_dir)
        sys.exit()
    dataset_eval = Dataset(eval_dir)
    loader_eval = create_loader(dataset_eval,
                                input_size=(3, cfg.DATASET.IMAGE_SIZE,
                                            cfg.DATASET.IMAGE_SIZE),
                                batch_size=4 * cfg.DATASET.BATCH_SIZE,
                                is_training=False,
                                use_prefetcher=True,
                                num_workers=cfg.WORKERS,
                                distributed=True,
                                crop_pct=DEFAULT_CROP_PCT,
                                mean=IMAGENET_DEFAULT_MEAN,
                                std=IMAGENET_DEFAULT_STD,
                                interpolation=cfg.DATASET.INTERPOLATION)

    # whether to use label smoothing
    if cfg.AUGMENTATION.SMOOTHING > 0.:
        train_loss_fn = LabelSmoothingCrossEntropy(
            smoothing=cfg.AUGMENTATION.SMOOTHING).cuda()
        validate_loss_fn = nn.CrossEntropyLoss().cuda()
    else:
        train_loss_fn = nn.CrossEntropyLoss().cuda()
        validate_loss_fn = train_loss_fn

    # initialize training parameters
    eval_metric = cfg.EVAL_METRICS
    best_metric, best_epoch, saver, best_children_pool = None, None, None, []
    if args.local_rank == 0:
        decreasing = True if eval_metric == 'loss' else False
        saver = CheckpointSaver(checkpoint_dir=output_dir,
                                decreasing=decreasing)

    # training scheme
    try:
        for epoch in range(start_epoch, num_epochs):
            loader_train.sampler.set_epoch(epoch)

            # train one epoch
            train_metrics = train_epoch(epoch,
                                        model,
                                        loader_train,
                                        optimizer,
                                        train_loss_fn,
                                        prioritized_board,
                                        MetaMN,
                                        cfg,
                                        lr_scheduler=lr_scheduler,
                                        saver=saver,
                                        output_dir=output_dir,
                                        logger=logger,
                                        est=model_est,
                                        local_rank=args.local_rank)

            # evaluate one epoch
            eval_metrics = validate(model,
                                    loader_eval,
                                    validate_loss_fn,
                                    prioritized_board,
                                    MetaMN,
                                    cfg,
                                    local_rank=args.local_rank,
                                    logger=logger)

            update_summary(epoch,
                           train_metrics,
                           eval_metrics,
                           os.path.join(output_dir, 'summary.csv'),
                           write_header=best_metric is None)

            if saver is not None:
                # save proper checkpoint with eval metric
                save_metric = eval_metrics[eval_metric]
                best_metric, best_epoch = saver.save_checkpoint(
                    model, optimizer, cfg, epoch=epoch, metric=save_metric)

    except KeyboardInterrupt:
        pass
scheduler = CosineAnnealingScheduler(
    optimizer,
    'lr',
    args.learning_rate,
    args.learning_rate / 1000,
    args.epochs * len(train_loader),
)
scheduler = create_lr_scheduler_with_warmup(scheduler, 0, args.learning_rate,
                                            1000)

(teacher, student), optimizer = amp.initialize([teacher, student],
                                               optimizer,
                                               opt_level="O2")
if args.distributed:
    student = convert_syncbn_model(student)
    teacher = DistributedDataParallel(teacher)
    student = DistributedDataParallel(student)


def create_segmentation_distillation_trainer(student,
                                             teacher,
                                             optimizer,
                                             supervised_loss_fn,
                                             distillation_loss_fn,
                                             device,
                                             use_f16=True,
                                             non_blocking=True):
    from ignite.engine import Engine, Events, _prepare_batch
    from ignite.metrics import RunningAverage, Loss
Beispiel #21
0
def main():
    setup_default_logging()
    args, args_text = _parse_args()

    args.prefetcher = not args.no_prefetcher
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1
    args.device = 'cuda:0'
    args.world_size = 1
    args.rank = 0  # global rank
    if args.distributed:
        args.device = 'cuda:%d' % args.local_rank
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()
        args.rank = torch.distributed.get_rank()
        _logger.info(
            'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
            % (args.rank, args.world_size))
    else:
        _logger.info('Training with a single process on 1 GPUs.')
    assert args.rank >= 0

    # resolve AMP arguments based on PyTorch / Apex availability
    use_amp = None
    if args.amp:
        # `--amp` chooses native amp before apex (APEX ver not actively maintained)
        if has_native_amp:
            args.native_amp = True
        elif has_apex:
            args.apex_amp = True
    if args.apex_amp and has_apex:
        use_amp = 'apex'
    elif args.native_amp and has_native_amp:
        use_amp = 'native'
    elif args.apex_amp or args.native_amp:
        _logger.warning(
            "Neither APEX or native Torch AMP is available, using float32. "
            "Install NVIDA apex or upgrade to PyTorch 1.6")

    random_seed(args.seed, args.rank)

    model = create_model(
        args.model,
        pretrained=args.pretrained,
        num_classes=args.num_classes,
        drop_rate=args.drop,
        drop_connect_rate=args.drop_connect,  # DEPRECATED, use drop_path
        drop_path_rate=args.drop_path,
        drop_block_rate=args.drop_block,
        global_pool=args.gp,
        bn_tf=args.bn_tf,
        bn_momentum=args.bn_momentum,
        bn_eps=args.bn_eps,
        scriptable=args.torchscript,
        checkpoint_path=args.initial_checkpoint)
    if args.num_classes is None:
        assert hasattr(
            model, 'num_classes'
        ), 'Model must have `num_classes` attr if not set on cmd line/config.'
        args.num_classes = model.num_classes  # FIXME handle model default vs config num_classes more elegantly

    if args.local_rank == 0:
        _logger.info(
            f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}'
        )

    data_config = resolve_data_config(vars(args),
                                      model=model,
                                      verbose=args.local_rank == 0)

    # setup augmentation batch splits for contrastive loss or split bn
    num_aug_splits = 0
    if args.aug_splits > 0:
        assert args.aug_splits > 1, 'A split of 1 makes no sense'
        num_aug_splits = args.aug_splits

    # enable split bn (separate bn stats per batch-portion)
    if args.split_bn:
        assert num_aug_splits > 1 or args.resplit
        model = convert_splitbn_model(model, max(num_aug_splits, 2))

    # move model to GPU, enable channels last layout if set
    model.cuda()
    if args.channels_last:
        model = model.to(memory_format=torch.channels_last)

    # setup synchronized BatchNorm for distributed training
    if args.distributed and args.sync_bn:
        assert not args.split_bn
        if has_apex and use_amp != 'native':
            # Apex SyncBN preferred unless native amp is activated
            model = convert_syncbn_model(model)
        else:
            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
        if args.local_rank == 0:
            _logger.info(
                'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
                'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.'
            )

    if args.torchscript:
        assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
        assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
        model = torch.jit.script(model)

    optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args))

    # setup automatic mixed-precision (AMP) loss scaling and op casting
    amp_autocast = suppress  # do nothing
    loss_scaler = None
    if use_amp == 'apex':
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
        loss_scaler = ApexScaler()
        if args.local_rank == 0:
            _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
    elif use_amp == 'native':
        amp_autocast = torch.cuda.amp.autocast
        loss_scaler = NativeScaler()
        if args.local_rank == 0:
            _logger.info(
                'Using native Torch AMP. Training in mixed precision.')
    else:
        if args.local_rank == 0:
            _logger.info('AMP not enabled. Training in float32.')

    # optionally resume from a checkpoint
    resume_epoch = None
    if args.resume:
        resume_epoch = resume_checkpoint(
            model,
            args.resume,
            optimizer=None if args.no_resume_opt else optimizer,
            loss_scaler=None if args.no_resume_opt else loss_scaler,
            log_info=args.local_rank == 0)

    # setup exponential moving average of model weights, SWA could be used here too
    model_ema = None
    if args.model_ema:
        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
        model_ema = ModelEmaV2(
            model,
            decay=args.model_ema_decay,
            device='cpu' if args.model_ema_force_cpu else None)
        if args.resume:
            load_checkpoint(model_ema.module, args.resume, use_ema=True)

    # setup distributed training
    if args.distributed:
        if has_apex and use_amp != 'native':
            # Apex DDP preferred unless native amp is activated
            if args.local_rank == 0:
                _logger.info("Using NVIDIA APEX DistributedDataParallel.")
            model = ApexDDP(model, delay_allreduce=True)
        else:
            if args.local_rank == 0:
                _logger.info("Using native Torch DistributedDataParallel.")
            model = NativeDDP(model, device_ids=[
                args.local_rank
            ])  # can use device str in Torch >= 1.1
        # NOTE: EMA model does not need to be wrapped by DDP

    # setup learning rate schedule and starting epoch
    lr_scheduler, num_epochs = create_scheduler(args, optimizer)
    start_epoch = 0
    if args.start_epoch is not None:
        # a specified start_epoch will always override the resume epoch
        start_epoch = args.start_epoch
    elif resume_epoch is not None:
        start_epoch = resume_epoch
    if lr_scheduler is not None and start_epoch > 0:
        lr_scheduler.step(start_epoch)

    if args.local_rank == 0:
        _logger.info('Scheduled epochs: {}'.format(num_epochs))

    # create the train and eval datasets
    dataset_train = create_dataset(args.dataset,
                                   root=args.data_dir,
                                   split=args.train_split,
                                   is_training=True,
                                   batch_size=args.batch_size,
                                   repeats=args.epoch_repeats)
    dataset_eval = create_dataset(args.dataset,
                                  root=args.data_dir,
                                  split=args.val_split,
                                  is_training=False,
                                  batch_size=args.batch_size)

    # setup mixup / cutmix
    collate_fn = None
    mixup_fn = None
    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
    if mixup_active:
        mixup_args = dict(mixup_alpha=args.mixup,
                          cutmix_alpha=args.cutmix,
                          cutmix_minmax=args.cutmix_minmax,
                          prob=args.mixup_prob,
                          switch_prob=args.mixup_switch_prob,
                          mode=args.mixup_mode,
                          label_smoothing=args.smoothing,
                          num_classes=args.num_classes)
        if args.prefetcher:
            assert not num_aug_splits  # collate conflict (need to support deinterleaving in collate mixup)
            collate_fn = FastCollateMixup(**mixup_args)
        else:
            mixup_fn = Mixup(**mixup_args)

    # wrap dataset in AugMix helper
    if num_aug_splits > 1:
        dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)

    # create data loaders w/ augmentation pipeiine
    train_interpolation = args.train_interpolation
    if args.no_aug or not train_interpolation:
        train_interpolation = data_config['interpolation']
    loader_train = create_loader(
        dataset_train,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        is_training=True,
        use_prefetcher=args.prefetcher,
        no_aug=args.no_aug,
        re_prob=args.reprob,
        re_mode=args.remode,
        re_count=args.recount,
        re_split=args.resplit,
        scale=args.scale,
        ratio=args.ratio,
        hflip=args.hflip,
        vflip=args.vflip,
        color_jitter=args.color_jitter,
        auto_augment=args.aa,
        num_aug_splits=num_aug_splits,
        interpolation=train_interpolation,
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        collate_fn=collate_fn,
        pin_memory=args.pin_mem,
        use_multi_epochs_loader=args.use_multi_epochs_loader)

    loader_eval = create_loader(
        dataset_eval,
        input_size=data_config['input_size'],
        batch_size=args.validation_batch_size_multiplier * args.batch_size,
        is_training=False,
        use_prefetcher=args.prefetcher,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        crop_pct=data_config['crop_pct'],
        pin_memory=args.pin_mem,
    )

    # setup loss function
    if args.jsd:
        assert num_aug_splits > 1  # JSD only valid with aug splits set
        train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits,
                                        smoothing=args.smoothing).cuda()
    elif mixup_active:
        # smoothing is handled with mixup target transform
        train_loss_fn = SoftTargetCrossEntropy().cuda()
    elif args.smoothing:
        train_loss_fn = LabelSmoothingCrossEntropy(
            smoothing=args.smoothing).cuda()
    else:
        train_loss_fn = nn.CrossEntropyLoss().cuda()
    validate_loss_fn = nn.CrossEntropyLoss().cuda()

    # setup checkpoint saver and eval metric tracking
    eval_metric = args.eval_metric
    best_metric = None
    best_epoch = None
    saver = None
    output_dir = ''
    if args.local_rank == 0:
        if args.experiment:
            exp_name = args.experiment
        else:
            exp_name = '-'.join([
                datetime.now().strftime("%Y%m%d-%H%M%S"),
                safe_model_name(args.model),
                str(data_config['input_size'][-1])
            ])
        output_dir = get_outdir(
            args.output if args.output else './output/train', exp_name)
        decreasing = True if eval_metric == 'loss' else False
        saver = CheckpointSaver(model=model,
                                optimizer=optimizer,
                                args=args,
                                model_ema=model_ema,
                                amp_scaler=loss_scaler,
                                checkpoint_dir=output_dir,
                                recovery_dir=output_dir,
                                decreasing=decreasing,
                                max_history=args.checkpoint_hist)
        with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
            f.write(args_text)

    try:
        for epoch in range(start_epoch, num_epochs):
            if args.distributed and hasattr(loader_train.sampler, 'set_epoch'):
                loader_train.sampler.set_epoch(epoch)

            train_metrics = train_one_epoch(epoch,
                                            model,
                                            loader_train,
                                            optimizer,
                                            train_loss_fn,
                                            args,
                                            lr_scheduler=lr_scheduler,
                                            saver=saver,
                                            output_dir=output_dir,
                                            amp_autocast=amp_autocast,
                                            loss_scaler=loss_scaler,
                                            model_ema=model_ema,
                                            mixup_fn=mixup_fn)

            if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
                if args.local_rank == 0:
                    _logger.info(
                        "Distributing BatchNorm running means and vars")
                distribute_bn(model, args.world_size, args.dist_bn == 'reduce')

            eval_metrics = validate(model,
                                    loader_eval,
                                    validate_loss_fn,
                                    args,
                                    amp_autocast=amp_autocast)

            if model_ema is not None and not args.model_ema_force_cpu:
                if args.distributed and args.dist_bn in ('broadcast',
                                                         'reduce'):
                    distribute_bn(model_ema, args.world_size,
                                  args.dist_bn == 'reduce')
                ema_eval_metrics = validate(model_ema.module,
                                            loader_eval,
                                            validate_loss_fn,
                                            args,
                                            amp_autocast=amp_autocast,
                                            log_suffix=' (EMA)')
                eval_metrics = ema_eval_metrics

            if lr_scheduler is not None:
                # step LR for next epoch
                lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])

            update_summary(epoch,
                           train_metrics,
                           eval_metrics,
                           os.path.join(output_dir, 'summary.csv'),
                           write_header=best_metric is None)

            if saver is not None:
                # save proper checkpoint with eval metric
                save_metric = eval_metrics[eval_metric]
                best_metric, best_epoch = saver.save_checkpoint(
                    epoch, metric=save_metric)

    except KeyboardInterrupt:
        pass
    if best_metric is not None:
        _logger.info('*** Best metric: {0} (epoch {1})'.format(
            best_metric, best_epoch))
Beispiel #22
0
    def __init__(self,
                 serialization_dir,
                 params,
                 model,
                 loss,
                 alphabet,
                 local_rank=0,
                 world_size=1,
                 sync_bn=False,
                 opt_level='O0',
                 keep_batchnorm_fp32=None,
                 loss_scale=1.0):
        self.alphabet = alphabet
        self.clip_grad_norm = params.get('clip_grad_norm', None)
        self.clip_grad_value = params.get('clip_grad_value', None)
        self.warmup_epochs = params.get('warmup_epochs', 0)
        self.label_smoothing = params.get('label_smoothing', 0.0)

        self.local_rank = local_rank
        self.loss = loss.cuda()
        self.model = model
        self.best_monitor = float('inf')
        self.monitor = params.get('monitor', 'loss')
        self.serialization_dir = serialization_dir
        self.distributed = world_size > 1
        self.world_size = world_size
        self.epoch = 0
        self.num_epochs = 0

        self.start_epoch = 0
        self.start_iteration = 0
        self.start_time = 0
        self.iterations_per_epoch = None

        self.time_since_last = time.time()

        self.save_every = params.get('save_every', 60 * 10)  # 10 minutes

        if sync_bn:
            logger.info('Using Apex `sync_bn`')
            self.model = convert_syncbn_model(self.model)

        self.model = model.cuda()

        # Setup optimizer
        parameters = [(n, p) for n, p in self.model.named_parameters()
                      if p.requires_grad]
        self.optimizer = optimizers.from_params(params.pop("optimizer"),
                                                parameters,
                                                world_size=self.world_size)

        self.model, self.optimizer = amp.initialize(
            self.model,
            self.optimizer,
            opt_level=opt_level,
            keep_batchnorm_fp32=keep_batchnorm_fp32,
            loss_scale=loss_scale)

        # Setup lr scheduler
        scheduler_params = params.pop('lr_scheduler', None)
        self.lr_scheduler = None
        if scheduler_params:
            self.lr_scheduler = lr_schedulers.from_params(
                scheduler_params, self.optimizer)

        self.base_lrs = list(
            map(lambda group: group['initial_lr'],
                self.optimizer.param_groups))

        # Setup metrics
        metrics_params = params.pop('metrics', [])
        if 'loss' not in metrics_params:
            metrics_params = ['loss'] + metrics_params

        # Initializing history
        self.metrics = {}
        for phase in ['train', 'val']:
            self.metrics[phase] = metrics.from_params(metrics_params,
                                                      alphabet=alphabet)

        if self.distributed:
            self.model = DistributedDataParallel(self.model,
                                                 delay_allreduce=True)
Beispiel #23
0
def main():
    setup_default_logging()
    args, args_text = _parse_args()

    args.prefetcher = not args.no_prefetcher
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1
    args.device = 'cuda:0'
    args.world_size = 1
    args.rank = 0  # global rank
    if args.distributed:
        args.device = 'cuda:%d' % args.local_rank
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()
        args.rank = torch.distributed.get_rank()
    assert args.rank >= 0

    if args.distributed:
        logging.info(
            'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
            % (args.rank, args.world_size))
    else:
        logging.info('Training with a single process on 1 GPU.')

    torch.manual_seed(args.seed + args.rank)

    # create model
    config = get_efficientdet_config(args.model)
    config.redundant_bias = args.redundant_bias  # redundant conv + BN bias layers (True to match official models)
    model = EfficientDet(config)
    model = DetBenchTrain(model, config)

    # FIXME create model factory, pretrained zoo
    # model = create_model(
    #     args.model,
    #     pretrained=args.pretrained,
    #     num_classes=args.num_classes,
    #     drop_rate=args.drop,
    #     drop_connect_rate=args.drop_connect,  # DEPRECATED, use drop_path
    #     drop_path_rate=args.drop_path,
    #     drop_block_rate=args.drop_block,
    #     global_pool=args.gp,
    #     bn_tf=args.bn_tf,
    #     bn_momentum=args.bn_momentum,
    #     bn_eps=args.bn_eps,
    #     checkpoint_path=args.initial_checkpoint)

    if args.local_rank == 0:
        logging.info('Model %s created, param count: %d' %
                     (args.model, sum([m.numel()
                                       for m in model.parameters()])))

    model.cuda()
    optimizer = create_optimizer(args, model)
    use_amp = False
    if has_apex and args.amp:
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
        use_amp = True
    if args.local_rank == 0:
        logging.info('NVIDIA APEX {}. AMP {}.'.format(
            'installed' if has_apex else 'not installed',
            'on' if use_amp else 'off'))

    # optionally resume from a checkpoint
    resume_state = {}
    resume_epoch = None
    if args.resume:
        resume_state, resume_epoch = resume_checkpoint(_unwrap_bench(model),
                                                       args.resume)
    if resume_state and not args.no_resume_opt:
        if 'optimizer' in resume_state:
            if args.local_rank == 0:
                logging.info('Restoring Optimizer state from checkpoint')
            optimizer.load_state_dict(resume_state['optimizer'])
        if use_amp and 'amp' in resume_state and 'load_state_dict' in amp.__dict__:
            if args.local_rank == 0:
                logging.info('Restoring NVIDIA AMP state from checkpoint')
            amp.load_state_dict(resume_state['amp'])
    del resume_state

    model_ema = None
    if args.model_ema:
        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
        model_ema = ModelEma(model,
                             decay=args.model_ema_decay,
                             device='cpu' if args.model_ema_force_cpu else '')
        #resume=args.resume)  # FIXME bit of a mess with bench
        if args.resume:
            load_checkpoint(_unwrap_bench(model_ema),
                            args.resume,
                            use_ema=True)

    if args.distributed:
        if args.sync_bn:
            try:
                if has_apex:
                    model = convert_syncbn_model(model)
                else:
                    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                        model)
                if args.local_rank == 0:
                    logging.info(
                        'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
                        'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.'
                    )
            except Exception as e:
                logging.error(
                    'Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1'
                )
        if has_apex:
            model = DDP(model, delay_allreduce=True)
        else:
            if args.local_rank == 0:
                logging.info(
                    "Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP."
                )
            model = DDP(model,
                        device_ids=[args.local_rank
                                    ])  # can use device str in Torch >= 1.1
        # NOTE: EMA model does not need to be wrapped by DDP

    lr_scheduler, num_epochs = create_scheduler(args, optimizer)
    start_epoch = 0
    if args.start_epoch is not None:
        # a specified start_epoch will always override the resume epoch
        start_epoch = args.start_epoch
    elif resume_epoch is not None:
        start_epoch = resume_epoch
    if lr_scheduler is not None and start_epoch > 0:
        lr_scheduler.step(start_epoch)

    if args.local_rank == 0:
        logging.info('Scheduled epochs: {}'.format(num_epochs))

    train_anno_set = 'train2017'
    train_annotation_path = os.path.join(args.data, 'annotations',
                                         f'instances_{train_anno_set}.json')
    train_image_dir = train_anno_set
    dataset_train = CocoDetection(os.path.join(args.data, train_image_dir),
                                  train_annotation_path)

    # FIXME cutmix/mixup worth investigating?
    # collate_fn = None
    # if args.prefetcher and args.mixup > 0:
    #     collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes)

    loader_train = create_loader(
        dataset_train,
        input_size=config.image_size,
        batch_size=args.batch_size,
        is_training=True,
        use_prefetcher=args.prefetcher,
        #re_prob=args.reprob,  # FIXME add back various augmentations
        #re_mode=args.remode,
        #re_count=args.recount,
        #re_split=args.resplit,
        #color_jitter=args.color_jitter,
        #auto_augment=args.aa,
        interpolation=args.train_interpolation,
        #mean=data_config['mean'],
        #std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        #collate_fn=collate_fn,
        pin_mem=args.pin_mem,
    )

    train_anno_set = 'val2017'
    train_annotation_path = os.path.join(args.data, 'annotations',
                                         f'instances_{train_anno_set}.json')
    train_image_dir = train_anno_set
    dataset_eval = CocoDetection(os.path.join(args.data, train_image_dir),
                                 train_annotation_path)

    loader_eval = create_loader(
        dataset_eval,
        input_size=config.image_size,
        batch_size=args.validation_batch_size_multiplier * args.batch_size,
        is_training=False,
        use_prefetcher=args.prefetcher,
        interpolation=args.interpolation,
        #mean=data_config['mean'],
        #std=data_config['std'],
        num_workers=args.workers,
        #distributed=args.distributed,
        pin_mem=args.pin_mem,
    )

    eval_metric = args.eval_metric
    best_metric = None
    best_epoch = None
    saver = None
    output_dir = ''
    if args.local_rank == 0:
        output_base = args.output if args.output else './output'
        exp_name = '-'.join(
            [datetime.now().strftime("%Y%m%d-%H%M%S"), args.model])
        output_dir = get_outdir(output_base, 'train', exp_name)
        decreasing = True if eval_metric == 'loss' else False
        saver = CheckpointSaver(checkpoint_dir=output_dir,
                                decreasing=decreasing)
        with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
            f.write(args_text)

    try:
        for epoch in range(start_epoch, num_epochs):
            if args.distributed:
                loader_train.sampler.set_epoch(epoch)

            train_metrics = train_epoch(epoch,
                                        model,
                                        loader_train,
                                        optimizer,
                                        args,
                                        lr_scheduler=lr_scheduler,
                                        saver=saver,
                                        output_dir=output_dir,
                                        use_amp=use_amp,
                                        model_ema=model_ema)

            if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
                if args.local_rank == 0:
                    logging.info(
                        "Distributing BatchNorm running means and vars")
                distribute_bn(model, args.world_size, args.dist_bn == 'reduce')

            eval_metrics = validate(model, loader_eval, args)

            if model_ema is not None and not args.model_ema_force_cpu:
                if args.distributed and args.dist_bn in ('broadcast',
                                                         'reduce'):
                    distribute_bn(model_ema, args.world_size,
                                  args.dist_bn == 'reduce')

                ema_eval_metrics = validate(model_ema.ema,
                                            loader_eval,
                                            args,
                                            log_suffix=' (EMA)')
                eval_metrics = ema_eval_metrics

            if lr_scheduler is not None:
                # step LR for next epoch
                lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])

            update_summary(epoch,
                           train_metrics,
                           eval_metrics,
                           os.path.join(output_dir, 'summary.csv'),
                           write_header=best_metric is None)

            if saver is not None:
                # save proper checkpoint with eval metric
                save_metric = eval_metrics[eval_metric]
                best_metric, best_epoch = saver.save_checkpoint(
                    _unwrap_bench(model),
                    optimizer,
                    args,
                    epoch=epoch,
                    model_ema=_unwrap_bench(model_ema),
                    metric=save_metric,
                    use_amp=use_amp)

    except KeyboardInterrupt:
        pass
    if best_metric is not None:
        logging.info('*** Best metric: {0} (epoch {1})'.format(
            best_metric, best_epoch))
Beispiel #24
0
def main():
    setup_default_logging()
    args, args_text = _parse_args()

    args.prefetcher = not args.no_prefetcher
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1
        if args.distributed and args.num_gpu > 1:
            _logger.warning(
                'Using more than one GPU per process in distributed mode is not allowed.Setting num_gpu to 1.'
            )
            args.num_gpu = 1

    args.device = 'cuda:0'
    args.world_size = 1
    args.rank = 0  # global rank
    if args.distributed:
        args.num_gpu = 1
        args.device = 'cuda:%d' % args.local_rank
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()
        args.rank = torch.distributed.get_rank()
    assert args.rank >= 0

    if args.distributed:
        _logger.info(
            'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
            % (args.rank, args.world_size))
    else:
        _logger.info('Training with a single process on %d GPUs.' %
                     args.num_gpu)

    torch.manual_seed(args.seed + args.rank)

    # ! build model
    if 'inception' in args.model:
        model = create_model(
            args.model,
            pretrained=args.pretrained,
            num_classes=args.num_classes,
            drop_rate=args.drop,
            aux_logits=True,  # ! add aux loss
            global_pool=args.gp,
            bn_tf=args.bn_tf,
            bn_momentum=args.bn_momentum,
            bn_eps=args.bn_eps,
            checkpoint_path=args.initial_checkpoint)
    else:
        model = create_model(
            args.model,
            pretrained=args.pretrained,
            num_classes=args.num_classes,
            drop_rate=args.drop,
            drop_connect_rate=args.drop_connect,  # DEPRECATED, use drop_path
            drop_path_rate=args.drop_path,
            drop_block_rate=args.drop_block,
            global_pool=args.gp,
            bn_tf=args.bn_tf,
            bn_momentum=args.bn_momentum,
            bn_eps=args.bn_eps,
            checkpoint_path=args.initial_checkpoint)

    # ! add more layer to classifier layer
    if args.create_classifier_layerfc:
        model.global_pool, model.classifier = create_classifier_layerfc(
            model.num_features, model.num_classes)

    if args.local_rank == 0:
        _logger.info('Model %s created, param count: %d' %
                     (args.model, sum([m.numel()
                                       for m in model.parameters()])))

    data_config = resolve_data_config(vars(args),
                                      model=model,
                                      verbose=args.local_rank == 0)

    num_aug_splits = 0
    if args.aug_splits > 0:
        assert args.aug_splits > 1, 'A split of 1 makes no sense'
        num_aug_splits = args.aug_splits

    if args.split_bn:
        assert num_aug_splits > 1 or args.resplit
        model = convert_splitbn_model(model, max(num_aug_splits, 2))

    use_amp = None
    if args.amp:
        # for backwards compat, `--amp` arg tries apex before native amp
        if has_apex:
            args.apex_amp = True
        elif has_native_amp:
            args.native_amp = True
    if args.apex_amp and has_apex:
        use_amp = 'apex'
    elif args.native_amp and has_native_amp:
        use_amp = 'native'
    elif args.apex_amp or args.native_amp:
        _logger.warning(
            "Neither APEX or native Torch AMP is available, using float32. "
            "Install NVIDA apex or upgrade to PyTorch 1.6")

    if args.num_gpu > 1:
        if use_amp == 'apex':
            _logger.warning(
                'Apex AMP does not work well with nn.DataParallel, disabling. Use DDP or Torch AMP.'
            )
            use_amp = None
        model = nn.DataParallel(model,
                                device_ids=list(range(args.num_gpu))).cuda()
        assert not args.channels_last, "Channels last not supported with DP, use DDP."
    else:
        model.cuda()
        if args.channels_last:
            model = model.to(memory_format=torch.channels_last)

    optimizer = create_optimizer(args, model)
    # ! optimizer
    if args.classification_layer_name is not None:  # ! add our own classification layer... doesn't really improve very much though...
        args.classification_layer_name = args.classification_layer_name.strip(
        ).split()
        print('classification_layer_name {}'.format(
            args.classification_layer_name))
        optimizer = create_optimizer(
            args,
            model,
            filter_bias_and_bn=args.filter_bias_and_bn,
            classification_layer_name=args.classification_layer_name)

    amp_autocast = suppress  # do nothing
    loss_scaler = None
    if use_amp == 'apex':
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
        loss_scaler = ApexScaler()
        if args.local_rank == 0:
            _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
    elif use_amp == 'native':
        amp_autocast = torch.cuda.amp.autocast
        loss_scaler = NativeScaler()
        if args.local_rank == 0:
            _logger.info(
                'Using native Torch AMP. Training in mixed precision.')
    else:
        if args.local_rank == 0:
            _logger.info('AMP not enabled. Training in float32.')

    if args.num_gpu > 1:  # ! these lines used to be above @optimizer, but we move it here, so that we override apex warning.
        if args.amp:
            _logger.warning(
                'AMP does not work well with nn.DataParallel, disabling. Use distributed mode for multi-GPU AMP.'
            )
            _logger.warning('... we will ignore this ... .')
            # args.amp = False
        model = nn.DataParallel(model,
                                device_ids=list(range(args.num_gpu))).cuda()
    else:
        model.cuda()

    # optionally resume from a checkpoint
    resume_epoch = None
    if args.resume:
        resume_epoch = resume_checkpoint(
            model,
            args.resume,
            optimizer=None if args.no_resume_opt else optimizer,
            loss_scaler=None if args.no_resume_opt else loss_scaler,
            log_info=args.local_rank == 0)

    model_ema = None
    if args.model_ema:
        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
        model_ema = ModelEma(model,
                             decay=args.model_ema_decay,
                             device='cpu' if args.model_ema_force_cpu else '',
                             resume=args.resume)

    if args.distributed:
        if args.sync_bn:
            assert not args.split_bn
            try:
                if has_apex and use_amp != 'native':
                    # Apex SyncBN preferred unless native amp is activated
                    model = convert_syncbn_model(model)
                else:
                    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                        model)
                if args.local_rank == 0:
                    _logger.info(
                        'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
                        'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.'
                    )
            except Exception as e:
                _logger.error(
                    'Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1'
                )
        if has_apex and use_amp != 'native':
            # Apex DDP preferred unless native amp is activated
            if args.local_rank == 0:
                _logger.info("Using NVIDIA APEX DistributedDataParallel.")
            model = ApexDDP(model, delay_allreduce=True)
        else:
            if args.local_rank == 0:
                _logger.info("Using native Torch DistributedDataParallel.")
            model = NativeDDP(model, device_ids=[
                args.local_rank
            ])  # can use device str in Torch >= 1.1
        # NOTE: EMA model does not need to be wrapped by DDP

    lr_scheduler, num_epochs = create_scheduler(args, optimizer)
    start_epoch = 0
    if args.start_epoch is not None:
        # a specified start_epoch will always override the resume epoch
        start_epoch = args.start_epoch
    elif resume_epoch is not None:
        start_epoch = resume_epoch
    if lr_scheduler is not None and start_epoch > 0:
        lr_scheduler.step(start_epoch)

    if args.local_rank == 0:
        _logger.info('Scheduled epochs: {}'.format(num_epochs))

    train_dir = os.path.join(args.data, 'train')
    if not os.path.exists(train_dir):
        _logger.error(
            'Training folder does not exist at: {}'.format(train_dir))
        exit(1)
    dataset_train = Dataset(train_dir, args=args)

    collate_fn = None
    mixup_fn = None
    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
    if mixup_active:
        mixup_args = dict(mixup_alpha=args.mixup,
                          cutmix_alpha=args.cutmix,
                          cutmix_minmax=args.cutmix_minmax,
                          prob=args.mixup_prob,
                          switch_prob=args.mixup_switch_prob,
                          mode=args.mixup_mode,
                          label_smoothing=args.smoothing,
                          num_classes=args.num_classes)
        if args.prefetcher:
            assert not num_aug_splits  # collate conflict (need to support deinterleaving in collate mixup)
            collate_fn = FastCollateMixup(**mixup_args)
        else:
            mixup_fn = Mixup(**mixup_args)

    if num_aug_splits > 1:
        dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)

    train_interpolation = args.train_interpolation
    if args.no_aug or not train_interpolation:
        train_interpolation = data_config['interpolation']

    loader_train = create_loader(
        dataset_train,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        is_training=True,
        use_prefetcher=args.prefetcher,
        no_aug=args.no_aug,
        re_prob=args.reprob,
        re_mode=args.remode,
        re_count=args.recount,
        re_split=args.resplit,
        scale=args.scale,
        ratio=args.ratio,
        hflip=args.hflip,
        vflip=args.vflip,
        color_jitter=args.color_jitter,
        auto_augment=args.aa,  # ! see file auto_augment.py
        num_aug_splits=num_aug_splits,
        interpolation=train_interpolation,
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        collate_fn=collate_fn,
        pin_memory=args.pin_mem,
        use_multi_epochs_loader=args.use_multi_epochs_loader,
        args=args)

    eval_dir = os.path.join(args.data, 'val')
    if not os.path.isdir(eval_dir):
        eval_dir = os.path.join(args.data, 'validation')
        if not os.path.isdir(eval_dir):
            _logger.error(
                'Validation folder does not exist at: {}'.format(eval_dir))
            exit(1)
    dataset_eval = Dataset(eval_dir, args=args)
    loader_eval = create_loader(
        dataset_eval,
        input_size=data_config['input_size'],
        batch_size=args.validation_batch_size_multiplier *
        args.batch_size,  # ! so we can eval faster
        is_training=False,
        use_prefetcher=args.prefetcher,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        crop_pct=data_config['crop_pct'],
        pin_memory=args.pin_mem,
        args=args)

    # add weighted loss for each label class
    if args.weighted_cross_entropy is None:
        weighted_cross_entropy = torch.ones(args.num_classes)
    else:
        weighted_cross_entropy = args.weighted_cross_entropy.strip().split()
        weighted_cross_entropy = torch.FloatTensor(
            [float(w) for w in weighted_cross_entropy])  # string

    if args.jsd:
        assert num_aug_splits > 1  # JSD only valid with aug splits set
        train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits,
                                        smoothing=args.smoothing).cuda()
    elif mixup_active:
        # smoothing is handled with mixup target transform
        train_loss_fn = SoftTargetCrossEntropy().cuda()
    elif args.smoothing:
        train_loss_fn = LabelSmoothingCrossEntropy(
            smoothing=args.smoothing).cuda()
    else:
        train_loss_fn = nn.CrossEntropyLoss(
            weight=weighted_cross_entropy).cuda()

    if args.weighted_cross_entropy_eval:
        validate_loss_fn = nn.CrossEntropyLoss(
            weight=weighted_cross_entropy).cuda()
    else:
        validate_loss_fn = nn.CrossEntropyLoss().cuda(
        )  # ! eval as usual ? weight=weighted_cross_entropy

    eval_metric = args.eval_metric
    best_metric = None
    best_epoch = 0
    saver = None
    output_dir = ''
    if args.local_rank == 0:
        output_base = args.output if args.output else './output'
        exp_name = '-'.join([
            datetime.now().strftime("%Y%m%d-%H%M%S"), args.model,
            str(data_config['input_size'][-1])
        ])
        output_dir = get_outdir(output_base, 'train', exp_name)
        decreasing = True if eval_metric == 'loss' else False
        saver = CheckpointSaver(model=model,
                                optimizer=optimizer,
                                args=args,
                                model_ema=model_ema,
                                amp_scaler=loss_scaler,
                                checkpoint_dir=output_dir,
                                recovery_dir=output_dir,
                                decreasing=decreasing)
        with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
            f.write(args_text)

    try:
        for epoch in range(start_epoch, num_epochs):
            if args.distributed:
                loader_train.sampler.set_epoch(epoch)

            train_metrics = train_epoch(epoch,
                                        model,
                                        loader_train,
                                        optimizer,
                                        train_loss_fn,
                                        args,
                                        lr_scheduler=lr_scheduler,
                                        saver=saver,
                                        output_dir=output_dir,
                                        amp_autocast=amp_autocast,
                                        loss_scaler=loss_scaler,
                                        model_ema=model_ema,
                                        mixup_fn=mixup_fn)

            if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
                if args.local_rank == 0:
                    _logger.info(
                        "Distributing BatchNorm running means and vars")
                distribute_bn(model, args.world_size, args.dist_bn == 'reduce')

            eval_metrics = validate(model,
                                    loader_eval,
                                    validate_loss_fn,
                                    args,
                                    amp_autocast=amp_autocast)

            if model_ema is not None and not args.model_ema_force_cpu:
                if args.distributed and args.dist_bn in ('broadcast',
                                                         'reduce'):
                    distribute_bn(model_ema, args.world_size,
                                  args.dist_bn == 'reduce')
                ema_eval_metrics = validate(model_ema.ema,
                                            loader_eval,
                                            validate_loss_fn,
                                            args,
                                            amp_autocast=amp_autocast,
                                            log_suffix=' (EMA)')
                eval_metrics = ema_eval_metrics

            if lr_scheduler is not None:
                # step LR for next epoch
                lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])

            update_summary(epoch,
                           train_metrics,
                           eval_metrics,
                           os.path.join(output_dir, 'summary.csv'),
                           write_header=best_metric is None)

            if saver is not None:
                # save proper checkpoint with eval metric
                save_metric = eval_metrics[eval_metric]
                best_metric, best_epoch = saver.save_checkpoint(
                    epoch, metric=save_metric)

            # early stop
            if epoch - best_epoch > args.early_stop_counter:
                _logger.info(
                    '*** Best metric: {0} (epoch {1}) (current epoch {2}'.
                    format(best_metric, best_epoch, epoch))
                break  # ! exit

    except KeyboardInterrupt:
        pass
    if best_metric is not None:
        _logger.info('*** Best metric: {0} (epoch {1})'.format(
            best_metric, best_epoch))
Beispiel #25
0
def main():
    setup_default_logging()
    args, args_text = _parse_args()

    args.prefetcher = not args.no_prefetcher
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1
        if args.distributed and args.num_gpu > 1:
            logging.warning(
                'Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.'
            )
            args.num_gpu = 1

    args.device = 'cuda:0'
    args.world_size = 1
    args.rank = 0  # global rank
    if args.distributed:
        args.num_gpu = 1
        args.device = 'cuda:%d' % args.local_rank
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()
        args.rank = torch.distributed.get_rank()
    assert args.rank >= 0

    if args.distributed:
        logging.info(
            'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
            % (args.rank, args.world_size))
    else:
        logging.info('Training with a single process on %d GPUs.' %
                     args.num_gpu)

    torch.manual_seed(args.seed + args.rank)

    model = create_model(
        args.model,
        pretrained=args.pretrained,
        num_classes=args.num_classes,
        drop_rate=args.drop,
        drop_connect_rate=args.drop_connect,  # DEPRECATED, use drop_path
        drop_path_rate=args.drop_path,
        drop_block_rate=args.drop_block,
        global_pool=args.gp,
        bn_tf=args.bn_tf,
        bn_momentum=args.bn_momentum,
        bn_eps=args.bn_eps,
        checkpoint_path=args.initial_checkpoint)

    if args.local_rank == 0:
        logging.info('Model %s created, param count: %d' %
                     (args.model, sum([m.numel()
                                       for m in model.parameters()])))

    data_config = resolve_data_config(vars(args),
                                      model=model,
                                      verbose=args.local_rank == 0)

    num_aug_splits = 0
    if args.aug_splits > 0:
        assert args.aug_splits > 1, 'A split of 1 makes no sense'
        num_aug_splits = args.aug_splits

    if args.split_bn:
        assert num_aug_splits > 1 or args.resplit
        model = convert_splitbn_model(model, max(num_aug_splits, 2))

    if args.num_gpu > 1:
        if args.amp:
            logging.warning(
                'AMP does not work well with nn.DataParallel, disabling. Use distributed mode for multi-GPU AMP.'
            )
            args.amp = False
        model = nn.DataParallel(model,
                                device_ids=list(range(args.num_gpu))).cuda()
    else:
        model.cuda()

    optimizer = create_optimizer(args, model)

    use_amp = False
    if has_apex and args.amp:
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
        use_amp = True
    if args.local_rank == 0:
        logging.info('NVIDIA APEX {}. AMP {}.'.format(
            'installed' if has_apex else 'not installed',
            'on' if use_amp else 'off'))

    # optionally resume from a checkpoint
    resume_state = {}
    resume_epoch = None
    if args.resume:
        resume_state, resume_epoch = resume_checkpoint(model, args.resume)
    if resume_state and not args.no_resume_opt:
        if 'optimizer' in resume_state:
            if args.local_rank == 0:
                logging.info('Restoring Optimizer state from checkpoint')
            optimizer.load_state_dict(resume_state['optimizer'])
        if use_amp and 'amp' in resume_state and 'load_state_dict' in amp.__dict__:
            if args.local_rank == 0:
                logging.info('Restoring NVIDIA AMP state from checkpoint')
            amp.load_state_dict(resume_state['amp'])
    del resume_state

    model_ema = None
    if args.model_ema:
        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
        model_ema = ModelEma(model,
                             decay=args.model_ema_decay,
                             device='cpu' if args.model_ema_force_cpu else '',
                             resume=args.resume)

    if args.distributed:
        if args.sync_bn:
            assert not args.split_bn
            try:
                if has_apex:
                    model = convert_syncbn_model(model)
                else:
                    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                        model)
                if args.local_rank == 0:
                    logging.info(
                        'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
                        'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.'
                    )
            except Exception as e:
                logging.error(
                    'Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1'
                )
        if has_apex:
            model = DDP(model, delay_allreduce=True)
        else:
            if args.local_rank == 0:
                logging.info(
                    "Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP."
                )
            model = DDP(model,
                        device_ids=[args.local_rank
                                    ])  # can use device str in Torch >= 1.1
        # NOTE: EMA model does not need to be wrapped by DDP

    lr_scheduler, num_epochs = create_scheduler(args, optimizer)
    start_epoch = 0
    if args.start_epoch is not None:
        # a specified start_epoch will always override the resume epoch
        start_epoch = args.start_epoch
    elif resume_epoch is not None:
        start_epoch = resume_epoch
    if lr_scheduler is not None and start_epoch > 0:
        lr_scheduler.step(start_epoch)

    if args.local_rank == 0:
        logging.info('Scheduled epochs: {}'.format(num_epochs))

    train_dir = os.path.join(args.data, 'train')
    if not os.path.exists(train_dir):
        logging.error(
            'Training folder does not exist at: {}'.format(train_dir))
        exit(1)
    dataset_train = Dataset(train_dir)

    collate_fn = None
    if args.prefetcher and args.mixup > 0:
        assert not num_aug_splits  # collate conflict (need to support deinterleaving in collate mixup)
        collate_fn = FastCollateMixup(args.mixup, args.smoothing,
                                      args.num_classes)

    if num_aug_splits > 1:
        dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)

    loader_train = create_loader(
        dataset_train,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        is_training=True,
        use_prefetcher=args.prefetcher,
        re_prob=args.reprob,
        re_mode=args.remode,
        re_count=args.recount,
        re_split=args.resplit,
        color_jitter=args.color_jitter,
        auto_augment=args.aa,
        num_aug_splits=num_aug_splits,
        interpolation=args.train_interpolation,
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        collate_fn=collate_fn,
        pin_memory=args.pin_mem,
        use_multi_epochs_loader=args.use_multi_epochs_loader)

    eval_dir = os.path.join(args.data, 'val')
    if not os.path.isdir(eval_dir):
        eval_dir = os.path.join(args.data, 'validation')
        if not os.path.isdir(eval_dir):
            logging.error(
                'Validation folder does not exist at: {}'.format(eval_dir))
            exit(1)
    dataset_eval = Dataset(eval_dir)

    loader_eval = create_loader(
        dataset_eval,
        input_size=data_config['input_size'],
        batch_size=args.validation_batch_size_multiplier * args.batch_size,
        is_training=False,
        use_prefetcher=args.prefetcher,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        crop_pct=data_config['crop_pct'],
        pin_memory=args.pin_mem,
    )

    if args.jsd:
        assert num_aug_splits > 1  # JSD only valid with aug splits set
        train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits,
                                        smoothing=args.smoothing).cuda()
        validate_loss_fn = nn.CrossEntropyLoss().cuda()
    elif args.mixup > 0.:
        # smoothing is handled with mixup label transform
        train_loss_fn = SoftTargetCrossEntropy().cuda()
        validate_loss_fn = nn.CrossEntropyLoss().cuda()
    elif args.smoothing:
        train_loss_fn = LabelSmoothingCrossEntropy(
            smoothing=args.smoothing).cuda()
        validate_loss_fn = nn.CrossEntropyLoss().cuda()
    else:
        train_loss_fn = nn.CrossEntropyLoss().cuda()
        validate_loss_fn = train_loss_fn

    eval_metric = args.eval_metric
    best_metric = None
    best_epoch = None
    saver = None
    output_dir = ''
    if args.local_rank == 0:
        output_base = args.output if args.output else './output'
        exp_name = '-'.join([
            datetime.now().strftime("%Y%m%d-%H%M%S"), args.model,
            str(data_config['input_size'][-1])
        ])
        output_dir = get_outdir(output_base, 'train', exp_name)
        decreasing = True if eval_metric == 'loss' else False
        saver = CheckpointSaver(checkpoint_dir=output_dir,
                                decreasing=decreasing)
        with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
            f.write(args_text)

    try:
        for epoch in range(start_epoch, num_epochs):
            if args.distributed:
                loader_train.sampler.set_epoch(epoch)

            train_metrics = train_epoch(epoch,
                                        model,
                                        loader_train,
                                        optimizer,
                                        train_loss_fn,
                                        args,
                                        lr_scheduler=lr_scheduler,
                                        saver=saver,
                                        output_dir=output_dir,
                                        use_amp=use_amp,
                                        model_ema=model_ema)

            if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
                if args.local_rank == 0:
                    logging.info(
                        "Distributing BatchNorm running means and vars")
                distribute_bn(model, args.world_size, args.dist_bn == 'reduce')

            eval_metrics = validate(model, loader_eval, validate_loss_fn, args)

            if model_ema is not None and not args.model_ema_force_cpu:
                if args.distributed and args.dist_bn in ('broadcast',
                                                         'reduce'):
                    distribute_bn(model_ema, args.world_size,
                                  args.dist_bn == 'reduce')

                ema_eval_metrics = validate(model_ema.ema,
                                            loader_eval,
                                            validate_loss_fn,
                                            args,
                                            log_suffix=' (EMA)')
                eval_metrics = ema_eval_metrics

            if lr_scheduler is not None:
                # step LR for next epoch
                lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])

            update_summary(epoch,
                           train_metrics,
                           eval_metrics,
                           os.path.join(output_dir, 'summary.csv'),
                           write_header=best_metric is None)

            if saver is not None:
                # save proper checkpoint with eval metric
                save_metric = eval_metrics[eval_metric]
                best_metric, best_epoch = saver.save_checkpoint(
                    model,
                    optimizer,
                    args,
                    epoch=epoch,
                    model_ema=model_ema,
                    metric=save_metric,
                    use_amp=use_amp)

    except KeyboardInterrupt:
        pass
    if best_metric is not None:
        logging.info('*** Best metric: {0} (epoch {1})'.format(
            best_metric, best_epoch))
Beispiel #26
0
def set_syncbn(net):
    if has_apex:
        net = parallel.convert_syncbn_model(net)
    else:
        net = nn.SyncBatchNorm.convert_sync_batchnorm(net)
    return net
Beispiel #27
0
def main():
    setup_default_logging()
    args, args_text = _parse_args()

    args.prefetcher = not args.no_prefetcher
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1
        if args.distributed and args.num_gpu > 1:
            logging.warning(
                'Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.'
            )
            args.num_gpu = 1

    args.device = 'cuda:0'
    args.world_size = 1
    args.rank = 0  # global rank
    if args.distributed:
        args.num_gpu = 1
        args.device = 'cuda:%d' % args.local_rank
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()
        args.rank = torch.distributed.get_rank()
    assert args.rank >= 0
    DistributedManager.set_args(args)
    if args.distributed:
        logging.info(
            'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
            % (args.rank, args.world_size))
    else:
        logging.info('Training with a single process on %d GPUs.' %
                     args.num_gpu)

    torch.manual_seed(args.seed + args.rank)

    model = create_model(args.model,
                         pretrained=args.pretrained,
                         num_classes=args.num_classes,
                         drop_rate=args.drop,
                         drop_connect_rate=args.drop_connect,
                         drop_path_rate=args.drop_path,
                         drop_block_rate=args.drop_block,
                         global_pool=args.gp,
                         bn_tf=args.bn_tf,
                         bn_momentum=args.bn_momentum,
                         bn_eps=args.bn_eps,
                         checkpoint_path=args.initial_checkpoint)

    if args.initial_checkpoint_pruned:
        try:
            data_config = resolve_data_config(vars(args),
                                              model=model,
                                              verbose=args.local_rank == 0)
            model2 = load_module_from_ckpt(
                model,
                args.initial_checkpoint_pruned,
                input_size=data_config['input_size'][1])
            logging.info("New pruned model adapted from the checkpoint")
        except Exception as e:
            raise RuntimeError(e)
    else:
        model2 = model

    if args.local_rank == 0:
        logging.info('Model %s created, param count: %d' %
                     (args.model, sum([m.numel()
                                       for m in model2.parameters()])))

    data_config = resolve_data_config(vars(args),
                                      model=model,
                                      verbose=args.local_rank == 0)

    num_aug_splits = 0
    if args.aug_splits > 0:
        assert args.aug_splits > 1, 'A split of 1 makes no sense'
        num_aug_splits = args.aug_splits

    if args.split_bn:
        assert num_aug_splits > 1 or args.resplit
        model = convert_splitbn_model(model, max(num_aug_splits, 2))

    if args.num_gpu > 1:
        model2 = nn.DataParallel(model2,
                                 device_ids=list(range(args.num_gpu))).cuda()
    else:
        model2.cuda()

    use_amp = False

    if args.distributed:
        model2 = nn.parallel.distributed.DistributedDataParallel(
            model2,
            device_ids=[args.local_rank])  # can use device str in Torch >= 1.1
        # NOTE: EMA model does not need to be wrapped by DDP

    train_dir = os.path.join(args.data, 'train')
    if not os.path.exists(train_dir):
        logging.error(
            'Training folder does not exist at: {}'.format(train_dir))
        exit(1)
    dataset_train = Dataset(train_dir)

    collate_fn = None
    if args.prefetcher and args.mixup > 0:
        collate_fn = FastCollateMixup(args.mixup, args.smoothing,
                                      args.num_classes)

    loader_train = create_loader(
        dataset_train,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        is_training=True,
        use_prefetcher=args.prefetcher,
        re_prob=args.reprob,
        re_mode=args.remode,
        re_count=args.recount,
        re_split=args.resplit,
        color_jitter=args.color_jitter,
        auto_augment=args.aa,
        num_aug_splits=num_aug_splits,
        interpolation=args.train_interpolation,
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        collate_fn=collate_fn,
        pin_memory=args.pin_mem,
    )

    eval_dir = os.path.join(args.data, 'val')
    if not os.path.isdir(eval_dir):
        eval_dir = os.path.join(args.data, 'validation')
        if not os.path.isdir(eval_dir):
            eval_dir = os.path.join(args.data, 'test')
            if not os.path.isdir(eval_dir):
                logging.error(
                    'Validation folder does not exist at: {}'.format(eval_dir))
                exit(1)

    test_dir = os.path.join(args.data, 'test')
    if not os.path.isdir(test_dir):
        test_dir = os.path.join(args.data, 'validation')
        if not os.path.isdir(test_dir):
            test_dir = os.path.join(args.data, 'val')
            if not os.path.isdir(test_dir):
                logging.error(
                    'Test folder does not exist at: {}'.format(test_dir))
                exit(1)

    dataset_eval = Dataset(eval_dir)
    if args.prune_test:
        dataset_test = Dataset(test_dir)
    else:
        dataset_test = Dataset(train_dir)
    loader_eval = create_loader(
        dataset_eval,
        input_size=data_config['input_size'],
        batch_size=args.validation_batch_size_multiplier * args.batch_size,
        is_training=False,
        use_prefetcher=args.prefetcher,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        crop_pct=data_config['crop_pct'],
        pin_memory=args.pin_mem,
    )
    len_loader = int(
        len(loader_eval) * (4 * args.batch_size) / args.batch_size_prune)
    if args.prune_test:
        len_loader = None
    if args.prune:
        loader_p = create_loader(
            dataset_test,
            input_size=data_config['input_size'],
            batch_size=args.batch_size_prune,
            is_training=False,
            use_prefetcher=args.prefetcher,
            interpolation=data_config['interpolation'],
            mean=data_config['mean'],
            std=data_config['std'],
            num_workers=args.workers,
            distributed=args.distributed,
            crop_pct=data_config['crop_pct'],
            pin_memory=args.pin_mem,
        )
        if 'resnet' in model2.__class__.__name__.lower() or (
                hasattr(model2, 'module')
                and 'resnet' in model2.module.__class__.__name__.lower()):
            list_channel_to_prune = compute_num_channels_per_layer_taylor(
                model2,
                data_config['input_size'],
                loader_p,
                pruning_ratio=args.pruning_ratio,
                taylor_file=args.taylor_file,
                local_rank=args.local_rank,
                len_data_loader=len_loader,
                prune_skip=args.prune_skip,
                taylor_abs=args.taylor_abs,
                prune_conv1=args.prune_conv1,
                use_time=args.use_time,
                distributed=args.distributed)
            new_net = redesign_module_resnet(
                model2,
                list_channel_to_prune,
                use_amp=use_amp,
                distributed=args.distributed,
                local_rank=args.local_rank,
                input_size=data_config['input_size'][1])
        else:
            list_channel_to_prune = compute_num_channels_per_layer_taylor(
                model2,
                data_config['input_size'],
                loader_p,
                pruning_ratio=args.pruning_ratio,
                taylor_file=args.taylor_file,
                local_rank=args.local_rank,
                len_data_loader=len_loader,
                prune_pwl=not args.no_pwl,
                taylor_abs=args.taylor_abs,
                use_se=not args.use_eca,
                use_time=args.use_time,
                distributed=args.distributed)
            new_net = redesign_module_efnet(
                model2,
                list_channel_to_prune,
                use_amp=use_amp,
                distributed=args.distributed,
                local_rank=args.local_rank,
                input_size=data_config['input_size'][1],
                use_se=not args.use_eca)

        new_net.train()
        model.train()
        if isinstance(model, nn.DataParallel) or isinstance(model, DDP):
            model = model.module
        else:
            model = model.cuda()

        co_mod = build_co_train_model(
            model,
            new_net.module.cpu() if hasattr(new_net, 'module') else new_net,
            gamma=args.gamma_knowledge,
            only_last=args.only_last,
            progressive_IKD_factor=args.progressive_IKD_factor)
        optimizer = create_optimizer(args, co_mod)

        del model
        del new_net
        gc.collect()
        torch.cuda.empty_cache()

        if args.num_gpu > 1:
            if args.amp:
                logging.warning(
                    'AMP does not work well with nn.DataParallel, disabling. Use distributed mode for multi-GPU AMP.'
                )
                args.amp = False
            co_mod = nn.DataParallel(co_mod,
                                     device_ids=list(range(
                                         args.num_gpu))).cuda()
        else:
            co_mod = co_mod.cuda()

        use_amp = False
        if has_apex and args.amp:
            co_mod, optimizer = amp.initialize(co_mod,
                                               optimizer,
                                               opt_level='O1')
            use_amp = True
        if args.local_rank == 0:
            logging.info('NVIDIA APEX {}. AMP {}.'.format(
                'installed' if has_apex else 'not installed',
                'on' if use_amp else 'off'))

        if args.distributed:
            if args.sync_bn:
                try:
                    if has_apex and use_amp:
                        co_mod = convert_syncbn_model(co_mod)
                    else:
                        co_mod = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                            co_mod)
                    if args.local_rank == 0:
                        logging.info(
                            'Converted model to use Synchronized BatchNorm.')
                except Exception as e:
                    logging.error(
                        'Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1'
                    )
            if has_apex and use_amp:
                co_mod = DDP(co_mod, delay_allreduce=False)
            else:
                if args.local_rank == 0 and use_amp:
                    logging.info(
                        "Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP."
                    )
                co_mod = nn.parallel.distributed.DistributedDataParallel(
                    co_mod,
                    device_ids=[args.local_rank
                                ])  # can use device str in Torch >= 1.1
            # NOTE: EMA model does not need to be wrapped by DDP
            co_mod.train()

        lr_scheduler, num_epochs = create_scheduler(args, optimizer)
        start_epoch = 0
        if args.start_epoch is not None:
            # a specified start_epoch will always override the resume epoch
            start_epoch = args.start_epoch
        if lr_scheduler is not None and start_epoch > 0:
            lr_scheduler.step(start_epoch)

    if args.jsd:
        assert num_aug_splits > 1  # JSD only valid with aug splits set
        train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits,
                                        smoothing=args.smoothing).cuda()
        validate_loss_fn = nn.CrossEntropyLoss().cuda()
    elif args.mixup > 0.:
        # smoothing is handled with mixup label transform
        train_loss_fn = SoftTargetCrossEntropy().cuda()
        validate_loss_fn = nn.CrossEntropyLoss().cuda()
    elif args.smoothing:
        train_loss_fn = LabelSmoothingCrossEntropy(
            smoothing=args.smoothing).cuda()
        validate_loss_fn = nn.CrossEntropyLoss().cuda()
    else:
        train_loss_fn = nn.CrossEntropyLoss().cuda()
        validate_loss_fn = train_loss_fn

    eval_metric = args.eval_metric
    best_metric = None
    best_epoch = None
    saver = None
    output_dir = ''
    if args.local_rank == 0:
        output_base = args.output if args.output else './output'
        exp_name = '-'.join([
            datetime.now().strftime("%Y%m%d-%H%M%S"), args.model,
            str(data_config['input_size'][-1])
        ])
        output_dir = get_outdir(output_base, 'train', exp_name)
        decreasing = True if eval_metric == 'loss' else False
        saver = CheckpointSaver(checkpoint_dir=output_dir,
                                decreasing=decreasing)
        with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
            f.write(args_text)

    try:
        if args.local_rank == 0:
            logging.info(f'First validation')
        co_mod.eval()
        eval_metrics = validate(co_mod, loader_eval, validate_loss_fn, args)
        if args.local_rank == 0:
            logging.info(f'Prec@top1 : {eval_metrics["prec1"]}')
        co_mod.train()
        for epoch in range(start_epoch, num_epochs):
            torch.cuda.empty_cache()
            if args.distributed:
                loader_train.sampler.set_epoch(epoch)

            train_metrics = train_epoch(epoch,
                                        co_mod,
                                        loader_train,
                                        optimizer,
                                        train_loss_fn,
                                        args,
                                        lr_scheduler=lr_scheduler,
                                        saver=saver,
                                        output_dir=output_dir,
                                        use_amp=use_amp,
                                        model_ema=None)
            torch.cuda.empty_cache()
            eval_metrics = validate(co_mod, loader_eval, validate_loss_fn,
                                    args)

            if lr_scheduler is not None:
                # step LR for next epoch
                lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])

            update_summary(epoch,
                           train_metrics,
                           eval_metrics,
                           os.path.join(output_dir, 'summary.csv'),
                           write_header=best_metric is None)

            if saver is not None:
                # save proper checkpoint with eval metric
                save_metric = eval_metrics[eval_metric]
                best_metric, best_epoch = saver.save_checkpoint(
                    co_mod,
                    optimizer,
                    args,
                    epoch=epoch,
                    model_ema=None,
                    metric=save_metric,
                    use_amp=use_amp)

    except KeyboardInterrupt:
        pass
    if best_metric is not None:
        logging.info('*** Best metric: {0} (epoch {1})'.format(
            best_metric, best_epoch))
def main():
    setup_default_logging()
    args, args_text = _parse_args()

    args.pretrained_backbone = not args.no_pretrained_backbone
    args.prefetcher = not args.no_prefetcher
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1
    args.device = 'cuda:' + str(args.cuda_device)
    args.world_size = 1
    args.rank = 0  # global rank
    if args.distributed:
        args.device = 'cuda:%d' % args.local_rank
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()
        args.rank = torch.distributed.get_rank()
    assert args.rank >= 0

    if args.distributed:
        logging.info(
            'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
            % (args.rank, args.world_size))
    else:
        logging.info('Training with a single process on 1 GPU.')

    use_amp = None
    if args.amp:
        # for backwards compat, `--amp` arg tries apex before native amp
        if has_apex:
            args.apex_amp = True
        elif has_native_amp:
            args.native_amp = True
        else:
            logging.warning(
                "Neither APEX or native Torch AMP is available, using float32. "
                "Install NVIDA apex or upgrade to PyTorch 1.6.")

    if args.apex_amp:
        if has_apex:
            use_amp = 'apex'
        else:
            logging.warning(
                "APEX AMP not available, using float32. Install NVIDA apex")
    elif args.native_amp:
        if has_native_amp:
            use_amp = 'native'
        else:
            logging.warning(
                "Native AMP not available, using float32. Upgrade to PyTorch 1.6."
            )

    torch.manual_seed(args.seed + args.rank)

    with set_layer_config(scriptable=args.torchscript):
        model = create_model(
            args.model,
            bench_task='train',
            num_classes=args.num_classes,
            pretrained=args.pretrained,
            pretrained_backbone=args.pretrained_backbone,
            redundant_bias=args.redundant_bias,
            label_smoothing=args.smoothing,
            legacy_focal=args.legacy_focal,
            jit_loss=args.jit_loss,
            soft_nms=args.soft_nms,
            bench_labeler=args.bench_labeler,
            checkpoint_path=args.initial_checkpoint,
        )
    model_config = model.config  # grab before we obscure with DP/DDP wrappers

    if args.local_rank == 0:
        logging.info('Model %s created, param count: %d' %
                     (args.model, sum([m.numel()
                                       for m in model.parameters()])))

    model.cuda()
    if args.channels_last:
        model = model.to(memory_format=torch.channels_last)

    if args.distributed and args.sync_bn:
        if has_apex and use_amp != 'native':
            model = convert_syncbn_model(model)
        else:
            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
        if args.local_rank == 0:
            logging.info(
                'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
                'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.'
            )

    if args.torchscript:
        assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model, force native amp with `--native-amp` flag'
        assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model. Use `--dist-bn reduce` instead of `--sync-bn`'
        model = torch.jit.script(model)

    optimizer = create_optimizer(args, model)

    amp_autocast = suppress  # do nothing
    loss_scaler = None
    if use_amp == 'apex':
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
        loss_scaler = ApexScaler()
        if args.local_rank == 0:
            logging.info('Using NVIDIA APEX AMP. Training in mixed precision.')
    elif use_amp == 'native':
        amp_autocast = torch.cuda.amp.autocast
        loss_scaler = NativeScaler()
        if args.local_rank == 0:
            logging.info(
                'Using native Torch AMP. Training in mixed precision.')
    else:
        if args.local_rank == 0:
            logging.info('AMP not enabled. Training in float32.')

    # optionally resume from a checkpoint
    resume_epoch = None
    if args.resume:
        resume_epoch = resume_checkpoint(
            unwrap_bench(model),
            args.resume,
            optimizer=None if args.no_resume_opt else optimizer,
            loss_scaler=None if args.no_resume_opt else loss_scaler,
            log_info=args.local_rank == 0)

    model_ema = None
    if args.model_ema:
        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
        model_ema = ModelEmaV2(model, decay=args.model_ema_decay)
        if args.resume:
            load_checkpoint(unwrap_bench(model_ema), args.resume, use_ema=True)

    if args.distributed:
        if has_apex and use_amp != 'native':
            if args.local_rank == 0:
                logging.info("Using apex DistributedDataParallel.")
            model = ApexDDP(model, delay_allreduce=True)
        else:
            if args.local_rank == 0:
                logging.info("Using torch DistributedDataParallel.")
            model = NativeDDP(model, device_ids=[args.device])
        # NOTE: EMA model does not need to be wrapped by DDP...
        if model_ema is not None and not args.resume:
            # ...but it is a good idea to sync EMA copy of weights
            # NOTE: ModelEma init could be moved after DDP wrapper if using PyTorch DDP, not Apex.
            model_ema.set(model)

    lr_scheduler, num_epochs = create_scheduler(args, optimizer)
    start_epoch = 0
    if args.start_epoch is not None:
        # a specified start_epoch will always override the resume epoch
        start_epoch = args.start_epoch
    elif resume_epoch is not None:
        start_epoch = resume_epoch
    if lr_scheduler is not None and start_epoch > 0:
        lr_scheduler.step(start_epoch)

    if args.local_rank == 0:
        logging.info('Scheduled epochs: {}'.format(num_epochs))

    loader_train, loader_eval, evaluator = create_datasets_and_loaders(
        args, model_config)

    if model_config.num_classes < loader_train.dataset.parser.max_label:
        logging.error(
            f'Model {model_config.num_classes} has fewer classes than dataset {loader_train.dataset.parser.max_label}.'
        )
        exit(1)
    if model_config.num_classes > loader_train.dataset.parser.max_label:
        logging.warning(
            f'Model {model_config.num_classes} has more classes than dataset {loader_train.dataset.parser.max_label}.'
        )

    eval_metric = args.eval_metric
    best_metric = None
    best_epoch = None
    saver = None
    output_dir = ''
    if args.local_rank == 0:
        output_base = args.output if args.output else './output'
        exp_name = '-'.join(
            [datetime.now().strftime("%Y%m%d-%H%M%S"), args.model])
        output_dir = get_outdir(output_base, 'train', exp_name)
        decreasing = True if eval_metric == 'loss' else False
        saver = CheckpointSaver(model,
                                optimizer,
                                args=args,
                                model_ema=model_ema,
                                amp_scaler=loss_scaler,
                                checkpoint_dir=output_dir,
                                decreasing=decreasing,
                                unwrap_fn=unwrap_bench)
        with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
            f.write(args_text)

    try:
        for epoch in range(start_epoch, num_epochs):
            if args.distributed:
                loader_train.sampler.set_epoch(epoch)

            train_metrics = train_epoch(epoch,
                                        model,
                                        loader_train,
                                        optimizer,
                                        args,
                                        lr_scheduler=lr_scheduler,
                                        saver=saver,
                                        output_dir=output_dir,
                                        amp_autocast=amp_autocast,
                                        loss_scaler=loss_scaler,
                                        model_ema=model_ema)

            if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
                if args.local_rank == 0:
                    logging.info(
                        "Distributing BatchNorm running means and vars")
                distribute_bn(model, args.world_size, args.dist_bn == 'reduce')

            # the overhead of evaluating with coco style datasets is fairly high, so just ema or non, not both
            if model_ema is not None:
                if args.distributed and args.dist_bn in ('broadcast',
                                                         'reduce'):
                    distribute_bn(model_ema, args.world_size,
                                  args.dist_bn == 'reduce')

                eval_metrics = validate(model_ema.module,
                                        loader_eval,
                                        args,
                                        evaluator,
                                        log_suffix=' (EMA)')
            else:
                eval_metrics = validate(model, loader_eval, args, evaluator)

            if lr_scheduler is not None:
                # step LR for next epoch
                lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])

            if saver is not None:
                update_summary(epoch,
                               train_metrics,
                               eval_metrics,
                               os.path.join(output_dir, 'summary.csv'),
                               write_header=best_metric is None)

                # save proper checkpoint with eval metric
                best_metric, best_epoch = saver.save_checkpoint(
                    epoch=epoch, metric=eval_metrics[eval_metric])

    except KeyboardInterrupt:
        pass
    if best_metric is not None:
        logging.info('*** Best metric: {0} (epoch {1})'.format(
            best_metric, best_epoch))
def main():
    import os

    args, args_text = _parse_args()

    eval_metric = args.eval_metric
    best_metric = None
    best_epoch = None
    saver = None
    output_dir = ''
    if args.local_rank == 0:
        output_base = args.output if args.output else './output'
        exp_name = 'train'
        if args.gate_train:
            exp_name += '-dynamic'
        if args.slim_train:
            exp_name += '-slimmable'
        exp_name += '-{}'.format(args.model)
        exp_info = '-'.join(
            [datetime.now().strftime("%Y%m%d-%H%M%S"), args.model])
        output_dir = get_outdir(output_base, exp_name, exp_info)
        decreasing = True if eval_metric == 'loss' else False
        saver = CheckpointSaver(checkpoint_dir=output_dir,
                                decreasing=decreasing)
        with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
            f.write(args_text)
    setup_default_logging(outdir=output_dir, local_rank=args.local_rank)

    torch.backends.cudnn.benchmark = True

    args.prefetcher = not args.no_prefetcher
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1
        if args.distributed and args.num_gpu > 1:
            logging.warning(
                'Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.'
            )
            args.num_gpu = 1

    args.device = 'cuda:0'
    args.world_size = 1
    args.rank = 0  # global rank
    if args.distributed:
        args.num_gpu = 1
        args.device = 'cuda:%d' % args.local_rank
        torch.cuda.set_device(args.local_rank)
        # torch.distributed.init_process_group(backend='nccl',
        #                                      init_method='tcp://127.0.0.1:23334',
        #                                      rank=args.local_rank,
        #                                      world_size=int(os.environ['WORLD_SIZE']))
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()
        args.rank = torch.distributed.get_rank()
    assert args.rank >= 0

    if args.distributed:
        logging.info(
            'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
            % (args.rank, args.world_size))
    else:
        logging.info('Training with a single process on %d GPUs.' %
                     args.num_gpu)

    # --------- random seed -----------
    random.seed(args.seed)  # TODO: do we need same seed on all GPU?
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    # torch.manual_seed(args.seed + args.rank)

    model = create_model(args.model,
                         pretrained=args.pretrained,
                         num_classes=args.num_classes,
                         drop_rate=args.drop,
                         drop_path_rate=args.drop_path,
                         global_pool=args.gp,
                         bn_tf=args.bn_tf,
                         bn_momentum=args.bn_momentum,
                         bn_eps=args.bn_eps,
                         checkpoint_path=args.initial_checkpoint)

    if args.local_rank == 0:
        logging.info('Model %s created, param count: %d' %
                     (args.model, sum([m.numel()
                                       for m in model.parameters()])))

    data_config = resolve_data_config(vars(args),
                                      model=model,
                                      verbose=args.local_rank == 0)

    num_aug_splits = 0
    if args.aug_splits > 0:
        assert args.aug_splits > 1, 'A split of 1 makes no sense'
        num_aug_splits = args.aug_splits

    if args.split_bn:
        assert num_aug_splits > 1 or args.resplit
        model = convert_splitbn_model(model, max(num_aug_splits, 2))

    if args.num_gpu > 1:
        if args.amp:
            logging.warning(
                'AMP does not work well with nn.DataParallel, disabling. Use distributed mode for multi-GPU AMP.'
            )
            args.amp = False
        model = nn.DataParallel(model,
                                device_ids=list(range(args.num_gpu))).cuda()
    else:
        model.cuda()

    if args.train_mode == 'gate':
        optimizer = create_optimizer(args, model.get_gate())
    else:
        optimizer = create_optimizer(args, model)

    # optionally resume from a checkpoint
    resume_epoch = None
    if args.resume:
        resume_epoch = resume_checkpoint(
            model,
            checkpoint_path=args.resume,
            optimizer=optimizer if not args.no_resume_opt else None,
            log_info=args.local_rank == 0,
            strict=False)

    use_amp = False
    if has_apex and args.amp:
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
        use_amp = True
    if args.local_rank == 0:
        logging.info('NVIDIA APEX {}. AMP {}.'.format(
            'installed' if has_apex else 'not installed',
            'on' if use_amp else 'off'))

    model_ema = None
    if args.model_ema:
        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
        model_ema = ModelEma(model,
                             decay=args.model_ema_decay,
                             device='cpu' if args.model_ema_force_cpu else '',
                             resume=args.resume,
                             log_info=args.local_rank == 0,
                             resume_strict=False)

    if args.distributed:
        if args.sync_bn:
            assert not args.split_bn
            try:
                if has_apex:
                    model = convert_syncbn_model(model)
                else:
                    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                        model)
                if args.local_rank == 0:
                    logging.info(
                        'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
                        'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.'
                    )
            except Exception as e:
                logging.error(
                    'Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1'
                )
        if has_apex:
            model = DDP(model, delay_allreduce=True)
        else:
            if args.local_rank == 0:
                logging.info(
                    "Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP."
                )
            model = DDP(model,
                        device_ids=[args.local_rank],
                        find_unused_parameters=True
                        )  # can use device str in Torch >= 1.1
        # NOTE: EMA model does not need to be wrapped by DDP

    lr_scheduler, num_epochs = create_scheduler(args, optimizer)
    start_epoch = 0
    if args.start_epoch is not None:
        # a specified start_epoch will always override the resume epoch
        start_epoch = args.start_epoch
    elif resume_epoch is not None:
        start_epoch = resume_epoch
    if lr_scheduler is not None and start_epoch > 0:
        lr_scheduler.step(start_epoch)

    if args.local_rank == 0:
        logging.info('Scheduled epochs: {}'.format(num_epochs))

    # ------------- data --------------
    train_dir = os.path.join(args.data, 'train')
    if not os.path.exists(train_dir):
        logging.error(
            'Training folder does not exist at: {}'.format(train_dir))
        exit(1)
    dataset_train = Dataset(train_dir)
    collate_fn = None
    if num_aug_splits > 1:
        dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)
    loader_train = create_loader(
        dataset_train,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        is_training=True,
        use_prefetcher=args.prefetcher,
        re_prob=args.reprob,
        re_mode=args.remode,
        re_count=args.recount,
        re_split=args.resplit,
        color_jitter=args.color_jitter,
        auto_augment=args.aa,
        num_aug_splits=num_aug_splits,
        interpolation=args.train_interpolation,
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        collate_fn=collate_fn,
        pin_memory=args.pin_mem,
    )
    loader_bn = create_loader(
        dataset_train,
        input_size=data_config['input_size'],
        batch_size=args.validation_batch_size_multiplier * args.batch_size,
        is_training=True,
        use_prefetcher=args.prefetcher,
        re_prob=args.reprob,
        re_mode=args.remode,
        re_count=args.recount,
        re_split=args.resplit,
        color_jitter=args.color_jitter,
        auto_augment=args.aa,
        num_aug_splits=num_aug_splits,
        interpolation=args.train_interpolation,
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        collate_fn=collate_fn,
        pin_memory=args.pin_mem,
    )

    eval_dir = os.path.join(args.data, 'val')
    if not os.path.isdir(eval_dir):
        eval_dir = os.path.join(args.data, 'validation')
        if not os.path.isdir(eval_dir):
            logging.error(
                'Validation folder does not exist at: {}'.format(eval_dir))
            exit(1)
    dataset_eval = Dataset(eval_dir)
    loader_eval = create_loader(
        dataset_eval,
        input_size=data_config['input_size'],
        batch_size=args.validation_batch_size_multiplier * args.batch_size,
        is_training=False,
        use_prefetcher=args.prefetcher,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        crop_pct=data_config['crop_pct'],
        pin_memory=args.pin_mem,
    )

    # ------------- loss_fn --------------
    if args.jsd:
        assert num_aug_splits > 1  # JSD only valid with aug splits set
        train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits,
                                        smoothing=args.smoothing).cuda()
        validate_loss_fn = nn.CrossEntropyLoss().cuda()
    elif args.smoothing:
        train_loss_fn = LabelSmoothingCrossEntropy(
            smoothing=args.smoothing).cuda()
        validate_loss_fn = nn.CrossEntropyLoss().cuda()
    else:
        train_loss_fn = nn.CrossEntropyLoss().cuda()
        validate_loss_fn = train_loss_fn
    if args.ieb:
        distill_loss_fn = SoftTargetCrossEntropy().cuda()
    else:
        distill_loss_fn = None

    if args.local_rank == 0:
        model_profiling(model, 224, 224, 1, 3, use_cuda=True, verbose=True)
    else:
        model_profiling(model, 224, 224, 1, 3, use_cuda=True, verbose=False)

    if not args.test_mode:
        # start training
        for epoch in range(start_epoch, num_epochs):
            if args.distributed:
                loader_train.sampler.set_epoch(epoch)
            train_metrics = OrderedDict([('loss', 0.)])
            # train
            if args.gate_train:
                train_metrics = train_epoch_slim_gate(
                    epoch,
                    model,
                    loader_train,
                    optimizer,
                    train_loss_fn,
                    args,
                    lr_scheduler=lr_scheduler,
                    saver=saver,
                    output_dir=output_dir,
                    use_amp=use_amp,
                    model_ema=model_ema,
                    optimizer_step=args.optimizer_step)
            else:
                train_metrics = train_epoch_slim(
                    epoch,
                    model,
                    loader_train,
                    optimizer,
                    loss_fn=train_loss_fn,
                    distill_loss_fn=distill_loss_fn,
                    args=args,
                    lr_scheduler=lr_scheduler,
                    saver=saver,
                    output_dir=output_dir,
                    use_amp=use_amp,
                    model_ema=model_ema,
                    optimizer_step=args.optimizer_step,
                )
            if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
                torch.cuda.synchronize()
                if args.local_rank == 0:
                    logging.info(
                        "Distributing BatchNorm running means and vars")
                distribute_bn(model, args.world_size, args.dist_bn == 'reduce')

            # eval
            if args.gate_train:
                eval_metrics = [
                    validate_gate(model, loader_eval, validate_loss_fn, args)
                ]
                if model_ema is not None and not args.model_ema_force_cpu:
                    ema_eval_metrics = [
                        validate_gate(model_ema.ema,
                                      loader_eval,
                                      validate_loss_fn,
                                      args,
                                      log_suffix='(EMA)')
                    ]

                    eval_metrics = ema_eval_metrics
            else:
                if epoch % 10 == 0 and epoch != 0:
                    eval_sample_list = ['smallest', 'largest', 'uniform']
                else:
                    eval_sample_list = ['smallest', 'largest']
                eval_metrics = [
                    validate_slim(model,
                                  loader_eval,
                                  validate_loss_fn,
                                  args,
                                  model_mode=model_mode)
                    for model_mode in eval_sample_list
                ]
                if model_ema is not None and not args.model_ema_force_cpu:
                    ema_eval_metrics = [
                        validate_slim(model_ema.ema,
                                      loader_eval,
                                      validate_loss_fn,
                                      args,
                                      log_suffix='(EMA)',
                                      model_mode=model_mode)
                        for model_mode in eval_sample_list
                    ]

                    eval_metrics = ema_eval_metrics

            if isinstance(eval_metrics, list):
                eval_metrics = eval_metrics[0]

            if lr_scheduler is not None:
                # step LR for next epoch
                lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])

            # save
            update_summary(epoch,
                           train_metrics,
                           eval_metrics,
                           os.path.join(output_dir, 'summary.csv'),
                           write_header=best_metric is None)

            if saver is not None:
                # save proper checkpoint with eval metric
                save_metric = eval_metrics[eval_metric]
                best_metric, best_epoch = saver.save_checkpoint(
                    model,
                    optimizer,
                    args,
                    epoch=epoch,
                    model_ema=model_ema,
                    metric=save_metric,
                    use_amp=use_amp)
        # end training
        if best_metric is not None:
            logging.info('*** Best metric: {0} (epoch {1})'.format(
                best_metric, best_epoch))

    # test
    eval_metrics = []

    # reset bn
    if args.reset_bn:
        if args.local_rank == 0:
            logging.info("Recalibrating BatchNorm statistics...")
        if model_ema is not None and not args.model_ema_force_cpu:
            model_list = [model, model_ema.ema]
        else:
            model_list = [model]
        for model_ in model_list:
            for layer in model_.modules():
                if isinstance(layer, nn.BatchNorm2d) or \
                        isinstance(layer, nn.SyncBatchNorm) or \
                        (has_apex and isinstance(layer, apex.parallel.SyncBatchNorm)):
                    layer.reset_running_stats()
            model_.train()
            with torch.no_grad():
                for batch_idx, (input, target) in enumerate(loader_bn):
                    for choice in range(args.num_choice):
                        if args.slim_train:
                            if hasattr(model_, 'module'):
                                model_.module.set_mode('uniform',
                                                       choice=choice)
                            else:
                                model_.set_mode('uniform', choice=choice)
                            model_(input)

                    if batch_idx % 1000 == 0 and batch_idx != 0:
                        break
            if args.local_rank == 0:
                logging.info("Finish recalibrating BatchNorm statistics.")
            if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
                torch.cuda.synchronize()
                if args.local_rank == 0:
                    logging.info(
                        "Distributing BatchNorm running means and vars")
                distribute_bn(model_, args.world_size,
                              args.dist_bn == 'reduce')

    # dynamic
    if args.gate_train:
        eval_metrics = [
            validate_gate(model, loader_eval, validate_loss_fn, args)
        ]
        if model_ema is not None and not args.model_ema_force_cpu:
            ema_eval_metrics = [
                validate_gate(model_ema.ema,
                              loader_eval,
                              validate_loss_fn,
                              args,
                              log_suffix='(EMA)')
            ]

            eval_metrics = ema_eval_metrics
    # supernet
    for choice in range(args.num_choice):
        eval_metrics.append(
            validate_slim(model,
                          loader_eval,
                          validate_loss_fn,
                          args,
                          model_mode=choice))
    if model_ema is not None and not args.model_ema_force_cpu:
        for choice in range(args.num_choice):
            eval_metrics.append(
                validate_slim(model_ema.ema,
                              loader_eval,
                              validate_loss_fn,
                              args,
                              log_suffix='(EMA)',
                              model_mode=choice))
    if args.local_rank == 0:
        best_metric, best_epoch = saver.save_checkpoint(model,
                                                        optimizer,
                                                        args,
                                                        epoch=0,
                                                        model_ema=model_ema,
                                                        metric=eval_metrics[0],
                                                        use_amp=use_amp)

    if args.local_rank == 0:
        print('Test results of the last epoch:\n', eval_metrics)
Beispiel #30
0
def main():
    setup_default_logging()
    args, args_text = _parse_args()

    args.prefetcher = not args.no_prefetcher
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1
        if args.distributed and args.num_gpu > 1:
            logging.warning(
                'Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.'
            )
            args.num_gpu = 1

    args.device = 'cuda:0'
    args.world_size = 1
    args.rank = 0  # global rank
    if args.distributed:
        args.num_gpu = 1
        args.device = 'cuda:%d' % args.local_rank
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()
        args.rank = torch.distributed.get_rank()
    assert args.rank >= 0

    if args.distributed:
        logging.info(
            'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
            % (args.rank, args.world_size))
    else:
        logging.info('Training with a single process on %d GPUs.' %
                     args.num_gpu)

    torch.manual_seed(args.seed + args.rank)

    net_config = json.load(open(args.model_config))
    model = NSGANetV2.build_from_config(net_config,
                                        drop_connect_rate=args.drop_path)
    init = torch.load(args.initial_checkpoint,
                      map_location='cpu')['state_dict']
    model.load_state_dict(init)

    if args.reset_classifier:
        NSGANetV2.reset_classifier(model,
                                   last_channel=model.classifier.in_features,
                                   n_classes=args.num_classes,
                                   dropout_rate=args.drop)

    # create a dummy model to get the model cfg just for running this timm training script
    dummy_model = create_model('efficientnet_b0')

    # add a teacher model
    if args.teacher:
        # using the supernet at full scale to supervise the training of the subnets
        # this is taken from https://github.com/mit-han-lab/once-for-all/blob/
        # 4decacf9d85dbc948a902c6dc34ea032d711e0a9/train_ofa_net.py#L125
        from evaluator import OFAEvaluator
        supernet = OFAEvaluator(n_classes=1000, model_path=args.teacher)
        teacher, _ = supernet.sample({
            'ks': [7] * 20,
            'e': [6] * 20,
            'd': [4] * 5
        })
        # as alternative, you could simply use a pretrained model from timm, e.g. "tf_efficientnet_b1",
        # "mobilenetv3_large_100", etc. See https://rwightman.github.io/pytorch-image-models/results/ for full list.
        # import timm
        # teacher = timm.create_model(args.teacher, pretrained=True)
        args.kd_ratio = 1.0
        print("teacher model loaded")
    else:
        args.kd_ratio = 0.0

    if args.local_rank == 0:
        logging.info('Model %s created, param count: %d' %
                     (args.model, sum([m.numel()
                                       for m in model.parameters()])))

    data_config = resolve_data_config(vars(args),
                                      model=dummy_model,
                                      verbose=args.local_rank == 0)
    if args.img_size is not None:
        data_config['input_size'] = (3, args.img_size, args.img_size
                                     )  # override the input image resolution

    num_aug_splits = 0
    if args.aug_splits > 0:
        assert args.aug_splits > 1, 'A split of 1 makes no sense'
        num_aug_splits = args.aug_splits

    if args.split_bn:
        assert num_aug_splits > 1 or args.resplit
        model = convert_splitbn_model(model, max(num_aug_splits, 2))

    if args.num_gpu > 1:
        if args.amp:
            logging.warning(
                'AMP does not work well with nn.DataParallel, disabling. Use distributed mode for multi-GPU AMP.'
            )
            args.amp = False
        model = nn.DataParallel(model,
                                device_ids=list(range(args.num_gpu))).cuda()
        if args.teacher:
            teacher = nn.DataParallel(teacher,
                                      device_ids=list(range(
                                          args.num_gpu))).cuda()
    else:
        model.cuda()
        if args.teacher: teacher.cuda()

    optimizer = create_optimizer(args, model)

    use_amp = False
    if has_apex and args.amp:
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
        use_amp = True
    if args.local_rank == 0:
        logging.info('NVIDIA APEX {}. AMP {}.'.format(
            'installed' if has_apex else 'not installed',
            'on' if use_amp else 'off'))

    # optionally resume from a checkpoint
    resume_state = {}
    resume_epoch = None
    if args.resume:
        resume_state, resume_epoch = resume_checkpoint(model, args.resume)
    if resume_state and not args.no_resume_opt:
        if 'optimizer' in resume_state:
            if args.local_rank == 0:
                logging.info('Restoring Optimizer state from checkpoint')
            optimizer.load_state_dict(resume_state['optimizer'])
        if use_amp and 'amp' in resume_state and 'load_state_dict' in amp.__dict__:
            if args.local_rank == 0:
                logging.info('Restoring NVIDIA AMP state from checkpoint')
            amp.load_state_dict(resume_state['amp'])
    del resume_state

    model_ema = None
    if args.model_ema:
        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
        model_ema = ModelEma(model,
                             decay=args.model_ema_decay,
                             device='cpu' if args.model_ema_force_cpu else '',
                             resume=args.resume)

    if args.distributed:
        if args.sync_bn:
            assert not args.split_bn
            try:
                if has_apex:
                    model = convert_syncbn_model(model)
                else:
                    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                        model)
                if args.local_rank == 0:
                    logging.info(
                        'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
                        'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.'
                    )
            except Exception as e:
                logging.error(
                    'Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1'
                )
        if has_apex:
            model = DDP(model, delay_allreduce=True)
            if args.teacher: teacher = DDP(teacher, delay_allreduce=True)
        else:
            if args.local_rank == 0:
                logging.info(
                    "Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP."
                )
            model = DDP(model,
                        device_ids=[args.local_rank
                                    ])  # can use device str in Torch >= 1.1
            if teacher: teacher = DDP(teacher, device_ids=[args.local_rank])

        # NOTE: EMA model does not need to be wrapped by DDP

    lr_scheduler, num_epochs = create_scheduler(args, optimizer)
    start_epoch = 0
    if args.start_epoch is not None:
        # a specified start_epoch will always override the resume epoch
        start_epoch = args.start_epoch
    elif resume_epoch is not None:
        start_epoch = resume_epoch
    if lr_scheduler is not None and start_epoch > 0:
        lr_scheduler.step(start_epoch)

    if args.local_rank == 0:
        logging.info('Scheduled epochs: {}'.format(num_epochs))

    train_dir = os.path.join(args.data, 'train')
    if not os.path.exists(train_dir):
        logging.error(
            'Training folder does not exist at: {}'.format(train_dir))
        exit(1)
    dataset_train = Dataset(train_dir)

    collate_fn = None
    if args.prefetcher and args.mixup > 0:
        assert not num_aug_splits  # collate conflict (need to support deinterleaving in collate mixup)
        collate_fn = FastCollateMixup(args.mixup, args.smoothing,
                                      args.num_classes)

    if num_aug_splits > 1:
        dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)

    loader_train = create_loader(
        dataset_train,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        is_training=True,
        use_prefetcher=args.prefetcher,
        re_prob=args.reprob,
        re_mode=args.remode,
        re_count=args.recount,
        re_split=args.resplit,
        color_jitter=args.color_jitter,
        auto_augment=args.aa,
        num_aug_splits=num_aug_splits,
        interpolation=args.train_interpolation,
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        collate_fn=collate_fn,
        pin_memory=args.pin_mem,
    )

    eval_dir = os.path.join(args.data, 'val')
    if not os.path.isdir(eval_dir):
        eval_dir = os.path.join(args.data, 'validation')
        if not os.path.isdir(eval_dir):
            logging.error(
                'Validation folder does not exist at: {}'.format(eval_dir))
            exit(1)
    dataset_eval = Dataset(eval_dir)

    loader_eval = create_loader(
        dataset_eval,
        input_size=data_config['input_size'],
        batch_size=args.validation_batch_size_multiplier * args.batch_size,
        is_training=False,
        use_prefetcher=args.prefetcher,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        crop_pct=data_config['crop_pct'],
        pin_memory=args.pin_mem,
    )

    if args.jsd:
        assert num_aug_splits > 1  # JSD only valid with aug splits set
        train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits,
                                        smoothing=args.smoothing).cuda()
        validate_loss_fn = nn.CrossEntropyLoss().cuda()
    elif args.mixup > 0.:
        # smoothing is handled with mixup label transform
        train_loss_fn = SoftTargetCrossEntropy().cuda()
        validate_loss_fn = nn.CrossEntropyLoss().cuda()
    elif args.smoothing:
        train_loss_fn = LabelSmoothingCrossEntropy(
            smoothing=args.smoothing).cuda()
        validate_loss_fn = nn.CrossEntropyLoss().cuda()
    else:
        train_loss_fn = nn.CrossEntropyLoss().cuda()
        validate_loss_fn = train_loss_fn

    eval_metric = args.eval_metric
    best_metric = None
    best_epoch = None
    saver = None
    output_dir = ''
    if args.local_rank == 0:
        output_base = args.output if args.output else './output'
        exp_name = '-'.join([
            datetime.now().strftime("%Y%m%d-%H%M%S"), args.model,
            str(data_config['input_size'][-1])
        ])
        output_dir = get_outdir(output_base, 'train', exp_name)
        decreasing = True if eval_metric == 'loss' else False
        saver = CheckpointSaver(checkpoint_dir=output_dir,
                                decreasing=decreasing)
        with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
            f.write(args_text)

    try:
        for epoch in range(start_epoch, num_epochs):
            if args.distributed:
                loader_train.sampler.set_epoch(epoch)

            train_metrics = train_epoch(epoch,
                                        model,
                                        loader_train,
                                        optimizer,
                                        train_loss_fn,
                                        args,
                                        lr_scheduler=lr_scheduler,
                                        saver=saver,
                                        output_dir=output_dir,
                                        use_amp=use_amp,
                                        model_ema=model_ema,
                                        teacher=teacher)

            if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
                if args.local_rank == 0:
                    logging.info(
                        "Distributing BatchNorm running means and vars")
                distribute_bn(model, args.world_size, args.dist_bn == 'reduce')

            eval_metrics = validate(model, loader_eval, validate_loss_fn, args)

            if model_ema is not None and not args.model_ema_force_cpu:
                if args.distributed and args.dist_bn in ('broadcast',
                                                         'reduce'):
                    distribute_bn(model_ema, args.world_size,
                                  args.dist_bn == 'reduce')

                ema_eval_metrics = validate(model_ema.ema,
                                            loader_eval,
                                            validate_loss_fn,
                                            args,
                                            log_suffix=' (EMA)')
                eval_metrics = ema_eval_metrics

            if lr_scheduler is not None:
                # step LR for next epoch
                lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])

            update_summary(epoch,
                           train_metrics,
                           eval_metrics,
                           os.path.join(output_dir, 'summary.csv'),
                           write_header=best_metric is None)

            if saver is not None:
                # save proper checkpoint with eval metric
                save_metric = eval_metrics[eval_metric]
                best_metric, best_epoch = saver.save_checkpoint(
                    model,
                    optimizer,
                    args,
                    epoch=epoch,
                    model_ema=model_ema,
                    metric=save_metric,
                    use_amp=use_amp)

    except KeyboardInterrupt:
        pass
    if best_metric is not None:
        logging.info('*** Best metric: {0} (epoch {1})'.format(
            best_metric, best_epoch))