def compare_backbone(n=50):
    assert n in BB_CFG
    print('Compare backbone:', n)
    # backbone from mmdet
    mmdet_res = build_backbone(BB_CFG[n])
    lyq_res = ResNet(n)
    pre = torch.load(CKPT[n])

    load_dict(mmdet_res, pre)
    load_dict(lyq_res, pre)

    mmdet_res.train()
    lyq_res.train()

    img_data = torch.rand(10, 3, 224, 224)
    mmdet_feats = mmdet_res(img_data)
    lyq_feats = lyq_res(img_data)

    print('Number of feats from mmdet backbone:', len(mmdet_feats))
    print('Number of feats from lyq backbone:', len(lyq_feats))

    for i in range(len(mmdet_feats)):
        mf = mmdet_feats[i]
        lf = lyq_feats[i]
        diff = (mf - lf).sum().item()
        print('diff:', diff)
Пример #2
0
    def __init__(
        self,
        backbone: Union[nn.Module, Mapping],
        neck: Union[nn.Module, Mapping, None] = None,
        *,
        pretrained_backbone: Optional[str] = None,
        output_shapes: List[ShapeSpec],
        output_names: Optional[List[str]] = None,
    ):
        """
        Args:
            backbone: either a backbone module or a mmdet config dict that defines a
                backbone. The backbone takes a 4D image tensor and returns a
                sequence of tensors.
            neck: either a backbone module or a mmdet config dict that defines a
                neck. The neck takes outputs of backbone and returns a
                sequence of tensors. If None, no neck is used.
            pretrained_backbone: defines the backbone weights that can be loaded by
                mmdet, such as "torchvision://resnet50".
            output_shapes: shape for every output of the backbone (or neck, if given).
                stride and channels are often needed.
            output_names: names for every output of the backbone (or neck, if given).
                By default, will use "out0", "out1", ...
        """
        super().__init__()
        if isinstance(backbone, Mapping):
            from mmdet.models import build_backbone

            backbone = build_backbone(_to_container(backbone))
        self.backbone = backbone

        if isinstance(neck, Mapping):
            from mmdet.models import build_neck

            neck = build_neck(_to_container(neck))
        self.neck = neck

        # It's confusing that backbone weights are given as a separate argument,
        # but "neck" weights, if any, are part of neck itself. This is the interface
        # of mmdet so we follow it. Reference:
        # https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/detectors/two_stage.py
        logger.info(f"Initializing mmdet backbone weights: {pretrained_backbone} ...")
        self.backbone.init_weights(pretrained_backbone)
        # train() in mmdet modules is non-trivial, and has to be explicitly
        # called. Reference:
        # https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/backbones/resnet.py
        self.backbone.train()
        if self.neck is not None:
            logger.info("Initializing mmdet neck weights ...")
            if isinstance(self.neck, nn.Sequential):
                for m in self.neck:
                    m.init_weights()
            else:
                self.neck.init_weights()
            self.neck.train()

        self._output_shapes = output_shapes
        if not output_names:
            output_names = [f"out{i}" for i in range(len(output_shapes))]
        self._output_names = output_names
Пример #3
0
 def __init__(self,
              backbone,
              neck=None,
              bbox_head=None,
              train_cfg=None,
              test_cfg=None,
              pretrained=None):
     super(SNet, self).__init__()
     self.backbone = build_backbone(backbone)
Пример #4
0
    def __init__(self,
                 num_streams,
                 backbones,
                 aggregation_mlp_channels=None,
                 conv_cfg=dict(type='Conv1d'),
                 norm_cfg=dict(type='BN1d', eps=1e-5, momentum=0.01),
                 act_cfg=dict(type='ReLU'),
                 suffixes=('net0', 'net1'),
                 **kwargs):
        super().__init__()
        assert isinstance(backbones, dict) or isinstance(backbones, list)
        if isinstance(backbones, dict):
            backbones_list = []
            for ind in range(num_streams):
                backbones_list.append(copy.deepcopy(backbones))
            backbones = backbones_list

        assert len(backbones) == num_streams
        assert len(suffixes) == num_streams

        self.backbone_list = nn.ModuleList()
        # Rename the ret_dict with different suffixs.
        self.suffixes = suffixes

        out_channels = 0

        for backbone_cfg in backbones:
            out_channels += backbone_cfg['fp_channels'][-1][-1]
            self.backbone_list.append(build_backbone(backbone_cfg))

        # Feature aggregation layers
        if aggregation_mlp_channels is None:
            aggregation_mlp_channels = [
                out_channels, out_channels // 2,
                out_channels // len(self.backbone_list)
            ]
        else:
            aggregation_mlp_channels.insert(0, out_channels)

        self.aggregation_layers = nn.Sequential()
        for i in range(len(aggregation_mlp_channels) - 1):
            self.aggregation_layers.add_module(
                f'layer{i}',
                ConvModule(
                    aggregation_mlp_channels[i],
                    aggregation_mlp_channels[i + 1],
                    1,
                    padding=0,
                    conv_cfg=conv_cfg,
                    norm_cfg=norm_cfg,
                    act_cfg=act_cfg,
                    bias=True,
                    inplace=True))
Пример #5
0
 def __init__(self,
              backbone,
              neck=None,
              bbox_head=None,
              train_cfg=None,
              test_cfg=None,
              init_cfg=None,
              pretrained=None):
     super(SingleStage3DDetector, self).__init__(init_cfg)
     self.backbone = build_backbone(backbone)
     if neck is not None:
         self.neck = build_neck(neck)
     bbox_head.update(train_cfg=train_cfg)
     bbox_head.update(test_cfg=test_cfg)
     self.bbox_head = build_head(bbox_head)
     self.train_cfg = train_cfg
     self.test_cfg = test_cfg
Пример #6
0
def main():
    args = parse_args()

    config = mmcv.Config.fromfile(os.path.join(root, args.config))
    data = torch.randn(1, 3, 800, 800)

    with torch.no_grad():
        backbone = build_backbone(config.model.backbone)
        neck = build_neck(config.model.neck)
        rpn_head = build_head(config.model.rpn_head)

    backbone.eval()
    neck.eval()
    rpn_head.eval()

    #torch.jit.save(rpn_head,'./rpn_head.pt')
    exit()
Пример #7
0
 def __init__(self,
              backbone,
              neck,
              neck_3d,
              bbox_head,
              n_voxels,
              anchor_generator,
              train_cfg=None,
              test_cfg=None,
              pretrained=None,
              init_cfg=None):
     super().__init__(init_cfg=init_cfg)
     self.backbone = build_backbone(backbone)
     self.neck = build_neck(neck)
     self.neck_3d = build_neck(neck_3d)
     bbox_head.update(train_cfg=train_cfg)
     bbox_head.update(test_cfg=test_cfg)
     self.bbox_head = build_head(bbox_head)
     self.n_voxels = n_voxels
     self.anchor_generator = build_anchor_generator(anchor_generator)
     self.train_cfg = train_cfg
     self.test_cfg = test_cfg
Пример #8
0
def main():
    args = parse_args()

    cfg = Config.fromfile(args.cfg_file)
    if args.is_detector:
        model = build_detector(cfg.model,
                               train_cfg=cfg.train_cfg,
                               test_cfg=cfg.test_cfg)
    else:
        model = build_backbone(cfg.model)

    ckpt = torch.load(args.ckpt_file, map_location='cpu')
    state_dict = ckpt['state_dict']
    for k in list(state_dict.keys()):
        if args.is_detector and '.fc.' in k:
            del state_dict[k]

    model_state_dict = model.state_dict()
    for k1, k2 in zip(model_state_dict, state_dict):
        sz1 = model_state_dict[k1].size()
        sz2 = state_dict[k2].size()
        print(
            f'{sz1==sz2} -- {k1}: {sz1} ------------ {k2[len("module.network."):]}: {sz2}'
        )
Пример #9
0
def test(args, model_dir, distributed):
    global has_imagenet_reassessed

    args.work_dir = model_dir
    model_name = osp.basename(model_dir)
    orig_top1_acc = 0.
    orig_top1_acc_reallabels = 0.
    # remove top1 tag
    tag = '-top1-'
    idx = model_name.find(tag)
    tag1 = '-top1_reallabels-'
    idx1 = model_name.find(tag1)
    if idx >= 0:
        orig_top1_acc = float(model_name[idx + len(tag):idx1 - 1]) / 100.
        if idx1 > 0:
            orig_top1_acc_reallabels = float(
                model_name[idx1 + len(tag1):-2]) / 100.
        model_name = model_name[:idx]

    # len_model_name = max(len_model_name, len(model_name))
    args.config = os.path.join(model_dir, f'{model_name}.py')
    if not os.path.exists(args.config):
        print(f'Not found {args.config}')
        return None

    cfg = Config.fromfile(args.config)
    # work_dir is determined in this priority: CLI > segment in file > filename
    if args.work_dir is not None:
        # update configs according to CLI args if args.work_dir is not None
        cfg.work_dir = args.work_dir
    elif cfg.get('work_dir', None) is None:
        # use config filename as default work_dir if cfg.work_dir is None
        cfg.work_dir = osp.join('./work_dirs',
                                osp.splitext(osp.basename(args.config))[0])

    if args.gpu_ids is not None:
        cfg.gpu_ids = args.gpu_ids
    else:
        cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)
    cfg.amp_opt_level = args.amp_opt_level
    if not has_apex:
        cfg.amp_opt_level = 'O0'
    cfg.amp_static_loss_scale = args.amp_static_loss_scale
    cfg.print_freq = args.print_freq

    cfg.seed = args.seed

    # init the logger before other steps
    timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
    log_file = osp.join(cfg.work_dir, 'evalImageNetX.log')
    logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)

    # log env info
    env_info_dict = collect_env()
    env_info = '\n'.join([('{}: {}'.format(k, v))
                          for k, v in env_info_dict.items()])
    dash_line = '-' * 60 + '\n'
    logger.info('Environment info:\n' + dash_line + env_info + '\n' +
                dash_line)

    # log cfg
    logger.info('Distributed training: {}'.format(distributed))

    # set random seeds
    if args.seed is None:
        args.seed = 23
    logger.info('Set random seed to {}, deterministic: {}'.format(
        args.seed, args.deterministic))
    set_random_seed(args.seed, deterministic=args.deterministic)

    # model
    model = build_backbone(cfg.model)
    num_params = sum([m.numel() for m in model.parameters()]) / 1e6
    logger.info('Model {} created, param count: {:.3f}M'.format(
        model_name, num_params))

    ckpt_file = os.path.join(model_dir, 'current.pth')

    load_checkpoint(model, ckpt_file, logger=logger)

    # ckpt = torch.load(ckpt_file, map_location='cpu')
    # state_dict = ckpt['model']
    # for k in list(state_dict.keys()):
    #     if k.startswith('module.'):
    #         # remove prefix
    #         state_dict[k[len("module."):]] = state_dict[k]
    #         # delete renamed k
    #         del state_dict[k]
    # model.load_state_dict(state_dict)

    if not distributed and len(cfg.gpu_ids) > 1:
        if cfg.amp_opt_level != 'O0':
            logger.warning(
                'AMP does not work well with nn.DataParallel, disabling.' +
                'Use distributed mode for multi-GPU AMP.')
            cfg.amp_opt_level = 'O0'
        model = nn.DataParallel(model, device_ids=list(cfg.gpu_ids)).cuda()
    else:
        model.cuda()

    # loss
    criterion_val = torch.nn.CrossEntropyLoss().cuda()

    # optimizer
    lr = cfg.optimizer['lr']
    lr *= cfg.batch_size * dist.get_world_size() / cfg.autoscale_lr_factor
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=lr,
                                momentum=cfg.optimizer['momentum'],
                                weight_decay=cfg.optimizer['weight_decay'],
                                nesterov=cfg.optimizer['nesterov'])

    if cfg.amp_opt_level != 'O0':
        loss_scale = cfg.amp_static_loss_scale if cfg.amp_static_loss_scale  \
            else 'dynamic'
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=cfg.amp_opt_level,
                                          loss_scale=loss_scale,
                                          verbosity=1)
        model = AttnNorm2Float(model)

    if distributed:
        if cfg.amp_opt_level != 'O0':
            model = DDP(model, delay_allreduce=True)
        else:
            model = DDP1(model, device_ids=[args.local_rank])

    result = [model_name, num_params]
    for dataset in imagenet_x:
        # data
        if dataset == 'imagenet-1k':
            cfg.data_root = 'data/ILSVRC2015/Data/CLS-LOC'
        else:
            cfg.data_root = f'data/{dataset}'
        if not os.path.exists(cfg.data_root):
            logger.info(f'not found {cfg.data_root}')
            continue

        indices_in_1k = None
        if dataset in ['imagenet-a', 'imagenet-o']:
            indices_in_1k = adv_indices_in_1k

        real_labels = False
        if dataset == 'imagenet-1k':
            real_labels_file = os.path.join(cfg.data_root,
                                            'reassessed-imagenet', 'real.json')
            if os.path.exists(real_labels_file):
                val_loader = get_val_loader(cfg,
                                            cfg.data_cfg['val_cfg'],
                                            distributed,
                                            real_json=real_labels_file)
                real_labels = True
                has_imagenet_reassessed = True
            else:
                logger.info(
                    f'not found {cfg.data_root} {real_labels_file} ' +
                    'consider to download real labels at ' +
                    'https://github.com/google-research/reassessed-imagenet')
                val_loader = get_val_loader(cfg, cfg.data_cfg['val_cfg'],
                                            distributed)
        else:
            val_loader = get_val_loader(cfg, cfg.data_cfg['val_cfg'],
                                        distributed)

        # eval
        results = validate(val_loader,
                           model,
                           criterion_val,
                           cfg,
                           logger,
                           distributed,
                           indices_in_1k=indices_in_1k,
                           real_labels=real_labels)

        result.append((round(results[0], 3), round(results[1], 3)))
        logger.info(
            f'** {model_name} - {dataset} top1-acc {results[0]:.3%}, top5-acc {results[1]:.3%}'
        )

        if len(results) == 4:
            result.append((round(results[2], 3), round(results[3], 3)))
            logger.info(
                f'** {model_name} - {dataset} top1-acc_reallabels {results[2]:.3%}, top5-acc_reallabels {results[3]:.3%}'
            )

    return result
Пример #10
0
def main(**kwargs):
    args = parse_args()
    for k, v in kwargs.items():
        args.__setattr__(k, v)

    assert args.out or args.show or args.json_out, \
        ('Please specify at least one operation (save or show the results) '
         'with the argument "--out" or "--show" or "--json_out"')

    if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):
        raise ValueError('The output file must be a pkl file.')

    if args.json_out is not None and args.json_out.endswith('.json'):
        args.json_out = args.json_out[:-5]

    if isinstance(args.config, str):
        cfg = mmcv.Config.fromfile(args.config)
    else:
        cfg = args.config

    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True
    cfg.model.pretrained = None
    cfg.data.val.test_mode = True
    cfg.data.test.test_mode = True

    # init distributed env first, since logger depends on the dist info.
    if args.launcher == 'none':
        distributed = False
    else:
        distributed = True
        init_dist(args.launcher, **cfg.dist_params)

    # build the dataloader
    # TODO: support multiple images per gpu (only minor changes are needed)
    if args.mode == 'val':
        dataset = build_dataset(cfg.data.val)
    else:
        dataset = build_dataset(cfg.data.test)

    data_loader = build_dataloader(dataset,
                                   imgs_per_gpu=args.imgs_per_gpu,
                                   workers_per_gpu=cfg.data.workers_per_gpu,
                                   dist=distributed,
                                   shuffle=False)

    # build the model and load checkpoint
    model = build_backbone(cfg.model['backbone'])
    fp16_cfg = cfg.get('fp16', None)
    if fp16_cfg is not None:
        wrap_fp16_model(model)
    checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')
    # old versions did not save class info in checkpoints, this walkaround is
    # for backward compatibility
    if 'CLASSES' in checkpoint['meta']:
        model.CLASSES = checkpoint['meta']['CLASSES']
    else:
        model.CLASSES = dataset.CLASSES

    if not distributed:
        model = MMDataParallel(model, device_ids=[0])
        outputs, result_times = single_gpu_test(model, data_loader, args.show)
    else:
        model = MMDistributedDataParallel(model.cuda())
        outputs, result_times = multi_gpu_test(model, data_loader, args.tmpdir,
                                               args.gpu_collect)

    rank, _ = get_dist_info()
    rpts = {}
    if args.out and rank == 0:
        print('\nwriting results to {}'.format(args.out))
        mmcv.dump(outputs, args.out)
        eval_types = args.eval
        if eval_types:
            print('Starting evaluate {}'.format(' and '.join(eval_types)))
            if eval_types == ['proposal_fast']:
                result_file = args.out
                rpts = coco_eval(result_file,
                                 eval_types,
                                 dataset.coco,
                                 classwise=True)
            else:
                if not isinstance(outputs[0], dict):
                    result_files = results2json(dataset, outputs, args.out)
                    rpts = coco_eval(result_files,
                                     eval_types,
                                     dataset.coco,
                                     classwise=True)
                else:
                    for name in outputs[0]:
                        print('\nEvaluating {}'.format(name))
                        outputs_ = [out[name] for out in outputs]
                        result_file = args.out + '.{}'.format(name)
                        result_files = results2json(dataset, outputs_,
                                                    result_file)
                        rpts = coco_eval(result_files,
                                         eval_types,
                                         dataset.coco,
                                         classwise=True)

    # Save predictions in the COCO json format
    if args.json_out and rank == 0:
        if not isinstance(outputs[0], dict):
            rpts['bbox'] = dict(log={}, data={})
            defect_rpt = defect_eval(outputs, dataset.coco.dataset,
                                     result_times)
            rpts['bbox']['log']['defect_eval'] = defect_rpt['log']
            rpts['bbox']['data']['defect_eval'] = defect_rpt['data']
        else:
            for name in outputs[0]:
                outputs_ = [out[name] for out in outputs]
                result_file = args.json_out + '.{}'.format(name)
                results2json(dataset, outputs_, result_file)
    return rpts
Пример #11
0
    def __init__(self,
                 num_streams,
                 backbones,
                 aggregation_mlp_channels=None,
                 conv_cfg=dict(type='Conv1d'),
                 norm_cfg=dict(type='BN1d', eps=1e-5, momentum=0.01),
                 act_cfg=dict(type='ReLU'),
                 suffixes=('net0', 'net1'),
                 init_cfg=None,
                 pretrained=None,
                 **kwargs):
        super().__init__(init_cfg=init_cfg)
        assert isinstance(backbones, dict) or isinstance(backbones, list)
        if isinstance(backbones, dict):
            backbones_list = []
            for ind in range(num_streams):
                backbones_list.append(copy.deepcopy(backbones))
            backbones = backbones_list

        assert len(backbones) == num_streams
        assert len(suffixes) == num_streams

        self.backbone_list = nn.ModuleList()
        # Rename the ret_dict with different suffixs.
        self.suffixes = suffixes

        out_channels = 0

        for backbone_cfg in backbones:
            out_channels += backbone_cfg['fp_channels'][-1][-1]
            self.backbone_list.append(build_backbone(backbone_cfg))

        # Feature aggregation layers
        if aggregation_mlp_channels is None:
            aggregation_mlp_channels = [
                out_channels, out_channels // 2,
                out_channels // len(self.backbone_list)
            ]
        else:
            aggregation_mlp_channels.insert(0, out_channels)

        self.aggregation_layers = nn.Sequential()
        for i in range(len(aggregation_mlp_channels) - 1):
            self.aggregation_layers.add_module(
                f'layer{i}',
                ConvModule(aggregation_mlp_channels[i],
                           aggregation_mlp_channels[i + 1],
                           1,
                           padding=0,
                           conv_cfg=conv_cfg,
                           norm_cfg=norm_cfg,
                           act_cfg=act_cfg,
                           bias=True,
                           inplace=True))

        assert not (init_cfg and pretrained), \
            'init_cfg and pretrained cannot be setting at the same time'
        if isinstance(pretrained, str):
            warnings.warn('DeprecationWarning: pretrained is a deprecated, '
                          'please use "init_cfg" instead')
            self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
Пример #12
0
def main():
    args = parse_args()

    cfg = Config.fromfile(args.config)
    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True

    # work_dir is determined in this priority: CLI > segment in file > filename
    if args.work_dir is not None:
        # update configs according to CLI args if args.work_dir is not None
        cfg.work_dir = args.work_dir
    elif cfg.get('work_dir', None) is None:
        # use config filename as default work_dir if cfg.work_dir is None
        cfg.work_dir = osp.join('./work_dirs',
                                osp.splitext(osp.basename(args.config))[0])
    if args.resume_from is not None:
        cfg.resume_from = args.resume_from
    if args.start_epoch is not None:
        cfg.start_epoch = args.start_epoch
    if args.gpu_ids is not None:
        cfg.gpu_ids = args.gpu_ids
    else:
        cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)

    cfg.amp_opt_level = args.amp_opt_level
    if not has_apex:
        cfg.amp_opt_level = 'O0'
    cfg.amp_static_loss_scale = args.amp_static_loss_scale
    cfg.eval = args.eval
    if cfg.eval:
        assert os.path.isfile(cfg.load_from)
    cfg.debug = args.debug
    cfg.print_freq = args.print_freq if not cfg.debug else 10
    cfg.save_freq = args.save_freq
    cfg.profiling = args.profiling
    if args.seed is None:
        args.seed = 23

    # init distributed env first, since logger depends on the dist info.
    if args.launcher == 'none':
        distributed = False
    else:
        distributed = True
        init_dist(args.launcher, **cfg.dist_params)

    # create work_dir
    mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
    # dump config
    cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))
    # init the logger before other steps
    timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
    log_file = osp.join(cfg.work_dir, '{}.log'.format(timestamp))
    logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)

    # log env info
    env_info_dict = collect_env()
    env_info = '\n'.join([('{}: {}'.format(k, v))
                          for k, v in env_info_dict.items()])
    dash_line = '-' * 60 + '\n'
    logger.info('Environment info:\n' + dash_line + env_info + '\n' +
                dash_line)

    # log cfg
    logger.info('Distributed training: {}'.format(distributed))
    logger.info('Config:\n{}'.format(cfg.text))

    # set random seeds
    if args.seed is not None:
        logger.info('Set random seed to {}, deterministic: {}'.format(
            args.seed, args.deterministic))
        set_random_seed(args.seed, deterministic=args.deterministic)
    cfg.seed = args.seed

    # model
    model = build_backbone(cfg.model)
    logger.info('Model {} created, param count: {:.3f}M'.format(
                 cfg.model['type'],
                 sum([m.numel() for m in model.parameters()]) / 1e6))

    if cfg.debug and dist.get_rank() == 0:
        print(model)

    if cfg.eval:
        load_pretrained(cfg, model, logger)

    if not distributed and len(cfg.gpu_ids) > 1:
        if cfg.amp_opt_level != 'O0':
            logger.warning(
                'AMP does not work well with nn.DataParallel, disabling.' +
                'Use distributed mode for multi-GPU AMP.')
            cfg.amp_opt_level = 'O0'
        model = nn.DataParallel(model, device_ids=list(cfg.gpu_ids)).cuda()
    else:
        model.cuda()

    # data
    fast_collate_mixup = None
    if hasattr(cfg.data_cfg['train_cfg'], 'mix_up_rate') and \
        cfg.data_cfg['train_cfg']['mix_up_rate'] > 0.:
        fast_collate_mixup = FastCollateMixup(
            cfg.data_cfg['train_cfg']['mix_up_rate'],
            cfg.data_cfg['train_cfg']['label_smoothing_rate'],
            cfg.data_cfg['train_cfg']['num_classes']
        )

    train_loader = get_train_loader(
        cfg, cfg.data_cfg['train_cfg'], distributed,
        fast_collate_mixup=fast_collate_mixup)

    real_labels_file = os.path.join(
        cfg.data_root, 'reassessed-imagenet', 'real.json')
    if os.path.exists(real_labels_file):
        val_loader = get_val_loader(cfg, cfg.data_cfg['val_cfg'],
                                    distributed, real_json=real_labels_file)
        real_labels = True
    else:
        logger.info(f'not found {cfg.data_root} {real_labels_file} ' +
                    'consider to download real labels at ' +
                    'https://github.com/google-research/reassessed-imagenet')
        val_loader = get_val_loader(cfg, cfg.data_cfg['val_cfg'], distributed)
        real_labels = False

    # loss
    if hasattr(cfg.data_cfg['train_cfg'], 'mix_up_rate') and \
        cfg.data_cfg['train_cfg']['mix_up_rate'] > 0.:
        criterion_train = SoftTargetCrossEntropy().cuda()
        criterion_val = torch.nn.CrossEntropyLoss().cuda()
    elif hasattr(cfg.data_cfg['train_cfg'], 'label_smoothing_rate') and \
        cfg.data_cfg['train_cfg']['label_smoothing_rate'] > 0.:
        criterion_train = LabelSmoothingCrossEntropy(
            cfg.data_cfg['train_cfg']['label_smoothing_rate']
        ).cuda()
        criterion_val = torch.nn.CrossEntropyLoss().cuda()
    else:
        criterion_train = torch.nn.CrossEntropyLoss().cuda()
        criterion_val = criterion_train

    # optimizer
    lr = cfg.optimizer['lr']
    lr *= cfg.batch_size * dist.get_world_size() / cfg.autoscale_lr_factor
    if hasattr(cfg.optimizer, 'remove_norm_weigth_decay') and \
        cfg.optimizer['remove_norm_weigth_decay']:
        norm_params, base_params = separate_norm_params(model)
        optimizer = torch.optim.SGD([
            {'params': base_params,
                'weight_decay': cfg.optimizer['weight_decay']},
            {'params': norm_params, 'weight_decay': 0.0}],
                                    lr=lr,
                                    momentum=cfg.optimizer['momentum'],
                                    nesterov=cfg.optimizer['nesterov'])
    else:
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=lr,
                                    momentum=cfg.optimizer['momentum'],
                                    weight_decay=cfg.optimizer['weight_decay'],
                                    nesterov=cfg.optimizer['nesterov'])

    if cfg.amp_opt_level != 'O0':
        loss_scale = cfg.amp_static_loss_scale if cfg.amp_static_loss_scale  \
            else 'dynamic'
        model, optimizer = amp.initialize(model, optimizer,
                                          opt_level=cfg.amp_opt_level,
                                          loss_scale=loss_scale,
                                          verbosity=1)
        model = AttnNorm2Float(model)


    if distributed:
        if cfg.amp_opt_level != 'O0':
            model = DDP(model, delay_allreduce=True)
        else:
            model = DDP1(model, device_ids=[args.local_rank])

    if cfg.profiling:
        x = torch.randn((2, 3, 224, 224), requires_grad=True).cuda()
        with torch.autograd.profiler.profile(use_cuda=True) as prof:
            model(x)
        prof.export_chrome_trace(os.path.join(cfg.work_dir, 'profiling.log'))
        logger.info(f"{prof}")
        return

    # scheduler
    scheduler = get_scheduler(optimizer, len(train_loader), cfg)

    # eval
    if cfg.eval:
        validate(val_loader, model, criterion_val, cfg, logger, distributed,
                 real_labels=real_labels)
        return

    # optionally resume from a checkpoint
    if cfg.resume_from:
        assert os.path.isfile(cfg.resume_from)
        load_checkpoint(cfg, model, optimizer, scheduler, logger)

    # training
    for epoch in range(cfg.start_epoch, cfg.total_epochs + 1):
        if isinstance(train_loader.sampler, DistributedSampler):
            train_loader.sampler.set_epoch(epoch)

        tic = time.time()
        train(epoch, train_loader, model,
              criterion_train, optimizer, scheduler, cfg, logger, distributed)
        used_time = time.time() - tic
        remaining_time = (cfg.total_epochs - epoch) * used_time / 3600
        logger.info(
            f'epoch {epoch}, total time {used_time:.2f} sec, estimated remaining time {remaining_time:.2f} hr')

        if real_labels is not None:
            test_acc, is_best, _, is_best_reallabels = validate(
                val_loader, model, criterion_val, cfg, logger, distributed,
                real_labels=real_labels)

            if dist.get_rank() == 0 and (epoch % cfg.save_freq == 0 or is_best or is_best_reallabels):
                save_checkpoint(cfg, epoch, model, optimizer,
                                best_acc1, best_acc1_reallabels,
                                scheduler, logger, is_best, is_best_reallabels)
        else:
            test_acc, is_best = validate(
                val_loader, model, criterion_val, cfg, logger, distributed)
            if dist.get_rank() == 0 and (epoch % cfg.save_freq == 0 or is_best):
                save_checkpoint(cfg, epoch, model, optimizer,
                                best_acc1, best_acc1_reallabels,
                                scheduler, logger, is_best, False)

        if cfg.debug:
            break

    # rename folder
    if dist.get_rank() == 0:
        os.rename(cfg.work_dir, cfg.work_dir+f'-top1-{best_acc1:.2%}')