Пример #1
0
def validate(args):
    # might as well try to validate something
    args.pretrained = args.pretrained or not args.checkpoint

    # create model
    model = create_model(
        args.model,
        num_classes=args.num_classes,
        in_chans=3,
        pretrained=args.pretrained)

    if args.checkpoint:
        load_checkpoint(model, args.checkpoint, args.use_ema)

    param_count = sum([m.numel() for m in model.parameters()])
    logging.info('Model %s created, param count: %d' % (args.model, param_count))

    data_config = resolve_data_config(model, args)
    model, test_time_pool = apply_test_time_pool(model, data_config, args)

    if args.num_gpu > 1:
        model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
    else:
        model = model.cuda()

    criterion = nn.CrossEntropyLoss().cuda()

    loader = create_loader(
        Dataset(args.data, load_bytes=args.tf_preprocessing),
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        use_prefetcher=True,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        crop_pct=1.0 if test_time_pool else data_config['crop_pct'],
        tf_preprocessing=args.tf_preprocessing)

    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    model.eval()
    end = time.time()
    with torch.no_grad():
        for i, (input, target) in enumerate(loader):
            target = target.cuda()
            input = input.cuda()

            # compute output
            output = model(input)
            loss = criterion(output, target)

            # measure accuracy and record loss
            prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
            losses.update(loss.item(), input.size(0))
            top1.update(prec1.item(), input.size(0))
            top5.update(prec5.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.log_freq == 0:
                logging.info(
                    'Test: [{0:>4d}/{1}]  '
                    'Time: {batch_time.val:.3f} ({batch_time.avg:.3f})  '
                    'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                    'Prec@1: {top1.val:>7.4f} ({top1.avg:>7.4f})  '
                    'Prec@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(
                        i, len(loader), batch_time=batch_time,
                        rate_avg=input.size(0) / batch_time.avg,
                        loss=losses, top1=top1, top5=top5))

    results = OrderedDict(
        top1=round(top1.avg, 3), top1_err=round(100 - top1.avg, 3),
        top5=round(top5.avg, 3), top5_err=round(100 - top5.avg, 3),
        param_count=round(param_count / 1e6, 2))

    logging.info(' * Prec@1 {:.3f} ({:.3f}) Prec@5 {:.3f} ({:.3f})'.format(
       results['top1'], results['top1_err'], results['top5'], results['top5_err']))

    return results
Пример #2
0
def main():
    setup_default_logging()
    args, args_text = _parse_args()

    args.prefetcher = not args.no_prefetcher
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1
        if args.distributed and args.num_gpu > 1:
            _logger.warning(
                'Using more than one GPU per process in distributed mode is not allowed.Setting num_gpu to 1.'
            )
            args.num_gpu = 1

    args.device = 'cuda:0'
    args.world_size = 1
    args.rank = 0  # global rank
    if args.distributed:
        args.num_gpu = 1
        args.device = 'cuda:%d' % args.local_rank
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()
        args.rank = torch.distributed.get_rank()
    assert args.rank >= 0

    if args.distributed:
        _logger.info(
            'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
            % (args.rank, args.world_size))
    else:
        _logger.info('Training with a single process on %d GPUs.' %
                     args.num_gpu)

    torch.manual_seed(args.seed + args.rank)

    model = create_model(
        args.model,
        pretrained=args.pretrained,
        num_classes=args.num_classes,
        drop_rate=args.drop,
        drop_connect_rate=args.drop_connect,  # DEPRECATED, use drop_path
        drop_path_rate=args.drop_path,
        drop_block_rate=args.drop_block,
        global_pool=args.gp,
        bn_tf=args.bn_tf,
        bn_momentum=args.bn_momentum,
        bn_eps=args.bn_eps,
        checkpoint_path=args.initial_checkpoint)

    if args.local_rank == 0:
        _logger.info('Model %s created, param count: %d' %
                     (args.model, sum([m.numel()
                                       for m in model.parameters()])))

    data_config = resolve_data_config(vars(args),
                                      model=model,
                                      verbose=args.local_rank == 0)

    num_aug_splits = 0
    if args.aug_splits > 0:
        assert args.aug_splits > 1, 'A split of 1 makes no sense'
        num_aug_splits = args.aug_splits

    if args.split_bn:
        assert num_aug_splits > 1 or args.resplit
        model = convert_splitbn_model(model, max(num_aug_splits, 2))

    use_amp = None
    if args.amp:
        # for backwards compat, `--amp` arg tries apex before native amp
        if has_apex:
            args.apex_amp = True
        elif has_native_amp:
            args.native_amp = True
    if args.apex_amp and has_apex:
        use_amp = 'apex'
    elif args.native_amp and has_native_amp:
        use_amp = 'native'
    elif args.apex_amp or args.native_amp:
        _logger.warning(
            "Neither APEX or native Torch AMP is available, using float32. "
            "Install NVIDA apex or upgrade to PyTorch 1.6")

    if args.num_gpu > 1:
        if use_amp == 'apex':
            _logger.warning(
                'Apex AMP does not work well with nn.DataParallel, disabling. Use DDP or Torch AMP.'
            )
            use_amp = None
        model = nn.DataParallel(model,
                                device_ids=list(range(args.num_gpu))).cuda()
        assert not args.channels_last, "Channels last not supported with DP, use DDP."
    else:
        model.cuda()
        if args.channels_last:
            model = model.to(memory_format=torch.channels_last)

    optimizer = create_optimizer(args, model)

    amp_autocast = suppress  # do nothing
    loss_scaler = None
    if use_amp == 'apex':
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
        loss_scaler = ApexScaler()
        if args.local_rank == 0:
            _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
    elif use_amp == 'native':
        amp_autocast = torch.cuda.amp.autocast
        loss_scaler = NativeScaler()
        if args.local_rank == 0:
            _logger.info(
                'Using native Torch AMP. Training in mixed precision.')
    else:
        if args.local_rank == 0:
            _logger.info('AMP not enabled. Training in float32.')

    # optionally resume from a checkpoint
    resume_epoch = None
    if args.resume:
        resume_epoch = resume_checkpoint(
            model,
            args.resume,
            optimizer=None if args.no_resume_opt else optimizer,
            loss_scaler=None if args.no_resume_opt else loss_scaler,
            log_info=args.local_rank == 0)

    model_ema = None
    if args.model_ema:
        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
        model_ema = ModelEma(model,
                             decay=args.model_ema_decay,
                             device='cpu' if args.model_ema_force_cpu else '',
                             resume=args.resume)

    if args.distributed:
        if args.sync_bn:
            assert not args.split_bn
            try:
                if has_apex and use_amp != 'native':
                    # Apex SyncBN preferred unless native amp is activated
                    model = convert_syncbn_model(model)
                else:
                    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                        model)
                if args.local_rank == 0:
                    _logger.info(
                        'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
                        'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.'
                    )
            except Exception as e:
                _logger.error(
                    'Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1'
                )
        if has_apex and use_amp != 'native':
            # Apex DDP preferred unless native amp is activated
            if args.local_rank == 0:
                _logger.info("Using NVIDIA APEX DistributedDataParallel.")
            model = ApexDDP(model, delay_allreduce=True)
        else:
            if args.local_rank == 0:
                _logger.info("Using native Torch DistributedDataParallel.")
            model = NativeDDP(model, device_ids=[
                args.local_rank
            ])  # can use device str in Torch >= 1.1
        # NOTE: EMA model does not need to be wrapped by DDP

    lr_scheduler, num_epochs = create_scheduler(args, optimizer)
    start_epoch = 0
    if args.start_epoch is not None:
        # a specified start_epoch will always override the resume epoch
        start_epoch = args.start_epoch
    elif resume_epoch is not None:
        start_epoch = resume_epoch
    if lr_scheduler is not None and start_epoch > 0:
        lr_scheduler.step(start_epoch)

    if args.local_rank == 0:
        _logger.info('Scheduled epochs: {}'.format(num_epochs))

    train_dir = os.path.join(args.data, 'train')
    if not os.path.exists(train_dir):
        _logger.error(
            'Training folder does not exist at: {}'.format(train_dir))
        exit(1)
    dataset_train = Dataset(train_dir)

    collate_fn = None
    mixup_fn = None
    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
    if mixup_active:
        mixup_args = dict(mixup_alpha=args.mixup,
                          cutmix_alpha=args.cutmix,
                          cutmix_minmax=args.cutmix_minmax,
                          prob=args.mixup_prob,
                          switch_prob=args.mixup_switch_prob,
                          mode=args.mixup_mode,
                          label_smoothing=args.smoothing,
                          num_classes=args.num_classes)
        if args.prefetcher:
            assert not num_aug_splits  # collate conflict (need to support deinterleaving in collate mixup)
            collate_fn = FastCollateMixup(**mixup_args)
        else:
            mixup_fn = Mixup(**mixup_args)

    if num_aug_splits > 1:
        dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)

    train_interpolation = args.train_interpolation
    if args.no_aug or not train_interpolation:
        train_interpolation = data_config['interpolation']
    loader_train = create_loader(
        dataset_train,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        is_training=True,
        use_prefetcher=args.prefetcher,
        no_aug=args.no_aug,
        re_prob=args.reprob,
        re_mode=args.remode,
        re_count=args.recount,
        re_split=args.resplit,
        scale=args.scale,
        ratio=args.ratio,
        hflip=args.hflip,
        vflip=args.vflip,
        color_jitter=args.color_jitter,
        auto_augment=args.aa,
        num_aug_splits=num_aug_splits,
        interpolation=train_interpolation,
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        collate_fn=collate_fn,
        pin_memory=args.pin_mem,
        use_multi_epochs_loader=args.use_multi_epochs_loader)

    eval_dir = os.path.join(args.data, 'val')
    if not os.path.isdir(eval_dir):
        eval_dir = os.path.join(args.data, 'validation')
        if not os.path.isdir(eval_dir):
            _logger.error(
                'Validation folder does not exist at: {}'.format(eval_dir))
            exit(1)
    dataset_eval = Dataset(eval_dir)

    loader_eval = create_loader(
        dataset_eval,
        input_size=data_config['input_size'],
        batch_size=args.validation_batch_size_multiplier * args.batch_size,
        is_training=False,
        use_prefetcher=args.prefetcher,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        crop_pct=data_config['crop_pct'],
        pin_memory=args.pin_mem,
    )

    if args.jsd:
        assert num_aug_splits > 1  # JSD only valid with aug splits set
        train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits,
                                        smoothing=args.smoothing).cuda()
    elif mixup_active:
        # smoothing is handled with mixup target transform
        train_loss_fn = SoftTargetCrossEntropy().cuda()
    elif args.smoothing:
        train_loss_fn = LabelSmoothingCrossEntropy(
            smoothing=args.smoothing).cuda()
    else:
        train_loss_fn = nn.CrossEntropyLoss().cuda()
    validate_loss_fn = nn.CrossEntropyLoss().cuda()

    eval_metric = args.eval_metric
    best_metric = None
    best_epoch = None

    if args.eval_checkpoint:  # evaluate the model
        load_checkpoint(model, args.eval_checkpoint, args.model_ema)
        val_metrics = validate(model, loader_eval, validate_loss_fn, args)
        print(f"Top-1 accuracy of the model is: {val_metrics['top1']:.1f}%")
        return

    saver = None
    output_dir = ''
    if args.local_rank == 0:
        output_base = args.output if args.output else './output'
        exp_name = '-'.join([
            datetime.now().strftime("%Y%m%d-%H%M%S"), args.model,
            str(data_config['input_size'][-1])
        ])
        output_dir = get_outdir(output_base, 'train', exp_name)
        decreasing = True if eval_metric == 'loss' else False
        saver = CheckpointSaver(model=model,
                                optimizer=optimizer,
                                args=args,
                                model_ema=model_ema,
                                amp_scaler=loss_scaler,
                                checkpoint_dir=output_dir,
                                recovery_dir=output_dir,
                                decreasing=decreasing)
        with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
            f.write(args_text)

    try:  # train the model
        for epoch in range(start_epoch, num_epochs):
            if args.distributed:
                loader_train.sampler.set_epoch(epoch)

            train_metrics = train_epoch(epoch,
                                        model,
                                        loader_train,
                                        optimizer,
                                        train_loss_fn,
                                        args,
                                        lr_scheduler=lr_scheduler,
                                        saver=saver,
                                        output_dir=output_dir,
                                        amp_autocast=amp_autocast,
                                        loss_scaler=loss_scaler,
                                        model_ema=model_ema,
                                        mixup_fn=mixup_fn)

            if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
                if args.local_rank == 0:
                    _logger.info(
                        "Distributing BatchNorm running means and vars")
                distribute_bn(model, args.world_size, args.dist_bn == 'reduce')

            eval_metrics = validate(model,
                                    loader_eval,
                                    validate_loss_fn,
                                    args,
                                    amp_autocast=amp_autocast)

            if model_ema is not None and not args.model_ema_force_cpu:
                if args.distributed and args.dist_bn in ('broadcast',
                                                         'reduce'):
                    distribute_bn(model_ema, args.world_size,
                                  args.dist_bn == 'reduce')
                ema_eval_metrics = validate(model_ema.ema,
                                            loader_eval,
                                            validate_loss_fn,
                                            args,
                                            amp_autocast=amp_autocast,
                                            log_suffix=' (EMA)')
                eval_metrics = ema_eval_metrics

            if lr_scheduler is not None:
                # step LR for next epoch
                lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])

            update_summary(epoch,
                           train_metrics,
                           eval_metrics,
                           os.path.join(output_dir, 'summary.csv'),
                           write_header=best_metric is None)

            if saver is not None:
                # save proper checkpoint with eval metric
                save_metric = eval_metrics[eval_metric]
                best_metric, best_epoch = saver.save_checkpoint(
                    epoch, metric=save_metric)

    except KeyboardInterrupt:
        pass
    if best_metric is not None:
        _logger.info('*** Best metric: {0} (epoch {1})'.format(
            best_metric, best_epoch))
Пример #3
0
def main():
    setup_default_logging()
    args, args_text = _parse_args()

    if args.log_wandb:
        if has_wandb:
            wandb.init(project=args.experiment, config=args)
        else:
            _logger.warning(
                "You've requested to log metrics to wandb but package not found. "
                "Metrics not being logged to wandb, try `pip install wandb`")

    args.prefetcher = not args.no_prefetcher
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1
    args.device = 'cuda:0'
    args.world_size = 1
    args.rank = 0  # global rank
    if args.distributed:
        args.device = 'cuda:%d' % args.local_rank
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()
        args.rank = torch.distributed.get_rank()
        _logger.info(
            'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
            % (args.rank, args.world_size))
    else:
        _logger.info('Training with a single process on 1 GPUs.')
    assert args.rank >= 0

    # resolve AMP arguments based on PyTorch / Apex availability
    use_amp = None
    if args.amp:
        # `--amp` chooses native amp before apex (APEX ver not actively maintained)
        if has_native_amp:
            args.native_amp = True
        elif has_apex:
            args.apex_amp = True
    if args.apex_amp and has_apex:
        use_amp = 'apex'
    elif args.native_amp and has_native_amp:
        use_amp = 'native'
    elif args.apex_amp or args.native_amp:
        _logger.warning(
            "Neither APEX or native Torch AMP is available, using float32. "
            "Install NVIDA apex or upgrade to PyTorch 1.6")

    random_seed(args.seed, args.rank)

    model_KD = None
    if args.kd_model_path is not None:
        model_KD = build_kd_model(args)

    model = create_model(
        args.model,
        pretrained=args.pretrained,
        num_classes=args.num_classes,
        drop_rate=args.drop,
        drop_connect_rate=args.drop_connect,  # DEPRECATED, use drop_path
        drop_path_rate=args.drop_path,
        drop_block_rate=args.drop_block,
        global_pool=args.gp,
        bn_momentum=args.bn_momentum,
        bn_eps=args.bn_eps,
        scriptable=args.torchscript,
        checkpoint_path=args.initial_checkpoint)
    if args.num_classes is None:
        assert hasattr(
            model, 'num_classes'
        ), 'Model must have `num_classes` attr if not set on cmd line/config.'
        args.num_classes = model.num_classes  # FIXME handle model default vs config num_classes more elegantly

    if args.local_rank == 0:
        _logger.info(
            f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}'
        )

    data_config = resolve_data_config(vars(args),
                                      model=model,
                                      verbose=args.local_rank == 0)

    # setup augmentation batch splits for contrastive loss or split bn
    num_aug_splits = 0
    if args.aug_splits > 0:
        assert args.aug_splits > 1, 'A split of 1 makes no sense'
        num_aug_splits = args.aug_splits

    # enable split bn (separate bn stats per batch-portion)
    if args.split_bn:
        assert num_aug_splits > 1 or args.resplit
        model = convert_splitbn_model(model, max(num_aug_splits, 2))

    # move model to GPU, enable channels last layout if set
    model.cuda()
    if args.channels_last:
        model = model.to(memory_format=torch.channels_last)

    # setup synchronized BatchNorm for distributed training
    if args.distributed and args.sync_bn:
        assert not args.split_bn
        if has_apex and use_amp == 'apex':
            # Apex SyncBN preferred unless native amp is activated
            model = convert_syncbn_model(model)
        else:
            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
        if args.local_rank == 0:
            _logger.info(
                'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
                'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.'
            )

    if args.torchscript:
        assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
        assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
        model = torch.jit.script(model)

    optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args))

    # setup automatic mixed-precision (AMP) loss scaling and op casting
    amp_autocast = suppress  # do nothing
    loss_scaler = None
    if use_amp == 'apex':
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
        loss_scaler = ApexScaler()
        if args.local_rank == 0:
            _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
    elif use_amp == 'native':
        amp_autocast = torch.cuda.amp.autocast
        loss_scaler = NativeScaler()
        if args.local_rank == 0:
            _logger.info(
                'Using native Torch AMP. Training in mixed precision.')
    else:
        if args.local_rank == 0:
            _logger.info('AMP not enabled. Training in float32.')

    # optionally resume from a checkpoint
    resume_epoch = None
    if args.resume:
        resume_epoch = resume_checkpoint(
            model,
            args.resume,
            optimizer=None if args.no_resume_opt else optimizer,
            loss_scaler=None if args.no_resume_opt else loss_scaler,
            log_info=args.local_rank == 0)

    # setup exponential moving average of model weights, SWA could be used here too
    model_ema = None
    if args.model_ema:
        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
        model_ema = ModelEmaV2(
            model,
            decay=args.model_ema_decay,
            device='cpu' if args.model_ema_force_cpu else None)
        if args.resume:
            load_checkpoint(model_ema.module, args.resume, use_ema=True)

    # setup distributed training
    if args.distributed:
        if has_apex and use_amp == 'apex':
            # Apex DDP preferred unless native amp is activated
            if args.local_rank == 0:
                _logger.info("Using NVIDIA APEX DistributedDataParallel.")
            model = ApexDDP(model, delay_allreduce=True)
        else:
            if args.local_rank == 0:
                _logger.info("Using native Torch DistributedDataParallel.")
            model = NativeDDP(model,
                              device_ids=[args.local_rank],
                              broadcast_buffers=not args.no_ddp_bb)
        # NOTE: EMA model does not need to be wrapped by DDP

    # setup learning rate schedule and starting epoch
    lr_scheduler, num_epochs = create_scheduler(args, optimizer)
    start_epoch = 0
    if args.start_epoch is not None:
        # a specified start_epoch will always override the resume epoch
        start_epoch = args.start_epoch
    elif resume_epoch is not None:
        start_epoch = resume_epoch
    if lr_scheduler is not None and start_epoch > 0:
        lr_scheduler.step(start_epoch)

    if args.local_rank == 0:
        _logger.info('Scheduled epochs: {}'.format(num_epochs))

    # create the train and eval datasets
    dataset_train = create_dataset(args.dataset,
                                   root=args.data_dir,
                                   split=args.train_split,
                                   is_training=True,
                                   class_map=args.class_map,
                                   download=args.dataset_download,
                                   batch_size=args.batch_size,
                                   repeats=args.epoch_repeats)
    dataset_eval = create_dataset(args.dataset,
                                  root=args.data_dir,
                                  split=args.val_split,
                                  is_training=False,
                                  class_map=args.class_map,
                                  download=args.dataset_download,
                                  batch_size=args.batch_size)

    # setup mixup / cutmix
    collate_fn = None
    mixup_fn = None
    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
    if mixup_active:
        mixup_args = dict(mixup_alpha=args.mixup,
                          cutmix_alpha=args.cutmix,
                          cutmix_minmax=args.cutmix_minmax,
                          prob=args.mixup_prob,
                          switch_prob=args.mixup_switch_prob,
                          mode=args.mixup_mode,
                          label_smoothing=args.smoothing,
                          num_classes=args.num_classes)
        if args.prefetcher:
            assert not num_aug_splits  # collate conflict (need to support deinterleaving in collate mixup)
            collate_fn = FastCollateMixup(**mixup_args)
        else:
            mixup_fn = Mixup(**mixup_args)

    # wrap dataset in AugMix helper
    if num_aug_splits > 1:
        dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)

    # create data loaders w/ augmentation pipeiine
    train_interpolation = args.train_interpolation
    if args.no_aug or not train_interpolation:
        train_interpolation = data_config['interpolation']
    loader_train = create_loader(
        dataset_train,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        is_training=True,
        use_prefetcher=args.prefetcher,
        no_aug=args.no_aug,
        re_prob=args.reprob,
        re_mode=args.remode,
        re_count=args.recount,
        re_split=args.resplit,
        scale=args.scale,
        ratio=args.ratio,
        hflip=args.hflip,
        vflip=args.vflip,
        color_jitter=args.color_jitter,
        auto_augment=args.aa,
        num_aug_repeats=args.aug_repeats,
        num_aug_splits=num_aug_splits,
        interpolation=train_interpolation,
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        collate_fn=collate_fn,
        pin_memory=args.pin_mem,
        use_multi_epochs_loader=args.use_multi_epochs_loader,
        worker_seeding=args.worker_seeding,
    )

    loader_eval = create_loader(
        dataset_eval,
        input_size=data_config['input_size'],
        batch_size=args.validation_batch_size or args.batch_size,
        is_training=False,
        use_prefetcher=args.prefetcher,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        crop_pct=data_config['crop_pct'],
        pin_memory=args.pin_mem,
    )

    # setup loss function
    if args.jsd_loss:
        assert num_aug_splits > 1  # JSD only valid with aug splits set
        train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits,
                                        smoothing=args.smoothing)
    elif mixup_active:
        # smoothing is handled with mixup target transform which outputs sparse, soft targets
        if args.bce_loss:
            train_loss_fn = BinaryCrossEntropy(
                target_threshold=args.bce_target_thresh)
        else:
            train_loss_fn = SoftTargetCrossEntropy()
    elif args.smoothing:
        if args.bce_loss:
            train_loss_fn = BinaryCrossEntropy(
                smoothing=args.smoothing,
                target_threshold=args.bce_target_thresh)
        else:
            train_loss_fn = LabelSmoothingCrossEntropy(
                smoothing=args.smoothing)
    else:
        train_loss_fn = nn.CrossEntropyLoss()
    train_loss_fn = train_loss_fn.cuda()
    validate_loss_fn = nn.CrossEntropyLoss().cuda()

    # setup checkpoint saver and eval metric tracking
    eval_metric = args.eval_metric
    best_metric = None
    best_epoch = None
    saver = None
    output_dir = None
    if args.rank == 0:
        if args.experiment:
            exp_name = args.experiment
        else:
            exp_name = '-'.join([
                datetime.now().strftime("%Y%m%d-%H%M%S"),
                safe_model_name(args.model),
                str(data_config['input_size'][-1])
            ])
        output_dir = get_outdir(
            args.output if args.output else './output/train', exp_name)
        decreasing = True if eval_metric == 'loss' else False
        saver = CheckpointSaver(model=model,
                                optimizer=optimizer,
                                args=args,
                                model_ema=model_ema,
                                amp_scaler=loss_scaler,
                                checkpoint_dir=output_dir,
                                recovery_dir=output_dir,
                                decreasing=decreasing,
                                max_history=args.checkpoint_hist)
        with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
            f.write(args_text)

    try:
        for epoch in range(start_epoch, num_epochs):
            if args.distributed and hasattr(loader_train.sampler, 'set_epoch'):
                loader_train.sampler.set_epoch(epoch)

            train_metrics = train_one_epoch(epoch,
                                            model,
                                            loader_train,
                                            optimizer,
                                            train_loss_fn,
                                            args,
                                            lr_scheduler=lr_scheduler,
                                            saver=saver,
                                            output_dir=output_dir,
                                            amp_autocast=amp_autocast,
                                            loss_scaler=loss_scaler,
                                            model_ema=model_ema,
                                            mixup_fn=mixup_fn,
                                            model_KD=model_KD)

            if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
                if args.local_rank == 0:
                    _logger.info(
                        "Distributing BatchNorm running means and vars")
                distribute_bn(model, args.world_size, args.dist_bn == 'reduce')

            eval_metrics = validate(model,
                                    loader_eval,
                                    validate_loss_fn,
                                    args,
                                    amp_autocast=amp_autocast)

            if model_ema is not None and not args.model_ema_force_cpu:
                if args.distributed and args.dist_bn in ('broadcast',
                                                         'reduce'):
                    distribute_bn(model_ema, args.world_size,
                                  args.dist_bn == 'reduce')
                ema_eval_metrics = validate(model_ema.module,
                                            loader_eval,
                                            validate_loss_fn,
                                            args,
                                            amp_autocast=amp_autocast,
                                            log_suffix=' (EMA)')
                eval_metrics = ema_eval_metrics

            if lr_scheduler is not None:
                # step LR for next epoch
                lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])

            if output_dir is not None:
                update_summary(epoch,
                               train_metrics,
                               eval_metrics,
                               os.path.join(output_dir, 'summary.csv'),
                               write_header=best_metric is None,
                               log_wandb=args.log_wandb and has_wandb)

            if saver is not None:
                # save proper checkpoint with eval metric
                save_metric = eval_metrics[eval_metric]
                best_metric, best_epoch = saver.save_checkpoint(
                    epoch, metric=save_metric)

    except KeyboardInterrupt:
        pass
    if best_metric is not None:
        _logger.info('*** Best metric: {0} (epoch {1})'.format(
            best_metric, best_epoch))
Пример #4
0
def validate(args):
    args.pretrained = args.pretrained or (not args.checkpoint)
    args.prefetcher = not args.no_prefetcher
    if os.path.splitext(args.data)[1] == '.tar' and os.path.isfile(args.data):
        dataset = DatasetTar(args.data,
                             load_bytes=args.tf_preprocessing,
                             class_map=args.class_map)
    else:
        dataset = Dataset(args.data,
                          load_bytes=args.tf_preprocessing,
                          class_map=args.class_map)
    logging.info(f'Validation data has {len(dataset)} images')
    args.num_classes = len(dataset.class_to_idx)
    logging.info(f'setting num classes to {args.num_classes}')

    # create model
    model = create_model(args.model,
                         num_classes=args.num_classes,
                         in_chans=3,
                         pretrained=args.pretrained,
                         scriptable=args.torchscript,
                         resnet_structure=args.resnet_structure,
                         resnet_block=args.resnet_block,
                         heaviest_network=args.heaviest_network,
                         use_kernel_3=args.use_kernel_3,
                         exp_r=args.exp_r,
                         depth=args.depth,
                         reduced_exp_ratio=args.reduced_exp_ratio,
                         use_dedicated_pwl_se=args.use_dedicated_pwl_se,
                         multipath_sampling=args.multipath_sampling,
                         force_sync_gpu=args.force_sync_gpu,
                         mobilenet_string=args.mobilenet_string
                         if not args.transform_model_to_mobilenet else '',
                         no_swish=args.no_swish,
                         use_swish=args.use_swish)
    data_config = resolve_data_config(vars(args), model=model)
    if args.checkpoint:
        load_checkpoint(model, args.checkpoint, True, strict=True)

    if 'mobilenasnet' in args.model and args.transform_model_to_mobilenet:
        model.eval()
        expected_latency = model.extract_expected_latency(
            file_name=args.lut_filename,
            batch_size=args.lut_measure_batch_size,
            iterations=args.repeat_measure,
            target=args.target_device)
        model.eval()
        model2, string_model = transform_model_to_mobilenet(
            model, mobilenet_string=args.mobilenet_string)
        del model
        model = model2
        model.eval()
        print('Model converted. Expected latency: {:0.2f}[ms]'.format(
            expected_latency * 1e3))

    elif args.normalize_weights:
        IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
        IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
        std = torch.tensor(IMAGENET_DEFAULT_STD).unsqueeze(0).unsqueeze(
            -1).unsqueeze(-1)
        mean = torch.tensor(IMAGENET_DEFAULT_MEAN).unsqueeze(0).unsqueeze(
            -1).unsqueeze(-1)
        W = model.conv_stem.weight.data
        bnw = model.bn1.weight.data
        bnb = model.bn1.bias.data
        model.conv_stem.weight.data = W / std
        bias = -bnw.data * (W.sum(dim=[-1, -2]) @ (mean / std).squeeze()) / (
            torch.sqrt(model.bn1.running_var + model.bn1.eps))
        model.bn1.bias.data = bnb + bias

    if args.fuse_bn:
        model = fuse_bn(model)

    if args.target_device == 'gpu':
        measure_time(model, batch_size=64, target='gpu')
        t = measure_time(model, batch_size=64, target='gpu')

    elif args.target_device == 'onnx':
        t = measure_time_onnx(model)

    else:
        measure_time(model)
        t = measure_time(model)

    param_count = sum([m.numel() for m in model.parameters()])
    flops = compute_flops(model, data_config['input_size'])
    logging.info(
        'Model {} created, param count: {}, flops: {}, Measured latency ({}): {:0.2f}[ms]'
        .format(args.model, param_count, flops / 1e9, args.target_device,
                t * 1e3))

    data_config = resolve_data_config(vars(args), model=model, verbose=False)
    model, test_time_pool = apply_test_time_pool(model, data_config, args)

    if args.torchscript:
        torch.jit.optimized_execution(True)
        model = torch.jit.script(model)

    if args.amp:
        model = amp.initialize(model.cuda(), opt_level='O1')

    else:
        model = model.cuda()

    if args.num_gpu > 1:
        model = torch.nn.DataParallel(model,
                                      device_ids=list(range(args.num_gpu)))

    criterion = nn.CrossEntropyLoss().cuda()

    crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
    loader = create_loader(
        dataset,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        is_training=False,
        use_prefetcher=args.prefetcher,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        crop_pct=crop_pct,
        pin_memory=args.pin_mem,
        tf_preprocessing=args.tf_preprocessing,
        squish=args.squish,
    )

    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    model.cuda()
    model.eval()
    with torch.no_grad():
        # warmup, reduce variability of first batch time, especially for comparing torchscript vs non
        input = torch.randn((args.batch_size, ) +
                            data_config['input_size']).cuda()
        model(input)
        end = time.time()
        for i, (input, target) in enumerate(loader):
            if i == 0:
                end = time.time()

            if args.no_prefetcher:
                target = target.cuda()
                input = input.cuda()

            if args.amp:
                input = input.half()

            # compute output
            output = model(input)
            loss = criterion(output, target)

            # measure accuracy and record loss
            k = min(5, args.num_classes)
            acc1, acc5 = accuracy(output.data, target, topk=(1, k))

            losses.update(loss.item(), input.size(0))
            top1.update(acc1.item(), input.size(0))
            top5.update(acc5.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.log_freq == 0:
                logging.info(
                    'Test: [{0:>4d}/{1}]  '
                    'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                    'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                    'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f})  '
                    'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
                        i,
                        len(loader),
                        batch_time=batch_time,
                        rate_avg=input.size(0) / batch_time.avg,
                        loss=losses,
                        top1=top1,
                        top5=top5))

    results = OrderedDict(top1=round(top1.avg, 4),
                          top1_err=round(100 - top1.avg, 4),
                          top5=round(top5.avg, 4),
                          top5_err=round(100 - top5.avg, 4),
                          param_count=round(param_count / 1e6, 2),
                          img_size=data_config['input_size'][-1],
                          cropt_pct=crop_pct,
                          interpolation=data_config['interpolation'])

    logging.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format(
        results['top1'], results['top1_err'], results['top5'],
        results['top5_err']))

    return results
Пример #5
0
def validate(args):
    _logger.info(f'\n\n ---------------EVALUATION {args.eps}------------------------------- \n\n')
    _logger.info("Argument parser collected the following arguments:")
    for arg in vars(args):
        _logger.info(f"    {arg}:{getattr(args, arg)}")
    _logger.info("\n")

    # might as well try to validate something
    args.pretrained = args.pretrained or not args.checkpoint
    args.prefetcher = not args.no_prefetcher
    amp_autocast = suppress  # do nothing
    if args.amp:
        if has_native_amp:
            args.native_amp = True
        elif has_apex:
            args.apex_amp = True
        else:
            _logger.warning("Neither APEX or Native Torch AMP is available.")
    assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set."
    if args.native_amp:
        amp_autocast = torch.cuda.amp.autocast
        _logger.info('Validating in mixed precision with native PyTorch AMP.')
    elif args.apex_amp:
        _logger.info('Validating in mixed precision with NVIDIA APEX AMP.')
    else:
        _logger.info('Validating in float32. AMP not enabled.')

    if args.legacy_jit:
        set_jit_legacy()

    # create model
    model = create_model(
        args.model,
        pretrained=args.pretrained,
        num_classes=args.num_classes,
        in_chans=3,
        global_pool=args.gp,
        scriptable=args.torchscript)
    if args.num_classes is None:
        assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
        args.num_classes = model.num_classes

    if args.checkpoint:
        load_checkpoint(model, args.checkpoint, args.use_ema)

    param_count = sum([m.numel() for m in model.parameters()])        
    _logger.info(
        f'Model {args.model} created, param count: {param_count} ({(float(param_count)/(10.0**6)):.1f} M)'
    )

    data_config = resolve_data_config(vars(args), model=model, use_test_size=True, verbose=True)
    test_time_pool = False
    if not args.no_test_pool:
        model, test_time_pool = apply_test_time_pool(model, data_config, use_test_size=True)

    if args.torchscript:
        torch.jit.optimized_execution(True)
        model = torch.jit.script(model)

    model = model.cuda()
    if args.apex_amp:
        model = amp.initialize(model, opt_level='O1')

    if args.channels_last:
        model = model.to(memory_format=torch.channels_last)

    if args.num_gpu > 1:
        model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu)))

    criterion = nn.CrossEntropyLoss().cuda()

    dataset = create_dataset(
        root=args.data_dir, name=args.dataset, split=args.split,
        load_bytes=args.tf_preprocessing, class_map=args.class_map)

    if args.valid_labels:
        with open(args.valid_labels, 'r') as f:
            valid_labels = {int(line.rstrip()) for line in f}
            valid_labels = [i in valid_labels for i in range(args.num_classes)]
    else:
        valid_labels = None

    if args.real_labels:
        real_labels = RealLabelsImagenet(dataset.filenames(basename=True), real_json=args.real_labels)
    else:
        real_labels = None

    crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
    loader = create_loader(
        dataset,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        use_prefetcher=args.prefetcher,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        crop_pct=crop_pct,
        pin_memory=args.pin_mem,
        tf_preprocessing=args.tf_preprocessing)

    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    top1_fgm_ae = AverageMeter()
    top5_fgm_ae = AverageMeter()
    top1_pgd_ae = AverageMeter()
    top5_pgd_ae = AverageMeter()

    model.eval()
    #with torch.no_grad():# TODO Requires grad
    # warmup, reduce variability of first batch time, especially for comparing torchscript vs non
    input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).cuda()
    if args.channels_last:
        input = input.contiguous(memory_format=torch.channels_last)
    model(input)
    end = time.time()
    for batch_idx, (input, target) in enumerate(loader):
        if args.no_prefetcher:
            target = target.cuda()
            input = input.cuda()
        if args.channels_last:
            input = input.contiguous(memory_format=torch.channels_last)

        # compute output
        with amp_autocast():
            output = model(input)

        if valid_labels is not None:
            output = output[:, valid_labels]
        loss = criterion(output, target)

        if real_labels is not None:
            real_labels.add_result(output)

        # TODO <---------------------
        # Generate adversarial examples for current inputs
        input_fgm_ae = fast_gradient_method(
            model_fn=model,
            x=input,
            eps=args.eps,
            norm=np.inf,
            clip_min=None,
            clip_max=None,
        )
        input_pgd_ae = projected_gradient_descent(
            model_fn=model,
            x=input, 
            eps=args.eps, 
            eps_iter=0.01, 
            nb_iter=40, 
            norm=np.inf,
            clip_min=None,
            clip_max=None,
        )
        # Predict with Adversarial Examples
        with torch.no_grad():
            with amp_autocast():
                output_fgm_ae = model(input_fgm_ae)
                output_pgd_ae = model(input_pgd_ae)

        # measure accuracy and record loss
        acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5))
        losses.update(loss.item(), input.size(0))
        top1.update(acc1.item(), input.size(0))
        top5.update(acc5.item(), input.size(0))

        acc1_fgm_ae, acc5_fgm_ae = accuracy(output_fgm_ae.detach(), target, topk=(1, 5))
        acc1_pgd_ae, acc5_pgd_ae = accuracy(output_pgd_ae.detach(), target, topk=(1, 5))
        top1_fgm_ae.update(acc1_fgm_ae.item(), input.size(0))
        top5_fgm_ae.update(acc5_fgm_ae.item(), input.size(0))
        top1_pgd_ae.update(acc1_pgd_ae.item(), input.size(0))
        top5_pgd_ae.update(acc5_pgd_ae.item(), input.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if batch_idx % args.log_freq == 0:
            _logger.info(
                'Test: [{0:>4d}/{1}]  '
                'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f})  '
                'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
                    batch_idx, len(loader), batch_time=batch_time,
                    rate_avg=input.size(0) / batch_time.avg,
                    loss=losses, top1=top1, top5=top5))

    if real_labels is not None:
        raise NotImplementedError # TODO NOt modified for the adversarial examples mode 
        # real labels mode replaces topk values at the end
        top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy(k=5)
    else:
        top1a, top5a = top1.avg, top5.avg
        top1a_fgm_ae, top5a_fgm_ae = top1_fgm_ae.avg, top5_fgm_ae.avg
        top1a_pgd_ae, top5a_pgd_ae = top1_pgd_ae.avg, top5_pgd_ae.avg
    results = OrderedDict(
        top1=round(top1a, 4), top1_err=round(100 - top1a, 4),
        top5=round(top5a, 4), top5_err=round(100 - top5a, 4),
        top1_fgm_ae=round(top1a_fgm_ae, 4),
        top5_fgm_ae=round(top5a_fgm_ae, 4),
        top1_pgd_ae=round(top1a_pgd_ae, 4),
        top5_pgd_ae=round(top5a_pgd_ae, 4),
        param_count=round(param_count / 1e6, 2),
        img_size=data_config['input_size'][-1],
        cropt_pct=crop_pct,
        interpolation=data_config['interpolation'])

    _logger.info(' * [Regular] Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format(
       results['top1'], results['top1_err'], results['top5'], results['top5_err']))

    _logger.info(' * [FGM Adversarial Attack] Acc@1 {:.3f}  Acc@5 {:.3f} '.format(
       results['top1_fgm_ae'], results['top5_fgm_ae']))
    _logger.info(' * [PGD Adversarial Attack] Acc@1 {:.3f}  Acc@5 {:.3f} '.format(
       results['top1_pgd_ae'], results['top5_pgd_ae']))

    return results
Пример #6
0
def validate(args):
    # might as well try to validate something
    args.pretrained = args.pretrained or not args.checkpoint
    args.prefetcher = not args.no_prefetcher
    amp_autocast = suppress  # do nothing
    if args.amp:
        if has_apex:
            args.apex_amp = True
        elif has_native_amp:
            args.native_amp = True
        else:
            _logger.warning(
                "Neither APEX or Native Torch AMP is available, using FP32.")
    assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set."
    if args.native_amp:
        amp_autocast = torch.cuda.amp.autocast

    if args.legacy_jit:
        set_jit_legacy()

    # create model
    model = create_model(args.model,
                         pretrained=args.pretrained,
                         num_classes=args.num_classes,
                         in_chans=3,
                         global_pool=args.gp,
                         scriptable=args.torchscript)
    if args.num_classes is None:
        assert hasattr(
            model, 'num_classes'
        ), 'Model must have `num_classes` attr if not set on cmd line/config.'
        args.num_classes = model.num_classes

    if args.checkpoint:
        load_checkpoint(model, args.checkpoint, args.use_ema)

    param_count = sum([m.numel() for m in model.parameters()])
    _logger.info('Model %s created, param count: %d' %
                 (args.model, param_count))

    data_config = resolve_data_config(vars(args), model=model)
    model, test_time_pool = (
        model, False) if args.no_test_pool else apply_test_time_pool(
            model, data_config)

    if args.torchscript:
        torch.jit.optimized_execution(True)
        model = torch.jit.script(model)

    model = model.cuda()
    if args.apex_amp:
        model = amp.initialize(model, opt_level='O1')

    if args.channels_last:
        model = model.to(memory_format=torch.channels_last)

    if args.num_gpu > 1:
        model = torch.nn.DataParallel(model,
                                      device_ids=list(range(args.num_gpu)))

    criterion = nn.CrossEntropyLoss().cuda()

    dataset = create_dataset(root=args.data,
                             name=args.dataset,
                             split=args.split,
                             load_bytes=args.tf_preprocessing,
                             class_map=args.class_map)

    if args.valid_labels:
        with open(args.valid_labels, 'r') as f:
            valid_labels = {int(line.rstrip()) for line in f}
            valid_labels = [i in valid_labels for i in range(args.num_classes)]
    else:
        valid_labels = None

    if args.real_labels:
        real_labels = RealLabelsImagenet(dataset.filenames(basename=True),
                                         real_json=args.real_labels)
    else:
        real_labels = None

    crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
    loader = create_loader(dataset,
                           input_size=data_config['input_size'],
                           batch_size=args.batch_size,
                           use_prefetcher=args.prefetcher,
                           interpolation=data_config['interpolation'],
                           mean=data_config['mean'],
                           std=data_config['std'],
                           num_workers=args.workers,
                           crop_pct=crop_pct,
                           pin_memory=args.pin_mem,
                           tf_preprocessing=args.tf_preprocessing)

    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    model.eval()
    with torch.no_grad():
        # warmup, reduce variability of first batch time, especially for comparing torchscript vs non
        input = torch.randn((args.batch_size, ) +
                            data_config['input_size']).cuda()
        if args.channels_last:
            input = input.contiguous(memory_format=torch.channels_last)
        model(input)
        end = time.time()
        for batch_idx, (input, target) in enumerate(loader):
            if args.no_prefetcher:
                target = target.cuda()
                input = input.cuda()
            if args.channels_last:
                input = input.contiguous(memory_format=torch.channels_last)

            # compute output
            with amp_autocast():
                output = model(input)

            if valid_labels is not None:
                output = output[:, valid_labels]
            loss = criterion(output, target)

            if real_labels is not None:
                real_labels.add_result(output)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5))
            losses.update(loss.item(), input.size(0))
            top1.update(acc1.item(), input.size(0))
            top5.update(acc5.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if batch_idx % args.log_freq == 0:
                _logger.info(
                    'Test: [{0:>4d}/{1}]  '
                    'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                    'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                    'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f})  '
                    'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
                        batch_idx,
                        len(loader),
                        batch_time=batch_time,
                        rate_avg=input.size(0) / batch_time.avg,
                        loss=losses,
                        top1=top1,
                        top5=top5))

    if real_labels is not None:
        # real labels mode replaces topk values at the end
        top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy(
            k=5)
    else:
        top1a, top5a = top1.avg, top5.avg
    results = OrderedDict(top1=round(top1a, 4),
                          top1_err=round(100 - top1a, 4),
                          top5=round(top5a, 4),
                          top5_err=round(100 - top5a, 4),
                          param_count=round(param_count / 1e6, 2),
                          img_size=data_config['input_size'][-1],
                          cropt_pct=crop_pct,
                          interpolation=data_config['interpolation'])

    _logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format(
        results['top1'], results['top1_err'], results['top5'],
        results['top5_err']))

    return results
Пример #7
0
def train_eval_fold(
    df: pd.DataFrame,
    image_dir: Path,
    args: argparse.Namespace,
    Config: dict,
    exp_name: str,
    use_amp: bool,
    checkpoint_path: Optional[Path],
) -> Optional[float]:
    """
    One full train/eval loop with validation on fold `fold`
    df: should have `fold` column
    """
    train_df = df.copy()
    if not Config["train_on_full"]:
        train_df = df[df["fold"] != args.fold].copy().reset_index(drop=True)
    # validation
    val_df = df[df["fold"] == args.fold].copy().reset_index(drop=True)

    train_ds, val_ds = init_datasets(
        Config,
        train_df,
        val_df,
        image_dir,
        txt_mod_name_or_path=Config["bert_name"],
        use_text=Config["arc_face_text"],
    )
    dataloaders = init_dataloaders(train_ds, val_ds, Config)
    logging.info(f"Data: train size: {len(train_ds)}, val_size: {len(val_ds)}")

    num_classes = int(train_df[Config["target_col"]].max() + 1)
    model = init_model(num_classes, Config, pretrained=True)
    logging.info(
        f"Model {model} created, param count: {sum([m.numel() for m in model.parameters()]):_}"
    )
    model.cuda()

    if Config["channels_last"]:
        model = model.to(memory_format=torch.channels_last)

    optimizer = init_optimizer(model, Config["opt_conf"], diff_lr=Config["diff_lr"])
    logging.info(f"Using optimizer: {optimizer}")
    amp_scaler = NativeScaler() if use_amp else None
    logging.info(f"AMP: {amp_scaler}")

    # optionally resume from a checkpoint
    resume_epoch = None
    resume_loss = None
    resume_score = None
    if checkpoint_path:
        resume_epoch, resume_loss, resume_score, _ = resume_checkpoint(
            model, checkpoint_path, optimizer=optimizer, loss_scaler=amp_scaler,
        )

    # setup exponential moving average of model weights, SWA could be used here too
    model_ema = None
    if Config["model_ema"]:
        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
        model_ema = ModelEmaV2(
            model,
            decay=Config["model_ema_decay"],
            device="cpu" if Config["model_ema_force_cpu"] else None,
        )
        if checkpoint_path:
            load_checkpoint(model_ema.module, checkpoint_path, use_ema=True)

    # regular init of scheduler
    scheduler = init_scheduler(optimizer, Config["sch_conf"])
    if scheduler is not None and resume_epoch is not None:
        scheduler.step(resume_epoch, resume_loss)
        logging.info(
            f"""after resume and step: lr - {optimizer.param_groups[0]['lr']},
            initial lr -{optimizer.param_groups[0]['initial_lr']}"""
        )
    tr_criterion = ArcFaceLoss(num_classes, s=Config["s"], m=Config["m"])

    result = train_model(
        model=model,
        dataloaders=dataloaders,
        optimizer=optimizer,
        tr_criterion=tr_criterion,
        scheduler=scheduler,
        metrics_fn=binned_threshold_f1,
        exp_name=f"{exp_name}_f{args.fold}",
        Config=Config,
        use_amp=use_amp,
        amp_scaler=amp_scaler,
        model_ema=model_ema,
        resume_epoch=resume_epoch,
        resume_loss=resume_loss,
        resume_score=resume_score,
    )
    return result
def validate(args):
    # might as well try to validate something
    args.pretrained = args.pretrained or not args.checkpoint
    args.prefetcher = not args.no_prefetcher
#    amp_autocast = suppress  # do nothing
#   if args.amp:
#        if has_native_amp:
#            args.native_amp = True
#        elif has_apex:
#            args.apex_amp = True
#        else:
#            _logger.warning("Neither APEX or Native Torch AMP is available.")
#    assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set."
#    if args.native_amp:
#        amp_autocast = torch.cuda.amp.autocast
#        _logger.info('Validating in mixed precision with native PyTorch AMP.')
#   elif args.apex_amp:
#        _logger.info('Validating in mixed precision with NVIDIA APEX AMP.')
#    else:
#        _logger.info('Validating in float32. AMP not enabled.')

    if args.legacy_jit:
        set_jit_legacy()

    # create model
    model = create_model(
        args.model,
        pretrained=args.pretrained,
        num_classes=args.num_classes,
        in_chans=3,
        global_pool=args.gp,
        scriptable=args.torchscript)
    if args.num_classes is None:
        assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
        args.num_classes = model.num_classes

    if args.checkpoint:
        load_checkpoint(model, args.checkpoint, args.use_ema)

    param_count = sum([m.numel() for m in model.parameters()])
    _logger.info('Model %s created, param count: %d' % (args.model, param_count))

    data_config = resolve_data_config(vars(args), model=model, use_test_size=True)
    test_time_pool = False
    if not args.no_test_pool:
        model, test_time_pool = apply_test_time_pool(model, data_config, use_test_size=True)

    if args.torchscript:
        torch.jit.optimized_execution(True)
        model = torch.jit.script(model)

#    model = model.cuda()
#    if args.apex_amp:
#        model = amp.initialize(model, opt_level='O1')

    if args.channels_last:
        model = model.to(memory_format=torch.channels_last)

#    if args.num_gpu > 1:
#        model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu)))

#   criterion = nn.CrossEntropyLoss().cuda()
    criterion = nn.CrossEntropyLoss()
    dataset = create_dataset(
        root=args.data, name=args.dataset, split=args.split,
        load_bytes=args.tf_preprocessing, class_map=args.class_map)

    # added for post quantization calibration

    calib_dataset = create_dataset(
        root=args.data, name=args.dataset, split=args.split,
        load_bytes=args.tf_preprocessing, class_map=args.class_map)
        

    if args.valid_labels:
        with open(args.valid_labels, 'r') as f:
            valid_labels = {int(line.rstrip()) for line in f}
            valid_labels = [i in valid_labels for i in range(args.num_classes)]
    else:
        valid_labels = None

    if args.real_labels:
        real_labels = RealLabelsImagenet(dataset.filenames(basename=True), real_json=args.real_labels)
    else:
        real_labels = None

    crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
    loader = create_loader(
        dataset,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        use_prefetcher=args.prefetcher,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        crop_pct=crop_pct,
        pin_memory=args.pin_mem,
        tf_preprocessing=args.tf_preprocessing)

    #Also create loader for calibration dataset
    calib_loader = create_loader(
        calib_dataset,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        use_prefetcher=args.prefetcher,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        crop_pct=crop_pct,
        pin_memory=args.pin_mem,
        tf_preprocessing=args.tf_preprocessing)

    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    print('Start calibration of quantization observers before post-quantization')
    model_to_quantize = copy.deepcopy(model)
    model_to_quantize.eval()

    #post training static quantization
    if args.quant_option == 'static':
        qconfig_dict = {"": torch.quantization.default_static_qconfig} 
        model_to_quantize = copy.deepcopy(model_fp)
        qconfig_dict = {"": torch.quantization.get_default_qconfig('qnnpack')}
        model_to_quantize.eval()
        # prepare
        model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_dict)
        # calibrate 
        with torch.no_grad():
            # warmup, reduce variability of first batch time, especially for comparing torchscript vs non
            input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])) 
            if args.channels_last:
                input = input.contiguous(memory_format=torch.channels_last)
            model(input)
            end = time.time()
            for batch_idx, (input, target) in enumerate(loader):

                if args.channels_last:
                    input = input.contiguous(memory_format=torch.channels_last)

                if valid_labels is not None:
                    output = output[:, valid_labels]
                loss = criterion(output, target)

                if real_labels is not None:
                    real_labels.add_result(output)

                # measure accuracy and record loss
                acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5))
                losses.update(loss.item(), input.size(0))
                top1.update(acc1.item(), input.size(0))
                top5.update(acc5.item(), input.size(0))

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                if batch_idx % args.log_freq == 0:
                    _logger.info(
                        'Test: [{0:>4d}/{1}]  '
                        'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                        'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                        'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f})  '
                        'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
                            batch_idx, len(loader), batch_time=batch_time,
                            rate_avg=input.size(0) / batch_time.avg,
                            loss=losses, top1=top1, top5=top5))        
        # quantize
        model_quantized = quantize_fx.convert_fx(model_prepared)           
    #post training dynamic/weight only quantization    
    elif args.quant_option == 'dynamic':    
        qconfig_dict = {"": torch.quantization.default_dynamic_qconfig}
        # prepare
        model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_dict)
        # no calibration needed when we only have dynamici/weight_only quantization
        # quantize
        model_quantized = quantize_fx.convert_fx(model_prepared)       
    else:
        _logger.warning("Invalid quantization option. Set option to default(static)")
    #
    # fusion
    #
    model_to_quantize = copy.deepcopy(model_fp)
    model_fused = quantize_fx.fuse_fx(model_to_quantize)   

    model = model_fused

    with torch.no_grad():
        # warmup, reduce variability of first batch time, especially for comparing torchscript vs non
#        input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).cuda()
        input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])) 
        if args.channels_last:
            input = input.contiguous(memory_format=torch.channels_last)
        model(input)
        end = time.time()
        for batch_idx, (input, target) in enumerate(loader):
 #           if args.no_prefetcher:
 #               target = target.cuda()
 #               input = input.cuda()
            if args.channels_last:
                input = input.contiguous(memory_format=torch.channels_last)

            # compute output
    #        with amp_autocast():
    #            output = model(input)

            if valid_labels is not None:
                output = output[:, valid_labels]
            loss = criterion(output, target)

            if real_labels is not None:
                real_labels.add_result(output)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5))
            losses.update(loss.item(), input.size(0))
            top1.update(acc1.item(), input.size(0))
            top5.update(acc5.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if batch_idx % args.log_freq == 0:
                _logger.info(
                    'Test: [{0:>4d}/{1}]  '
                    'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                    'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                    'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f})  '
                    'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
                        batch_idx, len(loader), batch_time=batch_time,
                        rate_avg=input.size(0) / batch_time.avg,
                        loss=losses, top1=top1, top5=top5))

    if real_labels is not None:
        # real labels mode replaces topk values at the end
        top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy(k=5)
    else:
        top1a, top5a = top1.avg, top5.avg
    results = OrderedDict(
        top1=round(top1a, 4), top1_err=round(100 - top1a, 4),
        top5=round(top5a, 4), top5_err=round(100 - top5a, 4),
        param_count=round(param_count / 1e6, 2),
        img_size=data_config['input_size'][-1],
        cropt_pct=crop_pct,
        interpolation=data_config['interpolation'])

    _logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format(
       results['top1'], results['top1_err'], results['top5'], results['top5_err']))

    return results
Пример #9
0
def validate(args):
    # might as well try to validate something
    args.pretrained = args.pretrained or not args.checkpoint
    args.prefetcher = not args.no_prefetcher

    # create model
    model = create_model(args.model,
                         num_classes=args.num_classes,
                         in_chans=3,
                         pretrained=args.pretrained)

    if args.checkpoint:
        load_checkpoint(model, args.checkpoint, args.use_ema)

    param_count = sum([m.numel() for m in model.parameters()])
    logging.info('Model %s created, param count: %d' %
                 (args.model, param_count))

    data_config = resolve_data_config(vars(args), model=model)
    model, test_time_pool = apply_test_time_pool(model, data_config, args)

    if args.num_gpu > 1:
        model = torch.nn.DataParallel(model,
                                      device_ids=list(range(
                                          args.num_gpu))).cuda()
    else:
        model = model.cuda()

    if args.fp16:
        model = model.half()

    criterion = nn.CrossEntropyLoss().cuda()

    if os.path.splitext(args.data)[1] == '.tar' and os.path.isfile(args.data):
        dataset = DatasetTar(args.data, load_bytes=args.tf_preprocessing)
    else:
        dataset = Dataset(args.data, load_bytes=args.tf_preprocessing)

    crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
    loader = create_loader(dataset,
                           input_size=data_config['input_size'],
                           batch_size=args.batch_size,
                           use_prefetcher=args.prefetcher,
                           interpolation=data_config['interpolation'],
                           mean=data_config['mean'],
                           std=data_config['std'],
                           num_workers=args.workers,
                           crop_pct=crop_pct,
                           fp16=args.fp16,
                           tf_preprocessing=args.tf_preprocessing)

    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    c_matrix = np.zeros((40, 40), dtype=int)
    labels = np.arange(0, 40, 1)

    model.eval()
    end = time.time()
    with torch.no_grad():
        cf = open('results.csv', 'w')
        cv = open('results-parent.csv', 'w')
        writer = csv.writer(cf)
        writer_2 = csv.writer(cv)
        for i, (input, target) in enumerate(loader):
            if args.no_prefetcher:
                target = target.cuda()
                input = input.cuda()
                if args.fp16:
                    input = input.half()

            # compute output
            output = model(input)
            loss = criterion(output, target)

            # measure accuracy and record loss
            prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
            losses.update(loss.item(), input.size(0))
            top1.update(prec1.item(), input.size(0))
            top5.update(prec5.item(), input.size(0))
            c_matrix += cal_confusions(output, target, labels=labels)

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            writer.writerow([i, round(top1.avg, 4)])
            # 计算大类分类准确率
            if args.hier_classify:
                a = [i for i in range(0, 6)]
                b = [i for i in range(6, 14)]
                c = [i for i in range(14, 37)]
                d = [i for i in range(37, 40)]
                corrects = 0.
                corrects += c_matrix[a][:, a].sum()
                corrects += c_matrix[b][:, b].sum()
                corrects += c_matrix[c][:, c].sum()
                corrects += c_matrix[d][:, d].sum()

                writer_2.writerow([i, round(corrects / c_matrix.sum(), 4)])
                logging.info('parent precision: {}'.format(corrects /
                                                           c_matrix.sum()))

            if i % args.log_freq == 0:
                logging.info(
                    'Test: [{0:>4d}/{1}]  '
                    'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                    'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                    'Prec@1: {top1.val:>7.3f} ({top1.avg:>7.3f})  '
                    'Prec@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
                        i,
                        len(loader),
                        batch_time=batch_time,
                        rate_avg=input.size(0) / batch_time.avg,
                        loss=losses,
                        top1=top1,
                        top5=top5))
        cf.close()
        cv.close()

    results = OrderedDict(top1=round(top1.avg, 4),
                          top1_err=round(100 - top1.avg, 4),
                          top5=round(top5.avg, 4),
                          top5_err=round(100 - top5.avg, 4),
                          param_count=round(param_count / 1e6, 2),
                          img_size=data_config['input_size'][-1],
                          cropt_pct=crop_pct,
                          interpolation=data_config['interpolation'])

    logging.info(' * Prec@1 {:.3f} ({:.3f}) Prec@5 {:.3f} ({:.3f})'.format(
        results['top1'], results['top1_err'], results['top5'],
        results['top5_err']))

    logging.info('confusion_matrix: \n {}'.format(c_matrix))
    logging.info('precision by confusion matrix: \n {}'.format(
        truediv(np.sum(np.diag(c_matrix)), np.sum(np.sum(c_matrix, axis=1)))))
    # with open('confusion_matrix.csv', 'w') as cf:
    #     writer = csv.writer(cf)
    #     for row in c_matrix:
    #         writer.writerow(row)
    #
    #     diag = np.diag(c_matrix)
    #     each_acc = truediv(diag, np.sum(c_matrix, axis=1))
    #     writer.writerow(each_acc)

    return results
Пример #10
0
def validate(args):
    # might as well try to validate something
    args.pretrained = False
    args.prefetcher = True

    # create model
    model = eval(args.model)(config_path=args.config_path,
                             target_flops=args.target_flops,
                             num_classes=args.num_classes,
                             bn_momentum=args.bn_momentum,
                             activation=args.activation,
                             se=args.se)

    if args.checkpoint:
        load_checkpoint(model, args.checkpoint, True)

    param_count = sum([m.numel() for m in model.parameters()])
    logging.info('Model %s created, param count: %d' %
                 (args.model, param_count))

    data_config = resolve_data_config(vars(args), model=model)
    #model, test_time_pool = apply_test_time_pool(model, data_config, args)

    if args.num_gpu > 1:
        model = torch.nn.DataParallel(model,
                                      device_ids=list(range(
                                          args.num_gpu))).cuda()
    else:
        model = model.cuda()

    criterion = nn.CrossEntropyLoss().cuda()

    if args.lmdb:
        eval_dir = os.path.join(args.data, 'test_lmdb', 'test.lmdb')
        dataset_eval = ImageFolderLMDB(eval_dir, None, None)
    else:
        eval_dir = os.path.join(args.data, 'val')
        dataset_eval = Dataset(eval_dir)

    #crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
    crop_pct = 1.0
    loader = create_loader(dataset_eval,
                           input_size=data_config['input_size'],
                           batch_size=args.batch_size,
                           is_training=False,
                           use_prefetcher=args.prefetcher,
                           interpolation=data_config['interpolation'],
                           mean=data_config['mean'],
                           std=data_config['std'],
                           num_workers=args.workers)
    # crop_pct=crop_pct)

    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    model.eval()
    end = time.time()
    with torch.no_grad():
        for i, (input, target) in enumerate(loader):

            # compute output
            output = model(input)
            loss = criterion(output, target)

            # measure accuracy and record loss
            prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
            losses.update(loss.item(), input.size(0))
            top1.update(prec1.item(), input.size(0))
            top5.update(prec5.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.log_freq == 0:
                logging.info(
                    'Test: [{0:>4d}/{1}]  '
                    'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                    'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                    'Prec@1: {top1.val:>7.3f} ({top1.avg:>7.3f})  '
                    'Prec@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
                        i,
                        len(loader),
                        batch_time=batch_time,
                        rate_avg=input.size(0) / batch_time.avg,
                        loss=losses,
                        top1=top1,
                        top5=top5))

    results = OrderedDict(top1=round(top1.avg, 4),
                          top1_err=round(100 - top1.avg, 4),
                          top5=round(top5.avg, 4),
                          top5_err=round(100 - top5.avg, 4),
                          param_count=round(param_count / 1e6, 2),
                          img_size=data_config['input_size'][-1],
                          cropt_pct=crop_pct,
                          interpolation=data_config['interpolation'])

    logging.info(' * Prec@1 {:.3f} ({:.3f}) Prec@5 {:.3f} ({:.3f})'.format(
        results['top1'], results['top1_err'], results['top5'],
        results['top5_err']))

    return results
Пример #11
0
def main(fold_i=0, data_=None, train_index=None, val_index=None):
    setup_default_logging()
    args, args_text = _parse_args()

    args.prefetcher = not args.no_prefetcher
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1
    args.device = 'cuda:0'
    args.world_size = 1
    args.rank = 0  # global rank
    best_score = 0.0
    args.output = args.output + 'fold_' + str(fold_i)
    if args.distributed:
        args.device = 'cuda:%d' % args.local_rank
        torch.cuda.set_device(args.local_rank)
        if fold_i == 0:
            torch.distributed.init_process_group(backend='nccl',
                                                 init_method='env://')
        args.world_size = torch.distributed.get_world_size()
        args.rank = torch.distributed.get_rank()
        _logger.info(
            'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
            % (args.rank, args.world_size))
    else:
        _logger.info('Training with a single process on 1 GPUs.')
    assert args.rank >= 0

    # resolve AMP arguments based on PyTorch / Apex availability
    use_amp = None
    if args.amp:
        # for backwards compat, `--amp` arg tries apex before native amp
        if has_apex:
            args.apex_amp = True
        elif has_native_amp:
            args.native_amp = True
    if args.apex_amp and has_apex:
        use_amp = 'apex'
    elif args.native_amp and has_native_amp:
        use_amp = 'native'
    elif args.apex_amp or args.native_amp:
        _logger.warning(
            "Neither APEX or native Torch AMP is available, using float32. "
            "Install NVIDA apex or upgrade to PyTorch 1.6")

    torch.manual_seed(args.seed + args.rank)

    model = create_model(
        args.model,
        pretrained=args.pretrained,
        num_classes=args.num_classes,
        drop_rate=args.drop,
        drop_connect_rate=args.drop_connect,  # DEPRECATED, use drop_path
        drop_path_rate=args.drop_path,
        drop_block_rate=args.drop_block,
        global_pool=args.gp,
        bn_tf=args.bn_tf,
        bn_momentum=args.bn_momentum,
        bn_eps=args.bn_eps,
        scriptable=args.torchscript,
        checkpoint_path=args.initial_checkpoint)

    if args.local_rank == 0:
        _logger.info('Model %s created, param count: %d' %
                     (args.model, sum([m.numel()
                                       for m in model.parameters()])))

    data_config = resolve_data_config(vars(args),
                                      model=model,
                                      verbose=args.local_rank == 0)

    # setup augmentation batch splits for contrastive loss or split bn
    num_aug_splits = 0
    if args.aug_splits > 0:
        assert args.aug_splits > 1, 'A split of 1 makes no sense'
        num_aug_splits = args.aug_splits

    # enable split bn (separate bn stats per batch-portion)
    if args.split_bn:
        assert num_aug_splits > 1 or args.resplit
        model = convert_splitbn_model(model, max(num_aug_splits, 2))

    # move model to GPU, enable channels last layout if set
    model = nn.DataParallel(model)
    model.cuda()
    if args.channels_last:
        model = model.to(memory_format=torch.channels_last)

    # setup synchronized BatchNorm for distributed training
    if args.distributed and args.sync_bn:
        assert not args.split_bn
        if has_apex and use_amp != 'native':
            # Apex SyncBN preferred unless native amp is activated
            model = convert_syncbn_model(model)
        else:
            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
        if args.local_rank == 0:
            _logger.info(
                'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
                'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.'
            )

    if args.torchscript:
        assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
        assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
        model = torch.jit.script(model)

    optimizer = create_optimizer(args, model)

    #optimizer = torch.optim.SGD(model.parameters(), lr=0.1, weight_decay=1e-6)
    # setup automatic mixed-precision (AMP) loss scaling and op casting

    amp_autocast = suppress  # do nothing
    loss_scaler = None
    if use_amp == 'apex':
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
        loss_scaler = ApexScaler()
        if args.local_rank == 0:
            _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
    elif use_amp == 'native':
        amp_autocast = torch.cuda.amp.autocast
        loss_scaler = NativeScaler()
        if args.local_rank == 0:
            _logger.info(
                'Using native Torch AMP. Training in mixed precision.')
    else:
        if args.local_rank == 0:
            _logger.info('AMP not enabled. Training in float32.')

    # optionally resume from a checkpoint
    resume_epoch = None
    if args.resume:
        resume_epoch = resume_checkpoint(
            model,
            args.resume,
            optimizer=None if args.no_resume_opt else optimizer,
            loss_scaler=None if args.no_resume_opt else loss_scaler,
            log_info=args.local_rank == 0)

    # setup exponential moving average of model weights, SWA could be used here too
    model_ema = None
    if args.model_ema:
        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
        model_ema = ModelEmaV2(
            model,
            decay=args.model_ema_decay,
            device='cpu' if args.model_ema_force_cpu else None)
        if args.resume:
            load_checkpoint(model_ema.module, args.resume, use_ema=True)

    # setup distributed training
    if args.distributed:
        if has_apex and use_amp != 'native':
            # Apex DDP preferred unless native amp is activated
            if args.local_rank == 0:
                _logger.info("Using NVIDIA APEX DistributedDataParallel.")
            model = ApexDDP(model, delay_allreduce=True)
        else:
            if args.local_rank == 0:
                _logger.info("Using native Torch DistributedDataParallel.")
            model = NativeDDP(model, device_ids=[
                args.local_rank
            ])  # can use device str in Torch >= 1.1
        # NOTE: EMA model does not need to be wrapped by DDP
    lr_scheduler, num_epochs = create_scheduler(args, optimizer)
    # lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=1, eta_min=1e-6, last_epoch=-1)

    if args.local_rank == 0:
        _logger.info('Scheduled epochs: {}'.format(20))

    ##create DataLoader
    train_trans = get_riadd_train_transforms(args)
    valid_trans = get_riadd_valid_transforms(args)

    train_data = data_.iloc[train_index, :].reset_index(drop=True)
    dataset_train = RiaddDataSet(image_ids=train_data, baseImgPath=args.data)

    val_data = data_.iloc[val_index, :].reset_index(drop=True)
    dataset_eval = RiaddDataSet(image_ids=val_data, baseImgPath=args.data)

    # setup mixup / cutmix
    collate_fn = None
    mixup_fn = None
    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
    if mixup_active:
        mixup_args = dict(mixup_alpha=args.mixup,
                          cutmix_alpha=args.cutmix,
                          cutmix_minmax=args.cutmix_minmax,
                          prob=args.mixup_prob,
                          switch_prob=args.mixup_switch_prob,
                          mode=args.mixup_mode,
                          label_smoothing=args.smoothing,
                          num_classes=args.num_classes)
        if args.prefetcher:
            assert not num_aug_splits  # collate conflict (need to support deinterleaving in collate mixup)
            collate_fn = FastCollateMixup(**mixup_args)
        else:
            mixup_fn = Mixup(**mixup_args)

    # wrap dataset in AugMix helper
    if num_aug_splits > 1:
        dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)

    # create data loaders w/ augmentation pipeiine
    train_interpolation = args.train_interpolation
    if args.no_aug or not train_interpolation:
        train_interpolation = data_config['interpolation']
    train_trans = get_riadd_train_transforms(args)
    loader_train = create_loader(
        dataset_train,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        is_training=True,
        use_prefetcher=args.prefetcher,
        no_aug=args.no_aug,
        re_prob=args.reprob,
        re_mode=args.remode,
        re_count=args.recount,
        re_split=args.resplit,
        scale=args.scale,
        ratio=args.ratio,
        hflip=args.hflip,
        vflip=args.vflip,
        color_jitter=args.color_jitter,
        auto_augment=args.aa,
        num_aug_splits=num_aug_splits,
        interpolation=train_interpolation,
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        collate_fn=collate_fn,
        pin_memory=args.pin_mem,
        use_multi_epochs_loader=args.use_multi_epochs_loader,
        transform=train_trans)

    valid_trans = get_riadd_valid_transforms(args)
    loader_eval = create_loader(
        dataset_eval,
        input_size=data_config['input_size'],
        batch_size=args.validation_batch_size_multiplier * args.batch_size,
        is_training=False,
        use_prefetcher=args.prefetcher,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        crop_pct=data_config['crop_pct'],
        pin_memory=args.pin_mem,
        transform=valid_trans)

    # # setup loss function
    # if args.jsd:
    #     assert num_aug_splits > 1  # JSD only valid with aug splits set
    #     train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda()
    # elif mixup_active:
    #     # smoothing is handled with mixup target transform
    #     train_loss_fn = SoftTargetCrossEntropy().cuda()
    # elif args.smoothing:
    #     train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda()
    # else:
    #     train_loss_fn = nn.CrossEntropyLoss().cuda()

    validate_loss_fn = nn.BCEWithLogitsLoss().cuda()
    train_loss_fn = nn.BCEWithLogitsLoss().cuda()

    # setup checkpoint saver and eval metric tracking
    eval_metric = args.eval_metric
    best_metric = None
    best_epoch = None
    saver = None
    vis = None
    output_dir = ''
    if args.local_rank == 0:
        output_base = args.output if args.output else './output'
        exp_name = '-'.join([
            datetime.now().strftime("%Y%m%d-%H%M%S"), args.model,
            str(data_config['input_size'][-1])
        ])
        output_dir = get_outdir(output_base, 'train', exp_name)
        decreasing = True if eval_metric == 'loss' else False
        saver = CheckpointSaver(model=model,
                                optimizer=optimizer,
                                args=args,
                                model_ema=model_ema,
                                amp_scaler=loss_scaler,
                                checkpoint_dir=output_dir,
                                recovery_dir=output_dir,
                                decreasing=decreasing)
        with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
            f.write(args_text)
        vis = Visualizer(env=args.output)

    try:
        for epoch in range(0, args.epochs):
            if args.distributed:
                loader_train.sampler.set_epoch(epoch)

            train_metrics = train_epoch(epoch,
                                        model,
                                        loader_train,
                                        optimizer,
                                        train_loss_fn,
                                        args,
                                        lr_scheduler=lr_scheduler,
                                        saver=saver,
                                        output_dir=output_dir,
                                        amp_autocast=amp_autocast,
                                        loss_scaler=loss_scaler,
                                        model_ema=model_ema,
                                        mixup_fn=mixup_fn)

            if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
                if args.local_rank == 0:
                    _logger.info(
                        "Distributing BatchNorm running means and vars")
                distribute_bn(model, args.world_size, args.dist_bn == 'reduce')

            eval_metrics = validate(model,
                                    loader_eval,
                                    validate_loss_fn,
                                    args,
                                    amp_autocast=amp_autocast)
            score, scores = get_score(eval_metrics['valid_label'],
                                      eval_metrics['predictions'])
            ##visdom
            if vis is not None:
                vis.plot_curves({'None': epoch},
                                iters=epoch,
                                title='None',
                                xlabel='iters',
                                ylabel='None')
                vis.plot_curves(
                    {'learing rate': optimizer.param_groups[0]['lr']},
                    iters=epoch,
                    title='lr',
                    xlabel='iters',
                    ylabel='learing rate')
                vis.plot_curves({'train loss': float(train_metrics['loss'])},
                                iters=epoch,
                                title='train loss',
                                xlabel='iters',
                                ylabel='train loss')
                vis.plot_curves({'val loss': float(eval_metrics['loss'])},
                                iters=epoch,
                                title='val loss',
                                xlabel='iters',
                                ylabel='val loss')
                vis.plot_curves({'val score': float(score)},
                                iters=epoch,
                                title='val score',
                                xlabel='iters',
                                ylabel='val score')

            if model_ema is not None and not args.model_ema_force_cpu:
                if args.distributed and args.dist_bn in ('broadcast',
                                                         'reduce'):
                    distribute_bn(model_ema, args.world_size,
                                  args.dist_bn == 'reduce')
                ema_eval_metrics = validate(model_ema.module,
                                            loader_eval,
                                            validate_loss_fn,
                                            args,
                                            amp_autocast=amp_autocast,
                                            log_suffix=' (EMA)')
                eval_metrics = ema_eval_metrics

            if lr_scheduler is not None:
                # step LR for next epoch
                # lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
                lr_scheduler.step(epoch + 1, score)

            update_summary(epoch,
                           train_metrics,
                           eval_metrics,
                           os.path.join(output_dir, 'summary.csv'),
                           write_header=best_metric is None)

            if saver is not None and score > best_score:
                # save proper checkpoint with eval metric
                best_score = score
                save_metric = best_score
                best_metric, best_epoch = saver.save_checkpoint(
                    epoch, metric=save_metric)
        del model
        del optimizer
        torch.cuda.empty_cache()
    except KeyboardInterrupt:
        pass
    if best_metric is not None:
        _logger.info('*** Best metric: {0} (epoch {1})'.format(
            best_metric, best_epoch))
Пример #12
0
def main():
    setup_default_logging()
    args = parser.parse_args()
    # might as well try to do something useful...
    args.pretrained = args.pretrained or not args.checkpoint

    # create model
    model_name = args.model
    if args.torchvision_model:
        model_name = args.torchvision_model
        model = models.__dict__[args.torchvision_model](
            pretrained=args.pretrained, num_classes=args.num_classes)
        if args.checkpoint:
            load_checkpoint(model, args.checkpoint)
    elif args.hub_model and args.hub_model_github_or_dir:
        model_name = args.hub_model
        model = torch.hub.load(args.hub_model_github_or_dir,
                               args.hub_model,
                               pretrained=args.pretrained)
        if args.checkpoint:
            load_checkpoint(model, args.checkpoint)
    else:
        model = create_model(args.model,
                             num_classes=args.num_classes,
                             in_chans=3,
                             pretrained=args.pretrained,
                             checkpoint_path=args.checkpoint)

    _logger.info('Model %s created, param count: %d' %
                 (model_name, sum([m.numel() for m in model.parameters()])))

    config = resolve_data_config(vars(args), model=model)
    model, test_time_pool = (
        model,
        False) if args.no_test_pool else apply_test_time_pool(model, config)

    if args.num_gpu > 1:
        model = torch.nn.DataParallel(model,
                                      device_ids=list(range(
                                          args.num_gpu))).cuda()
    else:
        model = model.cuda()

    loader = create_loader(
        ImageDataset(args.data),
        input_size=config['input_size'],
        batch_size=args.batch_size,
        use_prefetcher=True,
        interpolation=config['interpolation'],
        mean=config['mean'],
        std=config['std'],
        num_workers=args.workers,
        crop_pct=1.0 if test_time_pool else config['crop_pct'])

    model.eval()

    k = min(args.topk, args.num_classes)
    batch_time = AverageMeter()
    end = time.time()
    topk_ids = []
    with torch.no_grad():
        for batch_idx, (input, _) in enumerate(loader):
            input = input.cuda()
            labels = model(input)
            topk = labels.topk(k)[1]
            topk_ids.append(topk.cpu().numpy())

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if batch_idx % args.log_interval == 0:
                _logger.info(
                    'Predict: [{0}/{1}] Time {batch_time.val:.3f} ({batch_time.avg:.3f})'
                    .format(batch_idx, len(loader), batch_time=batch_time))

    topk_ids = np.concatenate(topk_ids, axis=0).squeeze()

    with open(os.path.join(args.output_dir, './topk_ids.csv'),
              'w') as out_file:
        filenames = loader.dataset.filenames(basename=True)
        for filename, label in zip(filenames, topk_ids):
            out_file.write('{0},{1},{2},{3},{4},{5}\n'.format(
                filename, label[0], label[1], label[2], label[3], label[4]))
Пример #13
0
 def init_weights(self, pretrained=False):
     load_checkpoint(self, "~/SGNAS/SGNAS_A_best.pth.tar", False)
Пример #14
0
def main():

    setup_default_logging()
    args, args_text = _parse_args()

    args.prefetcher = not args.no_prefetcher
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1
    args.device = 'cuda:0'
    args.world_size = 1
    args.rank = 0  # global rank
    if args.distributed:
        args.device = 'cuda:%d' % args.local_rank
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()
        args.rank = torch.distributed.get_rank()
        _logger.info(
            'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
            % (args.rank, args.world_size))
    else:
        _logger.info('Training with a single process on 1 GPUs.')
    assert args.rank >= 0

    # resolve AMP arguments based on PyTorch / Apex availability
    use_amp = None
    if args.amp:
        # for backwards compat, `--amp` arg tries apex before native amp
        if has_apex:
            args.apex_amp = True
        elif has_native_amp:
            args.native_amp = True
    if args.apex_amp and has_apex:
        use_amp = 'apex'
    elif args.native_amp and has_native_amp:
        use_amp = 'native'
    elif args.apex_amp or args.native_amp:
        _logger.warning(
            "Neither APEX or native Torch AMP is available, using float32. "
            "Install NVIDA apex or upgrade to PyTorch 1.6")

    torch.manual_seed(args.seed + args.rank)

    model = create_model(
        args.model,
        pretrained=args.pretrained,
        num_classes=args.num_classes,
        drop_rate=args.drop,
        drop_connect_rate=args.drop_connect,  # DEPRECATED, use drop_path
        drop_path_rate=args.drop_path,
        drop_block_rate=args.drop_block,
        global_pool=args.gp,
        bn_tf=args.bn_tf,
        bn_momentum=args.bn_momentum,
        bn_eps=args.bn_eps,
        scriptable=args.torchscript,
        use_cos_reg=args.cos_reg_component > 0,
        checkpoint_path=args.initial_checkpoint)
    with torch.cuda.device(0):
        input = torch.randn(1, 3, 224, 224)
        size_for_madd = 224 if args.img_size is None else args.img_size
        # flops, params = get_model_complexity_info(model, (3, size_for_madd, size_for_madd), as_strings=True, print_per_layer_stat=True)
        # print("=>Flops:  " + flops)
        # print("=>Params: " + params)
    if args.local_rank == 0:
        _logger.info('Model %s created, param count: %d' %
                     (args.model, sum([m.numel()
                                       for m in model.parameters()])))

    data_config = resolve_data_config(vars(args),
                                      model=model,
                                      verbose=args.local_rank == 0)

    # setup augmentation batch splits for contrastive loss or split bn
    num_aug_splits = 0
    if args.aug_splits > 0:
        assert args.aug_splits > 1, 'A split of 1 makes no sense'
        num_aug_splits = args.aug_splits

    # enable split bn (separate bn stats per batch-portion)
    if args.split_bn:
        assert num_aug_splits > 1 or args.resplit
        model = convert_splitbn_model(model, max(num_aug_splits, 2))

    # move model to GPU, enable channels last layout if set
    model.cuda()
    if args.channels_last:
        model = model.to(memory_format=torch.channels_last)

    # setup synchronized BatchNorm for distributed training
    if args.distributed and args.sync_bn:
        assert not args.split_bn
        if has_apex and use_amp != 'native':
            # Apex SyncBN preferred unless native amp is activated
            model = convert_syncbn_model(model)
        else:
            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
        if args.local_rank == 0:
            _logger.info(
                'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
                'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.'
            )

    if args.torchscript:
        assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
        assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
        model = torch.jit.script(model)

    optimizer = create_optimizer(args, model)

    # setup automatic mixed-precision (AMP) loss scaling and op casting
    amp_autocast = suppress  # do nothing
    loss_scaler = None
    if use_amp == 'apex':
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
        loss_scaler = ApexScaler()
        if args.local_rank == 0:
            _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
    elif use_amp == 'native':
        amp_autocast = torch.cuda.amp.autocast
        loss_scaler = NativeScaler()
        if args.local_rank == 0:
            _logger.info(
                'Using native Torch AMP. Training in mixed precision.')
    else:
        if args.local_rank == 0:
            _logger.info('AMP not enabled. Training in float32.')

    # optionally resume from a checkpoint
    resume_epoch = None
    if args.resume:
        resume_epoch = resume_checkpoint(
            model,
            args.resume,
            optimizer=None if args.no_resume_opt else optimizer,
            loss_scaler=None if args.no_resume_opt else loss_scaler,
            log_info=args.local_rank == 0)

    # setup exponential moving average of model weights, SWA could be used here too
    model_ema = None
    if args.model_ema:
        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
        model_ema = ModelEmaV2(
            model,
            decay=args.model_ema_decay,
            device='cpu' if args.model_ema_force_cpu else None)
        if args.resume:
            load_checkpoint(model_ema.module, args.resume, use_ema=True)

    # setup distributed training
    if args.distributed:
        if has_apex and use_amp != 'native':
            # Apex DDP preferred unless native amp is activated
            if args.local_rank == 0:
                _logger.info("Using NVIDIA APEX DistributedDataParallel.")
            model = ApexDDP(model, delay_allreduce=True)
        else:
            if args.local_rank == 0:
                _logger.info("Using native Torch DistributedDataParallel.")
            model = NativeDDP(model, device_ids=[
                args.local_rank
            ])  # can use device str in Torch >= 1.1
        # NOTE: EMA model does not need to be wrapped by DDP

    # setup learning rate schedule and starting epoch
    lr_scheduler, num_epochs = create_scheduler(args, optimizer)
    start_epoch = 0
    if args.start_epoch is not None:
        # a specified start_epoch will always override the resume epoch
        start_epoch = args.start_epoch
    elif resume_epoch is not None:
        start_epoch = resume_epoch
    if lr_scheduler is not None and start_epoch > 0:
        lr_scheduler.step(start_epoch)

    if args.local_rank == 0:
        _logger.info('Scheduled epochs: {}'.format(num_epochs))

    # create the train and eval datasets
    train_dir = os.path.join(args.data, 'train')
    if not os.path.exists(train_dir):
        _logger.error(
            'Training folder does not exist at: {}'.format(train_dir))
        exit(1)
    if args.use_lmdb:
        dataset_train = ImageFolderLMDB('../dataset_lmdb/train')
    else:
        dataset_train = Dataset(train_dir)
    # dataset_train = Dataset(train_dir)

    eval_dir = os.path.join(args.data, 'val')
    if not os.path.isdir(eval_dir):
        eval_dir = os.path.join(args.data, 'validation')
        if not os.path.isdir(eval_dir):
            _logger.error(
                'Validation folder does not exist at: {}'.format(eval_dir))
            exit(1)
    if args.use_lmdb:
        dataset_eval = ImageFolderLMDB('../dataset_lmdb/val')
    else:
        dataset_eval = Dataset(eval_dir)
    # dataset_eval = Dataset(eval_dir)

    # setup mixup / cutmix
    collate_fn = None
    mixup_fn = None
    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
    if mixup_active:
        mixup_args = dict(mixup_alpha=args.mixup,
                          cutmix_alpha=args.cutmix,
                          cutmix_minmax=args.cutmix_minmax,
                          prob=args.mixup_prob,
                          switch_prob=args.mixup_switch_prob,
                          mode=args.mixup_mode,
                          label_smoothing=args.smoothing,
                          num_classes=args.num_classes)
        if args.prefetcher:
            assert not num_aug_splits  # collate conflict (need to support deinterleaving in collate mixup)
            collate_fn = FastCollateMixup(**mixup_args)
        else:
            mixup_fn = Mixup(**mixup_args)

    # wrap dataset in AugMix helper
    if num_aug_splits > 1:
        dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)

    # create data loaders w/ augmentation pipeiine
    train_interpolation = args.train_interpolation
    if args.no_aug or not train_interpolation:
        train_interpolation = data_config['interpolation']
    loader_train = create_loader(
        dataset_train,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        is_training=True,
        use_prefetcher=args.prefetcher,
        no_aug=args.no_aug,
        re_prob=args.reprob,
        re_mode=args.remode,
        re_count=args.recount,
        re_split=args.resplit,
        scale=args.scale,
        ratio=args.ratio,
        hflip=args.hflip,
        vflip=args.vflip,
        color_jitter=args.color_jitter,
        auto_augment=args.aa,
        num_aug_splits=num_aug_splits,
        interpolation=train_interpolation,
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        collate_fn=collate_fn,
        pin_memory=args.pin_mem,
        use_multi_epochs_loader=args.use_multi_epochs_loader,
        repeated_aug=args.use_repeated_aug,
        world_size=args.world_size,
        rank=args.rank)

    loader_eval = create_loader(
        dataset_eval,
        input_size=data_config['input_size'],
        batch_size=args.validation_batch_size_multiplier * args.batch_size,
        is_training=False,
        use_prefetcher=args.prefetcher,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        crop_pct=data_config['crop_pct'],
        pin_memory=args.pin_mem,
    )

    loader_cali = create_loader(
        dataset_train,
        input_size=data_config['input_size'],
        batch_size=args.cali_batch_size,
        is_training=False,
        use_prefetcher=args.prefetcher,
        no_aug=True,
        re_prob=args.reprob,
        re_mode=args.remode,
        re_count=args.recount,
        re_split=args.resplit,
        scale=args.scale,
        ratio=args.ratio,
        hflip=args.hflip,
        vflip=args.vflip,
        color_jitter=args.color_jitter,
        auto_augment=args.aa,
        num_aug_splits=num_aug_splits,
        interpolation=train_interpolation,
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        collate_fn=None,
        pin_memory=args.pin_mem,
        use_multi_epochs_loader=args.use_multi_epochs_loader,
        repeated_aug=args.use_repeated_aug,
        world_size=args.world_size,
        rank=args.rank)

    # setup loss function
    if args.jsd:
        assert num_aug_splits > 1  # JSD only valid with aug splits set
        train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits,
                                        smoothing=args.smoothing).cuda()
    elif mixup_active:
        # smoothing is handled with mixup target transform
        if args.cos_reg_component > 0:
            args.use_cos_reg_component = True
            train_loss_fn = SoftTargetCrossEntropyCosReg(
                n_comn=args.cos_reg_component).cuda()
        else:
            train_loss_fn = SoftTargetCrossEntropy().cuda()
            args.use_cos_reg_component = False

    elif args.smoothing:
        train_loss_fn = LabelSmoothingCrossEntropy(
            smoothing=args.smoothing).cuda()
    else:
        train_loss_fn = nn.CrossEntropyLoss().cuda()
    validate_loss_fn = nn.CrossEntropyLoss().cuda()

    # setup checkpoint saver and eval metric tracking
    eval_metric = args.eval_metric
    best_metric = None
    best_epoch = None
    saver = None
    output_dir = ''
    if args.local_rank == 0:
        output_base = args.output if args.output else './output'
        exp_name = '-'.join([
            datetime.now().strftime("%Y%m%d-%H%M%S"), args.model,
            str(data_config['input_size'][-1])
        ])
        output_dir = get_outdir(output_base, 'train', exp_name)
        code_dir = get_outdir(output_dir, 'code')
        copy_tree(os.getcwd(), code_dir)
        decreasing = True if eval_metric == 'loss' else False
        saver = CheckpointSaver(model=model,
                                optimizer=optimizer,
                                args=args,
                                model_ema=model_ema,
                                amp_scaler=loss_scaler,
                                checkpoint_dir=output_dir,
                                recovery_dir=output_dir,
                                decreasing=decreasing)
        with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
            f.write(args_text)

    try:
        for epoch in range(start_epoch, num_epochs):
            if args.distributed:
                loader_train.sampler.set_epoch(epoch)
            if not args.eval_only:
                train_metrics = train_epoch(epoch,
                                            model,
                                            loader_train,
                                            optimizer,
                                            train_loss_fn,
                                            args,
                                            lr_scheduler=lr_scheduler,
                                            saver=saver,
                                            output_dir=output_dir,
                                            amp_autocast=amp_autocast,
                                            loss_scaler=loss_scaler,
                                            model_ema=model_ema,
                                            mixup_fn=mixup_fn)

            if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
                if args.local_rank == 0:
                    _logger.info(
                        "Distributing BatchNorm running means and vars")
                distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
            if args.max_iter > 0:
                _ = validate(model,
                             loader_cali,
                             validate_loss_fn,
                             args,
                             amp_autocast=amp_autocast,
                             use_bn_calibration=True)
            eval_metrics = validate(model,
                                    loader_eval,
                                    validate_loss_fn,
                                    args,
                                    amp_autocast=amp_autocast)

            if model_ema is not None and not args.model_ema_force_cpu:
                if args.distributed and args.dist_bn in ('broadcast',
                                                         'reduce'):
                    distribute_bn(model_ema, args.world_size,
                                  args.dist_bn == 'reduce')
                ema_eval_metrics = validate(model_ema.module,
                                            loader_eval,
                                            validate_loss_fn,
                                            args,
                                            amp_autocast=amp_autocast,
                                            log_suffix=' (EMA)')
                eval_metrics = ema_eval_metrics

            if lr_scheduler is not None:
                # step LR for next epoch
                lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
            if not args.eval_only:
                update_summary(epoch,
                               train_metrics,
                               eval_metrics,
                               os.path.join(output_dir, 'summary.csv'),
                               write_header=best_metric is None)

            if saver is not None:
                # save proper checkpoint with eval metric
                save_metric = eval_metrics[eval_metric]
                best_metric, best_epoch = saver.save_checkpoint(
                    epoch, metric=save_metric)
                if args.eval_only:
                    break

    except KeyboardInterrupt:
        pass
    if best_metric is not None:
        _logger.info('*** Best metric: {0} (epoch {1})'.format(
            best_metric, best_epoch))
Пример #15
0
def main():
    setup_default_logging()
    args, args_text = _parse_args()

    args.prefetcher = not args.no_prefetcher
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1
    args.device = 'cuda:0'
    args.world_size = 1
    args.rank = 0  # global rank
    if args.distributed:
        args.device = 'cuda:%d' % args.local_rank
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl', init_method='env://')
        args.world_size = torch.distributed.get_world_size()
        args.rank = torch.distributed.get_rank()
        _logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
                     % (args.rank, args.world_size))
    else:
        _logger.info('Training with a single process on 1 GPUs.')
    assert args.rank >= 0

    if args.control_amp == 'amp':
        args.amp = True
    elif args.control_amp == 'apex':
        args.apex_amp = True
    elif args.control_amp == 'native':
        args.native_amp = True

    # resolve AMP arguments based on PyTorch / Apex availability
    use_amp = None
    if args.amp:
        # for backwards compat, `--amp` arg tries apex before native amp
        if has_apex:
            args.apex_amp = True
        elif has_native_amp:
            args.native_amp = True
    if args.apex_amp and has_apex:
        use_amp = 'apex'
    elif args.native_amp and has_native_amp:
        use_amp = 'native'
    elif args.apex_amp or args.native_amp:
        _logger.warning("Neither APEX or native Torch AMP is available, using float32. "
                        "Install NVIDA apex or upgrade to PyTorch 1.6")

    _logger.info(
        '====================\n\n'
        'Actfun: {}\n'
        'LR: {}\n'
        'Epochs: {}\n'
        'p: {}\n'
        'k: {}\n'
        'g: {}\n'
        'Extra channel multiplier: {}\n'
        'AMP: {}\n'
        'Weight Init: {}\n'
        '\n===================='.format(args.actfun, args.lr, args.epochs, args.p, args.k, args.g,
                                        args.extra_channel_mult, use_amp, args.weight_init))

    torch.manual_seed(args.seed + args.rank)

    model = create_model(
        args.model,
        pretrained=args.pretrained,
        actfun=args.actfun,
        num_classes=args.num_classes,
        drop_rate=args.drop,
        drop_connect_rate=args.drop_connect,  # DEPRECATED, use drop_path
        drop_path_rate=args.drop_path,
        drop_block_rate=args.drop_block,
        global_pool=args.gp,
        bn_tf=args.bn_tf,
        bn_momentum=args.bn_momentum,
        bn_eps=args.bn_eps,
        scriptable=args.torchscript,
        checkpoint_path=args.initial_checkpoint,
        p=args.p,
        k=args.k,
        g=args.g,
        extra_channel_mult=args.extra_channel_mult,
        weight_init_name=args.weight_init,
        partial_ho_actfun=args.partial_ho_actfun
    )

    if args.tl:
        if args.data == 'caltech101' and not os.path.exists('caltech101'):
            dir_root = r'101_ObjectCategories'
            dir_new = r'caltech101'
            dir_new_train = os.path.join(dir_new, 'train')
            dir_new_val = os.path.join(dir_new, 'val')
            dir_new_test = os.path.join(dir_new, 'test')
            if not os.path.exists(dir_new):
                os.mkdir(dir_new)
                os.mkdir(dir_new_train)
                os.mkdir(dir_new_val)
                os.mkdir(dir_new_test)

            for dir2 in os.listdir(dir_root):
                if dir2 != 'BACKGROUND_Google':
                    curr_path = os.path.join(dir_root, dir2)
                    new_path_train = os.path.join(dir_new_train, dir2)
                    new_path_val = os.path.join(dir_new_val, dir2)
                    new_path_test = os.path.join(dir_new_test, dir2)
                    if not os.path.exists(new_path_train):
                        os.mkdir(new_path_train)
                    if not os.path.exists(new_path_val):
                        os.mkdir(new_path_val)
                    if not os.path.exists(new_path_test):
                        os.mkdir(new_path_test)

                    train_upper = int(0.8 * len(os.listdir(curr_path)))
                    val_upper = int(0.9 * len(os.listdir(curr_path)))
                    curr_files_all = os.listdir(curr_path)
                    curr_files_train = curr_files_all[:train_upper]
                    curr_files_val = curr_files_all[train_upper:val_upper]
                    curr_files_test = curr_files_all[val_upper:]

                    for file in curr_files_train:
                        copyfile(os.path.join(curr_path, file),
                                 os.path.join(new_path_train, file))
                    for file in curr_files_val:
                        copyfile(os.path.join(curr_path, file),
                                 os.path.join(new_path_val, file))
                    for file in curr_files_test:
                        copyfile(os.path.join(curr_path, file),
                                 os.path.join(new_path_test, file))
        time.sleep(5)

    if args.tl:
        pre_model = create_model(
            args.model,
            pretrained=True,
            actfun='swish',
            num_classes=args.num_classes,
            drop_rate=args.drop,
            drop_connect_rate=args.drop_connect,  # DEPRECATED, use drop_path
            drop_path_rate=args.drop_path,
            drop_block_rate=args.drop_block,
            global_pool=args.gp,
            bn_tf=args.bn_tf,
            bn_momentum=args.bn_momentum,
            bn_eps=args.bn_eps,
            scriptable=args.torchscript,
            checkpoint_path=args.initial_checkpoint,
            p=args.p,
            k=args.k,
            g=args.g,
            extra_channel_mult=args.extra_channel_mult,
            weight_init_name=args.weight_init,
            partial_ho_actfun=args.partial_ho_actfun
        )
        model = MLP.MLP(actfun=args.actfun,
                        input_dim=1280,
                        output_dim=args.num_classes,
                        k=args.k,
                        p=args.p,
                        g=args.g,
                        num_params=400_000,
                        permute_type='shuffle')
        pre_model_layers = list(pre_model.children())
        pre_model = torch.nn.Sequential(*pre_model_layers[:-1])
    else:
        pre_model = None

    if args.local_rank == 0:
        _logger.info('Model %s created, param count: %d' %
                     (args.model, sum([m.numel() for m in model.parameters()])))

    data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)

    # setup augmentation batch splits for contrastive loss or split bn
    num_aug_splits = 0
    if args.aug_splits > 0:
        assert args.aug_splits > 1, 'A split of 1 makes no sense'
        num_aug_splits = args.aug_splits

    # enable split bn (separate bn stats per batch-portion)
    if args.split_bn:
        assert num_aug_splits > 1 or args.resplit
        model = convert_splitbn_model(model, max(num_aug_splits, 2))

    # move model to GPU, enable channels last layout if set
    model.cuda()
    if args.tl:
        pre_model.cuda()
    if args.channels_last:
        model = model.to(memory_format=torch.channels_last)

    # setup synchronized BatchNorm for distributed training
    if args.distributed and args.sync_bn:
        assert not args.split_bn
        if has_apex and use_amp != 'native':
            # Apex SyncBN preferred unless native amp is activated
            model = convert_syncbn_model(model)
        else:
            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
        if args.local_rank == 0:
            _logger.info(
                'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
                'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')

    if args.torchscript:
        assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
        assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
        model = torch.jit.script(model)

    if args.tl:
        optimizer = torch.optim.Adam(model.parameters(), weight_decay=1e-5)
    else:
        optimizer = create_optimizer(args, model)

    # setup automatic mixed-precision (AMP) loss scaling and op casting
    amp_autocast = suppress  # do nothing
    loss_scaler = None
    if use_amp == 'apex':
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
        loss_scaler = ApexScaler()
        if args.local_rank == 0:
            _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
    elif use_amp == 'native':
        amp_autocast = torch.cuda.amp.autocast
        loss_scaler = NativeScaler()
        if args.local_rank == 0:
            _logger.info('Using native Torch AMP. Training in mixed precision.')
    else:
        if args.local_rank == 0:
            _logger.info('AMP not enabled. Training in float32.')

    if args.local_rank == 0:
        _logger.info('\n--------------------\nModel:\n' + repr(model) + '--------------------')

    # optionally resume from a checkpoint
    resume_epoch = None
    resume_path = os.path.join(args.resume, 'recover.pth.tar')
    if args.resume and os.path.exists(resume_path):
        resume_epoch = resume_checkpoint(
            model, resume_path,
            optimizer=None if args.no_resume_opt else optimizer,
            loss_scaler=None if args.no_resume_opt else loss_scaler,
            log_info=args.local_rank == 0)

    cp_loaded = None
    resume_epoch = None
    checkname = 'recover'
    if args.actfun != 'swish':
        checkname = '{}_'.format(args.actfun) + checkname
    check_path = os.path.join(args.check_path, checkname) + '.pth'
    loader = None
    if os.path.isfile(check_path):
        loader = check_path
    elif args.load_path != '' and os.path.isfile(args.load_path):
        loader = args.load_path
    if loader is not None:
        cp_loaded = torch.load(loader)
        model.load_state_dict(cp_loaded['model'])
        optimizer.load_state_dict(cp_loaded['optimizer'])
        resume_epoch = cp_loaded['epoch']
        model.cuda()
        loss_scaler.load_state_dict(cp_loaded['amp'])
        if args.channels_last:
            model = model.to(memory_format=torch.channels_last)
        _logger.info('============ LOADED CHECKPOINT: Epoch {}'.format(resume_epoch))

    model_raw = model

    # setup exponential moving average of model weights, SWA could be used here too
    model_ema = None
    if args.model_ema:
        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
        model_ema = ModelEmaV2(
            model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None)
        if args.resume and os.path.exists(resume_path):
            load_checkpoint(model_ema.module, args.resume, use_ema=True)
        if cp_loaded is not None:
            model_ema.load_state_dict(cp_loaded['model_ema'])

    # setup distributed training
    if args.distributed:
        if has_apex and use_amp != 'native':
            # Apex DDP preferred unless native amp is activated
            if args.local_rank == 0:
                _logger.info("Using NVIDIA APEX DistributedDataParallel.")
            model = ApexDDP(model, delay_allreduce=True)
        else:
            if args.local_rank == 0:
                _logger.info("Using native Torch DistributedDataParallel.")
            model = NativeDDP(model, device_ids=[args.local_rank])  # can use device str in Torch >= 1.1
        # NOTE: EMA model does not need to be wrapped by DDP

    # setup mixup / cutmix
    collate_fn = None
    mixup_fn = None
    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
    if mixup_active:
        mixup_args = dict(
            mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
            prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
            label_smoothing=args.smoothing, num_classes=args.num_classes)
        if args.prefetcher:
            assert not num_aug_splits  # collate conflict (need to support deinterleaving in collate mixup)
            collate_fn = FastCollateMixup(**mixup_args)
        else:
            mixup_fn = Mixup(**mixup_args)

    # create the train and eval datasets
    train_dir = os.path.join(args.data, 'train')
    if not os.path.exists(train_dir):
        _logger.error('Training folder does not exist at: {}'.format(train_dir))
        exit(1)
    dataset_train = Dataset(train_dir)

    eval_dir = os.path.join(args.data, 'val')
    if not os.path.isdir(eval_dir):
        eval_dir = os.path.join(args.data, 'validation')
        if not os.path.isdir(eval_dir):
            _logger.error('Validation folder does not exist at: {}'.format(eval_dir))
            exit(1)
    dataset_eval = Dataset(eval_dir)

    # wrap dataset in AugMix helper
    if num_aug_splits > 1:
        dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)

    # create data loaders w/ augmentation pipeline
    train_interpolation = args.train_interpolation
    if args.no_aug or not train_interpolation:
        train_interpolation = data_config['interpolation']
    loader_train = create_loader(
        dataset_train,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        is_training=True,
        use_prefetcher=args.prefetcher,
        no_aug=args.no_aug,
        re_prob=args.reprob,
        re_mode=args.remode,
        re_count=args.recount,
        re_split=args.resplit,
        scale=args.scale,
        ratio=args.ratio,
        hflip=args.hflip,
        vflip=args.vflip,
        color_jitter=args.color_jitter,
        auto_augment=args.aa,
        num_aug_splits=num_aug_splits,
        interpolation=train_interpolation,
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        collate_fn=collate_fn,
        pin_memory=args.pin_mem,
        use_multi_epochs_loader=args.use_multi_epochs_loader
    )

    loader_eval = create_loader(
        dataset_eval,
        input_size=data_config['input_size'],
        batch_size=args.validation_batch_size_multiplier * args.batch_size,
        is_training=False,
        use_prefetcher=args.prefetcher,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        crop_pct=data_config['crop_pct'],
        pin_memory=args.pin_mem,
    )

    # setup learning rate schedule and starting epoch
    lr_scheduler, num_epochs = create_scheduler(args, optimizer, dataset_train)
    start_epoch = 0
    if args.start_epoch is not None:
        # a specified start_epoch will always override the resume epoch
        start_epoch = args.start_epoch
    elif resume_epoch is not None:
        start_epoch = resume_epoch
    if lr_scheduler is not None and start_epoch > 0:
        lr_scheduler.step(start_epoch)
    if cp_loaded is not None:
        lr_scheduler.load_state_dict(cp_loaded['scheduler'])

    if args.local_rank == 0:
        _logger.info('Scheduled epochs: {}'.format(num_epochs))

    # setup loss function
    if args.jsd:
        assert num_aug_splits > 1  # JSD only valid with aug splits set
        train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda()
    elif mixup_active:
        # smoothing is handled with mixup target transform
        train_loss_fn = SoftTargetCrossEntropy().cuda()
    elif args.smoothing:
        train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda()
    else:
        train_loss_fn = nn.CrossEntropyLoss().cuda()
    validate_loss_fn = nn.CrossEntropyLoss().cuda()

    # setup checkpoint saver and eval metric tracking
    eval_metric = args.eval_metric
    best_metric = None
    best_epoch = None
    saver = None
    output_dir = ''
    if args.local_rank == 0:
        output_base = args.output if args.output else './output'
        exp_name = '-'.join([
            datetime.now().strftime("%Y%m%d-%H%M%S"),
            args.model,
            str(data_config['input_size'][-1])
        ])
        output_dir = get_outdir(output_base, 'train', exp_name)
        decreasing = True if eval_metric == 'loss' else False
        saver = CheckpointSaver(
            model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler,
            checkpoint_dir=output_dir, recovery_dir=args.resume, decreasing=decreasing)
        with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
            f.write(args_text)

    fieldnames = ['seed', 'weight_init', 'actfun', 'epoch', 'max_lr', 'lr', 'train_loss', 'eval_loss', 'eval_acc1', 'eval_acc5', 'ema']
    filename = 'output'
    if args.actfun != 'swish':
        filename = '{}_'.format(args.actfun) + filename
    outfile_path = os.path.join(args.output, filename) + '.csv'
    if not os.path.exists(outfile_path):
        with open(outfile_path, mode='w') as out_file:
            writer = csv.DictWriter(out_file, fieldnames=fieldnames, lineterminator='\n')
            writer.writeheader()

    try:
        for epoch in range(start_epoch, num_epochs):

            if os.path.exists(args.check_path):
                amp_loss = None
                if use_amp == 'native':
                    amp_loss = loss_scaler.state_dict()
                elif use_amp == 'apex':
                    amp_loss = amp.state_dict()
                if model_ema is not None:
                    ema_save = model_ema.state_dict()
                else:
                    ema_save = None

                torch.save({'model': model_raw.state_dict(),
                            'model_ema': ema_save,
                            'optimizer': optimizer.state_dict(),
                            'scheduler': lr_scheduler.state_dict(),
                            'epoch': epoch,
                            'amp': amp_loss
                            }, check_path)
                _logger.info('============ SAVED CHECKPOINT: Epoch {}'.format(epoch))

            if args.distributed:
                loader_train.sampler.set_epoch(epoch)

            train_metrics = train_epoch(
                epoch, model, loader_train, optimizer, train_loss_fn, args,
                lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
                amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn,
                pre_model=pre_model)

            if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
                if args.local_rank == 0:
                    _logger.info("Distributing BatchNorm running means and vars")
                distribute_bn(model, args.world_size, args.dist_bn == 'reduce')

            eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast,
                                    pre_model=pre_model)

            with open(outfile_path, mode='a') as out_file:
                writer = csv.DictWriter(out_file, fieldnames=fieldnames, lineterminator='\n')
                writer.writerow({'seed': args.seed,
                                 'actfun': args.actfun,
                                 'epoch': epoch,
                                 'lr': train_metrics['lr'],
                                 'train_loss': train_metrics['loss'],
                                 'eval_loss': eval_metrics['loss'],
                                 'eval_acc1': eval_metrics['top1'],
                                 'eval_acc5': eval_metrics['top5'],
                                 'ema': False
                                 })

            if model_ema is not None and not args.model_ema_force_cpu:
                if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
                    distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
                ema_eval_metrics = validate(
                    model_ema.module, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)',
                    pre_model=pre_model)
                eval_metrics = ema_eval_metrics

                with open(outfile_path, mode='a') as out_file:
                    writer = csv.DictWriter(out_file, fieldnames=fieldnames, lineterminator='\n')
                    writer.writerow({'seed': args.seed,
                                     'weight_init': args.weight_init,
                                     'actfun': args.actfun,
                                     'epoch': epoch,
                                     'max_lr': args.lr,
                                     'lr': train_metrics['lr'],
                                     'train_loss': train_metrics['loss'],
                                     'eval_loss': eval_metrics['loss'],
                                     'eval_acc1': eval_metrics['top1'],
                                     'eval_acc5': eval_metrics['top5'],
                                     'ema': True
                                     })

            if lr_scheduler is not None and args.sched != 'onecycle':
                # step LR for next epoch
                lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])

            update_summary(
                args.seed, epoch, args.lr, args.epochs, args.batch_size, args.actfun,
                train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
                write_header=best_metric is None)

            if saver is not None:
                # save proper checkpoint with eval metric
                save_metric = eval_metrics[eval_metric]
                best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric)

    except KeyboardInterrupt:
        pass
    if best_metric is not None:
        _logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
def main():
    setup_default_logging()
    args, args_text = _parse_args()

    args.prefetcher = not args.no_prefetcher
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1
    args.device = 'cuda:0'
    args.world_size = 1
    args.rank = 0  # global rank
    if args.distributed:
        args.device = 'cuda:%d' % args.local_rank
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()
        args.rank = torch.distributed.get_rank()
    assert args.rank >= 0

    if args.distributed:
        logging.info(
            'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
            % (args.rank, args.world_size))
    else:
        logging.info('Training with a single process on 1 GPU.')

    #torch.manual_seed(args.seed + args.rank)

    # create model
    config = get_efficientdet_config(args.model)
    config.redundant_bias = args.redundant_bias  # redundant conv + BN bias layers (True to match official models)
    model = EfficientDet(config)
    if args.initial_checkpoint:
        load_checkpoint(model, args.initial_checkpoint)
    config.num_classes = 5
    model.class_net.predict.conv_pw = create_conv2d(config.fpn_channels,
                                                    9 * 5,
                                                    1,
                                                    padding=config.pad_type,
                                                    bias=True)
    variance_scaling(model.class_net.predict.conv_pw.weight)
    model.class_net.predict.conv_pw.bias.data.fill_(-math.log((1 - 0.01) /
                                                              0.01))
    model = DetBenchTrain(model, config)

    model.cuda()
    print(model.model.class_net.predict.conv_pw)
    # FIXME create model factory, pretrained zoo
    # model = create_model(
    #     args.model,
    #     pretrained=args.pretrained,
    #     num_classes=args.num_classes,
    #     drop_rate=args.drop,
    #     drop_connect_rate=args.drop_connect,  # DEPRECATED, use drop_path
    #     drop_path_rate=args.drop_path,
    #     drop_block_rate=args.drop_block,
    #     global_pool=args.gp,
    #     bn_tf=args.bn_tf,
    #     bn_momentum=args.bn_momentum,
    #     bn_eps=args.bn_eps,
    #     checkpoint_path=args.initial_checkpoint)

    if args.local_rank == 0:
        logging.info('Model %s created, param count: %d' %
                     (args.model, sum([m.numel()
                                       for m in model.parameters()])))

    optimizer = create_optimizer(args, model)
    use_amp = False
    if has_apex and args.amp:
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
        use_amp = True
    if args.local_rank == 0:
        logging.info('NVIDIA APEX {}. AMP {}.'.format(
            'installed' if has_apex else 'not installed',
            'on' if use_amp else 'off'))

    # optionally resume from a checkpoint
    resume_state = {}
    resume_epoch = None
    if args.resume:
        resume_state, resume_epoch = resume_checkpoint(_unwrap_bench(model),
                                                       args.resume)
    if resume_state and not args.no_resume_opt:
        if 'optimizer' in resume_state:
            if args.local_rank == 0:
                logging.info('Restoring Optimizer state from checkpoint')
            optimizer.load_state_dict(resume_state['optimizer'])
        if use_amp and 'amp' in resume_state and 'load_state_dict' in amp.__dict__:
            if args.local_rank == 0:
                logging.info('Restoring NVIDIA AMP state from checkpoint')
            amp.load_state_dict(resume_state['amp'])
    del resume_state

    model_ema = None
    if args.model_ema:
        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
        model_ema = ModelEma(model,
                             decay=args.model_ema_decay,
                             device='cpu' if args.model_ema_force_cpu else '')
        #resume=args.resume)  # FIXME bit of a mess with bench
        if args.resume:
            load_checkpoint(_unwrap_bench(model_ema),
                            args.resume,
                            use_ema=True)

    if args.distributed:
        if args.sync_bn:
            try:
                if has_apex:
                    model = convert_syncbn_model(model)
                else:
                    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                        model)
                if args.local_rank == 0:
                    logging.info(
                        'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
                        'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.'
                    )
            except Exception as e:
                logging.error(
                    'Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1'
                )
        if has_apex:
            model = DDP(model, delay_allreduce=True)
        else:
            if args.local_rank == 0:
                logging.info(
                    "Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP."
                )
            model = DDP(model,
                        device_ids=[args.local_rank
                                    ])  # can use device str in Torch >= 1.1
        # NOTE: EMA model does not need to be wrapped by DDP
    lr_scheduler, num_epochs = create_scheduler(args, optimizer)
    start_epoch = 0
    if args.start_epoch is not None:
        # a specified start_epoch will always override the resume epoch
        start_epoch = args.start_epoch
    elif resume_epoch is not None:
        start_epoch = resume_epoch
    if lr_scheduler is not None and start_epoch > 0:
        lr_scheduler.step(start_epoch)

    if args.local_rank == 0:
        logging.info('Scheduled epochs: {}'.format(num_epochs))

    train_anno_set = 'train_small'
    train_annotation_path = os.path.join(args.data, 'annotations_small',
                                         f'train_annotations.json')
    train_image_dir = train_anno_set
    dataset_train = CocoDetection(os.path.join(args.data, train_image_dir),
                                  train_annotation_path)

    # FIXME cutmix/mixup worth investigating?
    # collate_fn = None
    # if args.prefetcher and args.mixup > 0:
    #     collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes)

    loader_train = create_loader(
        dataset_train,
        input_size=config.image_size,
        batch_size=args.batch_size,
        is_training=True,
        use_prefetcher=args.prefetcher,
        #re_prob=args.reprob,  # FIXME add back various augmentations
        #re_mode=args.remode,
        #re_count=args.recount,
        #re_split=args.resplit,
        #color_jitter=args.color_jitter,
        #auto_augment=args.aa,
        interpolation=args.train_interpolation,
        #mean=data_config['mean'],
        #std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        #collate_fn=collate_fn,
        pin_mem=args.pin_mem,
    )

    #train_anno_set = 'valid_small'
    #train_annotation_path = os.path.join(args.data, 'annotations_small', f'valid_annotations.json')
    #train_image_dir = train_anno_set
    dataset_eval = CocoDetection(os.path.join(args.data, train_image_dir),
                                 train_annotation_path)

    loader_eval = create_loader(
        dataset_eval,
        input_size=config.image_size,
        batch_size=args.validation_batch_size_multiplier * args.batch_size,
        is_training=False,
        use_prefetcher=args.prefetcher,
        interpolation=args.interpolation,
        #mean=data_config['mean'],
        #std=data_config['std'],
        num_workers=args.workers,
        #distributed=args.distributed,
        pin_mem=args.pin_mem,
    )

    evaluator = COCOEvaluator(dataset_eval.coco, distributed=args.distributed)

    eval_metric = args.eval_metric
    best_metric = None
    best_epoch = None
    saver = None
    output_dir = ''
    if args.local_rank == 0:
        output_base = args.output if args.output else './output'
        exp_name = '-'.join(
            [datetime.now().strftime("%Y%m%d-%H%M%S"), args.model])
        output_dir = get_outdir(output_base, 'train', exp_name)
        decreasing = True if eval_metric == 'loss' else False
        saver = CheckpointSaver(checkpoint_dir=output_dir,
                                decreasing=decreasing)
        with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
            f.write(args_text)

    try:
        for epoch in range(start_epoch, num_epochs):
            if args.distributed:
                loader_train.sampler.set_epoch(epoch)

            train_metrics = train_epoch(epoch,
                                        model,
                                        loader_train,
                                        optimizer,
                                        args,
                                        lr_scheduler=lr_scheduler,
                                        saver=saver,
                                        output_dir=output_dir,
                                        use_amp=use_amp,
                                        model_ema=model_ema)

            if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
                if args.local_rank == 0:
                    logging.info(
                        "Distributing BatchNorm running means and vars")
                distribute_bn(model, args.world_size, args.dist_bn == 'reduce')

            eval_metrics = validate(model, loader_eval, args, evaluator)

            if model_ema is not None and not args.model_ema_force_cpu:
                if args.distributed and args.dist_bn in ('broadcast',
                                                         'reduce'):
                    distribute_bn(model_ema, args.world_size,
                                  args.dist_bn == 'reduce')

                ema_eval_metrics = validate(model_ema.ema,
                                            loader_eval,
                                            args,
                                            log_suffix=' (EMA)')
                eval_metrics = ema_eval_metrics

            if lr_scheduler is not None:
                # step LR for next epoch
                lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])

            if saver is not None:
                update_summary(epoch,
                               train_metrics,
                               eval_metrics,
                               os.path.join(output_dir, 'summary.csv'),
                               write_header=best_metric is None)

                # save proper checkpoint with eval metric
                save_metric = eval_metrics[eval_metric]
                best_metric, best_epoch = saver.save_checkpoint(
                    _unwrap_bench(model),
                    optimizer,
                    args,
                    epoch=epoch,
                    model_ema=_unwrap_bench(model_ema),
                    metric=save_metric,
                    use_amp=use_amp)

    except KeyboardInterrupt:
        pass
    if best_metric is not None:
        logging.info('*** Best metric: {0} (epoch {1})'.format(
            best_metric, best_epoch))
Пример #17
0
def get_model(model_name=None):
    if not model_name:
        model_name = args.model
    if model_name == 'resnet18':
        model = torchvision.models.resnet18(pretrained=True)
    elif model_name == 'alexnet':
        model = torchvision.models.alexnet(pretrained=True)
    elif model_name == 'squeezenet':
        model = torchvision.models.squeezenet1_0(pretrained=True)
    elif model_name == 'vgg16':
        model = torchvision.models.vgg16(pretrained=True)
    elif model_name == 'densenet':
        model = torchvision.models.densenet161(pretrained=True)
    elif model_name == 'inception':
        model = torchvision.models.inception_v3(pretrained=True)
    elif model_name == 'googlenet':
        model = torchvision.models.googlenet(pretrained=True)
    elif model_name == 'shufflenet':
        model = torchvision.models.shufflenet_v2_x1_0(pretrained=True)
    elif model_name == 'mobilenet':
        model = torchvision.models.mobilenet_v2(pretrained=True)
    elif model_name == 'resnet50_32x4d':
        model = torchvision.models.resnext50_32x4d(pretrained=True)
    elif model_name == 'wide_resnet50_2':
        model = torchvision.models.wide_resnet50_2(pretrained=True)
    elif model_name == 'mnasnet':
        model = torchvision.models.mnasnet1_0(pretrained=True)
    elif model_name == 'resnext50_32x4d_ssl':
        model = torch.hub.load(
            'facebookresearch/semi-supervised-ImageNet1K-models',
            'resnext50_32x4d_ssl')
    elif model_name == 'resnext50_32x4d_swsl':
        model = torch.hub.load(
            'facebookresearch/semi-supervised-ImageNet1K-models',
            'resnext50_32x4d_swsl')
    elif model_name == 'resnet50_swsl':
        model = torch.hub.load(
            'facebookresearch/semi-supervised-ImageNet1K-models',
            'resnet50_swsl')

    elif 'seresnet50' in model_name:
        model = se_resnet50(num_classes=1000)
        model.load_state_dict(
            torch.load("../checkpoint/seresnet50-60a8950a85b2b.pkl"))
    elif model_name == 'T2t_vit_t_14' or model_name == 'T2t_vit_t_24':
        model = create_model(
            model_name,
            pretrained=False,
            num_classes=args.num_classes,
            in_chans=3,
        )
        load_checkpoint(model, checkpoint_paths[model_name], True)
    else:
        model = create_model(
            model_name,
            pretrained=args.pretrained,
            num_classes=args.num_classes,
            in_chans=3,
        )
        if not args.pretrained:
            if not args.set_temperature:
                load_checkpoint(model, checkpoint_paths[model_name], True)
            else:
                load_checkpoint(
                    model, checkpoint_paths[
                        f"{model_name}_tem{args.set_temperature}"], True)
                for i in range(len(model.blocks)):
                    model.blocks[i].attn.scale = 768**(-1 /
                                                       args.set_temperature)
                print("Set temperature to: ", model.blocks[0].attn.scale)

    return model.eval().to(device)
Пример #18
0
def main():
    setup_default_logging()
    args, args_text = _parse_args()

    args.prefetcher = not args.no_prefetcher
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1
    args.device = 'cuda:0'
    args.world_size = 1
    args.rank = 0  # global rank
    if args.distributed:
        args.device = 'cuda:%d' % args.local_rank
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl', init_method='env://')
        args.world_size = torch.distributed.get_world_size()
        args.rank = torch.distributed.get_rank()
        _logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
                     % (args.rank, args.world_size))
    else:
        _logger.info('Training with a single process on 1 GPUs.')
    assert args.rank >= 0

    # resolve AMP arguments based on PyTorch / Apex availability
    use_amp = None
    if args.amp:
        # for backwards compat, `--amp` arg tries apex before native amp
        if has_apex:
            args.apex_amp = True
        elif has_native_amp:
            args.native_amp = True
    if args.apex_amp and has_apex:
        use_amp = 'apex'
    elif args.native_amp and has_native_amp:
        use_amp = 'native'
    elif args.apex_amp or args.native_amp:
        _logger.warning("Neither APEX or native Torch AMP is available, using float32. "
                        "Install NVIDA apex or upgrade to PyTorch 1.6")

    torch.manual_seed(args.seed + args.rank)

    ####################################################################################
    # Start - SparseML optional load weights from SparseZoo
    ####################################################################################
    if args.initial_checkpoint == "zoo":
        # Load checkpoint from base weights associated with given SparseZoo recipe
        if args.sparseml_recipe.startswith("zoo:"):
            args.initial_checkpoint = Zoo.download_recipe_base_framework_files(
                args.sparseml_recipe,
                extensions=[".pth.tar", ".pth"]
            )[0]
        else:
            raise ValueError(
                "Attempting to load weights from SparseZoo recipe, but not given a "
                "SparseZoo recipe stub.  When initial-checkpoint is set to 'zoo'. "
                "sparseml-recipe must start with 'zoo:' and be a SparseZoo model "
                f"stub. sparseml-recipe was set to {args.sparseml_recipe}"
            )
    elif args.initial_checkpoint.startswith("zoo:"):
        # Load weights from a SparseZoo model stub
        zoo_model = Zoo.load_model_from_stub(args.initial_checkpoint)
        args.initial_checkpoint = zoo_model.download_framework_files(extensions=[".pth"])
    ####################################################################################
    # End - SparseML optional load weights from SparseZoo
    ####################################################################################

    model = create_model(
        args.model,
        pretrained=args.pretrained,
        num_classes=args.num_classes,
        drop_rate=args.drop,
        drop_connect_rate=args.drop_connect,  # DEPRECATED, use drop_path
        drop_path_rate=args.drop_path,
        drop_block_rate=args.drop_block,
        global_pool=args.gp,
        bn_tf=args.bn_tf,
        bn_momentum=args.bn_momentum,
        bn_eps=args.bn_eps,
        scriptable=args.torchscript,
        checkpoint_path=args.initial_checkpoint)
    if args.num_classes is None:
        assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
        args.num_classes = model.num_classes  # FIXME handle model default vs config num_classes more elegantly

    if args.local_rank == 0:
        _logger.info('Model %s created, param count: %d' %
                     (args.model, sum([m.numel() for m in model.parameters()])))

    data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)

    # setup augmentation batch splits for contrastive loss or split bn
    num_aug_splits = 0
    if args.aug_splits > 0:
        assert args.aug_splits > 1, 'A split of 1 makes no sense'
        num_aug_splits = args.aug_splits

    # enable split bn (separate bn stats per batch-portion)
    if args.split_bn:
        assert num_aug_splits > 1 or args.resplit
        model = convert_splitbn_model(model, max(num_aug_splits, 2))

    # move model to GPU, enable channels last layout if set
    model.cuda()
    if args.channels_last:
        model = model.to(memory_format=torch.channels_last)

    # setup synchronized BatchNorm for distributed training
    if args.distributed and args.sync_bn:
        assert not args.split_bn
        if has_apex and use_amp != 'native':
            # Apex SyncBN preferred unless native amp is activated
            model = convert_syncbn_model(model)
        else:
            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
        if args.local_rank == 0:
            _logger.info(
                'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
                'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')

    if args.torchscript:
        assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
        assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
        model = torch.jit.script(model)

    optimizer = create_optimizer(args, model)

    # setup automatic mixed-precision (AMP) loss scaling and op casting
    amp_autocast = suppress  # do nothing
    loss_scaler = None
    if use_amp == 'apex':
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
        loss_scaler = ApexScaler()
        if args.local_rank == 0:
            _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
    elif use_amp == 'native':
        amp_autocast = torch.cuda.amp.autocast
        loss_scaler = NativeScaler()
        if args.local_rank == 0:
            _logger.info('Using native Torch AMP. Training in mixed precision.')
    else:
        if args.local_rank == 0:
            _logger.info('AMP not enabled. Training in float32.')

    # optionally resume from a checkpoint
    resume_epoch = None
    if args.resume:
        resume_epoch = resume_checkpoint(
            model, args.resume,
            optimizer=None if args.no_resume_opt else optimizer,
            loss_scaler=None if args.no_resume_opt else loss_scaler,
            log_info=args.local_rank == 0)

    # setup exponential moving average of model weights, SWA could be used here too
    model_ema = None
    if args.model_ema:
        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
        model_ema = ModelEmaV2(
            model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None)
        if args.resume:
            load_checkpoint(model_ema.module, args.resume, use_ema=True)

    # setup distributed training
    if args.distributed:
        if has_apex and use_amp != 'native':
            # Apex DDP preferred unless native amp is activated
            if args.local_rank == 0:
                _logger.info("Using NVIDIA APEX DistributedDataParallel.")
            model = ApexDDP(model, delay_allreduce=True)
        else:
            if args.local_rank == 0:
                _logger.info("Using native Torch DistributedDataParallel.")
            model = NativeDDP(model, device_ids=[args.local_rank])  # can use device str in Torch >= 1.1
        # NOTE: EMA model does not need to be wrapped by DDP

    # setup learning rate schedule and starting epoch
    lr_scheduler, num_epochs = create_scheduler(args, optimizer)
    start_epoch = 0
    if args.start_epoch is not None:
        # a specified start_epoch will always override the resume epoch
        start_epoch = args.start_epoch
    elif resume_epoch is not None:
        start_epoch = resume_epoch
    if lr_scheduler is not None and start_epoch > 0:
        lr_scheduler.step(start_epoch)

    # create the train and eval datasets
    dataset_train = create_dataset(
        args.dataset, root=args.data_dir, split=args.train_split, is_training=True, batch_size=args.batch_size)
    dataset_eval = create_dataset(
        args.dataset, root=args.data_dir, split=args.val_split, is_training=False, batch_size=args.batch_size)

    # setup mixup / cutmix
    collate_fn = None
    mixup_fn = None
    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
    if mixup_active:
        mixup_args = dict(
            mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
            prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
            label_smoothing=args.smoothing, num_classes=args.num_classes)
        if args.prefetcher:
            assert not num_aug_splits  # collate conflict (need to support deinterleaving in collate mixup)
            collate_fn = FastCollateMixup(**mixup_args)
        else:
            mixup_fn = Mixup(**mixup_args)

    # wrap dataset in AugMix helper
    if num_aug_splits > 1:
        dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)

    # create data loaders w/ augmentation pipeiine
    train_interpolation = args.train_interpolation
    if args.no_aug or not train_interpolation:
        train_interpolation = data_config['interpolation']
    loader_train = create_loader(
        dataset_train,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        is_training=True,
        use_prefetcher=args.prefetcher,
        no_aug=args.no_aug,
        re_prob=args.reprob,
        re_mode=args.remode,
        re_count=args.recount,
        re_split=args.resplit,
        scale=args.scale,
        ratio=args.ratio,
        hflip=args.hflip,
        vflip=args.vflip,
        color_jitter=args.color_jitter,
        auto_augment=args.aa,
        num_aug_splits=num_aug_splits,
        interpolation=train_interpolation,
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        collate_fn=collate_fn,
        pin_memory=args.pin_mem,
        use_multi_epochs_loader=args.use_multi_epochs_loader
    )

    loader_eval = create_loader(
        dataset_eval,
        input_size=data_config['input_size'],
        batch_size=args.validation_batch_size_multiplier * args.batch_size,
        is_training=False,
        use_prefetcher=args.prefetcher,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        crop_pct=data_config['crop_pct'],
        pin_memory=args.pin_mem,
    )

    # setup loss function
    if args.jsd:
        assert num_aug_splits > 1  # JSD only valid with aug splits set
        train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda()
    elif mixup_active:
        # smoothing is handled with mixup target transform
        train_loss_fn = SoftTargetCrossEntropy().cuda()
    elif args.smoothing:
        train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda()
    else:
        train_loss_fn = nn.CrossEntropyLoss().cuda()
    validate_loss_fn = nn.CrossEntropyLoss().cuda()

    # setup checkpoint saver and eval metric tracking
    eval_metric = args.eval_metric
    best_metric = None
    best_epoch = None
    saver = None
    output_dir = ''
    if args.local_rank == 0:
        output_base = args.output if args.output else './output'
        exp_name = '-'.join([
            datetime.now().strftime("%Y%m%d-%H%M%S"),
            args.model,
            str(data_config['input_size'][-1])
        ])
        output_dir = get_outdir(output_base, 'train', exp_name)
        decreasing = True if eval_metric == 'loss' else False
        saver = CheckpointSaver(
            model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler,
            checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing, max_history=args.checkpoint_hist)
        with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
            f.write(args_text)

    ####################################################################################
    # Start SparseML Integration
    ####################################################################################
    sparseml_loggers = (
        [PythonLogger(), TensorBoardLogger(log_path=output_dir)]
        if output_dir
        else None
    )
    manager = ScheduledModifierManager.from_yaml(args.sparseml_recipe)
    optimizer = ScheduledOptimizer(
        optimizer,
        model,
        manager,
        steps_per_epoch=len(loader_train),
        loggers=sparseml_loggers
    )
    # override lr scheduler if recipe makes any LR updates
    if any("LearningRate" in str(modifier) for modifier in manager.modifiers):
        _logger.info("Disabling timm LR scheduler, managing LR using SparseML recipe")
        lr_scheduler = None
    if manager.max_epochs:
        _logger.info(
            f"Overriding max_epochs to {manager.max_epochs} from SparseML recipe"
        )
        num_epochs = manager.max_epochs or num_epochs
    ####################################################################################
    # End SparseML Integration
    ####################################################################################

    if args.local_rank == 0:
        _logger.info('Scheduled epochs: {}'.format(num_epochs))

    try:
        for epoch in range(start_epoch, num_epochs):
            if args.distributed and hasattr(loader_train.sampler, 'set_epoch'):
                loader_train.sampler.set_epoch(epoch)

            train_metrics = train_one_epoch(
                epoch, model, loader_train, optimizer, train_loss_fn, args,
                lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
                amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn)

            if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
                if args.local_rank == 0:
                    _logger.info("Distributing BatchNorm running means and vars")
                distribute_bn(model, args.world_size, args.dist_bn == 'reduce')

            eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast)

            if model_ema is not None and not args.model_ema_force_cpu:
                if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
                    distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
                ema_eval_metrics = validate(
                    model_ema.module, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)')
                eval_metrics = ema_eval_metrics

            if lr_scheduler is not None:
                # step LR for next epoch
                lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])

            update_summary(
                epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
                write_header=best_metric is None)

            if saver is not None:
                # save proper checkpoint with eval metric
                save_metric = eval_metrics[eval_metric]
                best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric)

        #################################################################################
        # Start SparseML ONNX Export
        #################################################################################
        if output_dir:
            _logger.info(
                f"training complete, exporting ONNX to {output_dir}/model.onnx"
            )
            exporter = ModuleExporter(model, output_dir)
            exporter.export_onnx(torch.randn((1, *data_config["input_size"])))
        #################################################################################
        # End SparseML ONNX Export
        #################################################################################

    except KeyboardInterrupt:
        pass
    if best_metric is not None:
        _logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
Пример #19
0
def validate(args):
    # might as well try to validate something
    args.pretrained = args.pretrained or not args.checkpoint
    args.prefetcher = not args.no_prefetcher
    amp_autocast = suppress  # do nothing
    if args.amp:
        if has_apex:
            args.apex_amp = True
        elif has_native_amp:
            args.native_amp = True
        else:
            _logger.warning(
                "Neither APEX or Native Torch AMP is available, using FP32.")
    assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set."
    if args.native_amp:
        amp_autocast = torch.cuda.amp.autocast

    if args.legacy_jit:
        set_jit_legacy()

    # create model
    model = create_model(args.model,
                         pretrained=args.pretrained,
                         num_classes=args.num_classes,
                         in_chans=3,
                         global_pool=args.gp,
                         scriptable=args.torchscript)

    if args.checkpoint:
        load_checkpoint(model, args.checkpoint, args.use_ema)

    param_count = sum([m.numel() for m in model.parameters()])
    _logger.info('Model %s created, param count: %d' %
                 (args.model, param_count))

    data_config = resolve_data_config(vars(args), model=model)
    model, test_time_pool = (
        model, False) if args.no_test_pool else apply_test_time_pool(
            model, data_config)

    if args.torchscript:
        torch.jit.optimized_execution(True)
        model = torch.jit.script(model)

    model = model.cuda()
    if args.apex_amp:
        model = amp.initialize(model, opt_level='O1')

    if args.channels_last:
        model = model.to(memory_format=torch.channels_last)

    if args.num_gpu > 1:
        model = torch.nn.DataParallel(model,
                                      device_ids=list(range(args.num_gpu)))

    criterion = nn.CrossEntropyLoss().cuda()

    if os.path.splitext(args.data)[1] == '.tar' and os.path.isfile(args.data):
        dataset = DatasetTar(args.data,
                             load_bytes=args.tf_preprocessing,
                             class_map=args.class_map)
    else:
        dataset = Dataset(args.data,
                          train_mode='val',
                          fold_num=args.fold_num,
                          load_bytes=args.tf_preprocessing,
                          class_map=args.class_map)

    if args.valid_labels:
        with open(args.valid_labels, 'r') as f:
            valid_labels = {int(line.rstrip()) for line in f}
            valid_labels = [i in valid_labels for i in range(args.num_classes)]
    else:
        valid_labels = None

    if args.real_labels:
        real_labels = RealLabelsImagenet(dataset.filenames(basename=True),
                                         real_json=args.real_labels)
    else:
        real_labels = None

    crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
    loader = create_loader(dataset,
                           input_size=data_config['input_size'],
                           batch_size=args.batch_size,
                           use_prefetcher=args.prefetcher,
                           interpolation=data_config['interpolation'],
                           mean=data_config['mean'],
                           std=data_config['std'],
                           num_workers=args.workers,
                           crop_pct=crop_pct,
                           pin_memory=args.pin_mem,
                           tf_preprocessing=args.tf_preprocessing)

    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    # top5 = AverageMeter()
    f1_m = AverageMeter()

    model.eval()
    last_idx = len(loader) - 1
    cuda = torch.device('cuda')
    temperature = nn.Parameter(torch.ones(1) *
                               1.5).to(cuda).detach().requires_grad_(True)

    m = nn.Sigmoid()
    nll_criterion = nn.CrossEntropyLoss().cuda()
    ece_criterion = _ECELoss().cuda()

    with torch.no_grad():
        # warmup, reduce variability of first batch time, especially for comparing torchscript vs non
        input = torch.randn((args.batch_size, ) +
                            data_config['input_size']).cuda()
        if args.channels_last:
            input = input.contiguous(memory_format=torch.channels_last)
        model(input)
        end = time.time()

        logits_list = []
        target_list = []

        for batch_idx, (input, target) in enumerate(loader):
            last_batch = batch_idx == last_idx
            if args.no_prefetcher:
                target = target.cuda()
                input = input.cuda()
            if args.channels_last:
                input = input.contiguous(memory_format=torch.channels_last)

            # compute output
            with amp_autocast():
                output = model(input)

            if valid_labels is not None:
                output = output[:, valid_labels]
            loss = criterion(output, target)

            if real_labels is not None:
                real_labels.add_result(output)

            # measure accuracy and record loss
            acc1, _ = accuracy(output.detach(), target, topk=(1, 1))

            logits_list.append(output)
            target_list.append(target)

            best_f1 = 0.0
            best_th = 1.0

            if last_batch:
                logits = torch.cat(logits_list).cuda()  ###
                targets = torch.cat(target_list).cuda()  ###

                targets_cpu = targets.cpu().numpy()
                sigmoided = m(logits)[:, 1].cpu().numpy()

                for i in range(1000, 0, -1):
                    th = i * 0.001
                    real_pred = (sigmoided >= th) * 1.0
                    f1 = f1_score(targets_cpu.squeeze(), real_pred.squeeze())

                    if f1 > best_f1:
                        best_f1 = f1
                        best_th = th

            losses.update(loss.item(), input.size(0))
            top1.update(acc1.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if batch_idx % args.log_freq == 0:
                _logger.info(
                    'Test: [{0:>4d}/{1}]  '
                    'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                    'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                    'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f})  '
                    'thresh: {thresh:>7.4f}  '
                    'f1: {f1:>7.4f}'.format(batch_idx,
                                            len(loader),
                                            batch_time=batch_time,
                                            rate_avg=input.size(0) /
                                            batch_time.avg,
                                            loss=losses,
                                            top1=top1,
                                            thresh=best_th,
                                            f1=best_f1))

    print(best_th, best_f1)

    #for temp_scalilng
    if args.temp_scaling:

        #         before_temperature_ece = ece_criterion(logits, targets).item()
        #         before_temperature_nll = nll_criterion(logits, targets).item()
        #         print('Before temperature - NLL: %.3f, ECE: %.3f' % (before_temperature_nll, before_temperature_ece))

        #         optimizer = optim.LBFGS([temperature], lr=0.01, max_iter=50)

        #         def eval():
        #             unsqueezed_temperature = temperature.unsqueeze(1).expand(logits.size(0), logits.size(1))
        #             loss = nll_criterion(logits/unsqueezed_temperature, targets)
        #             loss.backward()
        #             return loss
        #         optimizer.step(eval)

        #         unsqueezed_temperature = temperature.unsqueeze(1).expand(logits.size(0), logits.size(1))

        #         logits = logits/unsqueezed_temperature
        #         after_temperature_nll = nll_criterion(logits, targets).item()
        #         after_temperature_ece = ece_criterion(logits, targets).item()
        #         print('Optimal temperature: %.3f' % temperature.item())
        #         print('After temperature - NLL: %.3f, ECE: %.3f' % (after_temperature_nll, after_temperature_ece))

        sigmoided = m(logits)[:, 1].detach().cpu().numpy()
        temperature = nn.Parameter(torch.ones(1) *
                                   11).to(cuda).detach().requires_grad_(False)

        logits = logits / temperature.unsqueeze(1).expand(
            logits.size(0), logits.size(1))
        targets_cpu = targets.cpu().numpy()
        sigmoided = m(logits)[:, 1].detach().cpu().numpy()

        best_f1 = 0.0
        best_th = 1.0
        for i in range(1000, 0, -1):
            th = i * 0.001
            real_pred = (sigmoided >= th) * 1.0
            f1 = f1_score(targets_cpu.squeeze(), real_pred.squeeze())

            if f1 > best_f1:
                best_f1 = f1
                best_th = th

        print(best_th, best_f1)

    if real_labels is not None:
        # real labels mode replaces topk values at the end
        top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy(
            k=5)
    else:
        top1a, f1a = top1.avg, best_f1
    results = OrderedDict(top1=round(top1a, 4),
                          top1_err=round(100 - top1a, 4),
                          f1=f1a,
                          f1_err=round(100 - f1a, 4),
                          param_count=round(param_count / 1e6, 2),
                          img_size=data_config['input_size'][-1],
                          cropt_pct=crop_pct,
                          interpolation=data_config['interpolation'])

    _logger.info(' * Acc@1 {:.3f} ({:.3f}) f1 {:.3f} ({:.3f})'.format(
        results['top1'], results['top1_err'], results['f1'],
        results['f1_err']))

    return results
def validate(args):
    # might as well try to validate something
    args.pretrained = args.pretrained or not args.checkpoint
    args.prefetcher = not args.no_prefetcher
    if args.legacy_jit:
        set_jit_legacy()

    # create model
    if 'inception' in args.model:
        model = create_model(
            args.model,
            pretrained=args.pretrained,
            num_classes=args.num_classes,
            aux_logits=True,  # ! add aux loss
            in_chans=3,
            scriptable=args.torchscript)
    else:
        model = create_model(args.model,
                             pretrained=args.pretrained,
                             num_classes=args.num_classes,
                             in_chans=3,
                             scriptable=args.torchscript)

    # ! add more layer to classifier layer
    if args.create_classifier_layerfc:
        model.global_pool, model.classifier = create_classifier_layerfc(
            model.num_features, model.num_classes)

    if args.checkpoint:
        load_checkpoint(model, args.checkpoint, args.use_ema)

    param_count = sum([m.numel() for m in model.parameters()])
    _logger.info('Model %s created, param count: %d' %
                 (args.model, param_count))

    data_config = resolve_data_config(vars(args), model=model)
    model, test_time_pool = apply_test_time_pool(model, data_config, args)

    if args.torchscript:
        torch.jit.optimized_execution(True)
        model = torch.jit.script(model)

    if args.amp:
        model = amp.initialize(model.cuda(), opt_level='O1')
    else:
        model = model.cuda()

    if args.num_gpu > 1:
        model = torch.nn.DataParallel(model,
                                      device_ids=list(range(args.num_gpu)))

    if args.has_eval_label:
        criterion = nn.CrossEntropyLoss().cuda()  # ! don't have gold label

    if os.path.splitext(args.data)[1] == '.tar' and os.path.isfile(args.data):
        dataset = DatasetTar(args.data,
                             load_bytes=args.tf_preprocessing,
                             class_map=args.class_map)
    else:
        dataset = Dataset(args.data,
                          load_bytes=args.tf_preprocessing,
                          class_map=args.class_map,
                          args=args)

    if args.valid_labels:
        with open(args.valid_labels,
                  'r') as f:  # @valid_labels is index numbering
            valid_labels = {int(line.rstrip()) for line in f}
            valid_labels = [i in valid_labels for i in range(args.num_classes)]
    else:
        valid_labels = None

    if args.real_labels:
        real_labels = RealLabelsImagenet(dataset.filenames(basename=True),
                                         real_json=args.real_labels)
    else:
        real_labels = None

    crop_pct = 1.0 if test_time_pool else data_config['crop_pct']

    loader = create_loader(
        dataset,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        use_prefetcher=args.prefetcher,
        interpolation=data_config[
            'interpolation'],  # 'blank' is default Image.BILINEAR https://github.com/rwightman/pytorch-image-models/blob/470220b1f4c61ad7deb16dbfb8917089e842cd2a/timm/data/transforms.py#L43
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        crop_pct=crop_pct,
        pin_memory=args.pin_mem,
        tf_preprocessing=args.tf_preprocessing,
        auto_augment=args.aa,
        scale=args.scale,
        ratio=args.ratio,
        hflip=args.hflip,
        vflip=args.vflip,
        color_jitter=args.color_jitter,
        args=args)

    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    topk = AverageMeter()

    prediction = None  # ! need to save output
    true_label = None

    model.eval()
    with torch.no_grad():
        # warmup, reduce variability of first batch time, especially for comparing torchscript vs non
        input = torch.randn((args.batch_size, ) +
                            data_config['input_size']).cuda()
        model(input)
        end = time.time()
        for batch_idx, (input,
                        target) in enumerate(loader):  # ! not have real label

            if args.has_eval_label:  # ! just save true labels anyway... why not
                if true_label is None: true_label = target.cpu().data.numpy()
                else:
                    true_label = np.concatenate(
                        (true_label, target.cpu().data.numpy()), axis=0)

            if args.no_prefetcher:
                target = target.cuda()
                input = input.cuda()
                if args.fp16:
                    input = input.half()

            # compute output
            output = model(input)
            if isinstance(output, (tuple, list)):
                output = output[0]  # ! some model returns both loss + aux loss

            if valid_labels is not None:
                output = output[:,
                                valid_labels]  # ! keep only valid labels ? good to eval by class.

            # ! save prediction, don't append too slow ... whatever ?
            # ! are names of files also sorted ?
            if prediction is None:
                prediction = output.cpu().data.numpy()  # batchsize x label
            else:  # stack
                prediction = np.concatenate(
                    (prediction, output.cpu().data.numpy()), axis=0)

            if real_labels is not None:
                real_labels.add_result(output)

            if args.has_eval_label:
                # measure accuracy and record loss
                loss = criterion(
                    output, target)  # ! don't have gold standard on testset
                acc1, acc5 = accuracy(output.data, target, topk=(1, args.topk))
                losses.update(loss.item(), input.size(0))
                top1.update(acc1.item(), input.size(0))
                topk.update(acc5.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if args.has_eval_label and (batch_idx % args.log_freq == 0):
                _logger.info(
                    'Test: [{0:>4d}/{1}]  '
                    'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                    'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                    'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f})  '
                    'Acc@topk: {topk.val:>7.3f} ({topk.avg:>7.3f})'.format(
                        batch_idx,
                        len(loader),
                        batch_time=batch_time,
                        rate_avg=input.size(0) / batch_time.avg,
                        loss=losses,
                        top1=top1,
                        topk=topk))

    if not args.has_eval_label:
        top1a, topka = 0, 0  # just dummy, because we don't know ground labels
    else:
        if real_labels is not None:
            # real labels mode replaces topk values at the end
            top1a, topka = real_labels.get_accuracy(
                k=1), real_labels.get_accuracy(k=args.topk)
        else:
            top1a, topka = top1.avg, topk.avg

    results = OrderedDict(top1=round(top1a, 4),
                          top1_err=round(100 - top1a, 4),
                          topk=round(topka, 4),
                          topk_err=round(100 - topka, 4),
                          param_count=round(param_count / 1e6, 2),
                          img_size=data_config['input_size'][-1],
                          cropt_pct=crop_pct,
                          interpolation=data_config['interpolation'])

    _logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@topk {:.3f} ({:.3f})'.format(
        results['top1'], results['top1_err'], results['topk'],
        results['topk_err']))

    return results, prediction, true_label
Пример #21
0
def main():
    setup_default_logging()
    args, args_text = _parse_args()

    args.pretrained_backbone = not args.no_pretrained_backbone
    args.prefetcher = not args.no_prefetcher
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1
    args.device = 'cuda:0'
    args.world_size = 1
    args.rank = 0  # global rank
    if args.distributed:
        args.device = 'cuda:%d' % args.local_rank
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()
        args.rank = torch.distributed.get_rank()
    assert args.rank >= 0

    if args.distributed:
        logging.info(
            'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
            % (args.rank, args.world_size))
    else:
        logging.info('Training with a single process on 1 GPU.')

    torch.manual_seed(args.seed + args.rank)

    model = create_model(
        args.model,
        bench_task='train',
        pretrained=args.pretrained,
        pretrained_backbone=args.pretrained_backbone,
        redundant_bias=args.redundant_bias,
        checkpoint_path=args.initial_checkpoint,
    )
    # FIXME decide which args to keep and overlay on config / pass to backbone
    #     num_classes=args.num_classes,
    #     drop_rate=args.drop,
    #     drop_path_rate=args.drop_path,
    #     drop_block_rate=args.drop_block,
    input_size = model.config.image_size

    if args.local_rank == 0:
        logging.info('Model %s created, param count: %d' %
                     (args.model, sum([m.numel()
                                       for m in model.parameters()])))

    model.cuda()
    optimizer = create_optimizer(args, model)
    use_amp = False
    if has_apex and args.amp:
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
        use_amp = True
    if args.local_rank == 0:
        logging.info('NVIDIA APEX {}. AMP {}.'.format(
            'installed' if has_apex else 'not installed',
            'on' if use_amp else 'off'))

    # optionally resume from a checkpoint
    resume_state = {}
    resume_epoch = None
    if args.resume:
        resume_state, resume_epoch = resume_checkpoint(unwrap_bench(model),
                                                       args.resume)
    if resume_state and not args.no_resume_opt:
        if 'optimizer' in resume_state:
            if args.local_rank == 0:
                logging.info('Restoring Optimizer state from checkpoint')
            optimizer.load_state_dict(resume_state['optimizer'])
        if use_amp and 'amp' in resume_state and 'load_state_dict' in amp.__dict__:
            if args.local_rank == 0:
                logging.info('Restoring NVIDIA AMP state from checkpoint')
            amp.load_state_dict(resume_state['amp'])
    del resume_state

    model_ema = None
    if args.model_ema:
        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
        model_ema = ModelEma(model, decay=args.model_ema_decay)
        #resume=args.resume)  # FIXME bit of a mess with bench
        if args.resume:
            load_checkpoint(unwrap_bench(model_ema), args.resume, use_ema=True)

    if args.distributed:
        if args.sync_bn:
            try:
                if has_apex:
                    model = convert_syncbn_model(model)
                else:
                    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                        model)
                if args.local_rank == 0:
                    logging.info(
                        'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
                        'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.'
                    )
            except Exception as e:
                logging.error(
                    'Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1'
                )
        if has_apex:
            model = DDP(model, delay_allreduce=True)
        else:
            if args.local_rank == 0:
                logging.info(
                    "Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP."
                )
            model = DDP(model,
                        device_ids=[args.local_rank
                                    ])  # can use device str in Torch >= 1.1
        # NOTE: EMA model does not need to be wrapped by DDP

    lr_scheduler, num_epochs = create_scheduler(args, optimizer)
    start_epoch = 0
    if args.start_epoch is not None:
        # a specified start_epoch will always override the resume epoch
        start_epoch = args.start_epoch
    elif resume_epoch is not None:
        start_epoch = resume_epoch
    if lr_scheduler is not None and start_epoch > 0:
        lr_scheduler.step(start_epoch)

    if args.local_rank == 0:
        logging.info('Scheduled epochs: {}'.format(num_epochs))

    # train_anno_set = 'train2017'
    # train_annotation_path = os.path.join(args.data, 'annotations', f'instances_{train_anno_set}.json')
    # train_image_dir = train_anno_set
    #dataset_train = CocoDetection("/workspace/data/images",
    #                            "/workspace/data/datatrain90n.json")
    train_anno_set = 'train'
    train_annotation_path = os.path.join(args.data, 'annotations',
                                         f'instances_{train_anno_set}.json')
    train_image_dir = train_anno_set
    dataset_train = CocoDetection(os.path.join(args.data, train_image_dir),
                                  train_annotation_path)
    # FIXME cutmix/mixup worth investigating?
    # collate_fn = None
    # if args.prefetcher and args.mixup > 0:
    #     collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes)

    loader_train = create_loader(
        dataset_train,
        input_size=input_size,
        batch_size=args.batch_size,
        is_training=True,
        use_prefetcher=args.prefetcher,
        #re_prob=args.reprob,  # FIXME add back various augmentations
        #re_mode=args.remode,
        #re_count=args.recount,
        #re_split=args.resplit,
        #color_jitter=args.color_jitter,
        #auto_augment=args.aa,
        interpolation=args.train_interpolation,
        mean=[0.4533, 0.4744,
              0.4722],  #[0.4846, 0.5079, 0.5005],#[0.485, 0.456, 0.406],
        std=[0.2823, 0.2890,
             0.3084],  #[0.2687, 0.2705, 0.2869],#[0.485, 0.456, 0.406],
        num_workers=args.workers,
        distributed=args.distributed,
        #collate_fn=collate_fn,
        pin_mem=args.pin_mem,
    )
    train_anno_set = 'val'
    train_annotation_path = os.path.join(args.data, 'annotations',
                                         f'instances_{train_anno_set}.json')
    train_image_dir = train_anno_set
    dataset_eval = CocoDetection(os.path.join(args.data, train_image_dir),
                                 train_annotation_path)
    # train_anno_set = 'val'
    # train_annotation_path = os.path.join(args.data, 'annotations', f'instances_{train_anno_set}.json')
    # train_image_dir = train_anno_set
    #   dataset_eval = CocoDetection("/workspace/data/val/images",
    #                               "/workspace/data/dataval90n.json")
    loader_eval = create_loader(
        dataset_eval,
        input_size=input_size,
        batch_size=args.validation_batch_size_multiplier * args.batch_size,
        is_training=False,
        use_prefetcher=args.prefetcher,
        interpolation=args.interpolation,
        mean=[0.4535, 0.4744, 0.4724],  #[0.4851, 0.5083, 0.5009],
        std=[0.2835, 0.2903, 0.3098],  #[0.2690, 0.2709, 0.2877],
        num_workers=args.workers,
        #distributed=args.distributed,
        pin_mem=args.pin_mem,
    )

    # for xx,item in dataset_train :
    #     print("out",type(xx))

    #     break

    # exit()
    array_of_gt = []
    if args.local_rank == 0:
        for _, item in tqdm(dataset_eval):
            # print(item)
            for i in range(len(item['cls'])):
                array_of_gt.append(
                    BoundingBox(imageName=str(item["img_id"]),
                                classId=item["cls"][i],
                                x=item["bbox"][i][1] * item['img_scale'],
                                y=item["bbox"][i][0] * item['img_scale'],
                                w=item["bbox"][i][3] * item['img_scale'],
                                h=item["bbox"][i][2] * item['img_scale'],
                                typeCoordinates=CoordinatesType.Absolute,
                                bbType=BBType.GroundTruth,
                                format=BBFormat.XYX2Y2,
                                imgSize=(item['img_size'][0],
                                         item['img_size'][1])))

    evaluator = COCOEvaluator(dataset_eval.coco,
                              distributed=args.distributed,
                              gtboxes=array_of_gt)

    eval_metric = args.eval_metric
    best_metric = None
    best_epoch = None
    saver = None
    output_dir = ''
    if args.local_rank == 0:
        output_base = args.output if args.output else './output'
        exp_name = '-'.join(
            [datetime.now().strftime("%Y%m%d-%H%M%S"), args.model])
        output_dir = get_outdir(output_base, 'train', exp_name)
        decreasing = True if eval_metric == 'loss' else False
        saver = CheckpointSaver(checkpoint_dir=output_dir,
                                decreasing=decreasing)
        with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
            f.write(args_text)


#     print(model)
    try:
        for epoch in range(start_epoch, num_epochs):
            if args.distributed:
                loader_train.sampler.set_epoch(epoch)

            train_metrics = train_epoch(epoch,
                                        model,
                                        loader_train,
                                        optimizer,
                                        args,
                                        lr_scheduler=lr_scheduler,
                                        saver=saver,
                                        output_dir=output_dir,
                                        use_amp=use_amp,
                                        model_ema=model_ema)

            if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
                if args.local_rank == 0:
                    logging.info(
                        "Distributing BatchNorm running means and vars")
                distribute_bn(model, args.world_size, args.dist_bn == 'reduce')

            # the overhead of evaluating with coco style datasets is fairly high, so just ema or non, not both
            if model_ema is not None:
                if args.distributed and args.dist_bn in ('broadcast',
                                                         'reduce'):
                    distribute_bn(model_ema, args.world_size,
                                  args.dist_bn == 'reduce')

                eval_metrics = validate(model_ema.ema,
                                        loader_eval,
                                        args,
                                        evaluator,
                                        log_suffix=' (EMA)')
            else:
                eval_metrics = validate(model, loader_eval, args, evaluator)

            if lr_scheduler is not None:
                # step LR for next epoch
                lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])

            if saver is not None:
                update_summary(epoch,
                               train_metrics,
                               eval_metrics,
                               os.path.join(output_dir, 'summary.csv'),
                               write_header=best_metric is None)

                # save proper checkpoint with eval metric
                save_metric = eval_metrics[eval_metric]
                best_metric, best_epoch = saver.save_checkpoint(
                    unwrap_bench(model),
                    optimizer,
                    args,
                    epoch=epoch,
                    model_ema=unwrap_bench(model_ema),
                    metric=save_metric,
                    use_amp=use_amp)

    except KeyboardInterrupt:
        pass
    if best_metric is not None:
        logging.info('*** Best metric: {0} (epoch {1})'.format(
            best_metric, best_epoch))
Пример #22
0
def validate(args):
    # might as well try to validate something
    args.pretrained = args.pretrained or not args.checkpoint

    # create model
    model = create_model(args.model,
                         pretrained=args.pretrained,
                         num_classes=args.num_classes,
                         in_chans=3,
                         scriptable=args.torchscript)

    if args.checkpoint:
        load_checkpoint(model, args.checkpoint, args.use_ema)

    param_count = sum([m.numel() for m in model.parameters()])
    logging.info('Model %s created, param count: %d' %
                 (args.model, param_count))

    if args.torchscript:
        torch.jit.optimized_execution(True)
        model = torch.jit.script(model)

    model = model.cuda()

    if args.num_gpu > 1:
        model = torch.nn.DataParallel(model,
                                      device_ids=list(range(args.num_gpu)))

    criterion = nn.CrossEntropyLoss().cuda()

    # from torchvision.datasets import ImageNet
    # dataset = ImageNet(args.data, split='val')

    valdir = args.data
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    transform = cvtransforms.Compose([
        cvtransforms.Resize(size=(256), interpolation='BILINEAR'),
        cvtransforms.CenterCrop(224),
        cvtransforms.ToTensor(),
        cvtransforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])

    # loader = torch.utils.data.DataLoader(
    #     datasets.ImageFolder(valdir, transform, loader=opencv_loader),
    #     batch_size=args.batch_size, shuffle=False,
    #     num_workers=args.workers, pin_memory=False)

    loader = torch.utils.data.DataLoader(datasets.ImageFolder(
        valdir,
        transforms.Compose([
            transforms.Resize((256), interpolation=2),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
                                         batch_size=args.batch_size,
                                         shuffle=False,
                                         num_workers=args.workers,
                                         pin_memory=False)

    # loader_eval = loader.Loader('val', valdir, batch_size=args.batch_size, num_workers=args.workers, shuffle=False)

    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    model.eval()
    with torch.no_grad():
        # warmup, reduce variability of first batch time, especially for comparing torchscript vs non
        # input = torch.randn((args.batch_size,)).cuda()
        # model(input)
        end = time.time()
        for i, (input, target) in enumerate(loader):
            # if args.no_prefetcher:
            target = target.cuda()
            input = input.cuda()

            # compute output
            output, _ = model(input)
            # loss = criterion(output, target)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output.data, target, topk=(1, 5))
            # losses.update(loss.item(), input.size(0))
            top1.update(acc1.item(), input.size(0))
            top5.update(acc5.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.log_freq == 0:
                logging.info(
                    'Test: [{0:>4d}/{1}]  '
                    'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                    'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                    'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f})  '
                    'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
                        i,
                        len(loader),
                        batch_time=batch_time,
                        rate_avg=input.size(0) / batch_time.avg,
                        loss=losses,
                        top1=top1,
                        top5=top5))

    results = OrderedDict(top1=round(top1.avg, 4),
                          top1_err=round(100 - top1.avg, 4),
                          top5=round(top5.avg, 4),
                          top5_err=round(100 - top5.avg, 4),
                          param_count=round(param_count / 1e6, 2))

    logging.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format(
        results['top1'], results['top1_err'], results['top5'],
        results['top5_err']))

    return results
Пример #23
0
def main():
    setup_default_logging()
    args, args_text = _parse_args()

    args.pretrained_backbone = not args.no_pretrained_backbone
    args.prefetcher = not args.no_prefetcher
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1
    args.device = 'cuda:0'
    args.world_size = 1
    args.rank = 0  # global rank
    if args.distributed:
        args.device = 'cuda:%d' % args.local_rank
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl', init_method='env://')
        args.world_size = torch.distributed.get_world_size()
        args.rank = torch.distributed.get_rank()
    assert args.rank >= 0

    if args.distributed:
        logging.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
                     % (args.rank, args.world_size))
    else:
        logging.info('Training with a single process on 1 GPU.')

    use_amp = None
    if args.amp:
        # for backwards compat, `--amp` arg tries apex before native amp
        if has_apex:
            args.apex_amp = True
        elif has_native_amp:
            args.native_amp = True
        else:
            logging.warning("Neither APEX or native Torch AMP is available, using float32. "
                            "Install NVIDA apex or upgrade to PyTorch 1.6.")

    if args.apex_amp:
        if has_apex:
            use_amp = 'apex'
        else:
            logging.warning("APEX AMP not available, using float32. Install NVIDA apex")
    elif args.native_amp:
        if has_native_amp:
            use_amp = 'native'
        else:
            logging.warning("Native AMP not available, using float32. Upgrade to PyTorch 1.6.")

    torch.manual_seed(args.seed + args.rank)

    model = create_model(
        args.model,
        bench_task='train',
        num_classes=args.num_classes,
        pretrained=args.pretrained,
        pretrained_backbone=args.pretrained_backbone,
        redundant_bias=args.redundant_bias,
        label_smoothing=args.smoothing,
        new_focal=args.new_focal,
        jit_loss=args.jit_loss,
        bench_labeler=args.bench_labeler,
        checkpoint_path=args.initial_checkpoint,
    )
    model_config = model.config  # grab before we obscure with DP/DDP wrappers

    if args.local_rank == 0:
        logging.info('Model %s created, param count: %d' % (args.model, sum([m.numel() for m in model.parameters()])))

    model.cuda()
    if args.channels_last:
        model = model.to(memory_format=torch.channels_last)

    optimizer = create_optimizer(args, model)

    amp_autocast = suppress  # do nothing
    loss_scaler = None
    if use_amp == 'apex':
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
        loss_scaler = ApexScaler()
        if args.local_rank == 0:
            logging.info('Using NVIDIA APEX AMP. Training in mixed precision.')
    elif use_amp == 'native':
        amp_autocast = torch.cuda.amp.autocast
        loss_scaler = NativeScaler()
        if args.local_rank == 0:
            logging.info('Using native Torch AMP. Training in mixed precision.')
    else:
        if args.local_rank == 0:
            logging.info('AMP not enabled. Training in float32.')

    # optionally resume from a checkpoint
    resume_epoch = None
    if args.resume:
        resume_epoch = resume_checkpoint(
            unwrap_bench(model), args.resume,
            optimizer=None if args.no_resume_opt else optimizer,
            loss_scaler=None if args.no_resume_opt else loss_scaler,
            log_info=args.local_rank == 0)

    model_ema = None
    if args.model_ema:
        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
        model_ema = ModelEma(model, decay=args.model_ema_decay)
        if args.resume:
            # FIXME bit of a mess with bench, cannot use the load in ModelEma
            load_checkpoint(unwrap_bench(model_ema), args.resume, use_ema=True)

    if args.distributed:
        if args.sync_bn:
            try:
                if has_apex and use_amp != 'native':
                    model = convert_syncbn_model(model)
                else:
                    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
                if args.local_rank == 0:
                    logging.info(
                        'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
                        'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
            except Exception as e:
                logging.error('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1')
        if has_apex and use_amp != 'native':
            if args.local_rank == 0:
                logging.info("Using apex DistributedDataParallel.")
            model = ApexDDP(model, delay_allreduce=True)
        else:
            if args.local_rank == 0:
                logging.info("Using torch DistributedDataParallel.")
            model = NativeDDP(model, device_ids=[args.device])
        # NOTE: EMA model does not need to be wrapped by DDP

    lr_scheduler, num_epochs = create_scheduler(args, optimizer)
    start_epoch = 0
    if args.start_epoch is not None:
        # a specified start_epoch will always override the resume epoch
        start_epoch = args.start_epoch
    elif resume_epoch is not None:
        start_epoch = resume_epoch
    if lr_scheduler is not None and start_epoch > 0:
        lr_scheduler.step(start_epoch)

    if args.local_rank == 0:
        logging.info('Scheduled epochs: {}'.format(num_epochs))

    loader_train, loader_eval, evaluator = create_datasets_and_loaders(args, model_config)

    if model_config.num_classes < loader_train.dataset.parser.max_label:
        logging.error(
            f'Model {model_config.num_classes} has fewer classes than dataset {loader_train.dataset.parser.max_label}.')
        exit(1)
    if model_config.num_classes > loader_train.dataset.parser.max_label:
        logging.warning(
            f'Model {model_config.num_classes} has more classes than dataset {loader_train.dataset.parser.max_label}.')

    eval_metric = args.eval_metric
    best_metric = None
    best_epoch = None
    saver = None
    output_dir = ''
    if args.local_rank == 0:
        output_base = args.output if args.output else './output'
        exp_name = '-'.join([
            datetime.now().strftime("%Y%m%d-%H%M%S"),
            args.model
        ])
        output_dir = get_outdir(output_base, 'train', exp_name)
        decreasing = True if eval_metric == 'loss' else False
        saver = CheckpointSaver(
            model, optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler,
            checkpoint_dir=output_dir, decreasing=decreasing, unwrap_fn=unwrap_bench)
        with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
            f.write(args_text)

    try:
        for epoch in range(start_epoch, num_epochs):
            if args.distributed:
                loader_train.sampler.set_epoch(epoch)

            train_metrics = train_epoch(
                epoch, model, loader_train, optimizer, args,
                lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
                amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema)

            if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
                if args.local_rank == 0:
                    logging.info("Distributing BatchNorm running means and vars")
                distribute_bn(model, args.world_size, args.dist_bn == 'reduce')

            # the overhead of evaluating with coco style datasets is fairly high, so just ema or non, not both
            if model_ema is not None:
                if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
                    distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')

                eval_metrics = validate(model_ema.ema, loader_eval, args, evaluator, log_suffix=' (EMA)')
            else:
                eval_metrics = validate(model, loader_eval, args, evaluator)

            if lr_scheduler is not None:
                # step LR for next epoch
                lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])

            if saver is not None:
                update_summary(
                    epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
                    write_header=best_metric is None)

                # save proper checkpoint with eval metric
                best_metric, best_epoch = saver.save_checkpoint(epoch=epoch, metric=eval_metrics[eval_metric])

    except KeyboardInterrupt:
        pass
    if best_metric is not None:
        logging.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
Пример #24
0
def validate(args):
    # might as well try to validate something
    args.pretrained = args.pretrained or not args.checkpoint
    args.prefetcher = not args.no_prefetcher

    # create model
    model = create_model(
        args.model,
        num_classes=args.num_classes,
        in_chans=3,
        pretrained=args.pretrained)
    if args.checkpoint:
        load_checkpoint(model, args.checkpoint, args.use_ema)

    param_count = sum([m.numel() for m in model.parameters()])
    logging.info('Model %s created, param count: %d' % (args.model, param_count))

    data_config = resolve_data_config(vars(args), model=model)
    model, test_time_pool = apply_test_time_pool(model, data_config, args)

    if args.torchscript:
        torch.jit.optimized_execution(True)
        model = torch.jit.script(model)

    if args.amp:
        model = amp.initialize(model.cuda(), opt_level='O1')
    else:
        model = model.cuda()

    if args.num_gpu > 1:
        model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu)))

    criterion = nn.CrossEntropyLoss().cuda()

    #from torchvision.datasets import ImageNet
    #dataset = ImageNet(args.data, split='val')
    if os.path.splitext(args.data)[1] == '.tar' and os.path.isfile(args.data):
        dataset = DatasetTar(args.data, load_bytes=args.tf_preprocessing, class_map=args.class_map)
    else:
        dataset = Dataset(args.data, load_bytes=args.tf_preprocessing, class_map=args.class_map)

    crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
    loader = create_loader(
        dataset,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        use_prefetcher=args.prefetcher,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        crop_pct=crop_pct,
        pin_memory=args.pin_mem,
        tf_preprocessing=args.tf_preprocessing)

    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    model.eval()
    end = time.time()
    with torch.no_grad():
        for i, (input, target) in enumerate(loader):
            if args.no_prefetcher:
                target = target.cuda()
                input = input.cuda()
                if args.fp16:
                    input = input.half()

            # compute output
            output = model(input)
            loss = criterion(output, target)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output.data, target, topk=(1, 2))
            losses.update(loss.item(), input.size(0))
            top1.update(acc1.item(), input.size(0))
            top5.update(acc5.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.log_freq == 0:
                logging.info(
                    'Test: [{0:>4d}/{1}]  '
                    'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                    'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                    'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f})  '
                    'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
                        i, len(loader), batch_time=batch_time,
                        rate_avg=input.size(0) / batch_time.avg,
                        loss=losses, top1=top1, top5=top5))

    results = OrderedDict(
        top1=round(top1.avg, 4), top1_err=round(100 - top1.avg, 4),
        top5=round(top5.avg, 4), top5_err=round(100 - top5.avg, 4),
        param_count=round(param_count / 1e6, 2),
        img_size=data_config['input_size'][-1],
        cropt_pct=crop_pct,
        interpolation=data_config['interpolation'])

    logging.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format(
       results['top1'], results['top1_err'], results['top5'], results['top5_err']))

    return results
Пример #25
0
def main():
    setup_default_logging()
    args, args_text = _parse_args()

    if args.log_wandb:
        if has_wandb:
            wandb.init(project=args.experiment, config=args)
        else:
            _logger.warning(
                "You've requested to log metrics to wandb but package not found. "
                "Metrics not being logged to wandb, try `pip install wandb`")

    args.prefetcher = not args.no_prefetcher
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1
    args.device = 'cuda:0'
    args.world_size = 1
    args.rank = 0  # global rank
    if args.distributed:
        args.device = 'cuda:%d' % args.local_rank
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()
        args.rank = torch.distributed.get_rank()
        _logger.info(
            'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
            % (args.rank, args.world_size))
    else:
        _logger.info('Training with a single process on 1 GPUs.')
    assert args.rank >= 0

    # resolve AMP arguments based on PyTorch / Apex availability
    use_amp = None
    if args.amp:
        # `--amp` chooses native amp before apex (APEX ver not actively maintained)
        if has_native_amp:
            args.native_amp = True
        elif has_apex:
            args.apex_amp = True
    if args.apex_amp and has_apex:
        use_amp = 'apex'
    elif args.native_amp and has_native_amp:
        use_amp = 'native'
    elif args.apex_amp or args.native_amp:
        _logger.warning(
            "Neither APEX or native Torch AMP is available, using float32. "
            "Install NVIDA apex or upgrade to PyTorch 1.6")

    random_seed(args.seed, args.rank)

    if args.fuser:
        set_jit_fuser(args.fuser)

    data_splits = get_data_splits_by_name(
        dataset_name=args.dataset_name,
        data_root=args.data_dir,
        batch_size=args.batch_size,
    )
    loader_train, loader_eval = data_splits['train'], data_splits['test']

    model_wrapper_fn = MODEL_WRAPPER_REGISTRY.get(
        model_name=args.model.lower(),
        dataset_name=args.pretraining_original_dataset)
    model = model_wrapper_fn(pretrained=args.pretrained,
                             progress=True,
                             num_classes=len(loader_train.dataset.classes))

    if args.local_rank == 0:
        _logger.info(
            f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}'
        )

    data_config = resolve_data_config(vars(args),
                                      model=model,
                                      verbose=args.local_rank == 0)

    # setup augmentation batch splits for contrastive loss or split bn
    num_aug_splits = 0
    if args.aug_splits > 0:
        assert args.aug_splits > 1, 'A split of 1 makes no sense'
        num_aug_splits = args.aug_splits

    # enable split bn (separate bn stats per batch-portion)
    if args.split_bn:
        assert num_aug_splits > 1 or args.resplit
        model = convert_splitbn_model(model, max(num_aug_splits, 2))

    # move model to GPU, enable channels last layout if set
    model.cuda()
    if args.channels_last:
        model = model.to(memory_format=torch.channels_last)

    # setup synchronized BatchNorm for distributed training
    if args.distributed and args.sync_bn:
        assert not args.split_bn
        if has_apex and use_amp == 'apex':
            # Apex SyncBN preferred unless native amp is activated
            model = convert_syncbn_model(model)
        else:
            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
        if args.local_rank == 0:
            _logger.info(
                'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
                'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.'
            )

    optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args))

    # setup automatic mixed-precision (AMP) loss scaling and op casting
    amp_autocast = suppress  # do nothing
    loss_scaler = None
    if use_amp == 'apex':
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
        loss_scaler = ApexScaler()
        if args.local_rank == 0:
            _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
    elif use_amp == 'native':
        amp_autocast = torch.cuda.amp.autocast
        loss_scaler = NativeScaler()
        if args.local_rank == 0:
            _logger.info(
                'Using native Torch AMP. Training in mixed precision.')
    else:
        if args.local_rank == 0:
            _logger.info('AMP not enabled. Training in float32.')

    # optionally resume from a checkpoint
    resume_epoch = None
    if args.resume:
        resume_epoch = resume_checkpoint(
            model,
            args.resume,
            optimizer=None if args.no_resume_opt else optimizer,
            loss_scaler=None if args.no_resume_opt else loss_scaler,
            log_info=args.local_rank == 0)

    # setup exponential moving average of model weights, SWA could be used here too
    model_ema = None
    if args.model_ema:
        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
        model_ema = ModelEmaV2(
            model,
            decay=args.model_ema_decay,
            device='cpu' if args.model_ema_force_cpu else None)
        if args.resume:
            load_checkpoint(model_ema.module, args.resume, use_ema=True)

    # setup distributed training
    if args.distributed:
        if has_apex and use_amp == 'apex':
            # Apex DDP preferred unless native amp is activated
            if args.local_rank == 0:
                _logger.info("Using NVIDIA APEX DistributedDataParallel.")
            model = ApexDDP(model, delay_allreduce=True)
        else:
            if args.local_rank == 0:
                _logger.info("Using native Torch DistributedDataParallel.")
            model = NativeDDP(model,
                              device_ids=[args.local_rank],
                              broadcast_buffers=not args.no_ddp_bb)
        # NOTE: EMA model does not need to be wrapped by DDP

    # setup learning rate schedule and starting epoch
    lr_scheduler, num_epochs = create_scheduler(args, optimizer)
    start_epoch = 0
    if args.start_epoch is not None:
        # a specified start_epoch will always override the resume epoch
        start_epoch = args.start_epoch
    elif resume_epoch is not None:
        start_epoch = resume_epoch
    if lr_scheduler is not None and start_epoch > 0:
        lr_scheduler.step(start_epoch)

    if args.local_rank == 0:
        _logger.info('Scheduled epochs: {}'.format(num_epochs))

    # setup loss function
    if args.jsd_loss:
        assert num_aug_splits > 1  # JSD only valid with aug splits set
        train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits,
                                        smoothing=args.smoothing)
    elif args.smoothing:
        if args.bce_loss:
            train_loss_fn = BinaryCrossEntropy(
                smoothing=args.smoothing,
                target_threshold=args.bce_target_thresh)
        else:
            train_loss_fn = LabelSmoothingCrossEntropy(
                smoothing=args.smoothing)
    else:
        train_loss_fn = nn.CrossEntropyLoss()
    train_loss_fn = train_loss_fn.cuda()
    validate_loss_fn = nn.CrossEntropyLoss().cuda()

    # setup checkpoint saver and eval metric tracking
    eval_metric = args.eval_metric
    best_metric = None
    best_epoch = None
    saver = None
    output_dir = None
    if args.rank == 0:
        if args.experiment:
            exp_name = args.experiment
        else:
            exp_name = '-'.join([
                datetime.now().strftime("%Y%m%d-%H%M%S"),
                safe_model_name(args.model),
                str(data_config['input_size'][-1])
            ])
        output_dir = get_outdir(
            args.output if args.output else './output/train', exp_name)
        decreasing = True if eval_metric == 'loss' else False
        saver = CheckpointSaver(model=model,
                                optimizer=optimizer,
                                args=args,
                                model_ema=model_ema,
                                amp_scaler=loss_scaler,
                                checkpoint_dir=output_dir,
                                recovery_dir=output_dir,
                                decreasing=decreasing,
                                max_history=args.checkpoint_hist)
        with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
            f.write(args_text)

    try:
        for epoch in range(start_epoch, num_epochs):
            if args.distributed and hasattr(loader_train.sampler, 'set_epoch'):
                loader_train.sampler.set_epoch(epoch)

            train_metrics = train_one_epoch(epoch,
                                            model,
                                            loader_train,
                                            optimizer,
                                            train_loss_fn,
                                            args,
                                            lr_scheduler=lr_scheduler,
                                            saver=saver,
                                            output_dir=output_dir,
                                            amp_autocast=amp_autocast,
                                            loss_scaler=loss_scaler,
                                            model_ema=model_ema)

            if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
                if args.local_rank == 0:
                    _logger.info(
                        "Distributing BatchNorm running means and vars")
                distribute_bn(model, args.world_size, args.dist_bn == 'reduce')

            eval_metrics = validate(model,
                                    loader_eval,
                                    validate_loss_fn,
                                    args,
                                    amp_autocast=amp_autocast)

            if model_ema is not None and not args.model_ema_force_cpu:
                if args.distributed and args.dist_bn in ('broadcast',
                                                         'reduce'):
                    distribute_bn(model_ema, args.world_size,
                                  args.dist_bn == 'reduce')
                ema_eval_metrics = validate(model_ema.module,
                                            loader_eval,
                                            validate_loss_fn,
                                            args,
                                            amp_autocast=amp_autocast,
                                            log_suffix=' (EMA)')
                eval_metrics = ema_eval_metrics

            if lr_scheduler is not None:
                # step LR for next epoch
                lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])

            if output_dir is not None:
                update_summary(epoch,
                               train_metrics,
                               eval_metrics,
                               os.path.join(output_dir, 'summary.csv'),
                               write_header=best_metric is None,
                               log_wandb=args.log_wandb and has_wandb)

            if saver is not None:
                # save proper checkpoint with eval metric
                save_metric = eval_metrics[eval_metric]
                best_metric, best_epoch = saver.save_checkpoint(
                    epoch, metric=save_metric)

    except KeyboardInterrupt:
        pass
    if best_metric is not None:
        _logger.info('*** Best metric: {0} (epoch {1})'.format(
            best_metric, best_epoch))