示例#1
0
def data_creator(config):
    # torch.manual_seed(args.seed + torch.distributed.get_rank())

    args = config["args"]

    train_dir = join(args.data, "train")
    val_dir = join(args.data, "val")

    if args.mock_data:
        util.mock_data(train_dir, val_dir)

    # todo: verbose should depend on rank
    data_config = resolve_data_config(vars(args), verbose=True)

    dataset_train = Dataset(join(args.data, "train"))
    dataset_eval = Dataset(join(args.data, "val"))

    collate_fn = None
    if args.prefetcher and args.mixup > 0:
        # collate conflict (need to support deinterleaving in collate mixup)
        assert args.num_aug_splits == 0
        collate_fn = FastCollateMixup(args.mixup, args.smoothing,
                                      args.num_classes)

    common_params = dict(
        input_size=data_config["input_size"],
        use_prefetcher=args.prefetcher,
        mean=data_config["mean"],
        std=data_config["std"],
        num_workers=1,
        distributed=args.distributed,
        pin_memory=args.pin_mem)

    train_loader = create_loader(
        dataset_train,
        is_training=True,
        batch_size=config[BATCH_SIZE],
        re_prob=args.reprob,
        re_mode=args.remode,
        re_count=args.recount,
        re_split=args.resplit,
        collate_fn=collate_fn,
        color_jitter=args.color_jitter,
        auto_augment=args.aa,
        interpolation=args.train_interpolation,
        num_aug_splits=args.num_aug_splits,  # always 0 right now
        **common_params)
    eval_loader = create_loader(
        dataset_eval,
        is_training=False,
        batch_size=args.validation_batch_size_multiplier * config[BATCH_SIZE],
        interpolation=data_config["interpolation"],
        crop_pct=data_config["crop_pct"],
        **common_params)

    return train_loader, eval_loader
示例#2
0
 def _evaluate(self, val_data, metric_name=None):
     if self._problem_type == REGRESSION:
         validate_loss_fn = nn.MSELoss()
     else:
         validate_loss_fn = nn.CrossEntropyLoss()
     validate_loss_fn = validate_loss_fn.to(self.ctx[0])
     val_data = val_data.to_torch()
     val_loader = create_loader(
         val_data,
         input_size=self._data_cfg.input_size,
         batch_size=self._data_cfg.validation_batch_size_multiplier *
         self._train_cfg.batch_size,
         is_training=False,
         use_prefetcher=self._misc_cfg.prefetcher,
         interpolation=self._data_cfg.interpolation,
         mean=self._data_cfg.mean,
         std=self._data_cfg.std,
         num_workers=self._misc_cfg.num_workers,
         distributed=False,
         crop_pct=self._data_cfg.crop_pct,
         pin_memory=self._misc_cfg.pin_mem,
     )
     return self.validate(self.net,
                          val_loader,
                          validate_loss_fn,
                          amp_autocast=self._amp_autocast,
                          metric_name=metric_name)
示例#3
0
 def _init_dataloader(self):
     """Init dataloader from timm."""
     if self.horovod and hvd.local_rank() == 0:
         FileOps.copy_folder(self.cfg.dataset.remote_data_dir,
                             self.cfg.dataset.data_dir)
     if self.horovod:
         hvd.join()
     args = self.cfg.dataset
     train_dir = os.path.join(self.cfg.dataset.data_dir, 'train')
     dataset_train = Dataset(train_dir)
     world_size, rank = None, None
     if self.horovod:
         world_size, rank = hvd.size(), hvd.rank()
     self.train_loader = create_loader(dataset_train,
                                       input_size=tuple(args.input_size),
                                       batch_size=args.batch_size,
                                       is_training=True,
                                       use_prefetcher=self.cfg.prefetcher,
                                       rand_erase_prob=args.reprob,
                                       rand_erase_mode=args.remode,
                                       rand_erase_count=args.recount,
                                       color_jitter=args.color_jitter,
                                       auto_augment=args.aa,
                                       interpolation='random',
                                       mean=tuple(args.mean),
                                       std=tuple(args.std),
                                       num_workers=args.workers,
                                       distributed=self.horovod,
                                       world_size=world_size,
                                       rank=rank)
     valid_dir = os.path.join(self.cfg.dataset.data_dir, 'val')
     dataset_eval = Dataset(valid_dir)
     self.valid_loader = create_loader(dataset_eval,
                                       input_size=tuple(args.input_size),
                                       batch_size=4 * args.batch_size,
                                       is_training=False,
                                       use_prefetcher=self.cfg.prefetcher,
                                       interpolation=args.interpolation,
                                       mean=tuple(args.mean),
                                       std=tuple(args.std),
                                       num_workers=args.workers,
                                       distributed=self.horovod,
                                       world_size=world_size,
                                       rank=rank)
示例#4
0
def validate(args):
    rng = jax.random.PRNGKey(0)
    model, variables = create_model(args.model, pretrained=True, rng=rng)
    print(f'Created {args.model} model. Validating...')

    if args.no_jit:
        eval_step = lambda images, labels: eval_forward(
            model, variables, images, labels)
    else:
        eval_step = jax.jit(lambda images, labels: eval_forward(
            model, variables, images, labels))

    if os.path.splitext(args.data)[1] == '.tar' and os.path.isfile(args.data):
        dataset = DatasetTar(args.data)
    else:
        dataset = Dataset(args.data)

    data_config = resolve_data_config(vars(args), model=model)
    loader = create_loader(dataset,
                           input_size=data_config['input_size'],
                           batch_size=args.batch_size,
                           use_prefetcher=False,
                           interpolation=data_config['interpolation'],
                           mean=data_config['mean'],
                           std=data_config['std'],
                           num_workers=8,
                           crop_pct=data_config['crop_pct'])

    batch_time = AverageMeter()
    correct_top1, correct_top5 = 0, 0
    total_examples = 0
    start_time = prev_time = time.time()
    for batch_index, (images, labels) in enumerate(loader):
        images = images.numpy().transpose(0, 2, 3, 1)
        labels = labels.numpy()

        top1_count, top5_count = eval_step(images, labels)
        correct_top1 += top1_count
        correct_top5 += top5_count
        total_examples += images.shape[0]

        batch_time.update(time.time() - prev_time)
        if batch_index % 20 == 0 and batch_index > 0:
            print(
                f'Test: [{batch_index:>4d}/{len(loader)}]  '
                f'Rate: {images.shape[0] / batch_time.val:>5.2f}/s ({images.shape[0] / batch_time.avg:>5.2f}/s) '
                f'Acc@1: {100 * correct_top1 / total_examples:>7.3f} '
                f'Acc@5: {100 * correct_top5 / total_examples:>7.3f}')
        prev_time = time.time()

    acc_1 = 100 * correct_top1 / total_examples
    acc_5 = 100 * correct_top5 / total_examples
    print(
        f'Validation complete. {total_examples / (prev_time - start_time):>5.2f} img/s. '
        f'Acc@1 {acc_1:>7.3f}, Acc@5 {acc_5:>7.3f}')
    return dict(top1=float(acc_1), top5=float(acc_5))
示例#5
0
def validate(args):
    model = create_model(args.model, pretrained=True)
    print(f'Created {args.model} model. Validating...')

    eval_step = objax.Jit(
        lambda images, labels: eval_forward(model, images, labels),
        model.vars())

    dataset = create_dataset('imagenet', args.data)

    data_config = resolve_data_config(vars(args), model=model)
    loader = create_loader(
        dataset,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        use_prefetcher=False,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=8,
        crop_pct=data_config['crop_pct'])

    batch_time = AverageMeter()
    correct_top1, correct_top5 = 0, 0
    total_examples = 0
    start_time = prev_time = time.time()
    for batch_index, (images, labels) in enumerate(loader):
        images = images.numpy()
        labels = labels.numpy()

        top1_count, top5_count = eval_step(images, labels)
        correct_top1 += int(top1_count)
        correct_top5 += int(top5_count)
        total_examples += images.shape[0]

        batch_time.update(time.time() - prev_time)
        if batch_index % 20 == 0 and batch_index > 0:
            print(
                f'Test: [{batch_index:>4d}/{len(loader)}]  '
                f'Rate: {images.shape[0] / batch_time.val:>5.2f}/s ({images.shape[0] / batch_time.avg:>5.2f}/s) '
                f'Acc@1: {100 * correct_top1 / total_examples:>7.3f} '
                f'Acc@5: {100 * correct_top5 / total_examples:>7.3f}')
        prev_time = time.time()

    acc_1 = 100 * correct_top1 / total_examples
    acc_5 = 100 * correct_top5 / total_examples
    print(f'Validation complete. {total_examples / (prev_time - start_time):>5.2f} img/s. '
          f'Acc@1 {acc_1:>7.3f}, Acc@5 {acc_5:>7.3f}')
    return dict(top1=float(acc_1), top5=float(acc_5))
    def _predict_feature(self, x, **kwargs):
        if isinstance(x, str):
            return self._predict_feature((x,))
        elif isinstance(x, pd.DataFrame):
            assert 'image' in x.columns, "Expect column `image` for input images"
            df = self._predict_feature(tuple(x['image']))
            df = df.set_index(x.index)
            df['image'] = x['image']
            return df
        elif isinstance(x, (list, tuple)):
            assert isinstance(x[0], str), "expect image paths in list/tuple input"
            loader = create_loader(
                ImageListDataset(x),
                input_size=self._data_cfg.input_size,
                batch_size=self._train_cfg.batch_size,
                use_prefetcher=self._misc_cfg.prefetcher,
                interpolation=self._data_cfg.interpolation,
                mean=self._data_cfg.mean,
                std=self._data_cfg.std,
                num_workers=self._misc_cfg.num_workers,
                crop_pct=self._data_cfg.crop_pct
            )

            self.net.eval()

            results = []
            with torch.no_grad():
                for input, _ in loader:
                    input = input.to(self.ctx[0])
                    try:
                        features = self.net.forward_features(input)
                    except AttributeError:
                        features = self.net.module.forward_features(input)
                    for f in features:
                        f = f.cpu().numpy().flatten()
                        results.append({'image_feature': f})
            df = pd.DataFrame(results)
            df['image'] = x
            return df
        elif not isinstance(x, torch.Tensor):
            raise ValueError('Input is not supported: {}'.format(type(x)))
        with torch.no_grad():
            input = x.to(self.ctx[0])
            feature = self.net.forward_features(input)
            result = [{'image_feature': feature}]
        df = pd.DataFrame(result)
        return df
    def __init__(self,
                 save_path=None,
                 train_batch_size=256,
                 test_batch_size=512,
                 valid_size=None,
                 n_worker=32,
                 resize_scale=0.08,
                 distort_color=None,
                 image_size=224,
                 tf_preprocessing=False,
                 num_replicas=None,
                 rank=None,
                 use_prefetcher=False,
                 pin_memory=False,
                 fp16=False):

        warnings.filterwarnings('ignore')

        dataset = Dataset(os.path.join(save_path, "val"),
                          load_bytes=tf_preprocessing)
        dummy_model = create_model('efficientnet_b0')
        data_config = resolve_data_config({}, model=dummy_model)

        test_loader = create_loader(dataset,
                                    input_size=image_size,
                                    batch_size=test_batch_size,
                                    use_prefetcher=use_prefetcher,
                                    interpolation=data_config['interpolation'],
                                    mean=data_config['mean'],
                                    std=data_config['std'],
                                    num_workers=n_worker,
                                    crop_pct=data_config['crop_pct'],
                                    pin_memory=pin_memory,
                                    fp16=fp16,
                                    tf_preprocessing=None)

        self.test = test_loader
示例#8
0
文件: train.py 项目: joskid/sparseml
def main():
    setup_default_logging()
    args, args_text = _parse_args()

    args.prefetcher = not args.no_prefetcher
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1
    args.device = 'cuda:0'
    args.world_size = 1
    args.rank = 0  # global rank
    if args.distributed:
        args.device = 'cuda:%d' % args.local_rank
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl', init_method='env://')
        args.world_size = torch.distributed.get_world_size()
        args.rank = torch.distributed.get_rank()
        _logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
                     % (args.rank, args.world_size))
    else:
        _logger.info('Training with a single process on 1 GPUs.')
    assert args.rank >= 0

    # resolve AMP arguments based on PyTorch / Apex availability
    use_amp = None
    if args.amp:
        # for backwards compat, `--amp` arg tries apex before native amp
        if has_apex:
            args.apex_amp = True
        elif has_native_amp:
            args.native_amp = True
    if args.apex_amp and has_apex:
        use_amp = 'apex'
    elif args.native_amp and has_native_amp:
        use_amp = 'native'
    elif args.apex_amp or args.native_amp:
        _logger.warning("Neither APEX or native Torch AMP is available, using float32. "
                        "Install NVIDA apex or upgrade to PyTorch 1.6")

    torch.manual_seed(args.seed + args.rank)

    ####################################################################################
    # Start - SparseML optional load weights from SparseZoo
    ####################################################################################
    if args.initial_checkpoint == "zoo":
        # Load checkpoint from base weights associated with given SparseZoo recipe
        if args.sparseml_recipe.startswith("zoo:"):
            args.initial_checkpoint = Zoo.download_recipe_base_framework_files(
                args.sparseml_recipe,
                extensions=[".pth.tar", ".pth"]
            )[0]
        else:
            raise ValueError(
                "Attempting to load weights from SparseZoo recipe, but not given a "
                "SparseZoo recipe stub.  When initial-checkpoint is set to 'zoo'. "
                "sparseml-recipe must start with 'zoo:' and be a SparseZoo model "
                f"stub. sparseml-recipe was set to {args.sparseml_recipe}"
            )
    elif args.initial_checkpoint.startswith("zoo:"):
        # Load weights from a SparseZoo model stub
        zoo_model = Zoo.load_model_from_stub(args.initial_checkpoint)
        args.initial_checkpoint = zoo_model.download_framework_files(extensions=[".pth"])
    ####################################################################################
    # End - SparseML optional load weights from SparseZoo
    ####################################################################################

    model = create_model(
        args.model,
        pretrained=args.pretrained,
        num_classes=args.num_classes,
        drop_rate=args.drop,
        drop_connect_rate=args.drop_connect,  # DEPRECATED, use drop_path
        drop_path_rate=args.drop_path,
        drop_block_rate=args.drop_block,
        global_pool=args.gp,
        bn_tf=args.bn_tf,
        bn_momentum=args.bn_momentum,
        bn_eps=args.bn_eps,
        scriptable=args.torchscript,
        checkpoint_path=args.initial_checkpoint)
    if args.num_classes is None:
        assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
        args.num_classes = model.num_classes  # FIXME handle model default vs config num_classes more elegantly

    if args.local_rank == 0:
        _logger.info('Model %s created, param count: %d' %
                     (args.model, sum([m.numel() for m in model.parameters()])))

    data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)

    # setup augmentation batch splits for contrastive loss or split bn
    num_aug_splits = 0
    if args.aug_splits > 0:
        assert args.aug_splits > 1, 'A split of 1 makes no sense'
        num_aug_splits = args.aug_splits

    # enable split bn (separate bn stats per batch-portion)
    if args.split_bn:
        assert num_aug_splits > 1 or args.resplit
        model = convert_splitbn_model(model, max(num_aug_splits, 2))

    # move model to GPU, enable channels last layout if set
    model.cuda()
    if args.channels_last:
        model = model.to(memory_format=torch.channels_last)

    # setup synchronized BatchNorm for distributed training
    if args.distributed and args.sync_bn:
        assert not args.split_bn
        if has_apex and use_amp != 'native':
            # Apex SyncBN preferred unless native amp is activated
            model = convert_syncbn_model(model)
        else:
            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
        if args.local_rank == 0:
            _logger.info(
                'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
                'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')

    if args.torchscript:
        assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
        assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
        model = torch.jit.script(model)

    optimizer = create_optimizer(args, model)

    # setup automatic mixed-precision (AMP) loss scaling and op casting
    amp_autocast = suppress  # do nothing
    loss_scaler = None
    if use_amp == 'apex':
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
        loss_scaler = ApexScaler()
        if args.local_rank == 0:
            _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
    elif use_amp == 'native':
        amp_autocast = torch.cuda.amp.autocast
        loss_scaler = NativeScaler()
        if args.local_rank == 0:
            _logger.info('Using native Torch AMP. Training in mixed precision.')
    else:
        if args.local_rank == 0:
            _logger.info('AMP not enabled. Training in float32.')

    # optionally resume from a checkpoint
    resume_epoch = None
    if args.resume:
        resume_epoch = resume_checkpoint(
            model, args.resume,
            optimizer=None if args.no_resume_opt else optimizer,
            loss_scaler=None if args.no_resume_opt else loss_scaler,
            log_info=args.local_rank == 0)

    # setup exponential moving average of model weights, SWA could be used here too
    model_ema = None
    if args.model_ema:
        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
        model_ema = ModelEmaV2(
            model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None)
        if args.resume:
            load_checkpoint(model_ema.module, args.resume, use_ema=True)

    # setup distributed training
    if args.distributed:
        if has_apex and use_amp != 'native':
            # Apex DDP preferred unless native amp is activated
            if args.local_rank == 0:
                _logger.info("Using NVIDIA APEX DistributedDataParallel.")
            model = ApexDDP(model, delay_allreduce=True)
        else:
            if args.local_rank == 0:
                _logger.info("Using native Torch DistributedDataParallel.")
            model = NativeDDP(model, device_ids=[args.local_rank])  # can use device str in Torch >= 1.1
        # NOTE: EMA model does not need to be wrapped by DDP

    # setup learning rate schedule and starting epoch
    lr_scheduler, num_epochs = create_scheduler(args, optimizer)
    start_epoch = 0
    if args.start_epoch is not None:
        # a specified start_epoch will always override the resume epoch
        start_epoch = args.start_epoch
    elif resume_epoch is not None:
        start_epoch = resume_epoch
    if lr_scheduler is not None and start_epoch > 0:
        lr_scheduler.step(start_epoch)

    # create the train and eval datasets
    dataset_train = create_dataset(
        args.dataset, root=args.data_dir, split=args.train_split, is_training=True, batch_size=args.batch_size)
    dataset_eval = create_dataset(
        args.dataset, root=args.data_dir, split=args.val_split, is_training=False, batch_size=args.batch_size)

    # setup mixup / cutmix
    collate_fn = None
    mixup_fn = None
    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
    if mixup_active:
        mixup_args = dict(
            mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
            prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
            label_smoothing=args.smoothing, num_classes=args.num_classes)
        if args.prefetcher:
            assert not num_aug_splits  # collate conflict (need to support deinterleaving in collate mixup)
            collate_fn = FastCollateMixup(**mixup_args)
        else:
            mixup_fn = Mixup(**mixup_args)

    # wrap dataset in AugMix helper
    if num_aug_splits > 1:
        dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)

    # create data loaders w/ augmentation pipeiine
    train_interpolation = args.train_interpolation
    if args.no_aug or not train_interpolation:
        train_interpolation = data_config['interpolation']
    loader_train = create_loader(
        dataset_train,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        is_training=True,
        use_prefetcher=args.prefetcher,
        no_aug=args.no_aug,
        re_prob=args.reprob,
        re_mode=args.remode,
        re_count=args.recount,
        re_split=args.resplit,
        scale=args.scale,
        ratio=args.ratio,
        hflip=args.hflip,
        vflip=args.vflip,
        color_jitter=args.color_jitter,
        auto_augment=args.aa,
        num_aug_splits=num_aug_splits,
        interpolation=train_interpolation,
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        collate_fn=collate_fn,
        pin_memory=args.pin_mem,
        use_multi_epochs_loader=args.use_multi_epochs_loader
    )

    loader_eval = create_loader(
        dataset_eval,
        input_size=data_config['input_size'],
        batch_size=args.validation_batch_size_multiplier * args.batch_size,
        is_training=False,
        use_prefetcher=args.prefetcher,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        crop_pct=data_config['crop_pct'],
        pin_memory=args.pin_mem,
    )

    # setup loss function
    if args.jsd:
        assert num_aug_splits > 1  # JSD only valid with aug splits set
        train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda()
    elif mixup_active:
        # smoothing is handled with mixup target transform
        train_loss_fn = SoftTargetCrossEntropy().cuda()
    elif args.smoothing:
        train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda()
    else:
        train_loss_fn = nn.CrossEntropyLoss().cuda()
    validate_loss_fn = nn.CrossEntropyLoss().cuda()

    # setup checkpoint saver and eval metric tracking
    eval_metric = args.eval_metric
    best_metric = None
    best_epoch = None
    saver = None
    output_dir = ''
    if args.local_rank == 0:
        output_base = args.output if args.output else './output'
        exp_name = '-'.join([
            datetime.now().strftime("%Y%m%d-%H%M%S"),
            args.model,
            str(data_config['input_size'][-1])
        ])
        output_dir = get_outdir(output_base, 'train', exp_name)
        decreasing = True if eval_metric == 'loss' else False
        saver = CheckpointSaver(
            model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler,
            checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing, max_history=args.checkpoint_hist)
        with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
            f.write(args_text)

    ####################################################################################
    # Start SparseML Integration
    ####################################################################################
    sparseml_loggers = (
        [PythonLogger(), TensorBoardLogger(log_path=output_dir)]
        if output_dir
        else None
    )
    manager = ScheduledModifierManager.from_yaml(args.sparseml_recipe)
    optimizer = ScheduledOptimizer(
        optimizer,
        model,
        manager,
        steps_per_epoch=len(loader_train),
        loggers=sparseml_loggers
    )
    # override lr scheduler if recipe makes any LR updates
    if any("LearningRate" in str(modifier) for modifier in manager.modifiers):
        _logger.info("Disabling timm LR scheduler, managing LR using SparseML recipe")
        lr_scheduler = None
    if manager.max_epochs:
        _logger.info(
            f"Overriding max_epochs to {manager.max_epochs} from SparseML recipe"
        )
        num_epochs = manager.max_epochs or num_epochs
    ####################################################################################
    # End SparseML Integration
    ####################################################################################

    if args.local_rank == 0:
        _logger.info('Scheduled epochs: {}'.format(num_epochs))

    try:
        for epoch in range(start_epoch, num_epochs):
            if args.distributed and hasattr(loader_train.sampler, 'set_epoch'):
                loader_train.sampler.set_epoch(epoch)

            train_metrics = train_one_epoch(
                epoch, model, loader_train, optimizer, train_loss_fn, args,
                lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
                amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn)

            if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
                if args.local_rank == 0:
                    _logger.info("Distributing BatchNorm running means and vars")
                distribute_bn(model, args.world_size, args.dist_bn == 'reduce')

            eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast)

            if model_ema is not None and not args.model_ema_force_cpu:
                if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
                    distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
                ema_eval_metrics = validate(
                    model_ema.module, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)')
                eval_metrics = ema_eval_metrics

            if lr_scheduler is not None:
                # step LR for next epoch
                lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])

            update_summary(
                epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
                write_header=best_metric is None)

            if saver is not None:
                # save proper checkpoint with eval metric
                save_metric = eval_metrics[eval_metric]
                best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric)

        #################################################################################
        # Start SparseML ONNX Export
        #################################################################################
        if output_dir:
            _logger.info(
                f"training complete, exporting ONNX to {output_dir}/model.onnx"
            )
            exporter = ModuleExporter(model, output_dir)
            exporter.export_onnx(torch.randn((1, *data_config["input_size"])))
        #################################################################################
        # End SparseML ONNX Export
        #################################################################################

    except KeyboardInterrupt:
        pass
    if best_metric is not None:
        _logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
def main():
    setup_default_logging()
    args = parser.parse_args()
    # might as well try to do something useful...
    args.pretrained = args.pretrained or not args.checkpoint

    # create model
    model = create_model(args.model,
                         num_classes=args.num_classes,
                         in_chans=3,
                         pretrained=args.pretrained,
                         checkpoint_path=args.checkpoint)

    _logger.info('Model %s created, param count: %d' %
                 (args.model, sum([m.numel() for m in model.parameters()])))

    config = resolve_data_config(vars(args), model=model)
    model, test_time_pool = (
        model,
        False) if args.no_test_pool else apply_test_time_pool(model, config)

    if args.num_gpu > 1:
        model = torch.nn.DataParallel(model,
                                      device_ids=list(range(
                                          args.num_gpu))).cuda()
    else:
        model = model.cuda()

    loader = create_loader(
        ImageDataset(args.data),
        input_size=config['input_size'],
        batch_size=args.batch_size,
        use_prefetcher=True,
        interpolation=config['interpolation'],
        mean=config['mean'],
        std=config['std'],
        num_workers=args.workers,
        crop_pct=1.0 if test_time_pool else config['crop_pct'])

    model.eval()

    k = min(args.topk, args.num_classes)
    batch_time = AverageMeter()
    end = time.time()
    topk_ids = []
    with torch.no_grad():
        for batch_idx, (input, _) in enumerate(loader):
            input = input.cuda()
            labels = model(input)
            topk = labels.topk(k)[1]
            topk_ids.append(topk.cpu().numpy())

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if batch_idx % args.log_freq == 0:
                _logger.info(
                    'Predict: [{0}/{1}] Time {batch_time.val:.3f} ({batch_time.avg:.3f})'
                    .format(batch_idx, len(loader), batch_time=batch_time))

    topk_ids = np.concatenate(topk_ids, axis=0).squeeze()

    with open(os.path.join(args.output_dir, 'topk_ids.csv'), 'w') as out_file:
        filenames = loader.dataset.filenames(basename=True)
        for filename, label in zip(filenames, topk_ids):
            out_file.write('{0},{1},{2},{3},{4},{5}\n'.format(
                filename, label[0], label[1], label[2], label[3], label[4]))
示例#10
0
def main():
    args, cfg = parse_config_args('child net testing')

    # resolve logging
    output_dir = os.path.join(
        cfg.SAVE_PATH, "{}-{}".format(datetime.date.today().strftime('%m%d'),
                                      cfg.MODEL))

    if args.local_rank == 0:
        logger = get_logger(os.path.join(output_dir, 'test.log'))
        writer = SummaryWriter(os.path.join(output_dir, 'runs'))
    else:
        writer, logger = None, None

    # retrain model selection
    if cfg.NET.SELECTION == 470:
        arch_list = [[0], [3, 4, 3, 1], [3, 2, 3, 0], [3, 3, 3, 1],
                     [3, 3, 3, 3], [3, 3, 3, 3], [0]]
        cfg.DATASET.IMAGE_SIZE = 224
    elif cfg.NET.SELECTION == 42:
        arch_list = [[0], [3], [3, 1], [3, 1], [3, 3, 3], [3, 3], [0]]
        cfg.DATASET.IMAGE_SIZE = 96
    elif cfg.NET.SELECTION == 14:
        arch_list = [[0], [3], [3, 3], [3, 3], [3], [3], [0]]
        cfg.DATASET.IMAGE_SIZE = 64
    elif cfg.NET.SELECTION == 112:
        arch_list = [[0], [3], [3, 3], [3, 3], [3, 3, 3], [3, 3], [0]]
        cfg.DATASET.IMAGE_SIZE = 160
    elif cfg.NET.SELECTION == 285:
        arch_list = [[0], [3], [3, 3], [3, 1, 3], [3, 3, 3, 3], [3, 3, 3], [0]]
        cfg.DATASET.IMAGE_SIZE = 224
    elif cfg.NET.SELECTION == 600:
        arch_list = [[0], [3, 3, 2, 3, 3], [3, 2, 3, 2, 3], [3, 2, 3, 2, 3],
                     [3, 3, 2, 2, 3, 3], [3, 3, 2, 3, 3, 3], [0]]
        cfg.DATASET.IMAGE_SIZE = 224
    else:
        raise ValueError("Model Test Selection is not Supported!")

    # define childnet architecture from arch_list
    stem = ['ds_r1_k3_s1_e1_c16_se0.25', 'cn_r1_k1_s1_c320_se0.25']
    choice_block_pool = [
        'ir_r1_k3_s2_e4_c24_se0.25', 'ir_r1_k5_s2_e4_c40_se0.25',
        'ir_r1_k3_s2_e6_c80_se0.25', 'ir_r1_k3_s1_e6_c96_se0.25',
        'ir_r1_k3_s2_e6_c192_se0.25'
    ]
    arch_def = [[stem[0]]] + [[
        choice_block_pool[idx]
        for repeat_times in range(len(arch_list[idx + 1]))
    ] for idx in range(len(choice_block_pool))] + [[stem[1]]]

    # generate childnet
    model = gen_childnet(arch_list,
                         arch_def,
                         num_classes=cfg.DATASET.NUM_CLASSES,
                         drop_rate=cfg.NET.DROPOUT_RATE,
                         global_pool=cfg.NET.GP)

    if args.local_rank == 0:
        macs, params = get_model_flops_params(
            model,
            input_size=(1, 3, cfg.DATASET.IMAGE_SIZE, cfg.DATASET.IMAGE_SIZE))
        logger.info('[Model-{}] Flops: {} Params: {}'.format(
            cfg.NET.SELECTION, macs, params))

    # initialize distributed parameters
    torch.cuda.set_device(args.local_rank)
    torch.distributed.init_process_group(backend='nccl', init_method='env://')
    if args.local_rank == 0:
        logger.info("Training on Process {} with {} GPUs.".format(
            args.local_rank, cfg.NUM_GPU))

    # resume model from checkpoint
    assert cfg.AUTO_RESUME is True and os.path.exists(cfg.RESUME_PATH)
    _, __ = resume_checkpoint(model, cfg.RESUME_PATH)

    model = model.cuda()

    model_ema = None
    if cfg.NET.EMA.USE:
        # Important to create EMA model after cuda(), DP wrapper, and AMP but
        # before SyncBN and DDP wrapper
        model_ema = ModelEma(model,
                             decay=cfg.NET.EMA.DECAY,
                             device='cpu' if cfg.NET.EMA.FORCE_CPU else '',
                             resume=cfg.RESUME_PATH)

    # imagenet validation dataset
    eval_dir = os.path.join(cfg.DATA_DIR, 'val')
    if not os.path.exists(eval_dir) and args.local_rank == 0:
        logger.error(
            'Validation folder does not exist at: {}'.format(eval_dir))
        exit(1)

    dataset_eval = Dataset(eval_dir)
    loader_eval = create_loader(
        dataset_eval,
        input_size=(3, cfg.DATASET.IMAGE_SIZE, cfg.DATASET.IMAGE_SIZE),
        batch_size=cfg.DATASET.VAL_BATCH_MUL * cfg.DATASET.BATCH_SIZE,
        is_training=False,
        num_workers=cfg.WORKERS,
        distributed=True,
        pin_memory=cfg.DATASET.PIN_MEM,
        crop_pct=DEFAULT_CROP_PCT,
        mean=IMAGENET_DEFAULT_MEAN,
        std=IMAGENET_DEFAULT_STD)

    # only test accuracy of model-EMA
    validate_loss_fn = nn.CrossEntropyLoss().cuda()
    validate(0,
             model_ema.ema,
             loader_eval,
             validate_loss_fn,
             cfg,
             log_suffix='_EMA',
             logger=logger,
             writer=writer,
             local_rank=args.local_rank)
示例#11
0
def main():
    args, cfg = parse_config_args('child net training')

    # resolve logging
    output_dir = os.path.join(
        cfg.SAVE_PATH, "{}-{}".format(datetime.date.today().strftime('%m%d'),
                                      cfg.MODEL))

    if args.local_rank == 0:
        logger = get_logger(os.path.join(output_dir, 'retrain.log'))
        writer = SummaryWriter(os.path.join(output_dir, 'runs'))
    else:
        writer, logger = None, None

    # retrain model selection
    if cfg.NET.SELECTION == 481:
        arch_list = [[0], [3, 4, 3, 1], [3, 2, 3, 0], [3, 3, 3, 1, 1],
                     [3, 3, 3, 3], [3, 3, 3, 3], [0]]
        cfg.DATASET.IMAGE_SIZE = 224
    elif cfg.NET.SELECTION == 43:
        arch_list = [[0], [3], [3, 1], [3, 1], [3, 3, 3], [3, 3], [0]]
        cfg.DATASET.IMAGE_SIZE = 96
    elif cfg.NET.SELECTION == 14:
        arch_list = [[0], [3], [3, 3], [3, 3], [3], [3], [0]]
        cfg.DATASET.IMAGE_SIZE = 64
    elif cfg.NET.SELECTION == 114:
        arch_list = [[0], [3], [3, 3], [3, 3], [3, 3, 3], [3, 3], [0]]
        cfg.DATASET.IMAGE_SIZE = 160
    elif cfg.NET.SELECTION == 287:
        arch_list = [[0], [3], [3, 3], [3, 1, 3], [3, 3, 3, 3], [3, 3, 3], [0]]
        cfg.DATASET.IMAGE_SIZE = 224
    elif cfg.NET.SELECTION == 604:
        arch_list = [[0], [3, 3, 2, 3, 3], [3, 2, 3, 2, 3], [3, 2, 3, 2, 3],
                     [3, 3, 2, 2, 3, 3], [3, 3, 2, 3, 3, 3], [0]]
        cfg.DATASET.IMAGE_SIZE = 224
    else:
        raise ValueError("Model Retrain Selection is not Supported!")

    # define childnet architecture from arch_list
    stem = ['ds_r1_k3_s1_e1_c16_se0.25', 'cn_r1_k1_s1_c320_se0.25']
    choice_block_pool = [
        'ir_r1_k3_s2_e4_c24_se0.25', 'ir_r1_k5_s2_e4_c40_se0.25',
        'ir_r1_k3_s2_e6_c80_se0.25', 'ir_r1_k3_s1_e6_c96_se0.25',
        'ir_r1_k5_s2_e6_c192_se0.25'
    ]
    arch_def = [[stem[0]]] + [[
        choice_block_pool[idx]
        for repeat_times in range(len(arch_list[idx + 1]))
    ] for idx in range(len(choice_block_pool))] + [[stem[1]]]

    # generate childnet
    model = gen_childnet(arch_list,
                         arch_def,
                         num_classes=cfg.DATASET.NUM_CLASSES,
                         drop_rate=cfg.NET.DROPOUT_RATE,
                         global_pool=cfg.NET.GP)

    # initialize training parameters
    eval_metric = cfg.EVAL_METRICS
    best_metric, best_epoch, saver = None, None, None

    # initialize distributed parameters
    distributed = cfg.NUM_GPU > 1
    torch.cuda.set_device(args.local_rank)
    torch.distributed.init_process_group(backend='nccl', init_method='env://')
    if args.local_rank == 0:
        logger.info('Training on Process {} with {} GPUs.'.format(
            args.local_rank, cfg.NUM_GPU))

    # fix random seeds
    torch.manual_seed(cfg.SEED)
    torch.cuda.manual_seed_all(cfg.SEED)
    np.random.seed(cfg.SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # get parameters and FLOPs of model
    if args.local_rank == 0:
        macs, params = get_model_flops_params(
            model,
            input_size=(1, 3, cfg.DATASET.IMAGE_SIZE, cfg.DATASET.IMAGE_SIZE))
        logger.info('[Model-{}] Flops: {} Params: {}'.format(
            cfg.NET.SELECTION, macs, params))

    # create optimizer
    model = model.cuda()
    optimizer = create_optimizer(cfg, model)

    # optionally resume from a checkpoint
    resume_state, resume_epoch = {}, None
    if cfg.AUTO_RESUME:
        resume_state, resume_epoch = resume_checkpoint(model, cfg.RESUME_PATH)
        optimizer.load_state_dict(resume_state['optimizer'])
        del resume_state

    model_ema = None
    if cfg.NET.EMA.USE:
        model_ema = ModelEma(
            model,
            decay=cfg.NET.EMA.DECAY,
            device='cpu' if cfg.NET.EMA.FORCE_CPU else '',
            resume=cfg.RESUME_PATH if cfg.AUTO_RESUME else None)

    if distributed:
        if cfg.BATCHNORM.SYNC_BN:
            try:
                if HAS_APEX:
                    model = convert_syncbn_model(model)
                else:
                    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                        model)
                if args.local_rank == 0:
                    logger.info(
                        'Converted model to use Synchronized BatchNorm.')
            except Exception as e:
                if args.local_rank == 0:
                    logger.error(
                        'Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1 with exception {}'
                        .format(e))
        if HAS_APEX:
            model = DDP(model, delay_allreduce=True)
        else:
            if args.local_rank == 0:
                logger.info(
                    "Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP."
                )
            # can use device str in Torch >= 1.1
            model = DDP(model, device_ids=[args.local_rank])

    # imagenet train dataset
    train_dir = os.path.join(cfg.DATA_DIR, 'train')
    if not os.path.exists(train_dir) and args.local_rank == 0:
        logger.error('Training folder does not exist at: {}'.format(train_dir))
        exit(1)
    dataset_train = Dataset(train_dir)
    loader_train = create_loader(dataset_train,
                                 input_size=(3, cfg.DATASET.IMAGE_SIZE,
                                             cfg.DATASET.IMAGE_SIZE),
                                 batch_size=cfg.DATASET.BATCH_SIZE,
                                 is_training=True,
                                 color_jitter=cfg.AUGMENTATION.COLOR_JITTER,
                                 auto_augment=cfg.AUGMENTATION.AA,
                                 num_aug_splits=0,
                                 crop_pct=DEFAULT_CROP_PCT,
                                 mean=IMAGENET_DEFAULT_MEAN,
                                 std=IMAGENET_DEFAULT_STD,
                                 num_workers=cfg.WORKERS,
                                 distributed=distributed,
                                 collate_fn=None,
                                 pin_memory=cfg.DATASET.PIN_MEM,
                                 interpolation='random',
                                 re_mode=cfg.AUGMENTATION.RE_MODE,
                                 re_prob=cfg.AUGMENTATION.RE_PROB)

    # imagenet validation dataset
    eval_dir = os.path.join(cfg.DATA_DIR, 'val')
    if not os.path.exists(eval_dir) and args.local_rank == 0:
        logger.error(
            'Validation folder does not exist at: {}'.format(eval_dir))
        exit(1)
    dataset_eval = Dataset(eval_dir)
    loader_eval = create_loader(
        dataset_eval,
        input_size=(3, cfg.DATASET.IMAGE_SIZE, cfg.DATASET.IMAGE_SIZE),
        batch_size=cfg.DATASET.VAL_BATCH_MUL * cfg.DATASET.BATCH_SIZE,
        is_training=False,
        interpolation='bicubic',
        crop_pct=DEFAULT_CROP_PCT,
        mean=IMAGENET_DEFAULT_MEAN,
        std=IMAGENET_DEFAULT_STD,
        num_workers=cfg.WORKERS,
        distributed=distributed,
        pin_memory=cfg.DATASET.PIN_MEM)

    # whether to use label smoothing
    if cfg.AUGMENTATION.SMOOTHING > 0.:
        train_loss_fn = LabelSmoothingCrossEntropy(
            smoothing=cfg.AUGMENTATION.SMOOTHING).cuda()
        validate_loss_fn = nn.CrossEntropyLoss().cuda()
    else:
        train_loss_fn = nn.CrossEntropyLoss().cuda()
        validate_loss_fn = train_loss_fn

    # create learning rate scheduler
    lr_scheduler, num_epochs = create_scheduler(cfg, optimizer)
    start_epoch = resume_epoch if resume_epoch is not None else 0
    if start_epoch > 0:
        lr_scheduler.step(start_epoch)
    if args.local_rank == 0:
        logger.info('Scheduled epochs: {}'.format(num_epochs))

    try:
        best_record, best_ep = 0, 0
        for epoch in range(start_epoch, num_epochs):
            if distributed:
                loader_train.sampler.set_epoch(epoch)

            train_metrics = train_epoch(epoch,
                                        model,
                                        loader_train,
                                        optimizer,
                                        train_loss_fn,
                                        cfg,
                                        lr_scheduler=lr_scheduler,
                                        saver=saver,
                                        output_dir=output_dir,
                                        model_ema=model_ema,
                                        logger=logger,
                                        writer=writer,
                                        local_rank=args.local_rank)

            eval_metrics = validate(epoch,
                                    model,
                                    loader_eval,
                                    validate_loss_fn,
                                    cfg,
                                    logger=logger,
                                    writer=writer,
                                    local_rank=args.local_rank)

            if model_ema is not None and not cfg.NET.EMA.FORCE_CPU:
                ema_eval_metrics = validate(epoch,
                                            model_ema.ema,
                                            loader_eval,
                                            validate_loss_fn,
                                            cfg,
                                            log_suffix='_EMA',
                                            logger=logger,
                                            writer=writer,
                                            local_rank=args.local_rank)
                eval_metrics = ema_eval_metrics

            if lr_scheduler is not None:
                lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])

            update_summary(epoch,
                           train_metrics,
                           eval_metrics,
                           os.path.join(output_dir, 'summary.csv'),
                           write_header=best_metric is None)

            if saver is not None:
                # save proper checkpoint with eval metric
                save_metric = eval_metrics[eval_metric]
                best_metric, best_epoch = saver.save_checkpoint(
                    model,
                    optimizer,
                    cfg,
                    epoch=epoch,
                    model_ema=model_ema,
                    metric=save_metric)

            if best_record < eval_metrics[eval_metric]:
                best_record = eval_metrics[eval_metric]
                best_ep = epoch

            if args.local_rank == 0:
                logger.info('*** Best metric: {0} (epoch {1})'.format(
                    best_record, best_ep))

    except KeyboardInterrupt:
        pass

    if best_metric is not None:
        logger.info('*** Best metric: {0} (epoch {1})'.format(
            best_metric, best_epoch))
示例#12
0
def main():
    get_logger("./")
    args = parser.parse_args()
    args.prefetcher = not args.no_prefetcher
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1
        if args.distributed and args.num_gpu > 1:
            logging.warning(
                'Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.'
            )
            args.num_gpu = 1

    args.device = 'cuda:0'
    args.world_size = 1
    args.rank = 0  # global rank
    if args.distributed:
        args.num_gpu = 1
        args.device = 'cuda:%d' % args.local_rank
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()
        args.rank = torch.distributed.get_rank()
    assert args.rank >= 0

    if args.distributed:
        logging.info(
            'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
            % (args.rank, args.world_size))
    else:
        logging.info('Training with a single process on %d GPUs.' %
                     args.num_gpu)
    logging.info("Exponential : {}".format(args.model_ema_decay))
    logging.info("Color Jitter : {}".format(args.color_jitter))
    logging.info("Model EMA Decay : {}".format(args.model_ema_decay))

    torch.manual_seed(args.seed + args.rank)
    model = eval(args.model)(config_path=args.config_path,
                             target_flops=args.target_flops,
                             num_classes=args.num_classes,
                             bn_momentum=args.bn_momentum,
                             activation=args.activation,
                             se=args.se)

    if os.path.exists(args.initial_checkpoint):
        load_checkpoint(model, args.initial_checkpoint)

    if args.local_rank == 0:
        logging.info('Model %s created, param count: %d' %
                     (args.model, sum([m.numel()
                                       for m in model.parameters()])))

    data_config = resolve_data_config(vars(args),
                                      model=model,
                                      verbose=args.local_rank == 0)

    # optionally resume from a checkpoint
    optimizer_state = None
    resume_epoch = None
    if args.resume:
        optimizer_state, resume_epoch = resume_checkpoint(model, args.resume)

    if args.num_gpu > 1:
        if args.amp:
            logging.warning(
                'AMP does not work well with nn.DataParallel, disabling. Use distributed mode for multi-GPU AMP.'
            )
            args.amp = False
        model = nn.DataParallel(model,
                                device_ids=list(range(args.num_gpu))).cuda()
    else:
        model.cuda()

    logging.info(args.weight_decay)
    optimizer = create_optimizer(args, model)
    if optimizer_state is not None:
        optimizer.load_state_dict(optimizer_state["optimizer"])

    use_amp = False
    if has_apex and args.amp:
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
        use_amp = True
    if args.local_rank == 0:
        logging.info('NVIDIA APEX {}. AMP {}.'.format(
            'installed' if has_apex else 'not installed',
            'on' if use_amp else 'off'))

    model_ema = None
    if args.model_ema:
        # Important to create EMA model after cuda(), DP wrapper, and AMP but
        # before SyncBN and DDP wrapper
        model_ema = ModelEma(model,
                             decay=args.model_ema_decay,
                             device='cpu' if args.model_ema_force_cpu else '',
                             resume=args.resume)

    if args.distributed:
        if args.sync_bn:
            try:
                if has_apex:
                    model = convert_syncbn_model(model)
                else:
                    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                        model)
                if args.local_rank == 0:
                    logging.info(
                        'Converted model to use Synchronized BatchNorm.')
            except Exception as e:
                logging.error(
                    'Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1'
                )
        if has_apex:
            model = DDP(model, delay_allreduce=True)
        else:
            if args.local_rank == 0:
                logging.info(
                    "Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP."
                )
            # can use device str in Torch >= 1.1
            model = DDP(model, device_ids=[args.local_rank])
        # NOTE: EMA model does not need to be wrapped by DDP

    lr_scheduler, num_epochs = create_scheduler(args, optimizer)
    start_epoch = 0
    if args.start_epoch is not None:
        # a specified start_epoch will always override the resume epoch
        start_epoch = args.start_epoch
    elif resume_epoch is not None:
        start_epoch = resume_epoch
    if start_epoch > 0:
        lr_scheduler.step(start_epoch)

    if args.local_rank == 0:
        logging.info('Scheduled epochs: {}'.format(num_epochs))

    if args.lmdb:
        train_dir = os.path.join(args.data, 'train_lmdb', 'train.lmdb')
        dataset_train = ImageFolderLMDB(train_dir, None, None)
    else:
        train_dir = os.path.join(args.data, 'train')
        dataset_train = Dataset(train_dir)

    collate_fn = None
    if args.prefetcher and args.mixup > 0:
        collate_fn = FastCollateMixup(args.mixup, args.smoothing,
                                      args.num_classes)

    loader_train = create_loader(
        dataset_train,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        is_training=True,
        use_prefetcher=args.prefetcher,
        rand_erase_prob=args.reprob,
        rand_erase_mode=args.remode,
        color_jitter=args.color_jitter,
        interpolation='random',
        # FIXME cleanly resolve this? data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        collate_fn=collate_fn,
    )

    if args.lmdb:
        eval_dir = os.path.join(args.data, 'test_lmdb', 'test.lmdb')
        dataset_eval = ImageFolderLMDB(eval_dir, None, None)
    else:
        eval_dir = os.path.join(args.data, 'val')
        dataset_eval = Dataset(eval_dir)

    loader_eval = create_loader(
        dataset_eval,
        input_size=data_config['input_size'],
        batch_size=4 * args.batch_size,
        is_training=False,
        use_prefetcher=args.prefetcher,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
    )

    if args.mixup > 0.:
        # smoothing is handled with mixup label transform
        train_loss_fn = SoftTargetCrossEntropy().cuda()
        validate_loss_fn = nn.CrossEntropyLoss().cuda()
    elif args.smoothing:
        train_loss_fn = LabelSmoothingCrossEntropy(
            smoothing=args.smoothing).cuda()
        validate_loss_fn = nn.CrossEntropyLoss().cuda()
    else:
        train_loss_fn = nn.CrossEntropyLoss().cuda()
        validate_loss_fn = train_loss_fn

    eval_metric = args.eval_metric
    best_metric = None
    best_epoch = None
    saver = None
    output_dir = ''
    if args.local_rank == 0:
        output_base = args.output if args.output else './output'
        exp_name = '-'.join([
            datetime.now().strftime("%Y%m%d-%H%M%S"), args.model,
            str(data_config['input_size'][-1])
        ])
        output_dir = get_outdir(output_base, 'train', exp_name)
        decreasing = True if eval_metric == 'loss' else False
        saver = CheckpointSaver(checkpoint_dir=output_dir,
                                decreasing=decreasing)

    try:
        for epoch in range(start_epoch, num_epochs):
            if args.distributed:
                loader_train.sampler.set_epoch(epoch)

            train_metrics = train_epoch(epoch,
                                        model,
                                        loader_train,
                                        optimizer,
                                        train_loss_fn,
                                        args,
                                        lr_scheduler=lr_scheduler,
                                        saver=saver,
                                        output_dir=output_dir,
                                        use_amp=use_amp,
                                        model_ema=model_ema)

            eval_metrics = validate(model, loader_eval, validate_loss_fn, args)

            if model_ema is not None and not args.model_ema_force_cpu:
                ema_eval_metrics = validate(model_ema.ema,
                                            loader_eval,
                                            validate_loss_fn,
                                            args,
                                            log_suffix=' (EMA)')
                eval_metrics = ema_eval_metrics

            if lr_scheduler is not None:
                # step LR for next epoch
                lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])

            update_summary(epoch,
                           train_metrics,
                           eval_metrics,
                           os.path.join(output_dir, 'summary.csv'),
                           write_header=best_metric is None)

            if saver is not None:
                # save proper checkpoint with eval metric
                save_metric = eval_metrics[eval_metric]
                best_metric, best_epoch = saver.save_checkpoint(
                    model,
                    optimizer,
                    args,
                    epoch=epoch,
                    model_ema=model_ema,
                    metric=save_metric)

    except KeyboardInterrupt:
        pass
    if best_metric is not None:
        logging.info('*** Best metric: {0} (epoch {1})'.format(
            best_metric, best_epoch))
def train_imagenet_dq():
    setup_default_logging()
    args = parser.parse_args()
    args.prefetcher = not args.no_prefetcher
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1
        if args.distributed and args.num_gpu > 1:
            logging.warning('Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.')
            args.num_gpu = 1

    args.device = xm.xla_device()
    args.world_size = 1
    args.rank = 0  # global rank
    if args.distributed:
        args.num_gpu = 1
        args.device = 'cuda:%d' % args.local_rank
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl', init_method='env://')
        args.world_size = torch.distributed.get_world_size()
        args.rank = torch.distributed.get_rank()
    assert args.rank >= 0

    if args.distributed:
        logging.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
                     % (args.rank, args.world_size))
    else:
        logging.info('Training with a single process on %d GPUs.' % args.num_gpu)

    torch.manual_seed(args.seed + args.rank)
    device = xm.xla_device()
    model = create_model(
        args.model,
        pretrained=args.pretrained,
        num_classes=args.num_classes,
        drop_rate=args.drop,
        global_pool=args.gp,
        bn_tf=args.bn_tf,
        bn_momentum=args.bn_momentum,
        bn_eps=args.bn_eps,
        drop_connect_rate=0.2,
        checkpoint_path=args.initial_checkpoint,
        args = args).to(device)
    flops, params = get_model_complexity_info(model, (3, 224, 224), as_strings=True, print_per_layer_stat=args.display_info)
    print('Flops:  ' + flops)
    print('Params: ' + params)
    if args.KD_train:
        teacher_model = create_model(
            "efficientnet_b7_dq",
            pretrained=True,
            num_classes=args.num_classes,
            drop_rate=args.drop,
            global_pool=args.gp,
            bn_tf=args.bn_tf,
            bn_momentum=args.bn_momentum,
            bn_eps=args.bn_eps,
            drop_connect_rate=0.2,
            checkpoint_path=args.initial_checkpoint,
            args = args)
        


        flops_teacher, params_teacher = get_model_complexity_info(teacher_model, (3, 224, 224), as_strings=True, print_per_layer_stat=False)
        print("Using KD training...")
        print("FLOPs of teacher model: ", flops_teacher)
        print("Params of teacher model: ", params_teacher)

    if args.local_rank == 0:
        logging.info('Model %s created, param count: %d' %
                     (args.model, sum([m.numel() for m in model.parameters()])))

    data_config = resolve_data_config(model, args, verbose=args.local_rank == 0)

    # optionally resume from a checkpoint
    start_epoch = 0
    optimizer_state = None
    if args.resume:
        optimizer_state, start_epoch = resume_checkpoint(model, args.resume, args.start_epoch)
        # import pdb;pdb.set_trace()
    torch.manual_seed(42)
    if args.num_gpu > 1:
        if args.amp:
            logging.warning(
                'AMP does not work well with nn.DataParallel, disabling. Use distributed mode for multi-GPU AMP.')
            args.amp = False
        # device = xm.xla_device()
        # devices = (
        #     xm.get_xla_supported_devices(
        #     max_devices=num_cores) if num_cores != 0 else [])
        # model = nn.DataParallel(model, device_ids=devices).cuda()
        # model = model.to(device)
        if args.KD_train:
            teacher_model = nn.DataParallel(teacher_model, device_ids=list(range(args.num_gpu))).cuda()
    else:
        # device = xm.xla_device()
        # model = model.to(device)
        if args.KD_train:
            teacher_model.cuda()

    optimizer = create_optimizer(args, model)
    if optimizer_state is not None:
        optimizer.load_state_dict(optimizer_state)

    use_amp = False
    if has_apex and args.amp:
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
        use_amp = True
    if args.local_rank == 0:
        logging.info('NVIDIA APEX {}. AMP {}.'.format(
            'installed' if has_apex else 'not installed', 'on' if use_amp else 'off'))

    model_ema = None
    if args.model_ema:
        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
        # import pdb; pdb.set_trace()
        model_e = create_model(
            args.model,
            pretrained=args.pretrained,
            num_classes=args.num_classes,
            drop_rate=args.drop,
            global_pool=args.gp,
            bn_tf=args.bn_tf,
            bn_momentum=args.bn_momentum,
            bn_eps=args.bn_eps,
            drop_connect_rate=0.2,
            checkpoint_path=args.initial_checkpoint,
            args = args).to(device)
        model_ema = ModelEma(
            model_e,
            decay=args.model_ema_decay,
            device='cpu' if args.model_ema_force_cpu else '',
            resume=args.resume)

    if args.distributed:
        if args.sync_bn:
            try:
                if has_apex:
                    model = convert_syncbn_model(model)
                else:
                    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
                if args.local_rank == 0:
                    logging.info('Converted model to use Synchronized BatchNorm.')
            except Exception as e:
                logging.error('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1')
        if has_apex:
            model = DDP(model, delay_allreduce=True)
        else:
            if args.local_rank == 0:
                logging.info("Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP.")
            model = DDP(model, device_ids=[args.local_rank])  # can use device str in Torch >= 1.1
        # NOTE: EMA model does not need to be wrapped by DDP

    lr_scheduler, num_epochs = create_scheduler(args, optimizer)
    if start_epoch > 0:
        lr_scheduler.step(start_epoch)
    if args.local_rank == 0:
        logging.info('Scheduled epochs: {}'.format(num_epochs))

    train_dir = os.path.join(args.data, 'train')
    if not os.path.exists(train_dir):
        logging.error('Training folder does not exist at: {}'.format(train_dir))
        exit(1)
    dataset_train = Dataset(train_dir)

    collate_fn = None
    if args.prefetcher and args.mixup > 0:
        collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes)

    if args.auto_augment:
        print('using auto data augumentation...')
    loader_train = create_loader(
        dataset_train,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        is_training=True,
        use_prefetcher=args.prefetcher,
        rand_erase_prob=args.reprob,
        rand_erase_mode=args.remode,
        interpolation='bicubic',  # FIXME cleanly resolve this? data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        collate_fn=collate_fn,
        use_auto_aug=args.auto_augment,
        use_mixcut=args.mixcut,
    )

    eval_dir = os.path.join(args.data, 'val')
    if not os.path.isdir(eval_dir):
        logging.error('Validation folder does not exist at: {}'.format(eval_dir))
        exit(1)
    dataset_eval = Dataset(eval_dir)

    loader_eval = create_loader(
        dataset_eval,
        input_size=data_config['input_size'],
        batch_size = args.batch_size,
        is_training=False,
        use_prefetcher=args.prefetcher,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
    )

    if args.mixup > 0.:
        # smoothing is handled with mixup label transform
        train_loss_fn = SoftTargetCrossEntropy()
        validate_loss_fn = nn.CrossEntropyLoss()
    elif args.smoothing:
        train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
        validate_loss_fn = nn.CrossEntropyLoss()
    else:
        train_loss_fn = nn.CrossEntropyLoss()
        validate_loss_fn = train_loss_fn
    if args.KD_train:
        train_loss_fn = nn.KLDivLoss(reduction='batchmean')

    eval_metric = args.eval_metric
    best_metric = None
    best_epoch = None
    saver = None
    output_dir = ''
    if args.local_rank == 0:
        output_base = args.output if args.output else './output'
        exp_name = '-'.join([
            datetime.now().strftime("%Y%m%d-%H%M%S"),
            args.model,
            str(data_config['input_size'][-1])
        ])
        output_dir = get_outdir(output_base, 'train', exp_name)
        decreasing = True if eval_metric == 'loss' else False
        saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing)
    def train_epoch(
            epoch, model, loader, optimizer, loss_fn, args,
            lr_scheduler=None, saver=None, output_dir='', use_amp=False, model_ema=None, teacher_model = None, loader_len=0):

        if args.prefetcher and args.mixup > 0 and loader.mixup_enabled:
            if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
                loader.mixup_enabled = False

        batch_time_m = AverageMeter()
        data_time_m = AverageMeter()
        losses_m = AverageMeter()

        model.train()
        if args.KD_train:
            teacher_model.eval()

        end = time.time()
        last_idx = loader_len - 1
        num_updates = epoch * loader_len
        for batch_idx, (input, target) in loader:
            last_batch = batch_idx == last_idx
            data_time_m.update(time.time() - end)
            if not args.prefetcher:
                # input = input.cuda()
                # target = target.cuda()
                if args.mixup > 0.:
                    lam = 1.
                    if not args.mixup_off_epoch or epoch < args.mixup_off_epoch:
                        lam = np.random.beta(args.mixup, args.mixup)
                    input.mul_(lam).add_(1 - lam, input.flip(0))
                    target = mixup_target(target, args.num_classes, lam, args.smoothing)

            r = np.random.rand(1)
            if args.beta > 0 and r < args.cutmix_prob:
                # generate mixed sample
                lam = np.random.beta(args.beta, args.beta)
                rand_index = torch.randperm(input.size()[0])
                target_a = target
                target_b = target[rand_index]
                bbx1, bby1, bbx2, bby2 = rand_bbox(input.size(), lam)
                input[:, :, bbx1:bbx2, bby1:bby2] = input[rand_index, :, bbx1:bbx2, bby1:bby2]
                # adjust lambda to exactly match pixel ratio
                lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (input.size()[-1] * input.size()[-2]))
                # compute output
                input_var = torch.autograd.Variable(input, requires_grad=True)
                target_a_var = torch.autograd.Variable(target_a)
                target_b_var = torch.autograd.Variable(target_b)
                output = model(input_var)
                loss = loss_fn(output, target_a_var) * lam + loss_fn(output, target_b_var) * (1. - lam)
            else:
                # NOTE KD Train is exclusive with mixcut, FIX it later
                output = model(input)
                if args.KD_train:
                    # teacher_model.cuda()
                    teacher_outputs_tmp = []
                    assert(input.shape[0]%args.teacher_step == 0)
                    step_size = int(input.shape[0]//args.teacher_step)
                    with torch.no_grad():
                        for k in range(0,int(input.shape[0]),step_size):
                            input_tmp = input[k:k+step_size,:,:,:]
                            teacher_outputs_tmp.append(teacher_model(input_tmp))
                            # torch.cuda.empty_cache()
                    # import pdb; pdb.set_trace()
                    teacher_outputs = torch.cat(teacher_outputs_tmp)
                    alpha = args.KD_alpha
                    T = args.KD_temperature
                    loss = loss_fn(F.log_softmax(output/T, dim=1),
                                    F.softmax(teacher_outputs/T, dim=1)) * (alpha * T * T) + \
                    F.cross_entropy(output, target) * (1. - alpha)
                else:
                    loss = loss_fn(output, target)
            if not args.distributed:
                losses_m.update(loss.item(), input.size(0))

            optimizer.zero_grad()
            if use_amp:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            #optimizer.step()
            xm.optimizer_step(optimizer)

            # torch.cuda.synchronize()
            if model_ema is not None:
                model_ema.update(model)
            num_updates += 1

            batch_time_m.update(time.time() - end)
            if last_batch or batch_idx % args.log_interval == 0:
                lrl = [param_group['lr'] for param_group in optimizer.param_groups]
                lr = sum(lrl) / len(lrl)

                if args.distributed:
                    reduced_loss = reduce_tensor(loss.data, args.world_size)
                    losses_m.update(reduced_loss.item(), input.size(0))

                if args.local_rank == 0:
                    logging.info(
                        'Train: {} [{:>4d}/{} ({:>3.0f}%)]  '
                        'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f})  '
                        'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s  '
                        '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                        'LR: {lr:.3e}  '
                        'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
                            epoch,
                            batch_idx, loader_len,
                            100. * batch_idx / last_idx,
                            loss=losses_m,
                            batch_time=batch_time_m,
                            rate=input.size(0) * args.world_size / batch_time_m.val,
                            rate_avg=input.size(0) * args.world_size / batch_time_m.avg,
                            lr=lr,
                            data_time=data_time_m))

                    if args.save_images and output_dir:
                        torchvision.utils.save_image(
                            input,
                            os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
                            padding=0,
                            normalize=True)

            if saver is not None and args.recovery_interval and (
                    last_batch or (batch_idx + 1) % args.recovery_interval == 0):
                save_epoch = epoch + 1 if last_batch else epoch
                saver.save_recovery(
                    model, optimizer, args, save_epoch, model_ema=model_ema, batch_idx=batch_idx)

            if lr_scheduler is not None:
                lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)

            end = time.time()

        return OrderedDict([('loss', losses_m.avg)])


    def validate(model, loader, loss_fn, args, log_suffix='',loader_len=0):
        batch_time_m = AverageMeter()
        losses_m = AverageMeter()
        prec1_m = AverageMeter()
        prec5_m = AverageMeter()

        model.eval()

        end = time.time()
        last_idx = loader_len - 1
        with torch.no_grad():
            for batch_idx, (input, target) in loader:
                last_batch = batch_idx == last_idx
                # if not args.prefetcher:
                #     input = input.cuda()
                #     target = target.cuda()

                output = model(input)
                if isinstance(output, (tuple, list)):
                    output = output[0]

                # augmentation reduction
                reduce_factor = args.tta
                if reduce_factor > 1:
                    output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)
                    target = target[0:target.size(0):reduce_factor]

                loss = loss_fn(output, target)
                prec1, prec5 = accuracy(output, target, topk=(1, 5))

                if args.distributed:
                    reduced_loss = reduce_tensor(loss.data, args.world_size)
                    prec1 = reduce_tensor(prec1, args.world_size)
                    prec5 = reduce_tensor(prec5, args.world_size)
                else:
                    reduced_loss = loss.data

                # torch.cuda.synchronize()

                losses_m.update(reduced_loss.item(), input.size(0))
                prec1_m.update(prec1.item(), output.size(0))
                prec5_m.update(prec5.item(), output.size(0))

                batch_time_m.update(time.time() - end)
                end = time.time()
                if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0):
                    log_name = 'Test' + log_suffix
                    logging.info(
                        '{0}: [{1:>4d}/{2}]  '
                        'Time: {batch_time.val:.3f} ({batch_time.avg:.3f})  '
                        'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                        'Prec@1: {top1.val:>7.4f} ({top1.avg:>7.4f})  '
                        'Prec@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(
                            log_name, batch_idx, last_idx,
                            batch_time=batch_time_m, loss=losses_m,
                            top1=prec1_m, top5=prec5_m))

        metrics = OrderedDict([('loss', losses_m.avg), ('prec1', prec1_m.avg), ('prec5', prec5_m.avg)])

        return metrics
    try:
        # import pdb;pdb.set_trace()
        for epoch in range(start_epoch, num_epochs):
            loader_len=len(loader_train)
            if args.distributed:
                loader_train.sampler.set_epoch(epoch)
            # import pdb; pdb.set_trace()
            if args.KD_train:
                train_metrics = train_epoch(
                    epoch, model, loader_train, optimizer, train_loss_fn, args,
                    lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
                    use_amp=use_amp, model_ema=model_ema, teacher_model = teacher_model)
            else:
                para_loader = dp.ParallelLoader(loader_train, [device])
                train_metrics = train_epoch(
                    epoch, model, para_loader.per_device_loader(device), optimizer, train_loss_fn, args,
                    lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
                    use_amp=use_amp, model_ema=model_ema, loader_len=loader_len)

            # def __init__(self, model, bits_activations=8, bits_parameters=8, bits_accum=32,
            #                 overrides=None, mode=LinearQuantMode.SYMMETRIC, clip_acts=ClipMode.NONE,
            #                 per_channel_wts=False, model_activation_stats=None, fp16=False, clip_n_stds=None,
            #                 scale_approx_mult_bits=None):
            # import distiller
            # import pdb; pdb.set_trace()
            # quantizer = quantization.PostTrainLinearQuantizer.from_args(model, args)
            # quantizer.prepare_model(distiller.get_dummy_input(input_shape=model.input_shape))
            # quantizer = distiller.quantization.PostTrainLinearQuantizer(model, bits_activations=8, bits_parameters=8)
            # quantizer.prepare_model()

            # distiller.utils.assign_layer_fq_names(model)
            # # msglogger.info("Generating quantization calibration stats based on {0} users".format(args.qe_calibration))
            # collector = distiller.data_loggers.QuantCalibrationStatsCollector(model)
            # with collector_context(collector):
            #     eval_metrics = validate(model, loader_eval, validate_loss_fn, args)
            #     # Here call your model evaluation function, making sure to execute only
            #     # the portion of the dataset specified by the qe_calibration argument
            # yaml_path = './dir/quantization_stats.yaml'
            # collector.save(yaml_path)
            loader_len_val = len(loader_eval)
            para_loader = dp.ParallelLoader(loader_eval, [device])
            eval_metrics = validate(model, para_loader.per_device_loader(device), validate_loss_fn, args, loader_len=loader_len_val)

            if model_ema is not None and not args.model_ema_force_cpu:
                ema_eval_metrics = validate(model_ema.ema, loader_eval, validate_loss_fn, args, log_suffix=' (EMA)')
                eval_metrics = ema_eval_metrics

            if lr_scheduler is not None:
                lr_scheduler.step(epoch, eval_metrics[eval_metric])

            update_summary(
                epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
                write_header=best_metric is None)

            if saver is not None:
                # save proper checkpoint with eval metric
                save_metric = eval_metrics[eval_metric]
                best_metric, best_epoch = saver.save_checkpoint(
                    model, optimizer, args,
                    epoch=epoch + 1,
                    model_ema=model_ema,
                    metric=save_metric)

    except KeyboardInterrupt:
        pass
    if best_metric is not None:
        logging.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
def validate(args):
    # might as well try to validate something
    args.pretrained = args.pretrained or not args.checkpoint
    args.prefetcher = not args.no_prefetcher
#    amp_autocast = suppress  # do nothing
#   if args.amp:
#        if has_native_amp:
#            args.native_amp = True
#        elif has_apex:
#            args.apex_amp = True
#        else:
#            _logger.warning("Neither APEX or Native Torch AMP is available.")
#    assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set."
#    if args.native_amp:
#        amp_autocast = torch.cuda.amp.autocast
#        _logger.info('Validating in mixed precision with native PyTorch AMP.')
#   elif args.apex_amp:
#        _logger.info('Validating in mixed precision with NVIDIA APEX AMP.')
#    else:
#        _logger.info('Validating in float32. AMP not enabled.')

    if args.legacy_jit:
        set_jit_legacy()

    # create model
    model = create_model(
        args.model,
        pretrained=args.pretrained,
        num_classes=args.num_classes,
        in_chans=3,
        global_pool=args.gp,
        scriptable=args.torchscript)
    if args.num_classes is None:
        assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
        args.num_classes = model.num_classes

    if args.checkpoint:
        load_checkpoint(model, args.checkpoint, args.use_ema)

    param_count = sum([m.numel() for m in model.parameters()])
    _logger.info('Model %s created, param count: %d' % (args.model, param_count))

    data_config = resolve_data_config(vars(args), model=model, use_test_size=True)
    test_time_pool = False
    if not args.no_test_pool:
        model, test_time_pool = apply_test_time_pool(model, data_config, use_test_size=True)

    if args.torchscript:
        torch.jit.optimized_execution(True)
        model = torch.jit.script(model)

#    model = model.cuda()
#    if args.apex_amp:
#        model = amp.initialize(model, opt_level='O1')

    if args.channels_last:
        model = model.to(memory_format=torch.channels_last)

#    if args.num_gpu > 1:
#        model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu)))

#   criterion = nn.CrossEntropyLoss().cuda()
    criterion = nn.CrossEntropyLoss()
    dataset = create_dataset(
        root=args.data, name=args.dataset, split=args.split,
        load_bytes=args.tf_preprocessing, class_map=args.class_map)

    # added for post quantization calibration

    calib_dataset = create_dataset(
        root=args.data, name=args.dataset, split=args.split,
        load_bytes=args.tf_preprocessing, class_map=args.class_map)
        

    if args.valid_labels:
        with open(args.valid_labels, 'r') as f:
            valid_labels = {int(line.rstrip()) for line in f}
            valid_labels = [i in valid_labels for i in range(args.num_classes)]
    else:
        valid_labels = None

    if args.real_labels:
        real_labels = RealLabelsImagenet(dataset.filenames(basename=True), real_json=args.real_labels)
    else:
        real_labels = None

    crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
    loader = create_loader(
        dataset,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        use_prefetcher=args.prefetcher,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        crop_pct=crop_pct,
        pin_memory=args.pin_mem,
        tf_preprocessing=args.tf_preprocessing)

    #Also create loader for calibration dataset
    calib_loader = create_loader(
        calib_dataset,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        use_prefetcher=args.prefetcher,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        crop_pct=crop_pct,
        pin_memory=args.pin_mem,
        tf_preprocessing=args.tf_preprocessing)

    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    print('Start calibration of quantization observers before post-quantization')
    model_to_quantize = copy.deepcopy(model)
    model_to_quantize.eval()

    #post training static quantization
    if args.quant_option == 'static':
        qconfig_dict = {"": torch.quantization.default_static_qconfig} 
        model_to_quantize = copy.deepcopy(model_fp)
        qconfig_dict = {"": torch.quantization.get_default_qconfig('qnnpack')}
        model_to_quantize.eval()
        # prepare
        model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_dict)
        # calibrate 
        with torch.no_grad():
            # warmup, reduce variability of first batch time, especially for comparing torchscript vs non
            input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])) 
            if args.channels_last:
                input = input.contiguous(memory_format=torch.channels_last)
            model(input)
            end = time.time()
            for batch_idx, (input, target) in enumerate(loader):

                if args.channels_last:
                    input = input.contiguous(memory_format=torch.channels_last)

                if valid_labels is not None:
                    output = output[:, valid_labels]
                loss = criterion(output, target)

                if real_labels is not None:
                    real_labels.add_result(output)

                # measure accuracy and record loss
                acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5))
                losses.update(loss.item(), input.size(0))
                top1.update(acc1.item(), input.size(0))
                top5.update(acc5.item(), input.size(0))

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                if batch_idx % args.log_freq == 0:
                    _logger.info(
                        'Test: [{0:>4d}/{1}]  '
                        'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                        'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                        'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f})  '
                        'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
                            batch_idx, len(loader), batch_time=batch_time,
                            rate_avg=input.size(0) / batch_time.avg,
                            loss=losses, top1=top1, top5=top5))        
        # quantize
        model_quantized = quantize_fx.convert_fx(model_prepared)           
    #post training dynamic/weight only quantization    
    elif args.quant_option == 'dynamic':    
        qconfig_dict = {"": torch.quantization.default_dynamic_qconfig}
        # prepare
        model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_dict)
        # no calibration needed when we only have dynamici/weight_only quantization
        # quantize
        model_quantized = quantize_fx.convert_fx(model_prepared)       
    else:
        _logger.warning("Invalid quantization option. Set option to default(static)")
    #
    # fusion
    #
    model_to_quantize = copy.deepcopy(model_fp)
    model_fused = quantize_fx.fuse_fx(model_to_quantize)   

    model = model_fused

    with torch.no_grad():
        # warmup, reduce variability of first batch time, especially for comparing torchscript vs non
#        input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).cuda()
        input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])) 
        if args.channels_last:
            input = input.contiguous(memory_format=torch.channels_last)
        model(input)
        end = time.time()
        for batch_idx, (input, target) in enumerate(loader):
 #           if args.no_prefetcher:
 #               target = target.cuda()
 #               input = input.cuda()
            if args.channels_last:
                input = input.contiguous(memory_format=torch.channels_last)

            # compute output
    #        with amp_autocast():
    #            output = model(input)

            if valid_labels is not None:
                output = output[:, valid_labels]
            loss = criterion(output, target)

            if real_labels is not None:
                real_labels.add_result(output)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5))
            losses.update(loss.item(), input.size(0))
            top1.update(acc1.item(), input.size(0))
            top5.update(acc5.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if batch_idx % args.log_freq == 0:
                _logger.info(
                    'Test: [{0:>4d}/{1}]  '
                    'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                    'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                    'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f})  '
                    'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
                        batch_idx, len(loader), batch_time=batch_time,
                        rate_avg=input.size(0) / batch_time.avg,
                        loss=losses, top1=top1, top5=top5))

    if real_labels is not None:
        # real labels mode replaces topk values at the end
        top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy(k=5)
    else:
        top1a, top5a = top1.avg, top5.avg
    results = OrderedDict(
        top1=round(top1a, 4), top1_err=round(100 - top1a, 4),
        top5=round(top5a, 4), top5_err=round(100 - top5a, 4),
        param_count=round(param_count / 1e6, 2),
        img_size=data_config['input_size'][-1],
        cropt_pct=crop_pct,
        interpolation=data_config['interpolation'])

    _logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format(
       results['top1'], results['top1_err'], results['top5'], results['top5_err']))

    return results
示例#15
0
def validate(args):
    # might as well try to validate something
    args.pretrained = args.pretrained or not args.checkpoint
    args.prefetcher = not args.no_prefetcher
    amp_autocast = suppress  # do nothing
    if args.amp:
        if has_apex:
            args.apex_amp = True
        elif has_native_amp:
            args.native_amp = True
        else:
            _logger.warning(
                "Neither APEX or Native Torch AMP is available, using FP32.")
    assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set."
    if args.native_amp:
        amp_autocast = torch.cuda.amp.autocast

    if args.legacy_jit:
        set_jit_legacy()

    # create model
    model = create_model(args.model,
                         pretrained=args.pretrained,
                         num_classes=args.num_classes,
                         in_chans=3,
                         global_pool=args.gp,
                         scriptable=args.torchscript)
    if args.num_classes is None:
        assert hasattr(
            model, 'num_classes'
        ), 'Model must have `num_classes` attr if not set on cmd line/config.'
        args.num_classes = model.num_classes

    if args.checkpoint:
        load_checkpoint(model, args.checkpoint, args.use_ema)

    param_count = sum([m.numel() for m in model.parameters()])
    _logger.info('Model %s created, param count: %d' %
                 (args.model, param_count))

    data_config = resolve_data_config(vars(args), model=model)
    model, test_time_pool = (
        model, False) if args.no_test_pool else apply_test_time_pool(
            model, data_config)

    if args.torchscript:
        torch.jit.optimized_execution(True)
        model = torch.jit.script(model)

    model = model.cuda()
    if args.apex_amp:
        model = amp.initialize(model, opt_level='O1')

    if args.channels_last:
        model = model.to(memory_format=torch.channels_last)

    if args.num_gpu > 1:
        model = torch.nn.DataParallel(model,
                                      device_ids=list(range(args.num_gpu)))

    criterion = nn.CrossEntropyLoss().cuda()

    dataset = create_dataset(root=args.data,
                             name=args.dataset,
                             split=args.split,
                             load_bytes=args.tf_preprocessing,
                             class_map=args.class_map)

    if args.valid_labels:
        with open(args.valid_labels, 'r') as f:
            valid_labels = {int(line.rstrip()) for line in f}
            valid_labels = [i in valid_labels for i in range(args.num_classes)]
    else:
        valid_labels = None

    if args.real_labels:
        real_labels = RealLabelsImagenet(dataset.filenames(basename=True),
                                         real_json=args.real_labels)
    else:
        real_labels = None

    crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
    loader = create_loader(dataset,
                           input_size=data_config['input_size'],
                           batch_size=args.batch_size,
                           use_prefetcher=args.prefetcher,
                           interpolation=data_config['interpolation'],
                           mean=data_config['mean'],
                           std=data_config['std'],
                           num_workers=args.workers,
                           crop_pct=crop_pct,
                           pin_memory=args.pin_mem,
                           tf_preprocessing=args.tf_preprocessing)

    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    model.eval()
    with torch.no_grad():
        # warmup, reduce variability of first batch time, especially for comparing torchscript vs non
        input = torch.randn((args.batch_size, ) +
                            data_config['input_size']).cuda()
        if args.channels_last:
            input = input.contiguous(memory_format=torch.channels_last)
        model(input)
        end = time.time()
        for batch_idx, (input, target) in enumerate(loader):
            if args.no_prefetcher:
                target = target.cuda()
                input = input.cuda()
            if args.channels_last:
                input = input.contiguous(memory_format=torch.channels_last)

            # compute output
            with amp_autocast():
                output = model(input)

            if valid_labels is not None:
                output = output[:, valid_labels]
            loss = criterion(output, target)

            if real_labels is not None:
                real_labels.add_result(output)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5))
            losses.update(loss.item(), input.size(0))
            top1.update(acc1.item(), input.size(0))
            top5.update(acc5.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if batch_idx % args.log_freq == 0:
                _logger.info(
                    'Test: [{0:>4d}/{1}]  '
                    'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                    'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                    'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f})  '
                    'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
                        batch_idx,
                        len(loader),
                        batch_time=batch_time,
                        rate_avg=input.size(0) / batch_time.avg,
                        loss=losses,
                        top1=top1,
                        top5=top5))

    if real_labels is not None:
        # real labels mode replaces topk values at the end
        top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy(
            k=5)
    else:
        top1a, top5a = top1.avg, top5.avg
    results = OrderedDict(top1=round(top1a, 4),
                          top1_err=round(100 - top1a, 4),
                          top5=round(top5a, 4),
                          top5_err=round(100 - top5a, 4),
                          param_count=round(param_count / 1e6, 2),
                          img_size=data_config['input_size'][-1],
                          cropt_pct=crop_pct,
                          interpolation=data_config['interpolation'])

    _logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format(
        results['top1'], results['top1_err'], results['top5'],
        results['top5_err']))

    return results
示例#16
0
    filenames = [os.path.splitext(f)[0] for f in dataset.filenames()]

    # get appropriate transform for model's default pretrained config
    data_config = resolve_data_config(m['args'], model=model, verbose=True)
    test_time_pool = False
    if m['ttp']:
        model, test_time_pool = apply_test_time_pool(model, data_config)
        data_config['crop_pct'] = 1.0

    batch_size = m['batch_size']
    loader = create_loader(
        dataset,
        input_size=data_config['input_size'],
        batch_size=batch_size,
        use_prefetcher=True,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=6,
        crop_pct=data_config['crop_pct'],
        pin_memory=True)

    evaluator = ImageNetEvaluator(
        root=DATA_ROOT,
        model_name=m['paper_model_name'],
        paper_arxiv_id=m['paper_arxiv_id'],
        model_description=m.get('model_description', None),
    )
    model.cuda()
    model.eval()
    with torch.no_grad():
示例#17
0
def main():

    setup_default_logging()
    args, args_text = _parse_args()

    args.prefetcher = not args.no_prefetcher
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1
    args.device = 'cuda:0'
    args.world_size = 1
    args.rank = 0  # global rank
    if args.distributed:
        args.device = 'cuda:%d' % args.local_rank
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()
        args.rank = torch.distributed.get_rank()
        _logger.info(
            'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
            % (args.rank, args.world_size))
    else:
        _logger.info('Training with a single process on 1 GPUs.')
    assert args.rank >= 0

    # resolve AMP arguments based on PyTorch / Apex availability
    use_amp = None
    if args.amp:
        # for backwards compat, `--amp` arg tries apex before native amp
        if has_apex:
            args.apex_amp = True
        elif has_native_amp:
            args.native_amp = True
    if args.apex_amp and has_apex:
        use_amp = 'apex'
    elif args.native_amp and has_native_amp:
        use_amp = 'native'
    elif args.apex_amp or args.native_amp:
        _logger.warning(
            "Neither APEX or native Torch AMP is available, using float32. "
            "Install NVIDA apex or upgrade to PyTorch 1.6")

    torch.manual_seed(args.seed + args.rank)

    model = create_model(
        args.model,
        pretrained=args.pretrained,
        num_classes=args.num_classes,
        drop_rate=args.drop,
        drop_connect_rate=args.drop_connect,  # DEPRECATED, use drop_path
        drop_path_rate=args.drop_path,
        drop_block_rate=args.drop_block,
        global_pool=args.gp,
        bn_tf=args.bn_tf,
        bn_momentum=args.bn_momentum,
        bn_eps=args.bn_eps,
        scriptable=args.torchscript,
        use_cos_reg=args.cos_reg_component > 0,
        checkpoint_path=args.initial_checkpoint)
    with torch.cuda.device(0):
        input = torch.randn(1, 3, 224, 224)
        size_for_madd = 224 if args.img_size is None else args.img_size
        # flops, params = get_model_complexity_info(model, (3, size_for_madd, size_for_madd), as_strings=True, print_per_layer_stat=True)
        # print("=>Flops:  " + flops)
        # print("=>Params: " + params)
    if args.local_rank == 0:
        _logger.info('Model %s created, param count: %d' %
                     (args.model, sum([m.numel()
                                       for m in model.parameters()])))

    data_config = resolve_data_config(vars(args),
                                      model=model,
                                      verbose=args.local_rank == 0)

    # setup augmentation batch splits for contrastive loss or split bn
    num_aug_splits = 0
    if args.aug_splits > 0:
        assert args.aug_splits > 1, 'A split of 1 makes no sense'
        num_aug_splits = args.aug_splits

    # enable split bn (separate bn stats per batch-portion)
    if args.split_bn:
        assert num_aug_splits > 1 or args.resplit
        model = convert_splitbn_model(model, max(num_aug_splits, 2))

    # move model to GPU, enable channels last layout if set
    model.cuda()
    if args.channels_last:
        model = model.to(memory_format=torch.channels_last)

    # setup synchronized BatchNorm for distributed training
    if args.distributed and args.sync_bn:
        assert not args.split_bn
        if has_apex and use_amp != 'native':
            # Apex SyncBN preferred unless native amp is activated
            model = convert_syncbn_model(model)
        else:
            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
        if args.local_rank == 0:
            _logger.info(
                'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
                'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.'
            )

    if args.torchscript:
        assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
        assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
        model = torch.jit.script(model)

    optimizer = create_optimizer(args, model)

    # setup automatic mixed-precision (AMP) loss scaling and op casting
    amp_autocast = suppress  # do nothing
    loss_scaler = None
    if use_amp == 'apex':
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
        loss_scaler = ApexScaler()
        if args.local_rank == 0:
            _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
    elif use_amp == 'native':
        amp_autocast = torch.cuda.amp.autocast
        loss_scaler = NativeScaler()
        if args.local_rank == 0:
            _logger.info(
                'Using native Torch AMP. Training in mixed precision.')
    else:
        if args.local_rank == 0:
            _logger.info('AMP not enabled. Training in float32.')

    # optionally resume from a checkpoint
    resume_epoch = None
    if args.resume:
        resume_epoch = resume_checkpoint(
            model,
            args.resume,
            optimizer=None if args.no_resume_opt else optimizer,
            loss_scaler=None if args.no_resume_opt else loss_scaler,
            log_info=args.local_rank == 0)

    # setup exponential moving average of model weights, SWA could be used here too
    model_ema = None
    if args.model_ema:
        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
        model_ema = ModelEmaV2(
            model,
            decay=args.model_ema_decay,
            device='cpu' if args.model_ema_force_cpu else None)
        if args.resume:
            load_checkpoint(model_ema.module, args.resume, use_ema=True)

    # setup distributed training
    if args.distributed:
        if has_apex and use_amp != 'native':
            # Apex DDP preferred unless native amp is activated
            if args.local_rank == 0:
                _logger.info("Using NVIDIA APEX DistributedDataParallel.")
            model = ApexDDP(model, delay_allreduce=True)
        else:
            if args.local_rank == 0:
                _logger.info("Using native Torch DistributedDataParallel.")
            model = NativeDDP(model, device_ids=[
                args.local_rank
            ])  # can use device str in Torch >= 1.1
        # NOTE: EMA model does not need to be wrapped by DDP

    # setup learning rate schedule and starting epoch
    lr_scheduler, num_epochs = create_scheduler(args, optimizer)
    start_epoch = 0
    if args.start_epoch is not None:
        # a specified start_epoch will always override the resume epoch
        start_epoch = args.start_epoch
    elif resume_epoch is not None:
        start_epoch = resume_epoch
    if lr_scheduler is not None and start_epoch > 0:
        lr_scheduler.step(start_epoch)

    if args.local_rank == 0:
        _logger.info('Scheduled epochs: {}'.format(num_epochs))

    # create the train and eval datasets
    train_dir = os.path.join(args.data, 'train')
    if not os.path.exists(train_dir):
        _logger.error(
            'Training folder does not exist at: {}'.format(train_dir))
        exit(1)
    if args.use_lmdb:
        dataset_train = ImageFolderLMDB('../dataset_lmdb/train')
    else:
        dataset_train = Dataset(train_dir)
    # dataset_train = Dataset(train_dir)

    eval_dir = os.path.join(args.data, 'val')
    if not os.path.isdir(eval_dir):
        eval_dir = os.path.join(args.data, 'validation')
        if not os.path.isdir(eval_dir):
            _logger.error(
                'Validation folder does not exist at: {}'.format(eval_dir))
            exit(1)
    if args.use_lmdb:
        dataset_eval = ImageFolderLMDB('../dataset_lmdb/val')
    else:
        dataset_eval = Dataset(eval_dir)
    # dataset_eval = Dataset(eval_dir)

    # setup mixup / cutmix
    collate_fn = None
    mixup_fn = None
    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
    if mixup_active:
        mixup_args = dict(mixup_alpha=args.mixup,
                          cutmix_alpha=args.cutmix,
                          cutmix_minmax=args.cutmix_minmax,
                          prob=args.mixup_prob,
                          switch_prob=args.mixup_switch_prob,
                          mode=args.mixup_mode,
                          label_smoothing=args.smoothing,
                          num_classes=args.num_classes)
        if args.prefetcher:
            assert not num_aug_splits  # collate conflict (need to support deinterleaving in collate mixup)
            collate_fn = FastCollateMixup(**mixup_args)
        else:
            mixup_fn = Mixup(**mixup_args)

    # wrap dataset in AugMix helper
    if num_aug_splits > 1:
        dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)

    # create data loaders w/ augmentation pipeiine
    train_interpolation = args.train_interpolation
    if args.no_aug or not train_interpolation:
        train_interpolation = data_config['interpolation']
    loader_train = create_loader(
        dataset_train,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        is_training=True,
        use_prefetcher=args.prefetcher,
        no_aug=args.no_aug,
        re_prob=args.reprob,
        re_mode=args.remode,
        re_count=args.recount,
        re_split=args.resplit,
        scale=args.scale,
        ratio=args.ratio,
        hflip=args.hflip,
        vflip=args.vflip,
        color_jitter=args.color_jitter,
        auto_augment=args.aa,
        num_aug_splits=num_aug_splits,
        interpolation=train_interpolation,
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        collate_fn=collate_fn,
        pin_memory=args.pin_mem,
        use_multi_epochs_loader=args.use_multi_epochs_loader,
        repeated_aug=args.use_repeated_aug,
        world_size=args.world_size,
        rank=args.rank)

    loader_eval = create_loader(
        dataset_eval,
        input_size=data_config['input_size'],
        batch_size=args.validation_batch_size_multiplier * args.batch_size,
        is_training=False,
        use_prefetcher=args.prefetcher,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        crop_pct=data_config['crop_pct'],
        pin_memory=args.pin_mem,
    )

    loader_cali = create_loader(
        dataset_train,
        input_size=data_config['input_size'],
        batch_size=args.cali_batch_size,
        is_training=False,
        use_prefetcher=args.prefetcher,
        no_aug=True,
        re_prob=args.reprob,
        re_mode=args.remode,
        re_count=args.recount,
        re_split=args.resplit,
        scale=args.scale,
        ratio=args.ratio,
        hflip=args.hflip,
        vflip=args.vflip,
        color_jitter=args.color_jitter,
        auto_augment=args.aa,
        num_aug_splits=num_aug_splits,
        interpolation=train_interpolation,
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        collate_fn=None,
        pin_memory=args.pin_mem,
        use_multi_epochs_loader=args.use_multi_epochs_loader,
        repeated_aug=args.use_repeated_aug,
        world_size=args.world_size,
        rank=args.rank)

    # setup loss function
    if args.jsd:
        assert num_aug_splits > 1  # JSD only valid with aug splits set
        train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits,
                                        smoothing=args.smoothing).cuda()
    elif mixup_active:
        # smoothing is handled with mixup target transform
        if args.cos_reg_component > 0:
            args.use_cos_reg_component = True
            train_loss_fn = SoftTargetCrossEntropyCosReg(
                n_comn=args.cos_reg_component).cuda()
        else:
            train_loss_fn = SoftTargetCrossEntropy().cuda()
            args.use_cos_reg_component = False

    elif args.smoothing:
        train_loss_fn = LabelSmoothingCrossEntropy(
            smoothing=args.smoothing).cuda()
    else:
        train_loss_fn = nn.CrossEntropyLoss().cuda()
    validate_loss_fn = nn.CrossEntropyLoss().cuda()

    # setup checkpoint saver and eval metric tracking
    eval_metric = args.eval_metric
    best_metric = None
    best_epoch = None
    saver = None
    output_dir = ''
    if args.local_rank == 0:
        output_base = args.output if args.output else './output'
        exp_name = '-'.join([
            datetime.now().strftime("%Y%m%d-%H%M%S"), args.model,
            str(data_config['input_size'][-1])
        ])
        output_dir = get_outdir(output_base, 'train', exp_name)
        code_dir = get_outdir(output_dir, 'code')
        copy_tree(os.getcwd(), code_dir)
        decreasing = True if eval_metric == 'loss' else False
        saver = CheckpointSaver(model=model,
                                optimizer=optimizer,
                                args=args,
                                model_ema=model_ema,
                                amp_scaler=loss_scaler,
                                checkpoint_dir=output_dir,
                                recovery_dir=output_dir,
                                decreasing=decreasing)
        with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
            f.write(args_text)

    try:
        for epoch in range(start_epoch, num_epochs):
            if args.distributed:
                loader_train.sampler.set_epoch(epoch)
            if not args.eval_only:
                train_metrics = train_epoch(epoch,
                                            model,
                                            loader_train,
                                            optimizer,
                                            train_loss_fn,
                                            args,
                                            lr_scheduler=lr_scheduler,
                                            saver=saver,
                                            output_dir=output_dir,
                                            amp_autocast=amp_autocast,
                                            loss_scaler=loss_scaler,
                                            model_ema=model_ema,
                                            mixup_fn=mixup_fn)

            if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
                if args.local_rank == 0:
                    _logger.info(
                        "Distributing BatchNorm running means and vars")
                distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
            if args.max_iter > 0:
                _ = validate(model,
                             loader_cali,
                             validate_loss_fn,
                             args,
                             amp_autocast=amp_autocast,
                             use_bn_calibration=True)
            eval_metrics = validate(model,
                                    loader_eval,
                                    validate_loss_fn,
                                    args,
                                    amp_autocast=amp_autocast)

            if model_ema is not None and not args.model_ema_force_cpu:
                if args.distributed and args.dist_bn in ('broadcast',
                                                         'reduce'):
                    distribute_bn(model_ema, args.world_size,
                                  args.dist_bn == 'reduce')
                ema_eval_metrics = validate(model_ema.module,
                                            loader_eval,
                                            validate_loss_fn,
                                            args,
                                            amp_autocast=amp_autocast,
                                            log_suffix=' (EMA)')
                eval_metrics = ema_eval_metrics

            if lr_scheduler is not None:
                # step LR for next epoch
                lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
            if not args.eval_only:
                update_summary(epoch,
                               train_metrics,
                               eval_metrics,
                               os.path.join(output_dir, 'summary.csv'),
                               write_header=best_metric is None)

            if saver is not None:
                # save proper checkpoint with eval metric
                save_metric = eval_metrics[eval_metric]
                best_metric, best_epoch = saver.save_checkpoint(
                    epoch, metric=save_metric)
                if args.eval_only:
                    break

    except KeyboardInterrupt:
        pass
    if best_metric is not None:
        _logger.info('*** Best metric: {0} (epoch {1})'.format(
            best_metric, best_epoch))
示例#18
0
def main():
    setup_default_logging()
    args = parser.parse_args()
    args.prefetcher = not args.no_prefetcher
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1
        if args.distributed and args.num_gpu > 1:
            logging.warning(
                'Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.'
            )
            args.num_gpu = 1

    args.device = 'cuda:0'
    args.world_size = 1
    args.rank = 0  # global rank
    if args.distributed:
        args.num_gpu = 1
        args.device = 'cuda:%d' % args.local_rank
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()
        args.rank = torch.distributed.get_rank()
    assert args.rank >= 0

    if args.distributed:
        logging.info(
            'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
            % (args.rank, args.world_size))
    else:
        logging.info('Training with a single process on %d GPUs.' %
                     args.num_gpu)

    torch.manual_seed(args.seed + args.rank)

    model = create_model(args.model,
                         pretrained=args.pretrained,
                         num_classes=args.num_classes,
                         drop_rate=args.drop,
                         global_pool=args.gp,
                         bn_tf=args.bn_tf,
                         bn_momentum=args.bn_momentum,
                         bn_eps=args.bn_eps,
                         drop_connect_rate=0.2,
                         checkpoint_path=args.initial_checkpoint,
                         args=args)
    flops, params = get_model_complexity_info(
        model, (3, 224, 224),
        as_strings=True,
        print_per_layer_stat=args.display_info)
    print('Flops:  ' + flops)
    print('Params: ' + params)
    if args.KD_train:
        teacher_model = create_model("efficientnet_b7_dq",
                                     pretrained=True,
                                     num_classes=args.num_classes,
                                     drop_rate=args.drop,
                                     global_pool=args.gp,
                                     bn_tf=args.bn_tf,
                                     bn_momentum=args.bn_momentum,
                                     bn_eps=args.bn_eps,
                                     drop_connect_rate=0.2,
                                     checkpoint_path=args.initial_checkpoint,
                                     args=args)

        flops_teacher, params_teacher = get_model_complexity_info(
            teacher_model, (3, 224, 224),
            as_strings=True,
            print_per_layer_stat=False)
        print("Using KD training...")
        print("FLOPs of teacher model: ", flops_teacher)
        print("Params of teacher model: ", params_teacher)

    if args.local_rank == 0:
        logging.info('Model %s created, param count: %d' %
                     (args.model, sum([m.numel()
                                       for m in model.parameters()])))

    data_config = resolve_data_config(model,
                                      args,
                                      verbose=args.local_rank == 0)

    # optionally resume from a checkpoint
    start_epoch = 0
    optimizer_state = None
    if args.resume:
        optimizer_state, start_epoch = resume_checkpoint(
            model, args.resume, args.start_epoch)
        # import pdb;pdb.set_trace()

    if args.num_gpu > 1:
        if args.amp:
            logging.warning(
                'AMP does not work well with nn.DataParallel, disabling. Use distributed mode for multi-GPU AMP.'
            )
            args.amp = False
        model = nn.DataParallel(model,
                                device_ids=list(range(args.num_gpu))).cuda()
        if args.KD_train:
            teacher_model = nn.DataParallel(teacher_model,
                                            device_ids=list(range(
                                                args.num_gpu))).cuda()
    else:
        model.cuda()
        if args.KD_train:
            teacher_model.cuda()

    optimizer = create_optimizer(args, model)
    if optimizer_state is not None:
        optimizer.load_state_dict(optimizer_state)

    use_amp = False
    if has_apex and args.amp:
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
        use_amp = True
    if args.local_rank == 0:
        logging.info('NVIDIA APEX {}. AMP {}.'.format(
            'installed' if has_apex else 'not installed',
            'on' if use_amp else 'off'))

    model_ema = None
    if args.model_ema:
        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
        # import pdb; pdb.set_trace()
        model_ema = ModelEma(model,
                             decay=args.model_ema_decay,
                             device='cpu' if args.model_ema_force_cpu else '',
                             resume=args.resume)

    if args.distributed:
        if args.sync_bn:
            try:
                if has_apex:
                    model = convert_syncbn_model(model)
                else:
                    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                        model)
                if args.local_rank == 0:
                    logging.info(
                        'Converted model to use Synchronized BatchNorm.')
            except Exception as e:
                logging.error(
                    'Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1'
                )
        if has_apex:
            model = DDP(model, delay_allreduce=True)
        else:
            if args.local_rank == 0:
                logging.info(
                    "Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP."
                )
            model = DDP(model,
                        device_ids=[args.local_rank
                                    ])  # can use device str in Torch >= 1.1
        # NOTE: EMA model does not need to be wrapped by DDP

    lr_scheduler, num_epochs = create_scheduler(args, optimizer)
    if start_epoch > 0:
        lr_scheduler.step(start_epoch)
    if args.local_rank == 0:
        logging.info('Scheduled epochs: {}'.format(num_epochs))

    train_dir = os.path.join(args.data, 'train')
    if not os.path.exists(train_dir):
        logging.error(
            'Training folder does not exist at: {}'.format(train_dir))
        exit(1)
    dataset_train = Dataset(train_dir)

    collate_fn = None
    if args.prefetcher and args.mixup > 0:
        collate_fn = FastCollateMixup(args.mixup, args.smoothing,
                                      args.num_classes)

    if args.auto_augment:
        print('using auto data augumentation...')
    loader_train = create_loader(
        dataset_train,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        is_training=True,
        use_prefetcher=args.prefetcher,
        rand_erase_prob=args.reprob,
        rand_erase_mode=args.remode,
        interpolation=
        'bicubic',  # FIXME cleanly resolve this? data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        collate_fn=collate_fn,
        use_auto_aug=args.auto_augment,
        use_mixcut=args.mixcut,
    )

    eval_dir = os.path.join(args.data, 'val')
    if not os.path.isdir(eval_dir):
        logging.error(
            'Validation folder does not exist at: {}'.format(eval_dir))
        exit(1)
    dataset_eval = Dataset(eval_dir)

    loader_eval = create_loader(
        dataset_eval,
        input_size=data_config['input_size'],
        batch_size=4 * args.batch_size,
        is_training=False,
        use_prefetcher=args.prefetcher,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
    )

    if args.mixup > 0.:
        # smoothing is handled with mixup label transform
        train_loss_fn = SoftTargetCrossEntropy().cuda()
        validate_loss_fn = nn.CrossEntropyLoss().cuda()
    elif args.smoothing:
        train_loss_fn = LabelSmoothingCrossEntropy(
            smoothing=args.smoothing).cuda()
        validate_loss_fn = nn.CrossEntropyLoss().cuda()
    else:
        train_loss_fn = nn.CrossEntropyLoss().cuda()
        validate_loss_fn = train_loss_fn
    if args.KD_train:
        train_loss_fn = nn.KLDivLoss(reduction='batchmean').cuda()

    eval_metric = args.eval_metric
    best_metric = None
    best_epoch = None
    saver = None
    output_dir = ''
    if args.local_rank == 0:
        output_base = args.output if args.output else './output'
        exp_name = '-'.join([
            datetime.now().strftime("%Y%m%d-%H%M%S"), args.model,
            str(data_config['input_size'][-1])
        ])
        output_dir = get_outdir(output_base, 'train', exp_name)
        decreasing = True if eval_metric == 'loss' else False
        saver = CheckpointSaver(checkpoint_dir=output_dir,
                                decreasing=decreasing)

    try:
        # import pdb;pdb.set_trace()
        for epoch in range(start_epoch, num_epochs):
            if args.distributed:
                loader_train.sampler.set_epoch(epoch)
            # import pdb; pdb.set_trace()
            if args.KD_train:
                train_metrics = train_epoch(epoch,
                                            model,
                                            loader_train,
                                            optimizer,
                                            train_loss_fn,
                                            args,
                                            lr_scheduler=lr_scheduler,
                                            saver=saver,
                                            output_dir=output_dir,
                                            use_amp=use_amp,
                                            model_ema=model_ema,
                                            teacher_model=teacher_model)
            else:
                train_metrics = train_epoch(epoch,
                                            model,
                                            loader_train,
                                            optimizer,
                                            train_loss_fn,
                                            args,
                                            lr_scheduler=lr_scheduler,
                                            saver=saver,
                                            output_dir=output_dir,
                                            use_amp=use_amp,
                                            model_ema=model_ema)

            # def __init__(self, model, bits_activations=8, bits_parameters=8, bits_accum=32,
            #                 overrides=None, mode=LinearQuantMode.SYMMETRIC, clip_acts=ClipMode.NONE,
            #                 per_channel_wts=False, model_activation_stats=None, fp16=False, clip_n_stds=None,
            #                 scale_approx_mult_bits=None):
            # import distiller
            # import pdb; pdb.set_trace()
            # quantizer = quantization.PostTrainLinearQuantizer.from_args(model, args)
            # quantizer.prepare_model(distiller.get_dummy_input(input_shape=model.input_shape))
            # quantizer = distiller.quantization.PostTrainLinearQuantizer(model, bits_activations=8, bits_parameters=8)
            # quantizer.prepare_model()

            # distiller.utils.assign_layer_fq_names(model)
            # # msglogger.info("Generating quantization calibration stats based on {0} users".format(args.qe_calibration))
            # collector = distiller.data_loggers.QuantCalibrationStatsCollector(model)
            # with collector_context(collector):
            #     eval_metrics = validate(model, loader_eval, validate_loss_fn, args)
            #     # Here call your model evaluation function, making sure to execute only
            #     # the portion of the dataset specified by the qe_calibration argument
            # yaml_path = './dir/quantization_stats.yaml'
            # collector.save(yaml_path)

            eval_metrics = validate(model, loader_eval, validate_loss_fn, args)

            if model_ema is not None and not args.model_ema_force_cpu:
                ema_eval_metrics = validate(model_ema.ema,
                                            loader_eval,
                                            validate_loss_fn,
                                            args,
                                            log_suffix=' (EMA)')
                eval_metrics = ema_eval_metrics

            if lr_scheduler is not None:
                lr_scheduler.step(epoch, eval_metrics[eval_metric])

            update_summary(epoch,
                           train_metrics,
                           eval_metrics,
                           os.path.join(output_dir, 'summary.csv'),
                           write_header=best_metric is None)

            if saver is not None:
                # save proper checkpoint with eval metric
                save_metric = eval_metrics[eval_metric]
                best_metric, best_epoch = saver.save_checkpoint(
                    model,
                    optimizer,
                    args,
                    epoch=epoch + 1,
                    model_ema=model_ema,
                    metric=save_metric)

    except KeyboardInterrupt:
        pass
    if best_metric is not None:
        logging.info('*** Best metric: {0} (epoch {1})'.format(
            best_metric, best_epoch))
示例#19
0
def validate(args):
    # might as well try to validate something
    args.pretrained = args.pretrained or not args.checkpoint

    # create model
    model = create_model(
        args.model,
        num_classes=args.num_classes,
        in_chans=3,
        pretrained=args.pretrained)

    if args.checkpoint:
        load_checkpoint(model, args.checkpoint, args.use_ema)

    param_count = sum([m.numel() for m in model.parameters()])
    logging.info('Model %s created, param count: %d' % (args.model, param_count))

    data_config = resolve_data_config(model, args)
    model, test_time_pool = apply_test_time_pool(model, data_config, args)

    if args.num_gpu > 1:
        model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
    else:
        model = model.cuda()

    criterion = nn.CrossEntropyLoss().cuda()

    loader = create_loader(
        Dataset(args.data, load_bytes=args.tf_preprocessing),
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        use_prefetcher=True,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        crop_pct=1.0 if test_time_pool else data_config['crop_pct'],
        tf_preprocessing=args.tf_preprocessing)

    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    model.eval()
    end = time.time()
    with torch.no_grad():
        for i, (input, target) in enumerate(loader):
            target = target.cuda()
            input = input.cuda()

            # compute output
            output = model(input)
            loss = criterion(output, target)

            # measure accuracy and record loss
            prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
            losses.update(loss.item(), input.size(0))
            top1.update(prec1.item(), input.size(0))
            top5.update(prec5.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.log_freq == 0:
                logging.info(
                    'Test: [{0:>4d}/{1}]  '
                    'Time: {batch_time.val:.3f} ({batch_time.avg:.3f})  '
                    'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                    'Prec@1: {top1.val:>7.4f} ({top1.avg:>7.4f})  '
                    'Prec@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(
                        i, len(loader), batch_time=batch_time,
                        rate_avg=input.size(0) / batch_time.avg,
                        loss=losses, top1=top1, top5=top5))

    results = OrderedDict(
        top1=round(top1.avg, 3), top1_err=round(100 - top1.avg, 3),
        top5=round(top5.avg, 3), top5_err=round(100 - top5.avg, 3),
        param_count=round(param_count / 1e6, 2))

    logging.info(' * Prec@1 {:.3f} ({:.3f}) Prec@5 {:.3f} ({:.3f})'.format(
       results['top1'], results['top1_err'], results['top5'], results['top5_err']))

    return results
示例#20
0
def main():
    import os

    args, args_text = _parse_args()

    eval_metric = args.eval_metric
    best_metric = None
    best_epoch = None
    saver = None
    output_dir = ''
    if args.local_rank == 0:
        output_base = args.output if args.output else './output'
        exp_name = 'train'
        if args.gate_train:
            exp_name += '-dynamic'
        if args.slim_train:
            exp_name += '-slimmable'
        exp_name += '-{}'.format(args.model)
        exp_info = '-'.join(
            [datetime.now().strftime("%Y%m%d-%H%M%S"), args.model])
        output_dir = get_outdir(output_base, exp_name, exp_info)
        decreasing = True if eval_metric == 'loss' else False
        saver = CheckpointSaver(checkpoint_dir=output_dir,
                                decreasing=decreasing)
        with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
            f.write(args_text)
    setup_default_logging(outdir=output_dir, local_rank=args.local_rank)

    torch.backends.cudnn.benchmark = True

    args.prefetcher = not args.no_prefetcher
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1
        if args.distributed and args.num_gpu > 1:
            logging.warning(
                'Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.'
            )
            args.num_gpu = 1

    args.device = 'cuda:0'
    args.world_size = 1
    args.rank = 0  # global rank
    if args.distributed:
        args.num_gpu = 1
        args.device = 'cuda:%d' % args.local_rank
        torch.cuda.set_device(args.local_rank)
        # torch.distributed.init_process_group(backend='nccl',
        #                                      init_method='tcp://127.0.0.1:23334',
        #                                      rank=args.local_rank,
        #                                      world_size=int(os.environ['WORLD_SIZE']))
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()
        args.rank = torch.distributed.get_rank()
    assert args.rank >= 0

    if args.distributed:
        logging.info(
            'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
            % (args.rank, args.world_size))
    else:
        logging.info('Training with a single process on %d GPUs.' %
                     args.num_gpu)

    # --------- random seed -----------
    random.seed(args.seed)  # TODO: do we need same seed on all GPU?
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    # torch.manual_seed(args.seed + args.rank)

    model = create_model(args.model,
                         pretrained=args.pretrained,
                         num_classes=args.num_classes,
                         drop_rate=args.drop,
                         drop_path_rate=args.drop_path,
                         global_pool=args.gp,
                         bn_tf=args.bn_tf,
                         bn_momentum=args.bn_momentum,
                         bn_eps=args.bn_eps,
                         checkpoint_path=args.initial_checkpoint)

    # optionally resume from a checkpoint
    resume_state = {}
    resume_epoch = None
    if args.resume:
        resume_state, resume_epoch = resume_checkpoint(model, args.resume)

    if args.local_rank == 0:
        logging.info('Model %s created, param count: %d' %
                     (args.model, sum([m.numel()
                                       for m in model.parameters()])))

    data_config = resolve_data_config(vars(args),
                                      model=model,
                                      verbose=args.local_rank == 0)

    num_aug_splits = 0
    if args.aug_splits > 0:
        assert args.aug_splits > 1, 'A split of 1 makes no sense'
        num_aug_splits = args.aug_splits

    if args.split_bn:
        assert num_aug_splits > 1 or args.resplit
        model = convert_splitbn_model(model, max(num_aug_splits, 2))

    if args.num_gpu > 1:
        if args.amp:
            logging.warning(
                'AMP does not work well with nn.DataParallel, disabling. Use distributed mode for multi-GPU AMP.'
            )
            args.amp = False
        model = nn.DataParallel(model,
                                device_ids=list(range(args.num_gpu))).cuda()
    else:
        model.cuda()

    if args.train_mode == 'se':
        optimizer = create_optimizer(args, model.get_se())
    elif args.train_mode == 'bn':
        optimizer = create_optimizer(args, model.get_bn())
    elif args.train_mode == 'all':
        optimizer = create_optimizer(args, model)
    elif args.train_mode == 'gate':
        optimizer = create_optimizer(args, model.get_gate())

    use_amp = False
    if has_apex and args.amp:
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
        use_amp = True
    if args.local_rank == 0:
        logging.info('NVIDIA APEX {}. AMP {}.'.format(
            'installed' if has_apex else 'not installed',
            'on' if use_amp else 'off'))

    if resume_state and not args.no_resume_opt:
        # ----------- Load Optimizer ---------
        if 'optimizer' in resume_state:
            if args.local_rank == 0:
                logging.info('Restoring Optimizer state from checkpoint')
            optimizer.load_state_dict(resume_state['optimizer'])
        if use_amp and 'amp' in resume_state and 'load_state_dict' in amp.__dict__:
            if args.local_rank == 0:
                logging.info('Restoring NVIDIA AMP state from checkpoint')
            amp.load_state_dict(resume_state['amp'])
    del resume_state

    model_ema = None
    if args.model_ema:
        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
        model_ema = ModelEma(model,
                             decay=args.model_ema_decay,
                             device='cpu' if args.model_ema_force_cpu else '',
                             resume=args.resume)

    if args.distributed:
        if args.sync_bn:
            assert not args.split_bn
            try:
                if has_apex:
                    model = convert_syncbn_model(model)
                else:
                    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                        model)
                if args.local_rank == 0:
                    logging.info(
                        'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
                        'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.'
                    )
            except Exception as e:
                logging.error(
                    'Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1'
                )
        if has_apex:
            model = DDP(model, delay_allreduce=True)
        else:
            if args.local_rank == 0:
                logging.info(
                    "Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP."
                )
            model = DDP(model,
                        device_ids=[args.local_rank],
                        find_unused_parameters=True
                        )  # can use device str in Torch >= 1.1
        # NOTE: EMA model does not need to be wrapped by DDP

    lr_scheduler, num_epochs = create_scheduler(args, optimizer)
    start_epoch = 0
    if args.start_epoch is not None:
        # a specified start_epoch will always override the resume epoch
        start_epoch = args.start_epoch
    elif resume_epoch is not None:
        start_epoch = resume_epoch
    if lr_scheduler is not None and start_epoch > 0:
        lr_scheduler.step(start_epoch)

    if args.local_rank == 0:
        logging.info('Scheduled epochs: {}'.format(num_epochs))

    # ------------- data --------------
    train_dir = os.path.join(args.data, 'train')
    if not os.path.exists(train_dir):
        logging.error(
            'Training folder does not exist at: {}'.format(train_dir))
        exit(1)
    dataset_train = Dataset(train_dir)
    collate_fn = None
    if num_aug_splits > 1:
        dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)
    loader_train = create_loader(
        dataset_train,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        is_training=True,
        use_prefetcher=args.prefetcher,
        re_prob=args.reprob,
        re_mode=args.remode,
        re_count=args.recount,
        re_split=args.resplit,
        color_jitter=args.color_jitter,
        auto_augment=args.aa,
        num_aug_splits=num_aug_splits,
        interpolation=args.train_interpolation,
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        collate_fn=collate_fn,
        pin_memory=args.pin_mem,
    )
    loader_bn = create_loader(
        dataset_train,
        input_size=data_config['input_size'],
        batch_size=args.validation_batch_size_multiplier * args.batch_size,
        is_training=True,
        use_prefetcher=args.prefetcher,
        re_prob=args.reprob,
        re_mode=args.remode,
        re_count=args.recount,
        re_split=args.resplit,
        color_jitter=args.color_jitter,
        auto_augment=args.aa,
        num_aug_splits=num_aug_splits,
        interpolation=args.train_interpolation,
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        collate_fn=collate_fn,
        pin_memory=args.pin_mem,
    )

    eval_dir = os.path.join(args.data, 'val')
    if not os.path.isdir(eval_dir):
        eval_dir = os.path.join(args.data, 'validation')
        if not os.path.isdir(eval_dir):
            logging.error(
                'Validation folder does not exist at: {}'.format(eval_dir))
            exit(1)
    dataset_eval = Dataset(eval_dir)
    loader_eval = create_loader(
        dataset_eval,
        input_size=data_config['input_size'],
        batch_size=args.validation_batch_size_multiplier * args.batch_size,
        is_training=False,
        use_prefetcher=args.prefetcher,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        crop_pct=data_config['crop_pct'],
        pin_memory=args.pin_mem,
    )

    # ------------- loss_fn --------------
    if args.jsd:
        assert num_aug_splits > 1  # JSD only valid with aug splits set
        train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits,
                                        smoothing=args.smoothing).cuda()
        validate_loss_fn = nn.CrossEntropyLoss().cuda()
    elif args.smoothing:
        train_loss_fn = LabelSmoothingCrossEntropy(
            smoothing=args.smoothing).cuda()
        validate_loss_fn = nn.CrossEntropyLoss().cuda()
    else:
        train_loss_fn = nn.CrossEntropyLoss().cuda()
        validate_loss_fn = train_loss_fn
    if args.ieb:
        distill_loss_fn = SoftTargetCrossEntropy().cuda()
    else:
        distill_loss_fn = None

    if args.local_rank == 0:
        model_profiling(model, 224, 224, 1, 3, use_cuda=True, verbose=True)
    else:
        model_profiling(model, 224, 224, 1, 3, use_cuda=True, verbose=False)

    if not args.test_mode:
        # start training
        for epoch in range(start_epoch, num_epochs):
            if args.distributed:
                loader_train.sampler.set_epoch(epoch)
            train_metrics = OrderedDict([('loss', 0.)])
            # train
            if args.gate_train:
                train_metrics = train_epoch_slim_gate(
                    epoch,
                    model,
                    loader_train,
                    optimizer,
                    train_loss_fn,
                    args,
                    lr_scheduler=lr_scheduler,
                    saver=saver,
                    output_dir=output_dir,
                    use_amp=use_amp,
                    model_ema=model_ema,
                    optimizer_step=args.optimizer_step)
            else:
                train_metrics = train_epoch_slim(
                    epoch,
                    model,
                    loader_train,
                    optimizer,
                    loss_fn=train_loss_fn,
                    distill_loss_fn=distill_loss_fn,
                    args=args,
                    lr_scheduler=lr_scheduler,
                    saver=saver,
                    output_dir=output_dir,
                    use_amp=use_amp,
                    model_ema=model_ema,
                    optimizer_step=args.optimizer_step,
                )
            if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
                if args.local_rank == 0:
                    logging.info(
                        "Distributing BatchNorm running means and vars")
                distribute_bn(model, args.world_size, args.dist_bn == 'reduce')

            # eval
            if args.gate_train:
                eval_sample_list = ['dynamic']
            else:
                if epoch % 10 == 0 and epoch != 0:
                    eval_sample_list = ['smallest', 'largest', 'uniform']
                else:
                    eval_sample_list = ['smallest', 'largest']

            eval_metrics = [
                validate_slim(model,
                              loader_eval,
                              validate_loss_fn,
                              args,
                              model_mode=model_mode)
                for model_mode in eval_sample_list
            ]

            if model_ema is not None and not args.model_ema_force_cpu:

                ema_eval_metrics = [
                    validate_slim(model_ema.ema,
                                  loader_eval,
                                  validate_loss_fn,
                                  args,
                                  model_mode=model_mode)
                    for model_mode in eval_sample_list
                ]

                eval_metrics = ema_eval_metrics

            if isinstance(eval_metrics, list):
                eval_metrics = eval_metrics[0]

            if lr_scheduler is not None:
                # step LR for next epoch
                lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])

            # save
            update_summary(epoch,
                           train_metrics,
                           eval_metrics,
                           os.path.join(output_dir, 'summary.csv'),
                           write_header=best_metric is None)

            if saver is not None:
                # save proper checkpoint with eval metric
                save_metric = eval_metrics[eval_metric]
                best_metric, best_epoch = saver.save_checkpoint(
                    model,
                    optimizer,
                    args,
                    epoch=epoch,
                    model_ema=model_ema,
                    metric=save_metric,
                    use_amp=use_amp)
        # end training
        if best_metric is not None:
            logging.info('*** Best metric: {0} (epoch {1})'.format(
                best_metric, best_epoch))

    # test
    eval_metrics = []
    for choice in range(args.num_choice):
        # reset bn if not smallest or largest
        if choice != 0 and choice != args.num_choice - 1:
            for layer in model.modules():
                if isinstance(layer, nn.BatchNorm2d) or \
                        isinstance(layer, nn.SyncBatchNorm) or \
                        (has_apex and isinstance(layer, apex.parallel.SyncBatchNorm)):
                    layer.reset_running_stats()
            model.train()
            with torch.no_grad():
                for batch_idx, (input, target) in enumerate(loader_bn):
                    if args.slim_train:
                        if hasattr(model, 'module'):
                            model.module.set_mode('uniform', choice=choice)
                        else:
                            model.set_mode('uniform', choice=choice)
                        model(input)

                    if batch_idx % 1000 == 0 and batch_idx != 0:
                        print('Subnet {} : reset bn for {} steps'.format(
                            choice, batch_idx))
                        break
            if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
                if args.local_rank == 0:
                    logging.info(
                        "Distributing BatchNorm running means and vars")
                distribute_bn(model, args.world_size, args.dist_bn == 'reduce')

        eval_metrics.append(
            validate_slim(model,
                          loader_eval,
                          validate_loss_fn,
                          args,
                          model_mode=choice))
    if args.local_rank == 0:
        print('Test results of the last epoch:\n', eval_metrics)
示例#21
0
def validate(args):
    _logger.info(f'\n\n ---------------EVALUATION {args.eps}------------------------------- \n\n')
    _logger.info("Argument parser collected the following arguments:")
    for arg in vars(args):
        _logger.info(f"    {arg}:{getattr(args, arg)}")
    _logger.info("\n")

    # might as well try to validate something
    args.pretrained = args.pretrained or not args.checkpoint
    args.prefetcher = not args.no_prefetcher
    amp_autocast = suppress  # do nothing
    if args.amp:
        if has_native_amp:
            args.native_amp = True
        elif has_apex:
            args.apex_amp = True
        else:
            _logger.warning("Neither APEX or Native Torch AMP is available.")
    assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set."
    if args.native_amp:
        amp_autocast = torch.cuda.amp.autocast
        _logger.info('Validating in mixed precision with native PyTorch AMP.')
    elif args.apex_amp:
        _logger.info('Validating in mixed precision with NVIDIA APEX AMP.')
    else:
        _logger.info('Validating in float32. AMP not enabled.')

    if args.legacy_jit:
        set_jit_legacy()

    # create model
    model = create_model(
        args.model,
        pretrained=args.pretrained,
        num_classes=args.num_classes,
        in_chans=3,
        global_pool=args.gp,
        scriptable=args.torchscript)
    if args.num_classes is None:
        assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
        args.num_classes = model.num_classes

    if args.checkpoint:
        load_checkpoint(model, args.checkpoint, args.use_ema)

    param_count = sum([m.numel() for m in model.parameters()])        
    _logger.info(
        f'Model {args.model} created, param count: {param_count} ({(float(param_count)/(10.0**6)):.1f} M)'
    )

    data_config = resolve_data_config(vars(args), model=model, use_test_size=True, verbose=True)
    test_time_pool = False
    if not args.no_test_pool:
        model, test_time_pool = apply_test_time_pool(model, data_config, use_test_size=True)

    if args.torchscript:
        torch.jit.optimized_execution(True)
        model = torch.jit.script(model)

    model = model.cuda()
    if args.apex_amp:
        model = amp.initialize(model, opt_level='O1')

    if args.channels_last:
        model = model.to(memory_format=torch.channels_last)

    if args.num_gpu > 1:
        model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu)))

    criterion = nn.CrossEntropyLoss().cuda()

    dataset = create_dataset(
        root=args.data_dir, name=args.dataset, split=args.split,
        load_bytes=args.tf_preprocessing, class_map=args.class_map)

    if args.valid_labels:
        with open(args.valid_labels, 'r') as f:
            valid_labels = {int(line.rstrip()) for line in f}
            valid_labels = [i in valid_labels for i in range(args.num_classes)]
    else:
        valid_labels = None

    if args.real_labels:
        real_labels = RealLabelsImagenet(dataset.filenames(basename=True), real_json=args.real_labels)
    else:
        real_labels = None

    crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
    loader = create_loader(
        dataset,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        use_prefetcher=args.prefetcher,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        crop_pct=crop_pct,
        pin_memory=args.pin_mem,
        tf_preprocessing=args.tf_preprocessing)

    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    top1_fgm_ae = AverageMeter()
    top5_fgm_ae = AverageMeter()
    top1_pgd_ae = AverageMeter()
    top5_pgd_ae = AverageMeter()

    model.eval()
    #with torch.no_grad():# TODO Requires grad
    # warmup, reduce variability of first batch time, especially for comparing torchscript vs non
    input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).cuda()
    if args.channels_last:
        input = input.contiguous(memory_format=torch.channels_last)
    model(input)
    end = time.time()
    for batch_idx, (input, target) in enumerate(loader):
        if args.no_prefetcher:
            target = target.cuda()
            input = input.cuda()
        if args.channels_last:
            input = input.contiguous(memory_format=torch.channels_last)

        # compute output
        with amp_autocast():
            output = model(input)

        if valid_labels is not None:
            output = output[:, valid_labels]
        loss = criterion(output, target)

        if real_labels is not None:
            real_labels.add_result(output)

        # TODO <---------------------
        # Generate adversarial examples for current inputs
        input_fgm_ae = fast_gradient_method(
            model_fn=model,
            x=input,
            eps=args.eps,
            norm=np.inf,
            clip_min=None,
            clip_max=None,
        )
        input_pgd_ae = projected_gradient_descent(
            model_fn=model,
            x=input, 
            eps=args.eps, 
            eps_iter=0.01, 
            nb_iter=40, 
            norm=np.inf,
            clip_min=None,
            clip_max=None,
        )
        # Predict with Adversarial Examples
        with torch.no_grad():
            with amp_autocast():
                output_fgm_ae = model(input_fgm_ae)
                output_pgd_ae = model(input_pgd_ae)

        # measure accuracy and record loss
        acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5))
        losses.update(loss.item(), input.size(0))
        top1.update(acc1.item(), input.size(0))
        top5.update(acc5.item(), input.size(0))

        acc1_fgm_ae, acc5_fgm_ae = accuracy(output_fgm_ae.detach(), target, topk=(1, 5))
        acc1_pgd_ae, acc5_pgd_ae = accuracy(output_pgd_ae.detach(), target, topk=(1, 5))
        top1_fgm_ae.update(acc1_fgm_ae.item(), input.size(0))
        top5_fgm_ae.update(acc5_fgm_ae.item(), input.size(0))
        top1_pgd_ae.update(acc1_pgd_ae.item(), input.size(0))
        top5_pgd_ae.update(acc5_pgd_ae.item(), input.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if batch_idx % args.log_freq == 0:
            _logger.info(
                'Test: [{0:>4d}/{1}]  '
                'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f})  '
                'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
                    batch_idx, len(loader), batch_time=batch_time,
                    rate_avg=input.size(0) / batch_time.avg,
                    loss=losses, top1=top1, top5=top5))

    if real_labels is not None:
        raise NotImplementedError # TODO NOt modified for the adversarial examples mode 
        # real labels mode replaces topk values at the end
        top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy(k=5)
    else:
        top1a, top5a = top1.avg, top5.avg
        top1a_fgm_ae, top5a_fgm_ae = top1_fgm_ae.avg, top5_fgm_ae.avg
        top1a_pgd_ae, top5a_pgd_ae = top1_pgd_ae.avg, top5_pgd_ae.avg
    results = OrderedDict(
        top1=round(top1a, 4), top1_err=round(100 - top1a, 4),
        top5=round(top5a, 4), top5_err=round(100 - top5a, 4),
        top1_fgm_ae=round(top1a_fgm_ae, 4),
        top5_fgm_ae=round(top5a_fgm_ae, 4),
        top1_pgd_ae=round(top1a_pgd_ae, 4),
        top5_pgd_ae=round(top5a_pgd_ae, 4),
        param_count=round(param_count / 1e6, 2),
        img_size=data_config['input_size'][-1],
        cropt_pct=crop_pct,
        interpolation=data_config['interpolation'])

    _logger.info(' * [Regular] Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format(
       results['top1'], results['top1_err'], results['top5'], results['top5_err']))

    _logger.info(' * [FGM Adversarial Attack] Acc@1 {:.3f}  Acc@5 {:.3f} '.format(
       results['top1_fgm_ae'], results['top5_fgm_ae']))
    _logger.info(' * [PGD Adversarial Attack] Acc@1 {:.3f}  Acc@5 {:.3f} '.format(
       results['top1_pgd_ae'], results['top5_pgd_ae']))

    return results
示例#22
0
def main():
    setup_default_logging()
    args, args_text = _parse_args()

    if args.log_wandb:
        if has_wandb:
            wandb.init(project=args.experiment, config=args)
        else:
            _logger.warning(
                "You've requested to log metrics to wandb but package not found. "
                "Metrics not being logged to wandb, try `pip install wandb`")

    args.prefetcher = not args.no_prefetcher
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1
    args.device = 'cuda:0'
    args.world_size = 1
    args.rank = 0  # global rank
    if args.distributed:
        args.device = 'cuda:%d' % args.local_rank
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()
        args.rank = torch.distributed.get_rank()
        _logger.info(
            'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
            % (args.rank, args.world_size))
    else:
        _logger.info('Training with a single process on 1 GPUs.')
    assert args.rank >= 0

    # resolve AMP arguments based on PyTorch / Apex availability
    use_amp = None
    if args.amp:
        # `--amp` chooses native amp before apex (APEX ver not actively maintained)
        if has_native_amp:
            args.native_amp = True
        elif has_apex:
            args.apex_amp = True
    if args.apex_amp and has_apex:
        use_amp = 'apex'
    elif args.native_amp and has_native_amp:
        use_amp = 'native'
    elif args.apex_amp or args.native_amp:
        _logger.warning(
            "Neither APEX or native Torch AMP is available, using float32. "
            "Install NVIDA apex or upgrade to PyTorch 1.6")

    random_seed(args.seed, args.rank)

    model_KD = None
    if args.kd_model_path is not None:
        model_KD = build_kd_model(args)

    model = create_model(
        args.model,
        pretrained=args.pretrained,
        num_classes=args.num_classes,
        drop_rate=args.drop,
        drop_connect_rate=args.drop_connect,  # DEPRECATED, use drop_path
        drop_path_rate=args.drop_path,
        drop_block_rate=args.drop_block,
        global_pool=args.gp,
        bn_momentum=args.bn_momentum,
        bn_eps=args.bn_eps,
        scriptable=args.torchscript,
        checkpoint_path=args.initial_checkpoint)
    if args.num_classes is None:
        assert hasattr(
            model, 'num_classes'
        ), 'Model must have `num_classes` attr if not set on cmd line/config.'
        args.num_classes = model.num_classes  # FIXME handle model default vs config num_classes more elegantly

    if args.local_rank == 0:
        _logger.info(
            f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}'
        )

    data_config = resolve_data_config(vars(args),
                                      model=model,
                                      verbose=args.local_rank == 0)

    # setup augmentation batch splits for contrastive loss or split bn
    num_aug_splits = 0
    if args.aug_splits > 0:
        assert args.aug_splits > 1, 'A split of 1 makes no sense'
        num_aug_splits = args.aug_splits

    # enable split bn (separate bn stats per batch-portion)
    if args.split_bn:
        assert num_aug_splits > 1 or args.resplit
        model = convert_splitbn_model(model, max(num_aug_splits, 2))

    # move model to GPU, enable channels last layout if set
    model.cuda()
    if args.channels_last:
        model = model.to(memory_format=torch.channels_last)

    # setup synchronized BatchNorm for distributed training
    if args.distributed and args.sync_bn:
        assert not args.split_bn
        if has_apex and use_amp == 'apex':
            # Apex SyncBN preferred unless native amp is activated
            model = convert_syncbn_model(model)
        else:
            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
        if args.local_rank == 0:
            _logger.info(
                'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
                'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.'
            )

    if args.torchscript:
        assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
        assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
        model = torch.jit.script(model)

    optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args))

    # setup automatic mixed-precision (AMP) loss scaling and op casting
    amp_autocast = suppress  # do nothing
    loss_scaler = None
    if use_amp == 'apex':
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
        loss_scaler = ApexScaler()
        if args.local_rank == 0:
            _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
    elif use_amp == 'native':
        amp_autocast = torch.cuda.amp.autocast
        loss_scaler = NativeScaler()
        if args.local_rank == 0:
            _logger.info(
                'Using native Torch AMP. Training in mixed precision.')
    else:
        if args.local_rank == 0:
            _logger.info('AMP not enabled. Training in float32.')

    # optionally resume from a checkpoint
    resume_epoch = None
    if args.resume:
        resume_epoch = resume_checkpoint(
            model,
            args.resume,
            optimizer=None if args.no_resume_opt else optimizer,
            loss_scaler=None if args.no_resume_opt else loss_scaler,
            log_info=args.local_rank == 0)

    # setup exponential moving average of model weights, SWA could be used here too
    model_ema = None
    if args.model_ema:
        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
        model_ema = ModelEmaV2(
            model,
            decay=args.model_ema_decay,
            device='cpu' if args.model_ema_force_cpu else None)
        if args.resume:
            load_checkpoint(model_ema.module, args.resume, use_ema=True)

    # setup distributed training
    if args.distributed:
        if has_apex and use_amp == 'apex':
            # Apex DDP preferred unless native amp is activated
            if args.local_rank == 0:
                _logger.info("Using NVIDIA APEX DistributedDataParallel.")
            model = ApexDDP(model, delay_allreduce=True)
        else:
            if args.local_rank == 0:
                _logger.info("Using native Torch DistributedDataParallel.")
            model = NativeDDP(model,
                              device_ids=[args.local_rank],
                              broadcast_buffers=not args.no_ddp_bb)
        # NOTE: EMA model does not need to be wrapped by DDP

    # setup learning rate schedule and starting epoch
    lr_scheduler, num_epochs = create_scheduler(args, optimizer)
    start_epoch = 0
    if args.start_epoch is not None:
        # a specified start_epoch will always override the resume epoch
        start_epoch = args.start_epoch
    elif resume_epoch is not None:
        start_epoch = resume_epoch
    if lr_scheduler is not None and start_epoch > 0:
        lr_scheduler.step(start_epoch)

    if args.local_rank == 0:
        _logger.info('Scheduled epochs: {}'.format(num_epochs))

    # create the train and eval datasets
    dataset_train = create_dataset(args.dataset,
                                   root=args.data_dir,
                                   split=args.train_split,
                                   is_training=True,
                                   class_map=args.class_map,
                                   download=args.dataset_download,
                                   batch_size=args.batch_size,
                                   repeats=args.epoch_repeats)
    dataset_eval = create_dataset(args.dataset,
                                  root=args.data_dir,
                                  split=args.val_split,
                                  is_training=False,
                                  class_map=args.class_map,
                                  download=args.dataset_download,
                                  batch_size=args.batch_size)

    # setup mixup / cutmix
    collate_fn = None
    mixup_fn = None
    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
    if mixup_active:
        mixup_args = dict(mixup_alpha=args.mixup,
                          cutmix_alpha=args.cutmix,
                          cutmix_minmax=args.cutmix_minmax,
                          prob=args.mixup_prob,
                          switch_prob=args.mixup_switch_prob,
                          mode=args.mixup_mode,
                          label_smoothing=args.smoothing,
                          num_classes=args.num_classes)
        if args.prefetcher:
            assert not num_aug_splits  # collate conflict (need to support deinterleaving in collate mixup)
            collate_fn = FastCollateMixup(**mixup_args)
        else:
            mixup_fn = Mixup(**mixup_args)

    # wrap dataset in AugMix helper
    if num_aug_splits > 1:
        dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)

    # create data loaders w/ augmentation pipeiine
    train_interpolation = args.train_interpolation
    if args.no_aug or not train_interpolation:
        train_interpolation = data_config['interpolation']
    loader_train = create_loader(
        dataset_train,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        is_training=True,
        use_prefetcher=args.prefetcher,
        no_aug=args.no_aug,
        re_prob=args.reprob,
        re_mode=args.remode,
        re_count=args.recount,
        re_split=args.resplit,
        scale=args.scale,
        ratio=args.ratio,
        hflip=args.hflip,
        vflip=args.vflip,
        color_jitter=args.color_jitter,
        auto_augment=args.aa,
        num_aug_repeats=args.aug_repeats,
        num_aug_splits=num_aug_splits,
        interpolation=train_interpolation,
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        collate_fn=collate_fn,
        pin_memory=args.pin_mem,
        use_multi_epochs_loader=args.use_multi_epochs_loader,
        worker_seeding=args.worker_seeding,
    )

    loader_eval = create_loader(
        dataset_eval,
        input_size=data_config['input_size'],
        batch_size=args.validation_batch_size or args.batch_size,
        is_training=False,
        use_prefetcher=args.prefetcher,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        crop_pct=data_config['crop_pct'],
        pin_memory=args.pin_mem,
    )

    # setup loss function
    if args.jsd_loss:
        assert num_aug_splits > 1  # JSD only valid with aug splits set
        train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits,
                                        smoothing=args.smoothing)
    elif mixup_active:
        # smoothing is handled with mixup target transform which outputs sparse, soft targets
        if args.bce_loss:
            train_loss_fn = BinaryCrossEntropy(
                target_threshold=args.bce_target_thresh)
        else:
            train_loss_fn = SoftTargetCrossEntropy()
    elif args.smoothing:
        if args.bce_loss:
            train_loss_fn = BinaryCrossEntropy(
                smoothing=args.smoothing,
                target_threshold=args.bce_target_thresh)
        else:
            train_loss_fn = LabelSmoothingCrossEntropy(
                smoothing=args.smoothing)
    else:
        train_loss_fn = nn.CrossEntropyLoss()
    train_loss_fn = train_loss_fn.cuda()
    validate_loss_fn = nn.CrossEntropyLoss().cuda()

    # setup checkpoint saver and eval metric tracking
    eval_metric = args.eval_metric
    best_metric = None
    best_epoch = None
    saver = None
    output_dir = None
    if args.rank == 0:
        if args.experiment:
            exp_name = args.experiment
        else:
            exp_name = '-'.join([
                datetime.now().strftime("%Y%m%d-%H%M%S"),
                safe_model_name(args.model),
                str(data_config['input_size'][-1])
            ])
        output_dir = get_outdir(
            args.output if args.output else './output/train', exp_name)
        decreasing = True if eval_metric == 'loss' else False
        saver = CheckpointSaver(model=model,
                                optimizer=optimizer,
                                args=args,
                                model_ema=model_ema,
                                amp_scaler=loss_scaler,
                                checkpoint_dir=output_dir,
                                recovery_dir=output_dir,
                                decreasing=decreasing,
                                max_history=args.checkpoint_hist)
        with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
            f.write(args_text)

    try:
        for epoch in range(start_epoch, num_epochs):
            if args.distributed and hasattr(loader_train.sampler, 'set_epoch'):
                loader_train.sampler.set_epoch(epoch)

            train_metrics = train_one_epoch(epoch,
                                            model,
                                            loader_train,
                                            optimizer,
                                            train_loss_fn,
                                            args,
                                            lr_scheduler=lr_scheduler,
                                            saver=saver,
                                            output_dir=output_dir,
                                            amp_autocast=amp_autocast,
                                            loss_scaler=loss_scaler,
                                            model_ema=model_ema,
                                            mixup_fn=mixup_fn,
                                            model_KD=model_KD)

            if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
                if args.local_rank == 0:
                    _logger.info(
                        "Distributing BatchNorm running means and vars")
                distribute_bn(model, args.world_size, args.dist_bn == 'reduce')

            eval_metrics = validate(model,
                                    loader_eval,
                                    validate_loss_fn,
                                    args,
                                    amp_autocast=amp_autocast)

            if model_ema is not None and not args.model_ema_force_cpu:
                if args.distributed and args.dist_bn in ('broadcast',
                                                         'reduce'):
                    distribute_bn(model_ema, args.world_size,
                                  args.dist_bn == 'reduce')
                ema_eval_metrics = validate(model_ema.module,
                                            loader_eval,
                                            validate_loss_fn,
                                            args,
                                            amp_autocast=amp_autocast,
                                            log_suffix=' (EMA)')
                eval_metrics = ema_eval_metrics

            if lr_scheduler is not None:
                # step LR for next epoch
                lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])

            if output_dir is not None:
                update_summary(epoch,
                               train_metrics,
                               eval_metrics,
                               os.path.join(output_dir, 'summary.csv'),
                               write_header=best_metric is None,
                               log_wandb=args.log_wandb and has_wandb)

            if saver is not None:
                # save proper checkpoint with eval metric
                save_metric = eval_metrics[eval_metric]
                best_metric, best_epoch = saver.save_checkpoint(
                    epoch, metric=save_metric)

    except KeyboardInterrupt:
        pass
    if best_metric is not None:
        _logger.info('*** Best metric: {0} (epoch {1})'.format(
            best_metric, best_epoch))
def validate(args):
    # might as well try to validate something
    args.pretrained = args.pretrained or not args.checkpoint
    args.prefetcher = not args.no_prefetcher
    if args.legacy_jit:
        set_jit_legacy()

    # create model
    if 'inception' in args.model:
        model = create_model(
            args.model,
            pretrained=args.pretrained,
            num_classes=args.num_classes,
            aux_logits=True,  # ! add aux loss
            in_chans=3,
            scriptable=args.torchscript)
    else:
        model = create_model(args.model,
                             pretrained=args.pretrained,
                             num_classes=args.num_classes,
                             in_chans=3,
                             scriptable=args.torchscript)

    # ! add more layer to classifier layer
    if args.create_classifier_layerfc:
        model.global_pool, model.classifier = create_classifier_layerfc(
            model.num_features, model.num_classes)

    if args.checkpoint:
        load_checkpoint(model, args.checkpoint, args.use_ema)

    param_count = sum([m.numel() for m in model.parameters()])
    _logger.info('Model %s created, param count: %d' %
                 (args.model, param_count))

    data_config = resolve_data_config(vars(args), model=model)
    model, test_time_pool = apply_test_time_pool(model, data_config, args)

    if args.torchscript:
        torch.jit.optimized_execution(True)
        model = torch.jit.script(model)

    if args.amp:
        model = amp.initialize(model.cuda(), opt_level='O1')
    else:
        model = model.cuda()

    if args.num_gpu > 1:
        model = torch.nn.DataParallel(model,
                                      device_ids=list(range(args.num_gpu)))

    if args.has_eval_label:
        criterion = nn.CrossEntropyLoss().cuda()  # ! don't have gold label

    if os.path.splitext(args.data)[1] == '.tar' and os.path.isfile(args.data):
        dataset = DatasetTar(args.data,
                             load_bytes=args.tf_preprocessing,
                             class_map=args.class_map)
    else:
        dataset = Dataset(args.data,
                          load_bytes=args.tf_preprocessing,
                          class_map=args.class_map,
                          args=args)

    if args.valid_labels:
        with open(args.valid_labels,
                  'r') as f:  # @valid_labels is index numbering
            valid_labels = {int(line.rstrip()) for line in f}
            valid_labels = [i in valid_labels for i in range(args.num_classes)]
    else:
        valid_labels = None

    if args.real_labels:
        real_labels = RealLabelsImagenet(dataset.filenames(basename=True),
                                         real_json=args.real_labels)
    else:
        real_labels = None

    crop_pct = 1.0 if test_time_pool else data_config['crop_pct']

    loader = create_loader(
        dataset,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        use_prefetcher=args.prefetcher,
        interpolation=data_config[
            'interpolation'],  # 'blank' is default Image.BILINEAR https://github.com/rwightman/pytorch-image-models/blob/470220b1f4c61ad7deb16dbfb8917089e842cd2a/timm/data/transforms.py#L43
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        crop_pct=crop_pct,
        pin_memory=args.pin_mem,
        tf_preprocessing=args.tf_preprocessing,
        auto_augment=args.aa,
        scale=args.scale,
        ratio=args.ratio,
        hflip=args.hflip,
        vflip=args.vflip,
        color_jitter=args.color_jitter,
        args=args)

    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    topk = AverageMeter()

    prediction = None  # ! need to save output
    true_label = None

    model.eval()
    with torch.no_grad():
        # warmup, reduce variability of first batch time, especially for comparing torchscript vs non
        input = torch.randn((args.batch_size, ) +
                            data_config['input_size']).cuda()
        model(input)
        end = time.time()
        for batch_idx, (input,
                        target) in enumerate(loader):  # ! not have real label

            if args.has_eval_label:  # ! just save true labels anyway... why not
                if true_label is None: true_label = target.cpu().data.numpy()
                else:
                    true_label = np.concatenate(
                        (true_label, target.cpu().data.numpy()), axis=0)

            if args.no_prefetcher:
                target = target.cuda()
                input = input.cuda()
                if args.fp16:
                    input = input.half()

            # compute output
            output = model(input)
            if isinstance(output, (tuple, list)):
                output = output[0]  # ! some model returns both loss + aux loss

            if valid_labels is not None:
                output = output[:,
                                valid_labels]  # ! keep only valid labels ? good to eval by class.

            # ! save prediction, don't append too slow ... whatever ?
            # ! are names of files also sorted ?
            if prediction is None:
                prediction = output.cpu().data.numpy()  # batchsize x label
            else:  # stack
                prediction = np.concatenate(
                    (prediction, output.cpu().data.numpy()), axis=0)

            if real_labels is not None:
                real_labels.add_result(output)

            if args.has_eval_label:
                # measure accuracy and record loss
                loss = criterion(
                    output, target)  # ! don't have gold standard on testset
                acc1, acc5 = accuracy(output.data, target, topk=(1, args.topk))
                losses.update(loss.item(), input.size(0))
                top1.update(acc1.item(), input.size(0))
                topk.update(acc5.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if args.has_eval_label and (batch_idx % args.log_freq == 0):
                _logger.info(
                    'Test: [{0:>4d}/{1}]  '
                    'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                    'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                    'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f})  '
                    'Acc@topk: {topk.val:>7.3f} ({topk.avg:>7.3f})'.format(
                        batch_idx,
                        len(loader),
                        batch_time=batch_time,
                        rate_avg=input.size(0) / batch_time.avg,
                        loss=losses,
                        top1=top1,
                        topk=topk))

    if not args.has_eval_label:
        top1a, topka = 0, 0  # just dummy, because we don't know ground labels
    else:
        if real_labels is not None:
            # real labels mode replaces topk values at the end
            top1a, topka = real_labels.get_accuracy(
                k=1), real_labels.get_accuracy(k=args.topk)
        else:
            top1a, topka = top1.avg, topk.avg

    results = OrderedDict(top1=round(top1a, 4),
                          top1_err=round(100 - top1a, 4),
                          topk=round(topka, 4),
                          topk_err=round(100 - topka, 4),
                          param_count=round(param_count / 1e6, 2),
                          img_size=data_config['input_size'][-1],
                          cropt_pct=crop_pct,
                          interpolation=data_config['interpolation'])

    _logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@topk {:.3f} ({:.3f})'.format(
        results['top1'], results['top1_err'], results['topk'],
        results['topk_err']))

    return results, prediction, true_label
def validate(args):
    # might as well try to validate something
    args.pretrained = args.pretrained or not args.checkpoint
    args.prefetcher = not args.no_prefetcher

    # create model
    model = create_model(
        args.model,
        num_classes=args.num_classes,
        in_chans=3,
        pretrained=args.pretrained)
    if args.checkpoint:
        load_checkpoint(model, args.checkpoint, args.use_ema)

    param_count = sum([m.numel() for m in model.parameters()])
    logging.info('Model %s created, param count: %d' % (args.model, param_count))

    data_config = resolve_data_config(vars(args), model=model)
    model, test_time_pool = apply_test_time_pool(model, data_config, args)

    if args.torchscript:
        torch.jit.optimized_execution(True)
        model = torch.jit.script(model)

    if args.amp:
        model = amp.initialize(model.cuda(), opt_level='O1')
    else:
        model = model.cuda()

    if args.num_gpu > 1:
        model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu)))

    criterion = nn.CrossEntropyLoss().cuda()

    #from torchvision.datasets import ImageNet
    #dataset = ImageNet(args.data, split='val')
    if os.path.splitext(args.data)[1] == '.tar' and os.path.isfile(args.data):
        dataset = DatasetTar(args.data, load_bytes=args.tf_preprocessing, class_map=args.class_map)
    else:
        dataset = Dataset(args.data, load_bytes=args.tf_preprocessing, class_map=args.class_map)

    crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
    loader = create_loader(
        dataset,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        use_prefetcher=args.prefetcher,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        crop_pct=crop_pct,
        pin_memory=args.pin_mem,
        tf_preprocessing=args.tf_preprocessing)

    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    model.eval()
    end = time.time()
    with torch.no_grad():
        for i, (input, target) in enumerate(loader):
            if args.no_prefetcher:
                target = target.cuda()
                input = input.cuda()
                if args.fp16:
                    input = input.half()

            # compute output
            output = model(input)
            loss = criterion(output, target)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output.data, target, topk=(1, 2))
            losses.update(loss.item(), input.size(0))
            top1.update(acc1.item(), input.size(0))
            top5.update(acc5.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.log_freq == 0:
                logging.info(
                    'Test: [{0:>4d}/{1}]  '
                    'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                    'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                    'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f})  '
                    'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
                        i, len(loader), batch_time=batch_time,
                        rate_avg=input.size(0) / batch_time.avg,
                        loss=losses, top1=top1, top5=top5))

    results = OrderedDict(
        top1=round(top1.avg, 4), top1_err=round(100 - top1.avg, 4),
        top5=round(top5.avg, 4), top5_err=round(100 - top5.avg, 4),
        param_count=round(param_count / 1e6, 2),
        img_size=data_config['input_size'][-1],
        cropt_pct=crop_pct,
        interpolation=data_config['interpolation'])

    logging.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format(
       results['top1'], results['top1_err'], results['top5'], results['top5_err']))

    return results
示例#25
0
文件: train.py 项目: zeta1999/Cream
def main():
    args, cfg = parse_config_args('super net training')

    # resolve logging
    output_dir = os.path.join(
        cfg.SAVE_PATH, "{}-{}".format(datetime.date.today().strftime('%m%d'),
                                      cfg.MODEL))

    if args.local_rank == 0:
        logger = get_logger(os.path.join(output_dir, "train.log"))
    else:
        logger = None

    # initialize distributed parameters
    torch.cuda.set_device(args.local_rank)
    torch.distributed.init_process_group(backend='nccl', init_method='env://')
    if args.local_rank == 0:
        logger.info('Training on Process %d with %d GPUs.', args.local_rank,
                    cfg.NUM_GPU)

    # fix random seeds
    torch.manual_seed(cfg.SEED)
    torch.cuda.manual_seed_all(cfg.SEED)
    np.random.seed(cfg.SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # generate supernet
    model, sta_num, resolution = gen_supernet(
        flops_minimum=cfg.SUPERNET.FLOPS_MINIMUM,
        flops_maximum=cfg.SUPERNET.FLOPS_MAXIMUM,
        num_classes=cfg.DATASET.NUM_CLASSES,
        drop_rate=cfg.NET.DROPOUT_RATE,
        global_pool=cfg.NET.GP,
        resunit=cfg.SUPERNET.RESUNIT,
        dil_conv=cfg.SUPERNET.DIL_CONV,
        slice=cfg.SUPERNET.SLICE,
        verbose=cfg.VERBOSE,
        logger=logger)

    # initialize meta matching networks
    MetaMN = MetaMatchingNetwork(cfg)

    # number of choice blocks in supernet
    choice_num = len(model.blocks[1][0])
    if args.local_rank == 0:
        logger.info('Supernet created, param count: %d',
                    (sum([m.numel() for m in model.parameters()])))
        logger.info('resolution: %d', (resolution))
        logger.info('choice number: %d', (choice_num))

    #initialize prioritized board
    prioritized_board = PrioritizedBoard(cfg,
                                         CHOICE_NUM=choice_num,
                                         sta_num=sta_num)

    # initialize flops look-up table
    model_est = FlopsEst(model)

    # optionally resume from a checkpoint
    optimizer_state = None
    resume_epoch = None
    if cfg.AUTO_RESUME:
        optimizer_state, resume_epoch = resume_checkpoint(
            model, cfg.RESUME_PATH)

    # create optimizer and resume from checkpoint
    optimizer = create_optimizer_supernet(cfg, model, USE_APEX)
    if optimizer_state is not None:
        optimizer.load_state_dict(optimizer_state['optimizer'])
    model = model.cuda()

    # convert model to distributed mode
    if cfg.BATCHNORM.SYNC_BN:
        try:
            if USE_APEX:
                model = convert_syncbn_model(model)
            else:
                model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
            if args.local_rank == 0:
                logger.info('Converted model to use Synchronized BatchNorm.')
        except Exception as exception:
            logger.info(
                'Failed to enable Synchronized BatchNorm. '
                'Install Apex or Torch >= 1.1 with Exception %s', exception)
    if USE_APEX:
        model = DDP(model, delay_allreduce=True)
    else:
        if args.local_rank == 0:
            logger.info(
                "Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP."
            )
        # can use device str in Torch >= 1.1
        model = DDP(model, device_ids=[args.local_rank])

    # create learning rate scheduler
    lr_scheduler, num_epochs = create_supernet_scheduler(cfg, optimizer)

    start_epoch = resume_epoch if resume_epoch is not None else 0
    if start_epoch > 0:
        lr_scheduler.step(start_epoch)

    if args.local_rank == 0:
        logger.info('Scheduled epochs: %d', num_epochs)

    # imagenet train dataset
    train_dir = os.path.join(cfg.DATA_DIR, 'train')
    if not os.path.exists(train_dir):
        logger.info('Training folder does not exist at: %s', train_dir)
        sys.exit()

    dataset_train = Dataset(train_dir)
    loader_train = create_loader(dataset_train,
                                 input_size=(3, cfg.DATASET.IMAGE_SIZE,
                                             cfg.DATASET.IMAGE_SIZE),
                                 batch_size=cfg.DATASET.BATCH_SIZE,
                                 is_training=True,
                                 use_prefetcher=True,
                                 re_prob=cfg.AUGMENTATION.RE_PROB,
                                 re_mode=cfg.AUGMENTATION.RE_MODE,
                                 color_jitter=cfg.AUGMENTATION.COLOR_JITTER,
                                 interpolation='random',
                                 num_workers=cfg.WORKERS,
                                 distributed=True,
                                 collate_fn=None,
                                 crop_pct=DEFAULT_CROP_PCT,
                                 mean=IMAGENET_DEFAULT_MEAN,
                                 std=IMAGENET_DEFAULT_STD)

    # imagenet validation dataset
    eval_dir = os.path.join(cfg.DATA_DIR, 'val')
    if not os.path.isdir(eval_dir):
        logger.info('Validation folder does not exist at: %s', eval_dir)
        sys.exit()
    dataset_eval = Dataset(eval_dir)
    loader_eval = create_loader(dataset_eval,
                                input_size=(3, cfg.DATASET.IMAGE_SIZE,
                                            cfg.DATASET.IMAGE_SIZE),
                                batch_size=4 * cfg.DATASET.BATCH_SIZE,
                                is_training=False,
                                use_prefetcher=True,
                                num_workers=cfg.WORKERS,
                                distributed=True,
                                crop_pct=DEFAULT_CROP_PCT,
                                mean=IMAGENET_DEFAULT_MEAN,
                                std=IMAGENET_DEFAULT_STD,
                                interpolation=cfg.DATASET.INTERPOLATION)

    # whether to use label smoothing
    if cfg.AUGMENTATION.SMOOTHING > 0.:
        train_loss_fn = LabelSmoothingCrossEntropy(
            smoothing=cfg.AUGMENTATION.SMOOTHING).cuda()
        validate_loss_fn = nn.CrossEntropyLoss().cuda()
    else:
        train_loss_fn = nn.CrossEntropyLoss().cuda()
        validate_loss_fn = train_loss_fn

    # initialize training parameters
    eval_metric = cfg.EVAL_METRICS
    best_metric, best_epoch, saver, best_children_pool = None, None, None, []
    if args.local_rank == 0:
        decreasing = True if eval_metric == 'loss' else False
        saver = CheckpointSaver(checkpoint_dir=output_dir,
                                decreasing=decreasing)

    # training scheme
    try:
        for epoch in range(start_epoch, num_epochs):
            loader_train.sampler.set_epoch(epoch)

            # train one epoch
            train_metrics = train_epoch(epoch,
                                        model,
                                        loader_train,
                                        optimizer,
                                        train_loss_fn,
                                        prioritized_board,
                                        MetaMN,
                                        cfg,
                                        lr_scheduler=lr_scheduler,
                                        saver=saver,
                                        output_dir=output_dir,
                                        logger=logger,
                                        est=model_est,
                                        local_rank=args.local_rank)

            # evaluate one epoch
            eval_metrics = validate(model,
                                    loader_eval,
                                    validate_loss_fn,
                                    prioritized_board,
                                    MetaMN,
                                    cfg,
                                    local_rank=args.local_rank,
                                    logger=logger)

            update_summary(epoch,
                           train_metrics,
                           eval_metrics,
                           os.path.join(output_dir, 'summary.csv'),
                           write_header=best_metric is None)

            if saver is not None:
                # save proper checkpoint with eval metric
                save_metric = eval_metrics[eval_metric]
                best_metric, best_epoch = saver.save_checkpoint(
                    model, optimizer, cfg, epoch=epoch, metric=save_metric)

    except KeyboardInterrupt:
        pass
def train_imagenet():
    torch.manual_seed(42)

    device = xm.xla_device()
    # model = get_model_property('model_fn')().to(device)
    model = create_model(
        FLAGS.model,
        pretrained=FLAGS.pretrained,
        num_classes=FLAGS.num_classes,
        drop_rate=FLAGS.drop,
        global_pool=FLAGS.gp,
        bn_tf=FLAGS.bn_tf,
        bn_momentum=FLAGS.bn_momentum,
        bn_eps=FLAGS.bn_eps,
        drop_connect_rate=0.2,
        checkpoint_path=FLAGS.initial_checkpoint,
        args = FLAGS).to(device)
    model_ema=None
    if FLAGS.model_ema:
        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
        # import pdb; pdb.set_trace()
        model_e = create_model(
            FLAGS.model,
            pretrained=FLAGS.pretrained,
            num_classes=FLAGS.num_classes,
            drop_rate=FLAGS.drop,
            global_pool=FLAGS.gp,
            bn_tf=FLAGS.bn_tf,
            bn_momentum=FLAGS.bn_momentum,
            bn_eps=FLAGS.bn_eps,
            drop_connect_rate=0.2,
            checkpoint_path=FLAGS.initial_checkpoint,
            args = FLAGS).to(device)
        model_ema = ModelEma(
            model_e,
            decay=FLAGS.model_ema_decay,
            device='cpu' if FLAGS.model_ema_force_cpu else '',
            resume=FLAGS.resume)
    print('==> Preparing data..')
    img_dim = 224
    if FLAGS.fake_data:
        train_dataset_len = 1200000  # Roughly the size of Imagenet dataset.
        train_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim),
                    torch.zeros(FLAGS.batch_size, dtype=torch.int64)),
            sample_count=train_dataset_len // FLAGS.batch_size //
            xm.xrt_world_size())
        test_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim),
                    torch.zeros(FLAGS.batch_size, dtype=torch.int64)),
            sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size())
    # else:
    #     normalize = transforms.Normalize(
    #         mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    #     train_dataset = torchvision.datasets.ImageFolder(
    #         os.path.join(FLAGS.data, 'train'),
    #         transforms.Compose([
    #             transforms.RandomResizedCrop(img_dim),
    #             transforms.RandomHorizontalFlip(),
    #             transforms.ToTensor(),
    #             normalize,
    #         ]))
    #     train_dataset_len = len(train_dataset.imgs)
    #     resize_dim = max(img_dim, 256)
    #     test_dataset = torchvision.datasets.ImageFolder(
    #         os.path.join(FLAGS.data, 'val'),
    #         # Matches Torchvision's eval transforms except Torchvision uses size
    #         # 256 resize for all models both here and in the train loader. Their
    #         # version crashes during training on 299x299 images, e.g. inception.
    #         transforms.Compose([
    #             transforms.Resize(resize_dim),
    #             transforms.CenterCrop(img_dim),
    #             transforms.ToTensor(),
    #             normalize,
    #         ]))

    #     train_sampler = None
    #     if xm.xrt_world_size() > 1:
    #         train_sampler = torch.utils.data.distributed.DistributedSampler(
    #             train_dataset,
    #             num_replicas=xm.xrt_world_size(),
    #             rank=xm.get_ordinal(),
    #             shuffle=True)
    #     train_loader = torch.utils.data.DataLoader(
    #         train_dataset,
    #         batch_size=FLAGS.batch_size,
    #         sampler=train_sampler,
    #         shuffle=False if train_sampler else True,
    #         num_workers=FLAGS.workers)
    #     test_loader = torch.utils.data.DataLoader(
    #         test_dataset,
    #         batch_size=FLAGS.batch_size,
    #         shuffle=False,
    #         num_workers=FLAGS.workers)
    else:
        train_dir = os.path.join(FLAGS.data, 'train')
        data_config = resolve_data_config(model, FLAGS, verbose=FLAGS.local_rank == 0)
        dataset_train = Dataset(train_dir)

        collate_fn = None
        if not FLAGS.no_prefetcher and FLAGS.mixup > 0:
            collate_fn = FastCollateMixup(FLAGS.mixup, FLAGS.smoothing, FLAGS.num_classes)
        train_loader = create_loader(
            dataset_train,
            input_size=data_config['input_size'],
            batch_size=FLAGS.batch_size,
            is_training=True,
            use_prefetcher=not FLAGS.no_prefetcher,
            rand_erase_prob=FLAGS.reprob,
            rand_erase_mode=FLAGS.remode,
            interpolation='bicubic',  # FIXME cleanly resolve this? data_config['interpolation'],
            mean=data_config['mean'],
            std=data_config['std'],
            num_workers=FLAGS.workers,
            distributed=FLAGS.distributed,
            collate_fn=collate_fn,
            use_auto_aug=FLAGS.auto_augment,
            use_mixcut=FLAGS.mixcut,
        )

        eval_dir = os.path.join(FLAGS.data, 'val')
        train_dataset_len = len(train_loader)
        if not os.path.isdir(eval_dir):
            logging.error('Validation folder does not exist at: {}'.format(eval_dir))
            exit(1)
        dataset_eval = Dataset(eval_dir)

        test_loader = create_loader(
            dataset_eval,
            input_size=data_config['input_size'],
            batch_size = FLAGS.batch_size,
            is_training=False,
            use_prefetcher=FLAGS.prefetcher,
            interpolation=data_config['interpolation'],
            mean=data_config['mean'],
            std=data_config['std'],
            num_workers=FLAGS.workers,
            distributed=FLAGS.distributed,
        )


    writer = None
    start_epoch = 0
    if FLAGS.output and xm.is_master_ordinal():
        writer = SummaryWriter(log_dir=FLAGS.output)
    optimizer = create_optimizer(flags, model)
    lr_scheduler, num_epochs = create_scheduler(flags, optimizer)
    if start_epoch > 0:
        lr_scheduler.step(start_epoch)
    # optimizer = optim.SGD(
    #     model.parameters(),
    #     lr=FLAGS.lr,
    #     momentum=FLAGS.momentum,
    #     weight_decay=5e-4)
    num_training_steps_per_epoch = train_dataset_len // (
        FLAGS.batch_size * xm.xrt_world_size())
        
    lr_scheduler = schedulers.wrap_optimizer_with_scheduler(
        optimizer,
        scheduler_type=getattr(FLAGS, 'lr_scheduler_type', None),
        scheduler_divisor=getattr(FLAGS, 'lr_scheduler_divisor', None),
        scheduler_divide_every_n_epochs=getattr(
            FLAGS, 'lr_scheduler_divide_every_n_epochs', None),
        num_steps_per_epoch=num_training_steps_per_epoch,
        summary_writer=writer)
    train_loss_fn = LabelSmoothingCrossEntropy(smoothing=flags.smoothing)
    validate_loss_fn = nn.CrossEntropyLoss()
    # loss_fn = nn.CrossEntropyLoss()

    def train_loop_fn(loader):
        tracker = xm.RateTracker()
        model.train()
        for x, (data, target) in loader:
            optimizer.zero_grad()
            output = model(data)
            loss = train_loss_fn(output, target)
            loss.backward()
            xm.optimizer_step(optimizer)
            tracker.add(FLAGS.batch_size)
            if model_ema is not None:
                model_ema.update(model)
            if lr_scheduler:
                lr_scheduler.step()
            if x % FLAGS.log_steps == 0:
                test_utils.print_training_update(device, x, loss.item(), tracker.rate(),
                                            tracker.global_rate())

    def test_loop_fn(loader):
        total_samples = 0
        correct = 0
        model.eval()
        for x, (data, target) in loader:
            output = model(data)
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
            total_samples += data.size()[0]

        accuracy = 100.0 * correct / total_samples
        test_utils.print_test_update(device, accuracy)
        return accuracy
    def test_loop_fn_ema(loader):
            total_samples = 0
            correct = 0
            model_ema.eval()
            for x, (data, target) in loader:
                output = model_ema(data)
                pred = output.max(1, keepdim=True)[1]
                correct += pred.eq(target.view_as(pred)).sum().item()
                total_samples += data.size()[0]

            accuracy = 100.0 * correct / total_samples
            test_utils.print_test_update(device, accuracy)
            return accuracy
    accuracy = 0.0
    for epoch in range(1, FLAGS.epochs + 1):
        para_loader = dp.ParallelLoader(train_loader, [device])
        train_loop_fn(para_loader.per_device_loader(device))

        para_loader = dp.ParallelLoader(test_loader, [device])
        accuracy = test_loop_fn(para_loader.per_device_loader(device))
        print('Epoch: {}, Mean Accuracy: {:.2f}%'.format(epoch, accuracy))
        if model_ema is not None:
            accuracy = test_loop_fn_ema(para_loader.per_device_loader(device))
            print('Epoch: {}, Mean Accuracy: {:.2f}%'.format(epoch, accuracy))
        test_utils.add_scalar_to_summary(writer, 'Accuracy/test', accuracy, epoch)

        if FLAGS.metrics_debug:
            print(torch_xla._XLAC._xla_metrics_report())

    return accuracy
示例#27
0
def main():
    setup_default_logging()
    args, args_text = _parse_args()

    args.prefetcher = not args.no_prefetcher
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1
        if args.distributed and args.num_gpu > 1:
            _logger.warning(
                'Using more than one GPU per process in distributed mode is not allowed.Setting num_gpu to 1.')
            args.num_gpu = 1

    args.device = 'cuda:0'
    args.world_size = 1
    args.rank = 0  # global rank
    if args.distributed:
        args.num_gpu = 1
        args.device = 'cuda:%d' % args.local_rank
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl', init_method='env://')
        args.world_size = torch.distributed.get_world_size()
        args.rank = torch.distributed.get_rank()
    assert args.rank >= 0

    if args.distributed:
        _logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
                     % (args.rank, args.world_size))
    else:
        _logger.info('Training with a single process on %d GPUs.' % args.num_gpu)

    torch.manual_seed(args.seed + args.rank)

    model = create_model(
        args.model,
        pretrained=args.pretrained,
        num_classes=args.num_classes,
        drop_rate=args.drop,
        drop_connect_rate=args.drop_connect,  # DEPRECATED, use drop_path
        drop_path_rate=args.drop_path,
        drop_block_rate=args.drop_block,
        global_pool=args.gp,
        bn_tf=args.bn_tf,
        bn_momentum=args.bn_momentum,
        bn_eps=args.bn_eps,
        checkpoint_path=args.initial_checkpoint)

    if args.local_rank == 0:
        _logger.info('Model %s created, param count: %d' %
                     (args.model, sum([m.numel() for m in model.parameters()])))

    data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)

    num_aug_splits = 0
    if args.aug_splits > 0:
        assert args.aug_splits > 1, 'A split of 1 makes no sense'
        num_aug_splits = args.aug_splits

    if args.split_bn:
        assert num_aug_splits > 1 or args.resplit
        model = convert_splitbn_model(model, max(num_aug_splits, 2))

    use_amp = None
    if args.amp:
        # for backwards compat, `--amp` arg tries apex before native amp
        if has_apex:
            args.apex_amp = True
        elif has_native_amp:
            args.native_amp = True
    if args.apex_amp and has_apex:
        use_amp = 'apex'
    elif args.native_amp and has_native_amp:
        use_amp = 'native'
    elif args.apex_amp or args.native_amp:
        _logger.warning("Neither APEX or native Torch AMP is available, using float32. "
                        "Install NVIDA apex or upgrade to PyTorch 1.6")

    if args.num_gpu > 1:
        if use_amp == 'apex':
            _logger.warning(
                'Apex AMP does not work well with nn.DataParallel, disabling. Use DDP or Torch AMP.')
            use_amp = None
        model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
        assert not args.channels_last, "Channels last not supported with DP, use DDP."
    else:
        model.cuda()
        if args.channels_last:
            model = model.to(memory_format=torch.channels_last)

    optimizer = create_optimizer(args, model)

    amp_autocast = suppress  # do nothing
    loss_scaler = None
    if use_amp == 'apex':
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
        loss_scaler = ApexScaler()
        if args.local_rank == 0:
            _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
    elif use_amp == 'native':
        amp_autocast = torch.cuda.amp.autocast
        loss_scaler = NativeScaler()
        if args.local_rank == 0:
            _logger.info('Using native Torch AMP. Training in mixed precision.')
    else:
        if args.local_rank == 0:
            _logger.info('AMP not enabled. Training in float32.')

    # optionally resume from a checkpoint
    resume_epoch = None
    if args.resume:
        resume_epoch = resume_checkpoint(
            model, args.resume,
            optimizer=None if args.no_resume_opt else optimizer,
            loss_scaler=None if args.no_resume_opt else loss_scaler,
            log_info=args.local_rank == 0)

    model_ema = None
    if args.model_ema:
        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
        model_ema = ModelEma(
            model,
            decay=args.model_ema_decay,
            device='cpu' if args.model_ema_force_cpu else '',
            resume=args.resume)

    if args.distributed:
        if args.sync_bn:
            assert not args.split_bn
            try:
                if has_apex and use_amp != 'native':
                    # Apex SyncBN preferred unless native amp is activated
                    model = convert_syncbn_model(model)
                else:
                    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
                if args.local_rank == 0:
                    _logger.info(
                        'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
                        'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
            except Exception as e:
                _logger.error('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1')
        if has_apex and use_amp != 'native':
            # Apex DDP preferred unless native amp is activated
            if args.local_rank == 0:
                _logger.info("Using NVIDIA APEX DistributedDataParallel.")
            model = ApexDDP(model, delay_allreduce=True)
        else:
            if args.local_rank == 0:
                _logger.info("Using native Torch DistributedDataParallel.")
            model = NativeDDP(model, device_ids=[args.local_rank])  # can use device str in Torch >= 1.1
        # NOTE: EMA model does not need to be wrapped by DDP

    lr_scheduler, num_epochs = create_scheduler(args, optimizer)
    start_epoch = 0
    if args.start_epoch is not None:
        # a specified start_epoch will always override the resume epoch
        start_epoch = args.start_epoch
    elif resume_epoch is not None:
        start_epoch = resume_epoch
    if lr_scheduler is not None and start_epoch > 0:
        lr_scheduler.step(start_epoch)

    if args.local_rank == 0:
        _logger.info('Scheduled epochs: {}'.format(num_epochs))

    train_dir = os.path.join(args.data, 'train')
    if not os.path.exists(train_dir):
        _logger.error('Training folder does not exist at: {}'.format(train_dir))
        exit(1)
    dataset_train = Dataset(train_dir)

    collate_fn = None
    mixup_fn = None
    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
    if mixup_active:
        mixup_args = dict(
            mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
            prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
            label_smoothing=args.smoothing, num_classes=args.num_classes)
        if args.prefetcher:
            assert not num_aug_splits  # collate conflict (need to support deinterleaving in collate mixup)
            collate_fn = FastCollateMixup(**mixup_args)
        else:
            mixup_fn = Mixup(**mixup_args)

    if num_aug_splits > 1:
        dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)

    train_interpolation = args.train_interpolation
    if args.no_aug or not train_interpolation:
        train_interpolation = data_config['interpolation']
    loader_train = create_loader(
        dataset_train,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        is_training=True,
        use_prefetcher=args.prefetcher,
        no_aug=args.no_aug,
        re_prob=args.reprob,
        re_mode=args.remode,
        re_count=args.recount,
        re_split=args.resplit,
        scale=args.scale,
        ratio=args.ratio,
        hflip=args.hflip,
        vflip=args.vflip,
        color_jitter=args.color_jitter,
        auto_augment=args.aa,
        num_aug_splits=num_aug_splits,
        interpolation=train_interpolation,
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        collate_fn=collate_fn,
        pin_memory=args.pin_mem,
        use_multi_epochs_loader=args.use_multi_epochs_loader
    )

    eval_dir = os.path.join(args.data, 'val')
    if not os.path.isdir(eval_dir):
        eval_dir = os.path.join(args.data, 'validation')
        if not os.path.isdir(eval_dir):
            _logger.error('Validation folder does not exist at: {}'.format(eval_dir))
            exit(1)
    dataset_eval = Dataset(eval_dir)

    loader_eval = create_loader(
        dataset_eval,
        input_size=data_config['input_size'],
        batch_size=args.validation_batch_size_multiplier * args.batch_size,
        is_training=False,
        use_prefetcher=args.prefetcher,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        crop_pct=data_config['crop_pct'],
        pin_memory=args.pin_mem,
    )

    if args.jsd:
        assert num_aug_splits > 1  # JSD only valid with aug splits set
        train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda()
    elif mixup_active:
        # smoothing is handled with mixup target transform
        train_loss_fn = SoftTargetCrossEntropy().cuda()
    elif args.smoothing:
        train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda()
    else:
        train_loss_fn = nn.CrossEntropyLoss().cuda()
    validate_loss_fn = nn.CrossEntropyLoss().cuda()

    eval_metric = args.eval_metric
    best_metric = None
    best_epoch = None
    saver = None
    output_dir = ''
    if args.local_rank == 0:
        output_base = args.output if args.output else './output'
        exp_name = '-'.join([
            datetime.now().strftime("%Y%m%d-%H%M%S"),
            args.model,
            str(data_config['input_size'][-1])
        ])
        output_dir = get_outdir(output_base, 'train', exp_name)
        decreasing = True if eval_metric == 'loss' else False
        saver = CheckpointSaver(
            model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler,
            checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing)
        with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
            f.write(args_text)

    try:
        for epoch in range(start_epoch, num_epochs):
            if args.distributed:
                loader_train.sampler.set_epoch(epoch)

            train_metrics = train_epoch(
                epoch, model, loader_train, optimizer, train_loss_fn, args,
                lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
                amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn)

            if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
                if args.local_rank == 0:
                    _logger.info("Distributing BatchNorm running means and vars")
                distribute_bn(model, args.world_size, args.dist_bn == 'reduce')

            eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast)

            if model_ema is not None and not args.model_ema_force_cpu:
                if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
                    distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
                ema_eval_metrics = validate(
                    model_ema.ema, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)')
                eval_metrics = ema_eval_metrics

            if lr_scheduler is not None:
                # step LR for next epoch
                lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])

            update_summary(
                epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
                write_header=best_metric is None)

            if saver is not None:
                # save proper checkpoint with eval metric
                save_metric = eval_metrics[eval_metric]
                best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric)

                # if saver.cmp(best_metric, save_metric):
                #     _logger.info(f"Metric is no longer improving [BEST: {best_metric}, CURRENT: {save_metric}]"
                #                  f"\nFinishing training process")
                #     if epoch > 15:
                #         break

    except KeyboardInterrupt:
        pass
    if best_metric is not None:
        message = '*** Best metric: <{0:.2f}>, epoch: <{1}>, path: <{2}> ***'\
            .format(best_metric, best_epoch, output_dir)
        _logger.info(message)
        print(message)
示例#28
0
def main():
    args = parser.parse_args()

    args.prefetcher = not args.no_prefetcher
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1
        if args.distributed and args.num_gpu > 1:
            print(
                'Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.'
            )
            args.num_gpu = 1

    args.device = 'cuda:0'
    args.world_size = 1
    args.rank = 0  # global rank
    if args.distributed:
        args.num_gpu = 1
        args.device = 'cuda:%d' % args.local_rank
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()
        args.rank = torch.distributed.get_rank()
    assert args.rank >= 0

    if args.distributed:
        print(
            'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
            % (args.rank, args.world_size))
    else:
        print('Training with a single process on %d GPUs.' % args.num_gpu)

    torch.manual_seed(args.seed + args.rank)

    model = create_model(args.model,
                         pretrained=args.pretrained,
                         num_classes=args.num_classes,
                         drop_rate=args.drop,
                         global_pool=args.gp,
                         bn_tf=args.bn_tf,
                         bn_momentum=args.bn_momentum,
                         bn_eps=args.bn_eps,
                         checkpoint_path=args.initial_checkpoint)

    print('Model %s created, param count: %d' %
          (args.model, sum([m.numel() for m in model.parameters()])))

    data_config = resolve_data_config(model,
                                      args,
                                      verbose=args.local_rank == 0)

    # optionally resume from a checkpoint
    start_epoch = 0
    optimizer_state = None
    if args.resume:
        optimizer_state, start_epoch = resume_checkpoint(
            model, args.resume, args.start_epoch)

    if args.num_gpu > 1:
        if args.amp:
            print(
                'Warning: AMP does not work well with nn.DataParallel, disabling. '
                'Use distributed mode for multi-GPU AMP.')
            args.amp = False
        model = nn.DataParallel(model,
                                device_ids=list(range(args.num_gpu))).cuda()
    else:
        if args.distributed and args.sync_bn and has_apex:
            model = convert_syncbn_model(model)
        model.cuda()

    optimizer = create_optimizer(args, model)
    if optimizer_state is not None:
        optimizer.load_state_dict(optimizer_state)

    if has_apex and args.amp:
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
        use_amp = True
        print('AMP enabled')
    else:
        use_amp = False
        print('AMP disabled')

    model_ema = None
    if args.model_ema:
        model_ema = ModelEma(model,
                             decay=args.model_ema_decay,
                             device='cpu' if args.model_ema_force_cpu else '',
                             resume=args.resume)

    if args.distributed:
        model = DDP(model, delay_allreduce=True)
        if model_ema is not None and not args.model_ema_force_cpu:
            # must also distribute EMA model to allow validation
            model_ema.ema = DDP(model_ema.ema, delay_allreduce=True)
            model_ema.ema_has_module = True

    lr_scheduler, num_epochs = create_scheduler(args, optimizer)
    if start_epoch > 0:
        lr_scheduler.step(start_epoch)
    if args.local_rank == 0:
        print('Scheduled epochs: ', num_epochs)

    train_dir = os.path.join(args.data, 'train')
    if not os.path.exists(train_dir):
        print('Error: training folder does not exist at: %s' % train_dir)
        exit(1)
    dataset_train = Dataset(train_dir)

    collate_fn = None
    if args.prefetcher and args.mixup > 0:
        collate_fn = FastCollateMixup(args.mixup, args.smoothing,
                                      args.num_classes)

    loader_train = create_loader(
        dataset_train,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        is_training=True,
        use_prefetcher=args.prefetcher,
        rand_erase_prob=args.reprob,
        rand_erase_mode=args.remode,
        interpolation=
        'random',  # FIXME cleanly resolve this? data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        collate_fn=collate_fn,
    )

    eval_dir = os.path.join(args.data, 'validation')
    if not os.path.isdir(eval_dir):
        print('Error: validation folder does not exist at: %s' % eval_dir)
        exit(1)
    dataset_eval = Dataset(eval_dir)

    loader_eval = create_loader(
        dataset_eval,
        input_size=data_config['input_size'],
        batch_size=4 * args.batch_size,
        is_training=False,
        use_prefetcher=args.prefetcher,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
    )

    if args.mixup > 0.:
        # smoothing is handled with mixup label transform
        train_loss_fn = SoftTargetCrossEntropy().cuda()
        validate_loss_fn = nn.CrossEntropyLoss().cuda()
    elif args.smoothing:
        train_loss_fn = LabelSmoothingCrossEntropy(
            smoothing=args.smoothing).cuda()
        validate_loss_fn = nn.CrossEntropyLoss().cuda()
    else:
        train_loss_fn = nn.CrossEntropyLoss().cuda()
        validate_loss_fn = train_loss_fn

    eval_metric = args.eval_metric
    best_metric = None
    best_epoch = None
    saver = None
    output_dir = ''
    if args.local_rank == 0:
        output_base = args.output if args.output else './output'
        exp_name = '-'.join([
            datetime.now().strftime("%Y%m%d-%H%M%S"), args.model,
            str(data_config['input_size'][-1])
        ])
        output_dir = get_outdir(output_base, 'train', exp_name)
        decreasing = True if eval_metric == 'loss' else False
        saver = CheckpointSaver(checkpoint_dir=output_dir,
                                decreasing=decreasing)

    try:
        for epoch in range(start_epoch, num_epochs):
            if args.distributed:
                loader_train.sampler.set_epoch(epoch)

            train_metrics = train_epoch(epoch,
                                        model,
                                        loader_train,
                                        optimizer,
                                        train_loss_fn,
                                        args,
                                        lr_scheduler=lr_scheduler,
                                        saver=saver,
                                        output_dir=output_dir,
                                        use_amp=use_amp,
                                        model_ema=model_ema)

            eval_metrics = validate(model, loader_eval, validate_loss_fn, args)

            if model_ema is not None and not args.model_ema_force_cpu:
                ema_eval_metrics = validate(model_ema.ema,
                                            loader_eval,
                                            validate_loss_fn,
                                            args,
                                            log_suffix=' (EMA)')
                eval_metrics = ema_eval_metrics

            if lr_scheduler is not None:
                lr_scheduler.step(epoch, eval_metrics[eval_metric])

            update_summary(epoch,
                           train_metrics,
                           eval_metrics,
                           os.path.join(output_dir, 'summary.csv'),
                           write_header=best_metric is None)

            if saver is not None:
                # save proper checkpoint with eval metric
                save_metric = eval_metrics[eval_metric]
                best_metric, best_epoch = saver.save_checkpoint(
                    model,
                    optimizer,
                    args,
                    epoch=epoch + 1,
                    model_ema=model_ema,
                    metric=save_metric)

    except KeyboardInterrupt:
        pass
    if best_metric is not None:
        print('*** Best metric: {0} (epoch {1})'.format(
            best_metric, best_epoch))
示例#29
0
def main():
    setup_default_logging()
    args, args_text = _parse_args()

    args.prefetcher = not args.no_prefetcher
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1
        if args.distributed and args.num_gpu > 1:
            logging.warning(
                'Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.'
            )
            args.num_gpu = 1

    args.device = 'cuda:0'
    args.world_size = 1
    args.rank = 0  # global rank
    if args.distributed:
        args.num_gpu = 1
        args.device = 'cuda:%d' % args.local_rank
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()
        args.rank = torch.distributed.get_rank()
    assert args.rank >= 0

    if args.distributed:
        logging.info(
            'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
            % (args.rank, args.world_size))
    else:
        logging.info('Training with a single process on %d GPUs.' %
                     args.num_gpu)

    torch.manual_seed(args.seed + args.rank)

    model = create_model(args.model,
                         pretrained=args.pretrained,
                         num_classes=args.num_classes,
                         drop_rate=args.drop,
                         drop_connect_rate=args.drop_connect,
                         global_pool=args.gp,
                         bn_tf=args.bn_tf,
                         bn_momentum=args.bn_momentum,
                         bn_eps=args.bn_eps,
                         checkpoint_path=args.initial_checkpoint)

    if args.local_rank == 0:
        logging.info('Model %s created, param count: %d' %
                     (args.model, sum([m.numel()
                                       for m in model.parameters()])))

    data_config = resolve_data_config(vars(args),
                                      model=model,
                                      verbose=args.local_rank == 0)

    num_aug_splits = 0
    if args.aug_splits > 0:
        assert args.aug_splits > 1, 'A split of 1 makes no sense'
        num_aug_splits = args.aug_splits

    if args.split_bn:
        assert num_aug_splits > 1 or args.resplit
        model = convert_splitbn_model(model, max(num_aug_splits, 2))

    if args.num_gpu > 1:
        if args.amp:
            logging.warning(
                'AMP does not work well with nn.DataParallel, disabling. Use distributed mode for multi-GPU AMP.'
            )
            args.amp = False
        model = nn.DataParallel(model,
                                device_ids=list(range(args.num_gpu))).cuda()
    else:
        model.cuda()

    optimizer = create_optimizer(args, model)

    use_amp = False
    if has_apex and args.amp:
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
        use_amp = True
    if args.local_rank == 0:
        logging.info('NVIDIA APEX {}. AMP {}.'.format(
            'installed' if has_apex else 'not installed',
            'on' if use_amp else 'off'))

    # optionally resume from a checkpoint
    resume_state = {}
    resume_epoch = None
    if args.resume:
        resume_state, resume_epoch = resume_checkpoint(model, args.resume)
    if resume_state and not args.no_resume_opt:
        if 'optimizer' in resume_state:
            if args.local_rank == 0:
                logging.info('Restoring Optimizer state from checkpoint')
            optimizer.load_state_dict(resume_state['optimizer'])
        if use_amp and 'amp' in resume_state and 'load_state_dict' in amp.__dict__:
            if args.local_rank == 0:
                logging.info('Restoring NVIDIA AMP state from checkpoint')
            amp.load_state_dict(resume_state['amp'])
    del resume_state

    model_ema = None
    if args.model_ema:
        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
        model_ema = ModelEma(model,
                             decay=args.model_ema_decay,
                             device='cpu' if args.model_ema_force_cpu else '',
                             resume=args.resume)

    if args.distributed:
        if args.sync_bn:
            assert not args.split_bn
            try:
                if has_apex:
                    model = convert_syncbn_model(model)
                else:
                    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                        model)
                if args.local_rank == 0:
                    logging.info(
                        'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
                        'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.'
                    )
            except Exception as e:
                logging.error(
                    'Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1'
                )
        if has_apex:
            model = DDP(model, delay_allreduce=True)
        else:
            if args.local_rank == 0:
                logging.info(
                    "Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP."
                )
            model = DDP(model,
                        device_ids=[args.local_rank
                                    ])  # can use device str in Torch >= 1.1
        # NOTE: EMA model does not need to be wrapped by DDP

    lr_scheduler, num_epochs = create_scheduler(args, optimizer)
    start_epoch = 0
    if args.start_epoch is not None:
        # a specified start_epoch will always override the resume epoch
        start_epoch = args.start_epoch
    elif resume_epoch is not None:
        start_epoch = resume_epoch
    if lr_scheduler is not None and start_epoch > 0:
        lr_scheduler.step(start_epoch)

    if args.local_rank == 0:
        logging.info('Scheduled epochs: {}'.format(num_epochs))

    train_dir = os.path.join(args.data, 'train')
    if not os.path.exists(train_dir):
        logging.error(
            'Training folder does not exist at: {}'.format(train_dir))
        exit(1)
    dataset_train = Dataset(train_dir)

    collate_fn = None
    if args.prefetcher and args.mixup > 0:
        assert not num_aug_splits  # collate conflict (need to support deinterleaving in collate mixup)
        collate_fn = FastCollateMixup(args.mixup, args.smoothing,
                                      args.num_classes)

    if num_aug_splits > 1:
        dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)

    loader_train = create_loader(
        dataset_train,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        is_training=True,
        use_prefetcher=args.prefetcher,
        re_prob=args.reprob,
        re_mode=args.remode,
        re_count=args.recount,
        re_split=args.resplit,
        color_jitter=args.color_jitter,
        auto_augment=args.aa,
        num_aug_splits=num_aug_splits,
        interpolation=args.train_interpolation,
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        collate_fn=collate_fn,
        pin_memory=args.pin_mem,
    )

    eval_dir = os.path.join(args.data, 'val')
    if not os.path.isdir(eval_dir):
        eval_dir = os.path.join(args.data, 'validation')
        if not os.path.isdir(eval_dir):
            logging.error(
                'Validation folder does not exist at: {}'.format(eval_dir))
            exit(1)
    dataset_eval = Dataset(eval_dir)

    loader_eval = create_loader(
        dataset_eval,
        input_size=data_config['input_size'],
        batch_size=4 * args.batch_size,
        is_training=False,
        use_prefetcher=args.prefetcher,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        crop_pct=data_config['crop_pct'],
        pin_memory=args.pin_mem,
    )

    if args.jsd:
        assert num_aug_splits > 1  # JSD only valid with aug splits set
        train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits,
                                        smoothing=args.smoothing).cuda()
        validate_loss_fn = nn.CrossEntropyLoss().cuda()
    elif args.mixup > 0.:
        # smoothing is handled with mixup label transform
        train_loss_fn = SoftTargetCrossEntropy().cuda()
        validate_loss_fn = nn.CrossEntropyLoss().cuda()
    elif args.smoothing:
        train_loss_fn = LabelSmoothingCrossEntropy(
            smoothing=args.smoothing).cuda()
        validate_loss_fn = nn.CrossEntropyLoss().cuda()
    else:
        train_loss_fn = nn.CrossEntropyLoss().cuda()
        validate_loss_fn = train_loss_fn

    eval_metric = args.eval_metric
    best_metric = None
    best_epoch = None
    saver = None
    output_dir = ''
    if args.local_rank == 0:
        output_base = args.output if args.output else './output'
        exp_name = '-'.join([
            datetime.now().strftime("%Y%m%d-%H%M%S"), args.model,
            str(data_config['input_size'][-1])
        ])
        output_dir = get_outdir(output_base, 'train', exp_name)
        decreasing = True if eval_metric == 'loss' else False
        saver = CheckpointSaver(checkpoint_dir=output_dir,
                                decreasing=decreasing)
        with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
            f.write(args_text)

    try:
        for epoch in range(start_epoch, num_epochs):
            if args.distributed:
                loader_train.sampler.set_epoch(epoch)

            train_metrics = train_epoch(epoch,
                                        model,
                                        loader_train,
                                        optimizer,
                                        train_loss_fn,
                                        args,
                                        lr_scheduler=lr_scheduler,
                                        saver=saver,
                                        output_dir=output_dir,
                                        use_amp=use_amp,
                                        model_ema=model_ema)

            if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
                if args.local_rank == 0:
                    logging.info(
                        "Distributing BatchNorm running means and vars")
                distribute_bn(model, args.world_size, args.dist_bn == 'reduce')

            eval_metrics = validate(model, loader_eval, validate_loss_fn, args)

            if model_ema is not None and not args.model_ema_force_cpu:
                if args.distributed and args.dist_bn in ('broadcast',
                                                         'reduce'):
                    distribute_bn(model_ema, args.world_size,
                                  args.dist_bn == 'reduce')

                ema_eval_metrics = validate(model_ema.ema,
                                            loader_eval,
                                            validate_loss_fn,
                                            args,
                                            log_suffix=' (EMA)')
                eval_metrics = ema_eval_metrics

            if lr_scheduler is not None:
                # step LR for next epoch
                lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])

            update_summary(epoch,
                           train_metrics,
                           eval_metrics,
                           os.path.join(output_dir, 'summary.csv'),
                           write_header=best_metric is None)

            if saver is not None:
                # save proper checkpoint with eval metric
                save_metric = eval_metrics[eval_metric]
                best_metric, best_epoch = saver.save_checkpoint(
                    model,
                    optimizer,
                    args,
                    epoch=epoch,
                    model_ema=model_ema,
                    metric=save_metric,
                    use_amp=use_amp)

    except KeyboardInterrupt:
        pass
    if best_metric is not None:
        logging.info('*** Best metric: {0} (epoch {1})'.format(
            best_metric, best_epoch))
示例#30
0
def main():
    start_endpoint = "http://localhost:3000/start"
    stop_endpoint = "http://localhost:3000/stop"
    setup_default_logging()
    args = parser.parse_args()
    # might as well try to do something useful...
    args.pretrained = args.pretrained or not args.checkpoint

    output_dir = args.checkpoint.split('/')
    output_dir.pop(-1)
    output_dir = ('/').join(output_dir)

    # create model
    model = create_model(
        args.model,
        num_classes=args.num_classes,
        in_chans=3,
        pretrained=args.pretrained,
        checkpoint_path=args.checkpoint)

    logging.info('Model %s created, param count: %d' %
                 (args.model, sum([m.numel() for m in model.parameters()])))

    # config = resolve_data_config(vars(args), model=model)
    # model, test_time_pool = apply_test_time_pool(model, config, args)

    if args.num_gpu > 1:
        model = torch.nn.DataParallel(
            model, device_ids=list(range(args.num_gpu))).cuda()
    else:
        model = model.cuda()

    dataset_eval = torchvision.datasets.CIFAR100(
        root='./data', train=False, download=True)

    data_config = resolve_data_config(vars(args), model=model)

    # #CIFAR_100_MEAN = (0.5071, 0.4865, 0.4409)
    # #CIFAR_100_STD = (0.2673, 0.2564, 0.2762)
    data_config['mean'] = (0.5071, 0.4865, 0.4409)
    data_config['std'] = (0.2673, 0.2564, 0.2762)

    loader = create_loader(
        dataset_eval,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        is_training=False,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        crop_pct=data_config['crop_pct']
    )

    model.eval()

    batch_time = AverageMeter()

    with torch.no_grad():
        idle_power = requests.post(url=start_endpoint)
        idle_json = idle_power.json()
        for batch_idx, (input, _) in enumerate(loader):
            input = input.cuda()

            tstart = time.time()
            output = model(input)
            tend = time.time()

            if batch_idx != 0:
                batch_time.update(tend - tstart)

                if batch_idx % args.log_freq == 0:
                    print('Predict: [{0}/{1}] Time {batch_time.val:.6f} ({batch_time.avg:.6f})'.format(
                        batch_idx, len(loader), batch_time=batch_time), end='\r')

    load_power = requests.post(url=stop_endpoint)
    load_json = load_power.json()
    fps = 1 / batch_time.avg
    inference_power = float(load_json['load']) - float(idle_json['idle'])
    stats = [{'FPS': [float(fps)]},
                {'Total_Power': [float(inference_power)]}]
    with open(os.path.join(output_dir, '{}_fps_cifar.yaml'.format(args.model)), 'w') as f:
        yaml.safe_dump(stats, f)