Пример #1
0
 def _init_model_ema(self):
     args = self.cfg.model_ema
     model_ema = ModelEma(self.model,
                          decay=args.model_ema_decay,
                          device='cpu' if args.model_ema_force_cpu else '',
                          resume=self.cfg.load_checkpoint)
     return model_ema
Пример #2
0
 def _init_model_ema(self):
     """Init Model Ema."""
     args = self.config.model_ema
     model_ema = ModelEma(self.trainer.model,
                          decay=args.model_ema_decay,
                          device='cpu' if args.model_ema_force_cpu else '',
                          resume=None)
     return model_ema
Пример #3
0
 def _init_model_ema(self):
     """Init Model Ema."""
     args = self.cfg.model_ema
     model_ema = ModelEma(self.model,
                          decay=args.model_ema_decay,
                          device='cpu' if args.model_ema_force_cpu else '',
                          resume=self.cfg.get('load_checkpoint', None))
     return model_ema
Пример #4
0
def main(args):
    utils.init_distributed_mode(args)

    print(args)

    if args.distillation_type != 'none' and args.finetune and not args.eval:
        raise NotImplementedError(
            "Finetuning with distillation not yet supported")

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    # random.seed(seed)

    cudnn.benchmark = True

    dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
    dataset_val, _ = build_dataset(is_train=False, args=args)

    if True:  # args.distributed:
        num_tasks = utils.get_world_size()
        global_rank = utils.get_rank()
        if args.repeated_aug:
            sampler_train = RASampler(dataset_train,
                                      num_replicas=num_tasks,
                                      rank=global_rank,
                                      shuffle=True)
        else:
            sampler_train = torch.utils.data.DistributedSampler(
                dataset_train,
                num_replicas=num_tasks,
                rank=global_rank,
                shuffle=True)
        if args.dist_eval:
            if len(dataset_val) % num_tasks != 0:
                print(
                    'Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
                    'This will slightly alter validation results as extra duplicate entries are added to achieve '
                    'equal num of samples per-process.')
            sampler_val = torch.utils.data.DistributedSampler(
                dataset_val,
                num_replicas=num_tasks,
                rank=global_rank,
                shuffle=False)
        else:
            sampler_val = torch.utils.data.SequentialSampler(dataset_val)
    else:
        sampler_train = torch.utils.data.RandomSampler(dataset_train)
        sampler_val = torch.utils.data.SequentialSampler(dataset_val)

    data_loader_train = torch.utils.data.DataLoader(
        dataset_train,
        sampler=sampler_train,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=True,
    )

    data_loader_val = torch.utils.data.DataLoader(dataset_val,
                                                  sampler=sampler_val,
                                                  batch_size=int(
                                                      1.5 * args.batch_size),
                                                  num_workers=args.num_workers,
                                                  pin_memory=args.pin_mem,
                                                  drop_last=False)

    mixup_fn = None
    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
    if mixup_active:
        mixup_fn = Mixup(mixup_alpha=args.mixup,
                         cutmix_alpha=args.cutmix,
                         cutmix_minmax=args.cutmix_minmax,
                         prob=args.mixup_prob,
                         switch_prob=args.mixup_switch_prob,
                         mode=args.mixup_mode,
                         label_smoothing=args.smoothing,
                         num_classes=args.nb_classes)

    print(f"Creating model: {args.model}")
    model = create_model(
        args.model,
        pretrained=False,
        num_classes=args.nb_classes,
        drop_rate=args.drop,
        drop_path_rate=args.drop_path,
        drop_block_rate=None,
    )

    if args.finetune:
        if args.finetune.startswith('https'):
            checkpoint = torch.hub.load_state_dict_from_url(args.finetune,
                                                            map_location='cpu',
                                                            check_hash=True)
        else:
            checkpoint = torch.load(args.finetune, map_location='cpu')

        checkpoint_model = checkpoint['model']
        state_dict = model.state_dict()
        for k in [
                'head.weight', 'head.bias', 'head_dist.weight',
                'head_dist.bias'
        ]:
            if k in checkpoint_model and checkpoint_model[
                    k].shape != state_dict[k].shape:
                print(f"Removing key {k} from pretrained checkpoint")
                del checkpoint_model[k]

        # interpolate position embedding
        pos_embed_checkpoint = checkpoint_model['pos_embed']
        embedding_size = pos_embed_checkpoint.shape[-1]
        num_patches = model.patch_embed.num_patches
        num_extra_tokens = model.pos_embed.shape[-2] - num_patches
        # height (== width) for the checkpoint position embedding
        orig_size = int(
            (pos_embed_checkpoint.shape[-2] - num_extra_tokens)**0.5)
        # height (== width) for the new position embedding
        new_size = int(num_patches**0.5)
        # class_token and dist_token are kept unchanged
        extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
        # only the position tokens are interpolated
        pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
        pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size,
                                        embedding_size).permute(0, 3, 1, 2)
        pos_tokens = torch.nn.functional.interpolate(pos_tokens,
                                                     size=(new_size, new_size),
                                                     mode='bicubic',
                                                     align_corners=False)
        pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
        new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
        checkpoint_model['pos_embed'] = new_pos_embed

        model.load_state_dict(checkpoint_model, strict=False)

    model.to(device)

    model_ema = None
    if args.model_ema:
        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
        model_ema = ModelEma(model,
                             decay=args.model_ema_decay,
                             device='cpu' if args.model_ema_force_cpu else '',
                             resume='')

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpu])
        model_without_ddp = model.module
    n_parameters = sum(p.numel() for p in model.parameters()
                       if p.requires_grad)
    print('number of params:', n_parameters)

    linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size(
    ) / 512.0
    args.lr = linear_scaled_lr
    optimizer = create_optimizer(args, model_without_ddp)
    loss_scaler = NativeScaler()

    lr_scheduler, _ = create_scheduler(args, optimizer)

    criterion = LabelSmoothingCrossEntropy()

    if args.mixup > 0.:
        # smoothing is handled with mixup label transform
        criterion = SoftTargetCrossEntropy()
    elif args.smoothing:
        criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
    else:
        criterion = torch.nn.CrossEntropyLoss()

    teacher_model = None
    if args.distillation_type != 'none':
        assert args.teacher_path, 'need to specify teacher-path when using distillation'
        print(f"Creating teacher model: {args.teacher_model}")
        teacher_model = create_model(
            args.teacher_model,
            pretrained=False,
            num_classes=args.nb_classes,
            global_pool='avg',
        )
        if args.teacher_path.startswith('https'):
            checkpoint = torch.hub.load_state_dict_from_url(args.teacher_path,
                                                            map_location='cpu',
                                                            check_hash=True)
        else:
            checkpoint = torch.load(args.teacher_path, map_location='cpu')
        teacher_model.load_state_dict(checkpoint['model'])
        teacher_model.to(device)
        teacher_model.eval()

    # wrap the criterion in our custom DistillationLoss, which
    # just dispatches to the original criterion if args.distillation_type is 'none'
    criterion = DistillationLoss(criterion, teacher_model,
                                 args.distillation_type,
                                 args.distillation_alpha,
                                 args.distillation_tau)

    output_dir = Path(args.output_dir)
    if args.resume:
        if args.resume.startswith('https'):
            checkpoint = torch.hub.load_state_dict_from_url(args.resume,
                                                            map_location='cpu',
                                                            check_hash=True)
        else:
            checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer'])
            lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
            args.start_epoch = checkpoint['epoch'] + 1
            if args.model_ema:
                utils._load_checkpoint_for_ema(model_ema,
                                               checkpoint['model_ema'])
            if 'scaler' in checkpoint:
                loss_scaler.load_state_dict(checkpoint['scaler'])

    if args.eval:
        test_stats = evaluate(data_loader_val, model, device)
        print(
            f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%"
        )
        return

    print(f"Start training for {args.epochs} epochs")
    start_time = time.time()
    max_accuracy = 0.0
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            data_loader_train.sampler.set_epoch(epoch)

        train_stats = train_one_epoch(
            model,
            criterion,
            data_loader_train,
            optimizer,
            device,
            epoch,
            loss_scaler,
            args.clip_grad,
            model_ema,
            mixup_fn,
            set_training_mode=args.finetune ==
            ''  # keep in eval mode during finetuning
        )

        lr_scheduler.step(epoch)
        if args.output_dir:
            checkpoint_paths = [output_dir / ('checkpoint_%04d.pth' % (epoch))]
            for checkpoint_path in checkpoint_paths:
                utils.save_on_master(
                    {
                        'model': model_without_ddp.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'lr_scheduler': lr_scheduler.state_dict(),
                        'epoch': epoch,
                        'model_ema': get_state_dict(model_ema),
                        'scaler': loss_scaler.state_dict(),
                        'args': args,
                    }, checkpoint_path)

        if not args.train_without_eval:
            test_stats = evaluate(data_loader_val, model, device)
            print(
                f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%"
            )
            max_accuracy = max(max_accuracy, test_stats["acc1"])
            print(f'Max accuracy: {max_accuracy:.2f}%')

            log_stats = {
                **{f'train_{k}': v
                   for k, v in train_stats.items()},
                **{f'test_{k}': v
                   for k, v in test_stats.items()}, 'epoch': epoch,
                'n_parameters': n_parameters
            }
        else:
            log_stats = {
                **{f'train_{k}': v
                   for k, v in train_stats.items()}, 'epoch': epoch,
                'n_parameters': n_parameters
            }
        if args.output_dir and utils.is_main_process():
            with (output_dir / "log.txt").open("a") as f:
                f.write(json.dumps(log_stats) + "\n")

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
Пример #5
0
def main():
    args, cfg = parse_config_args('child net training')

    # resolve logging
    output_dir = os.path.join(
        cfg.SAVE_PATH, "{}-{}".format(datetime.date.today().strftime('%m%d'),
                                      cfg.MODEL))

    if args.local_rank == 0:
        logger = get_logger(os.path.join(output_dir, 'retrain.log'))
        writer = SummaryWriter(os.path.join(output_dir, 'runs'))
    else:
        writer, logger = None, None

    # retrain model selection
    if cfg.NET.SELECTION == 481:
        arch_list = [[0], [3, 4, 3, 1], [3, 2, 3, 0], [3, 3, 3, 1, 1],
                     [3, 3, 3, 3], [3, 3, 3, 3], [0]]
        cfg.DATASET.IMAGE_SIZE = 224
    elif cfg.NET.SELECTION == 43:
        arch_list = [[0], [3], [3, 1], [3, 1], [3, 3, 3], [3, 3], [0]]
        cfg.DATASET.IMAGE_SIZE = 96
    elif cfg.NET.SELECTION == 14:
        arch_list = [[0], [3], [3, 3], [3, 3], [3], [3], [0]]
        cfg.DATASET.IMAGE_SIZE = 64
    elif cfg.NET.SELECTION == 114:
        arch_list = [[0], [3], [3, 3], [3, 3], [3, 3, 3], [3, 3], [0]]
        cfg.DATASET.IMAGE_SIZE = 160
    elif cfg.NET.SELECTION == 287:
        arch_list = [[0], [3], [3, 3], [3, 1, 3], [3, 3, 3, 3], [3, 3, 3], [0]]
        cfg.DATASET.IMAGE_SIZE = 224
    elif cfg.NET.SELECTION == 604:
        arch_list = [[0], [3, 3, 2, 3, 3], [3, 2, 3, 2, 3], [3, 2, 3, 2, 3],
                     [3, 3, 2, 2, 3, 3], [3, 3, 2, 3, 3, 3], [0]]
        cfg.DATASET.IMAGE_SIZE = 224
    else:
        raise ValueError("Model Retrain Selection is not Supported!")

    # define childnet architecture from arch_list
    stem = ['ds_r1_k3_s1_e1_c16_se0.25', 'cn_r1_k1_s1_c320_se0.25']
    choice_block_pool = [
        'ir_r1_k3_s2_e4_c24_se0.25', 'ir_r1_k5_s2_e4_c40_se0.25',
        'ir_r1_k3_s2_e6_c80_se0.25', 'ir_r1_k3_s1_e6_c96_se0.25',
        'ir_r1_k5_s2_e6_c192_se0.25'
    ]
    arch_def = [[stem[0]]] + [[
        choice_block_pool[idx]
        for repeat_times in range(len(arch_list[idx + 1]))
    ] for idx in range(len(choice_block_pool))] + [[stem[1]]]

    # generate childnet
    model = gen_childnet(arch_list,
                         arch_def,
                         num_classes=cfg.DATASET.NUM_CLASSES,
                         drop_rate=cfg.NET.DROPOUT_RATE,
                         global_pool=cfg.NET.GP)

    # initialize training parameters
    eval_metric = cfg.EVAL_METRICS
    best_metric, best_epoch, saver = None, None, None

    # initialize distributed parameters
    distributed = cfg.NUM_GPU > 1
    torch.cuda.set_device(args.local_rank)
    torch.distributed.init_process_group(backend='nccl', init_method='env://')
    if args.local_rank == 0:
        logger.info('Training on Process {} with {} GPUs.'.format(
            args.local_rank, cfg.NUM_GPU))

    # fix random seeds
    torch.manual_seed(cfg.SEED)
    torch.cuda.manual_seed_all(cfg.SEED)
    np.random.seed(cfg.SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # get parameters and FLOPs of model
    if args.local_rank == 0:
        macs, params = get_model_flops_params(
            model,
            input_size=(1, 3, cfg.DATASET.IMAGE_SIZE, cfg.DATASET.IMAGE_SIZE))
        logger.info('[Model-{}] Flops: {} Params: {}'.format(
            cfg.NET.SELECTION, macs, params))

    # create optimizer
    model = model.cuda()
    optimizer = create_optimizer(cfg, model)

    # optionally resume from a checkpoint
    resume_state, resume_epoch = {}, None
    if cfg.AUTO_RESUME:
        resume_state, resume_epoch = resume_checkpoint(model, cfg.RESUME_PATH)
        optimizer.load_state_dict(resume_state['optimizer'])
        del resume_state

    model_ema = None
    if cfg.NET.EMA.USE:
        model_ema = ModelEma(
            model,
            decay=cfg.NET.EMA.DECAY,
            device='cpu' if cfg.NET.EMA.FORCE_CPU else '',
            resume=cfg.RESUME_PATH if cfg.AUTO_RESUME else None)

    if distributed:
        if cfg.BATCHNORM.SYNC_BN:
            try:
                if HAS_APEX:
                    model = convert_syncbn_model(model)
                else:
                    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                        model)
                if args.local_rank == 0:
                    logger.info(
                        'Converted model to use Synchronized BatchNorm.')
            except Exception as e:
                if args.local_rank == 0:
                    logger.error(
                        'Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1 with exception {}'
                        .format(e))
        if HAS_APEX:
            model = DDP(model, delay_allreduce=True)
        else:
            if args.local_rank == 0:
                logger.info(
                    "Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP."
                )
            # can use device str in Torch >= 1.1
            model = DDP(model, device_ids=[args.local_rank])

    # imagenet train dataset
    train_dir = os.path.join(cfg.DATA_DIR, 'train')
    if not os.path.exists(train_dir) and args.local_rank == 0:
        logger.error('Training folder does not exist at: {}'.format(train_dir))
        exit(1)
    dataset_train = Dataset(train_dir)
    loader_train = create_loader(dataset_train,
                                 input_size=(3, cfg.DATASET.IMAGE_SIZE,
                                             cfg.DATASET.IMAGE_SIZE),
                                 batch_size=cfg.DATASET.BATCH_SIZE,
                                 is_training=True,
                                 color_jitter=cfg.AUGMENTATION.COLOR_JITTER,
                                 auto_augment=cfg.AUGMENTATION.AA,
                                 num_aug_splits=0,
                                 crop_pct=DEFAULT_CROP_PCT,
                                 mean=IMAGENET_DEFAULT_MEAN,
                                 std=IMAGENET_DEFAULT_STD,
                                 num_workers=cfg.WORKERS,
                                 distributed=distributed,
                                 collate_fn=None,
                                 pin_memory=cfg.DATASET.PIN_MEM,
                                 interpolation='random',
                                 re_mode=cfg.AUGMENTATION.RE_MODE,
                                 re_prob=cfg.AUGMENTATION.RE_PROB)

    # imagenet validation dataset
    eval_dir = os.path.join(cfg.DATA_DIR, 'val')
    if not os.path.exists(eval_dir) and args.local_rank == 0:
        logger.error(
            'Validation folder does not exist at: {}'.format(eval_dir))
        exit(1)
    dataset_eval = Dataset(eval_dir)
    loader_eval = create_loader(
        dataset_eval,
        input_size=(3, cfg.DATASET.IMAGE_SIZE, cfg.DATASET.IMAGE_SIZE),
        batch_size=cfg.DATASET.VAL_BATCH_MUL * cfg.DATASET.BATCH_SIZE,
        is_training=False,
        interpolation='bicubic',
        crop_pct=DEFAULT_CROP_PCT,
        mean=IMAGENET_DEFAULT_MEAN,
        std=IMAGENET_DEFAULT_STD,
        num_workers=cfg.WORKERS,
        distributed=distributed,
        pin_memory=cfg.DATASET.PIN_MEM)

    # whether to use label smoothing
    if cfg.AUGMENTATION.SMOOTHING > 0.:
        train_loss_fn = LabelSmoothingCrossEntropy(
            smoothing=cfg.AUGMENTATION.SMOOTHING).cuda()
        validate_loss_fn = nn.CrossEntropyLoss().cuda()
    else:
        train_loss_fn = nn.CrossEntropyLoss().cuda()
        validate_loss_fn = train_loss_fn

    # create learning rate scheduler
    lr_scheduler, num_epochs = create_scheduler(cfg, optimizer)
    start_epoch = resume_epoch if resume_epoch is not None else 0
    if start_epoch > 0:
        lr_scheduler.step(start_epoch)
    if args.local_rank == 0:
        logger.info('Scheduled epochs: {}'.format(num_epochs))

    try:
        best_record, best_ep = 0, 0
        for epoch in range(start_epoch, num_epochs):
            if distributed:
                loader_train.sampler.set_epoch(epoch)

            train_metrics = train_epoch(epoch,
                                        model,
                                        loader_train,
                                        optimizer,
                                        train_loss_fn,
                                        cfg,
                                        lr_scheduler=lr_scheduler,
                                        saver=saver,
                                        output_dir=output_dir,
                                        model_ema=model_ema,
                                        logger=logger,
                                        writer=writer,
                                        local_rank=args.local_rank)

            eval_metrics = validate(epoch,
                                    model,
                                    loader_eval,
                                    validate_loss_fn,
                                    cfg,
                                    logger=logger,
                                    writer=writer,
                                    local_rank=args.local_rank)

            if model_ema is not None and not cfg.NET.EMA.FORCE_CPU:
                ema_eval_metrics = validate(epoch,
                                            model_ema.ema,
                                            loader_eval,
                                            validate_loss_fn,
                                            cfg,
                                            log_suffix='_EMA',
                                            logger=logger,
                                            writer=writer,
                                            local_rank=args.local_rank)
                eval_metrics = ema_eval_metrics

            if lr_scheduler is not None:
                lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])

            update_summary(epoch,
                           train_metrics,
                           eval_metrics,
                           os.path.join(output_dir, 'summary.csv'),
                           write_header=best_metric is None)

            if saver is not None:
                # save proper checkpoint with eval metric
                save_metric = eval_metrics[eval_metric]
                best_metric, best_epoch = saver.save_checkpoint(
                    model,
                    optimizer,
                    cfg,
                    epoch=epoch,
                    model_ema=model_ema,
                    metric=save_metric)

            if best_record < eval_metrics[eval_metric]:
                best_record = eval_metrics[eval_metric]
                best_ep = epoch

            if args.local_rank == 0:
                logger.info('*** Best metric: {0} (epoch {1})'.format(
                    best_record, best_ep))

    except KeyboardInterrupt:
        pass

    if best_metric is not None:
        logger.info('*** Best metric: {0} (epoch {1})'.format(
            best_metric, best_epoch))
Пример #6
0
def main():
    args, cfg = parse_config_args('child net testing')

    # resolve logging
    output_dir = os.path.join(
        cfg.SAVE_PATH, "{}-{}".format(datetime.date.today().strftime('%m%d'),
                                      cfg.MODEL))

    if args.local_rank == 0:
        logger = get_logger(os.path.join(output_dir, 'test.log'))
        writer = SummaryWriter(os.path.join(output_dir, 'runs'))
    else:
        writer, logger = None, None

    # retrain model selection
    if cfg.NET.SELECTION == 470:
        arch_list = [[0], [3, 4, 3, 1], [3, 2, 3, 0], [3, 3, 3, 1],
                     [3, 3, 3, 3], [3, 3, 3, 3], [0]]
        cfg.DATASET.IMAGE_SIZE = 224
    elif cfg.NET.SELECTION == 42:
        arch_list = [[0], [3], [3, 1], [3, 1], [3, 3, 3], [3, 3], [0]]
        cfg.DATASET.IMAGE_SIZE = 96
    elif cfg.NET.SELECTION == 14:
        arch_list = [[0], [3], [3, 3], [3, 3], [3], [3], [0]]
        cfg.DATASET.IMAGE_SIZE = 64
    elif cfg.NET.SELECTION == 112:
        arch_list = [[0], [3], [3, 3], [3, 3], [3, 3, 3], [3, 3], [0]]
        cfg.DATASET.IMAGE_SIZE = 160
    elif cfg.NET.SELECTION == 285:
        arch_list = [[0], [3], [3, 3], [3, 1, 3], [3, 3, 3, 3], [3, 3, 3], [0]]
        cfg.DATASET.IMAGE_SIZE = 224
    elif cfg.NET.SELECTION == 600:
        arch_list = [[0], [3, 3, 2, 3, 3], [3, 2, 3, 2, 3], [3, 2, 3, 2, 3],
                     [3, 3, 2, 2, 3, 3], [3, 3, 2, 3, 3, 3], [0]]
        cfg.DATASET.IMAGE_SIZE = 224
    else:
        raise ValueError("Model Test Selection is not Supported!")

    # define childnet architecture from arch_list
    stem = ['ds_r1_k3_s1_e1_c16_se0.25', 'cn_r1_k1_s1_c320_se0.25']
    choice_block_pool = [
        'ir_r1_k3_s2_e4_c24_se0.25', 'ir_r1_k5_s2_e4_c40_se0.25',
        'ir_r1_k3_s2_e6_c80_se0.25', 'ir_r1_k3_s1_e6_c96_se0.25',
        'ir_r1_k3_s2_e6_c192_se0.25'
    ]
    arch_def = [[stem[0]]] + [[
        choice_block_pool[idx]
        for repeat_times in range(len(arch_list[idx + 1]))
    ] for idx in range(len(choice_block_pool))] + [[stem[1]]]

    # generate childnet
    model = gen_childnet(arch_list,
                         arch_def,
                         num_classes=cfg.DATASET.NUM_CLASSES,
                         drop_rate=cfg.NET.DROPOUT_RATE,
                         global_pool=cfg.NET.GP)

    if args.local_rank == 0:
        macs, params = get_model_flops_params(
            model,
            input_size=(1, 3, cfg.DATASET.IMAGE_SIZE, cfg.DATASET.IMAGE_SIZE))
        logger.info('[Model-{}] Flops: {} Params: {}'.format(
            cfg.NET.SELECTION, macs, params))

    # initialize distributed parameters
    torch.cuda.set_device(args.local_rank)
    torch.distributed.init_process_group(backend='nccl', init_method='env://')
    if args.local_rank == 0:
        logger.info("Training on Process {} with {} GPUs.".format(
            args.local_rank, cfg.NUM_GPU))

    # resume model from checkpoint
    assert cfg.AUTO_RESUME is True and os.path.exists(cfg.RESUME_PATH)
    _, __ = resume_checkpoint(model, cfg.RESUME_PATH)

    model = model.cuda()

    model_ema = None
    if cfg.NET.EMA.USE:
        # Important to create EMA model after cuda(), DP wrapper, and AMP but
        # before SyncBN and DDP wrapper
        model_ema = ModelEma(model,
                             decay=cfg.NET.EMA.DECAY,
                             device='cpu' if cfg.NET.EMA.FORCE_CPU else '',
                             resume=cfg.RESUME_PATH)

    # imagenet validation dataset
    eval_dir = os.path.join(cfg.DATA_DIR, 'val')
    if not os.path.exists(eval_dir) and args.local_rank == 0:
        logger.error(
            'Validation folder does not exist at: {}'.format(eval_dir))
        exit(1)

    dataset_eval = Dataset(eval_dir)
    loader_eval = create_loader(
        dataset_eval,
        input_size=(3, cfg.DATASET.IMAGE_SIZE, cfg.DATASET.IMAGE_SIZE),
        batch_size=cfg.DATASET.VAL_BATCH_MUL * cfg.DATASET.BATCH_SIZE,
        is_training=False,
        num_workers=cfg.WORKERS,
        distributed=True,
        pin_memory=cfg.DATASET.PIN_MEM,
        crop_pct=DEFAULT_CROP_PCT,
        mean=IMAGENET_DEFAULT_MEAN,
        std=IMAGENET_DEFAULT_STD)

    # only test accuracy of model-EMA
    validate_loss_fn = nn.CrossEntropyLoss().cuda()
    validate(0,
             model_ema.ema,
             loader_eval,
             validate_loss_fn,
             cfg,
             log_suffix='_EMA',
             logger=logger,
             writer=writer,
             local_rank=args.local_rank)
Пример #7
0
def main(args):
    utils.init_distributed_mode(args)

    print(args)

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    # random.seed(seed)

    cudnn.benchmark = True

    dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
    dataset_val, _ = build_dataset(is_train=False, args=args)

    if True:  # args.distributed:
        num_tasks = utils.get_world_size()
        global_rank = utils.get_rank()
        if args.repeated_aug:
            sampler_train = RASampler(dataset_train,
                                      num_replicas=num_tasks,
                                      rank=global_rank,
                                      shuffle=True)
        else:
            sampler_train = torch.utils.data.DistributedSampler(
                dataset_train,
                num_replicas=num_tasks,
                rank=global_rank,
                shuffle=True)
    else:
        sampler_train = torch.utils.data.RandomSampler(dataset_train)

    data_loader_train = torch.utils.data.DataLoader(
        dataset_train,
        sampler=sampler_train,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=True,
    )

    data_loader_val = torch.utils.data.DataLoader(dataset_val,
                                                  batch_size=int(
                                                      1.5 * args.batch_size),
                                                  shuffle=False,
                                                  num_workers=args.num_workers,
                                                  pin_memory=args.pin_mem,
                                                  drop_last=False)

    mixup_fn = None
    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
    if mixup_active:
        mixup_fn = Mixup(mixup_alpha=args.mixup,
                         cutmix_alpha=args.cutmix,
                         cutmix_minmax=args.cutmix_minmax,
                         prob=args.mixup_prob,
                         switch_prob=args.mixup_switch_prob,
                         mode=args.mixup_mode,
                         label_smoothing=args.smoothing,
                         num_classes=args.nb_classes)

    print(f"Creating model: {args.model}")
    model = create_model(
        args.model,
        pretrained=False,
        num_classes=args.nb_classes,
        drop_rate=args.drop,
        drop_path_rate=args.drop_path,
        drop_block_rate=args.drop_block,
    )

    # TODO: finetuning

    model.to(device)

    model_ema = None
    if args.model_ema:
        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
        model_ema = ModelEma(model,
                             decay=args.model_ema_decay,
                             device='cpu' if args.model_ema_force_cpu else '',
                             resume='')

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpu])
        model_without_ddp = model.module
    n_parameters = sum(p.numel() for p in model.parameters()
                       if p.requires_grad)
    print('number of params:', n_parameters)

    linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size(
    ) / 512.0
    args.lr = linear_scaled_lr
    optimizer = create_optimizer(args, model)
    loss_scaler = NativeScaler()

    lr_scheduler, _ = create_scheduler(args, optimizer)

    criterion = LabelSmoothingCrossEntropy()

    if args.mixup > 0.:
        # smoothing is handled with mixup label transform
        criterion = SoftTargetCrossEntropy()
    elif args.smoothing:
        criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
    else:
        criterion = torch.nn.CrossEntropyLoss()

    output_dir = Path(args.output_dir)
    if args.resume:
        if args.resume.startswith('https'):
            checkpoint = torch.hub.load_state_dict_from_url(args.resume,
                                                            map_location='cpu',
                                                            check_hash=True)
        else:
            checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer'])
            lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
            args.start_epoch = checkpoint['epoch'] + 1
            if args.model_ema:
                utils._load_checkpoint_for_ema(model_ema,
                                               checkpoint['model_ema'])

    if args.eval:
        test_stats = evaluate(data_loader_val, model, device)
        print(
            f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%"
        )
        return

    print("Start training")
    start_time = time.time()
    max_accuracy = 0.0
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            data_loader_train.sampler.set_epoch(epoch)

        train_stats = train_one_epoch(model, criterion, data_loader_train,
                                      optimizer, device, epoch, loss_scaler,
                                      args.clip_grad, model_ema, mixup_fn)

        lr_scheduler.step(epoch)
        if args.output_dir:
            checkpoint_paths = [output_dir / 'checkpoint.pth']
            for checkpoint_path in checkpoint_paths:
                utils.save_on_master(
                    {
                        'model': model_without_ddp.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'lr_scheduler': lr_scheduler.state_dict(),
                        'epoch': epoch,
                        'model_ema': get_state_dict(model_ema),
                        'args': args,
                    }, checkpoint_path)

        test_stats = evaluate(data_loader_val, model, device)
        print(
            f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%"
        )
        max_accuracy = max(max_accuracy, test_stats["acc1"])
        print(f'Max accuracy: {max_accuracy:.2f}%')

        log_stats = {
            **{f'train_{k}': v
               for k, v in train_stats.items()},
            **{f'test_{k}': v
               for k, v in test_stats.items()}, 'epoch': epoch,
            'n_parameters': n_parameters
        }

        if args.output_dir and utils.is_main_process():
            with (output_dir / "log.txt").open("a") as f:
                f.write(json.dumps(log_stats) + "\n")

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
Пример #8
0
def main(args):
    utils.init_distributed_mode(args)

    # disable any harsh augmentation in case of Self-supervise training
    if args.training_mode == 'SSL':
        print("NOTE: Smoothing, Mixup, CutMix, and AutoAugment will be disabled in case of Self-supervise training")
        args.smoothing = args.reprob = args.reprob = args.recount = args.mixup = args.cutmix = 0.0
        args.aa = ''

        if args.SiT_LinearEvaluation == 1:
            print("Warning: Linear Evaluation should be set to 0 during SSL training - changing SiT_LinearEvaluation to 0")
            args.SiT_LinearEvaluation = 0
        
    utils.print_args(args)

    device = torch.device(args.device)
    seed = args.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    cudnn.benchmark = True

    print("Loading dataset ....")
    dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)   
    dataset_val, _ = build_dataset(is_train=False, args=args)
    

    num_tasks = utils.get_world_size()
    global_rank = utils.get_rank()
    if args.repeated_aug:
        sampler_train = RASampler(dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True)
    else:
        sampler_train = torch.utils.data.DistributedSampler(dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True)
    
    sampler_val = torch.utils.data.SequentialSampler(dataset_val)


    data_loader_train = torch.utils.data.DataLoader(dataset_train, sampler=sampler_train,
        batch_size=args.batch_size, num_workers=args.num_workers,
        pin_memory=args.pin_mem, drop_last=True, collate_fn=collate_fn)

    data_loader_val = torch.utils.data.DataLoader(dataset_val, sampler=sampler_val,
        batch_size=int(1.5 * args.batch_size), num_workers=args.num_workers,
        pin_memory=args.pin_mem, drop_last=False, collate_fn=collate_fn)

    mixup_fn = None
    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
    if mixup_active:
        mixup_fn = Mixup(
            mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
            prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
            label_smoothing=args.smoothing, num_classes=args.nb_classes)

    print(f"Creating model: {args.model}")
    model = create_model(
        args.model, pretrained=False, num_classes=args.nb_classes,
        drop_rate=args.drop, drop_path_rate=args.drop_path, representation_size=args.representation_size,
        drop_block_rate=None, training_mode=args.training_mode)

    if args.finetune:
        checkpoint = torch.load(args.finetune, map_location='cpu')

        checkpoint_model = checkpoint['model']
        state_dict = model.state_dict()
        for k in ['rot_head.weight', 'rot_head.bias', 'contrastive_head.weight', 'contrastive_head.bias']:
            if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
                print(f"Removing key {k} from pretrained checkpoint")
                del checkpoint_model[k]

        # interpolate position embedding
        pos_embed_checkpoint = checkpoint_model['pos_embed']
        embedding_size = pos_embed_checkpoint.shape[-1]
        num_patches = model.patch_embed.num_patches
        num_extra_tokens = model.pos_embed.shape[-2] - num_patches
        orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
        new_size = int(num_patches ** 0.5)
        extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
        pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
        pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
        pos_tokens = torch.nn.functional.interpolate(
            pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
        pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
        new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
        checkpoint_model['pos_embed'] = new_pos_embed

        model.load_state_dict(checkpoint_model, strict=False)

    model.to(device)

    # Freeze the backbone in case of linear evaluation
    if args.SiT_LinearEvaluation == 1:
        requires_grad(model, False)
        
        model.rot_head.weight.requires_grad = True
        model.rot_head.bias.requires_grad = True
        
        model.contrastive_head.weight.requires_grad = True
        model.contrastive_head.bias.requires_grad = True
        
        if args.representation_size is not None:
            model.pre_logits_rot.fc.weight.requires_grad = True
            model.pre_logits_rot.fc.bias.requires_grad = True
            
            model.pre_logits_contrastive.fc.weight.requires_grad = True
            model.pre_logits_contrastive.fc.bias.requires_grad = True            


    model_ema = None
    if args.model_ema:
        model_ema = ModelEma(model, decay=args.model_ema_decay,
            device='cpu' if args.model_ema_force_cpu else '', resume='')

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module
        
    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print('number of params:', n_parameters)

    linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0
    args.lr = linear_scaled_lr
    optimizer = create_optimizer(args, model_without_ddp)
    loss_scaler = NativeScaler()

    lr_scheduler, _ = create_scheduler(args, optimizer)

    if args.training_mode == 'SSL':
        criterion = MTL_loss(args.device, args.batch_size)
    elif args.training_mode == 'finetune' and args.mixup > 0.:
        criterion = SoftTargetCrossEntropy()
    else:
        criterion = torch.nn.CrossEntropyLoss()



    output_dir = Path(args.output_dir)
    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer'])
            lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
            args.start_epoch = checkpoint['epoch'] + 1
            if args.model_ema:
                utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema'])
            if 'scaler' in checkpoint:
                loss_scaler.load_state_dict(checkpoint['scaler'])

    if args.eval:
        test_stats = evaluate_SSL(data_loader_val, model, device)
        print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
        return

    print(f"Start training for {args.epochs} epochs")
    start_time = time.time()
    max_accuracy = 0.0
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            data_loader_train.sampler.set_epoch(epoch)

        if args.training_mode == 'SSL':
            train_stats = train_SSL(
                model, criterion, data_loader_train, optimizer, device, epoch, loss_scaler,
                args.clip_grad, model_ema, mixup_fn)
        else:
            train_stats = train_finetune(
                model, criterion, data_loader_train, optimizer, device, epoch, loss_scaler,
                args.clip_grad, model_ema, mixup_fn)
            
        lr_scheduler.step(epoch)
            
        if epoch%args.validate_every == 0:
            if args.output_dir:
                checkpoint_paths = [output_dir / 'checkpoint.pth']
                for checkpoint_path in checkpoint_paths:
                    utils.save_on_master({
                        'model': model_without_ddp.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'lr_scheduler': lr_scheduler.state_dict(),
                        'epoch': epoch,
                        'model_ema': get_state_dict(model_ema),
                        'scaler': loss_scaler.state_dict(),
                        'args': args,
                    }, checkpoint_path)
    
            if args.training_mode == 'SSL':
                test_stats = evaluate_SSL(data_loader_val, model, device, epoch, args.output_dir)
            else:
                test_stats = evaluate_finetune(data_loader_val, model, device)

                print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
                max_accuracy = max(max_accuracy, test_stats["acc1"])
                print(f'Max accuracy: {max_accuracy:.2f}%')

        log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
                     **{f'test_{k}': v for k, v in test_stats.items()},
                     'epoch': epoch,
                     'n_parameters': n_parameters}

        if args.output_dir and utils.is_main_process():
            with (output_dir / "log.txt").open("a") as f:
                f.write(json.dumps(log_stats) + "\n")

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
Пример #9
0
def main(args, ds_init):
    utils.init_distributed_mode(args)

    if ds_init is not None:
        utils.create_ds_config(args)

    print(args)

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    # random.seed(seed)

    cudnn.benchmark = True

    dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
    if args.disable_eval_during_finetuning:
        dataset_val = None
    else:
        dataset_val, _ = build_dataset(is_train=False, args=args)

    if True:  # args.distributed:
        num_tasks = utils.get_world_size()
        global_rank = utils.get_rank()
        sampler_train = torch.utils.data.DistributedSampler(
            dataset_train,
            num_replicas=num_tasks,
            rank=global_rank,
            shuffle=True)
        print("Sampler_train = %s" % str(sampler_train))
        if args.dist_eval:
            if len(dataset_val) % num_tasks != 0:
                print(
                    'Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
                    'This will slightly alter validation results as extra duplicate entries are added to achieve '
                    'equal num of samples per-process.')
            sampler_val = torch.utils.data.DistributedSampler(
                dataset_val,
                num_replicas=num_tasks,
                rank=global_rank,
                shuffle=False)
        else:
            sampler_val = torch.utils.data.SequentialSampler(dataset_val)
    else:
        sampler_train = torch.utils.data.RandomSampler(dataset_train)
        sampler_val = torch.utils.data.SequentialSampler(dataset_val)

    if global_rank == 0 and args.log_dir is not None:
        os.makedirs(args.log_dir, exist_ok=True)
        log_writer = utils.TensorboardLogger(log_dir=args.log_dir)
    else:
        log_writer = None

    data_loader_train = torch.utils.data.DataLoader(
        dataset_train,
        sampler=sampler_train,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=True,
    )

    if dataset_val is not None:
        data_loader_val = torch.utils.data.DataLoader(
            dataset_val,
            sampler=sampler_val,
            batch_size=int(1.5 * args.batch_size),
            num_workers=args.num_workers,
            pin_memory=args.pin_mem,
            drop_last=False)
    else:
        data_loader_val = None

    mixup_fn = None
    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
    if mixup_active:
        print("Mixup is activated!")
        mixup_fn = Mixup(mixup_alpha=args.mixup,
                         cutmix_alpha=args.cutmix,
                         cutmix_minmax=args.cutmix_minmax,
                         prob=args.mixup_prob,
                         switch_prob=args.mixup_switch_prob,
                         mode=args.mixup_mode,
                         label_smoothing=args.smoothing,
                         num_classes=args.nb_classes)

    model = create_model(
        args.model,
        pretrained=False,
        num_classes=args.nb_classes,
        drop_rate=args.drop,
        drop_path_rate=args.drop_path,
        attn_drop_rate=args.attn_drop_rate,
        drop_block_rate=None,
        use_mean_pooling=args.use_mean_pooling,
        init_scale=args.init_scale,
        use_rel_pos_bias=args.rel_pos_bias,
        use_abs_pos_emb=args.abs_pos_emb,
        init_values=args.layer_scale_init_value,
    )

    patch_size = model.patch_embed.patch_size
    print("Patch size = %s" % str(patch_size))
    args.window_size = (args.input_size // patch_size[0],
                        args.input_size // patch_size[1])
    args.patch_size = patch_size

    if args.finetune:
        if args.finetune.startswith('https'):
            checkpoint = torch.hub.load_state_dict_from_url(args.finetune,
                                                            map_location='cpu',
                                                            check_hash=True)
        else:
            checkpoint = torch.load(args.finetune, map_location='cpu')

        print("Load ckpt from %s" % args.finetune)
        checkpoint_model = None
        for model_key in args.model_key.split('|'):
            if model_key in checkpoint:
                checkpoint_model = checkpoint[model_key]
                print("Load state_dict by model_key = %s" % model_key)
                break
        if checkpoint_model is None:
            checkpoint_model = checkpoint
        state_dict = model.state_dict()
        for k in ['head.weight', 'head.bias']:
            if k in checkpoint_model and checkpoint_model[
                    k].shape != state_dict[k].shape:
                print(f"Removing key {k} from pretrained checkpoint")
                del checkpoint_model[k]

        if model.use_rel_pos_bias and "rel_pos_bias.relative_position_bias_table" in checkpoint_model:
            print(
                "Expand the shared relative position embedding to each transformer block. "
            )
            num_layers = model.get_num_layers()
            rel_pos_bias = checkpoint_model[
                "rel_pos_bias.relative_position_bias_table"]
            for i in range(num_layers):
                checkpoint_model["blocks.%d.attn.relative_position_bias_table"
                                 % i] = rel_pos_bias.clone()

            checkpoint_model.pop("rel_pos_bias.relative_position_bias_table")

        all_keys = list(checkpoint_model.keys())
        for key in all_keys:
            if "relative_position_index" in key:
                checkpoint_model.pop(key)

            if "relative_position_bias_table" in key:
                rel_pos_bias = checkpoint_model[key]
                src_num_pos, num_attn_heads = rel_pos_bias.size()
                dst_num_pos, _ = model.state_dict()[key].size()
                dst_patch_shape = model.patch_embed.patch_shape
                if dst_patch_shape[0] != dst_patch_shape[1]:
                    raise NotImplementedError()
                num_extra_tokens = dst_num_pos - (
                    dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1)
                src_size = int((src_num_pos - num_extra_tokens)**0.5)
                dst_size = int((dst_num_pos - num_extra_tokens)**0.5)
                if src_size != dst_size:
                    print("Position interpolate for %s from %dx%d to %dx%d" %
                          (key, src_size, src_size, dst_size, dst_size))
                    extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
                    rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]

                    def geometric_progression(a, r, n):
                        return a * (1.0 - r**n) / (1.0 - r)

                    left, right = 1.01, 1.5
                    while right - left > 1e-6:
                        q = (left + right) / 2.0
                        gp = geometric_progression(1, q, src_size // 2)
                        if gp > dst_size // 2:
                            right = q
                        else:
                            left = q

                    # if q > 1.090307:
                    #     q = 1.090307

                    dis = []
                    cur = 1
                    for i in range(src_size // 2):
                        dis.append(cur)
                        cur += q**(i + 1)

                    r_ids = [-_ for _ in reversed(dis)]

                    x = r_ids + [0] + dis
                    y = r_ids + [0] + dis

                    t = dst_size // 2.0
                    dx = np.arange(-t, t + 0.1, 1.0)
                    dy = np.arange(-t, t + 0.1, 1.0)

                    print("Original positions = %s" % str(x))
                    print("Target positions = %s" % str(dx))

                    all_rel_pos_bias = []

                    for i in range(num_attn_heads):
                        z = rel_pos_bias[:, i].view(src_size,
                                                    src_size).float().numpy()
                        f = interpolate.interp2d(x, y, z, kind='cubic')
                        all_rel_pos_bias.append(
                            torch.Tensor(f(dx, dy)).contiguous().view(
                                -1, 1).to(rel_pos_bias.device))

                    rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)

                    new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens),
                                                 dim=0)
                    checkpoint_model[key] = new_rel_pos_bias

        # interpolate position embedding
        if 'pos_embed' in checkpoint_model:
            pos_embed_checkpoint = checkpoint_model['pos_embed']
            embedding_size = pos_embed_checkpoint.shape[-1]
            num_patches = model.patch_embed.num_patches
            num_extra_tokens = model.pos_embed.shape[-2] - num_patches
            # height (== width) for the checkpoint position embedding
            orig_size = int(
                (pos_embed_checkpoint.shape[-2] - num_extra_tokens)**0.5)
            # height (== width) for the new position embedding
            new_size = int(num_patches**0.5)
            # class_token and dist_token are kept unchanged
            if orig_size != new_size:
                print("Position interpolate from %dx%d to %dx%d" %
                      (orig_size, orig_size, new_size, new_size))
                extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
                # only the position tokens are interpolated
                pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
                pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size,
                                                embedding_size).permute(
                                                    0, 3, 1, 2)
                pos_tokens = torch.nn.functional.interpolate(
                    pos_tokens,
                    size=(new_size, new_size),
                    mode='bicubic',
                    align_corners=False)
                pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
                new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
                checkpoint_model['pos_embed'] = new_pos_embed

        utils.load_state_dict(model,
                              checkpoint_model,
                              prefix=args.model_prefix)
        # model.load_state_dict(checkpoint_model, strict=False)

    model.to(device)

    model_ema = None
    if args.model_ema:
        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
        model_ema = ModelEma(model,
                             decay=args.model_ema_decay,
                             device='cpu' if args.model_ema_force_cpu else '',
                             resume='')
        print("Using EMA with decay = %.8f" % args.model_ema_decay)

    model_without_ddp = model
    n_parameters = sum(p.numel() for p in model.parameters()
                       if p.requires_grad)

    print("Model = %s" % str(model_without_ddp))
    print('number of params:', n_parameters)

    total_batch_size = args.batch_size * args.update_freq * utils.get_world_size(
    )
    num_training_steps_per_epoch = len(dataset_train) // total_batch_size
    print("LR = %.8f" % args.lr)
    print("Batch size = %d" % total_batch_size)
    print("Update frequent = %d" % args.update_freq)
    print("Number of training examples = %d" % len(dataset_train))
    print("Number of training training per epoch = %d" %
          num_training_steps_per_epoch)

    num_layers = model_without_ddp.get_num_layers()
    if args.layer_decay < 1.0:
        assigner = LayerDecayValueAssigner(
            list(args.layer_decay**(num_layers + 1 - i)
                 for i in range(num_layers + 2)))
    else:
        assigner = None

    if assigner is not None:
        print("Assigned values = %s" % str(assigner.values))

    skip_weight_decay_list = model.no_weight_decay()
    if args.disable_weight_decay_on_rel_pos_bias:
        for i in range(num_layers):
            skip_weight_decay_list.add(
                "blocks.%d.attn.relative_position_bias_table" % i)

    if args.enable_deepspeed:
        loss_scaler = None
        optimizer_params = get_parameter_groups(
            model, args.weight_decay, skip_weight_decay_list,
            assigner.get_layer_id if assigner is not None else None,
            assigner.get_scale if assigner is not None else None)
        model, optimizer, _, _ = ds_init(
            args=args,
            model=model,
            model_parameters=optimizer_params,
            dist_init_required=not args.distributed,
        )

        print("model.gradient_accumulation_steps() = %d" %
              model.gradient_accumulation_steps())
        assert model.gradient_accumulation_steps() == args.update_freq
    else:
        if args.distributed:
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[args.gpu], find_unused_parameters=True)
            model_without_ddp = model.module

        optimizer = create_optimizer(args,
                                     model_without_ddp,
                                     skip_list=skip_weight_decay_list,
                                     get_num_layer=assigner.get_layer_id
                                     if assigner is not None else None,
                                     get_layer_scale=assigner.get_scale
                                     if assigner is not None else None)
        loss_scaler = NativeScaler()

    print("Use step level LR scheduler!")
    lr_schedule_values = utils.cosine_scheduler(
        args.lr,
        args.min_lr,
        args.epochs,
        num_training_steps_per_epoch,
        warmup_epochs=args.warmup_epochs,
        warmup_steps=args.warmup_steps,
    )
    if args.weight_decay_end is None:
        args.weight_decay_end = args.weight_decay
    wd_schedule_values = utils.cosine_scheduler(args.weight_decay,
                                                args.weight_decay_end,
                                                args.epochs,
                                                num_training_steps_per_epoch)
    print("Max WD = %.7f, Min WD = %.7f" %
          (max(wd_schedule_values), min(wd_schedule_values)))

    if mixup_fn is not None:
        # smoothing is handled with mixup label transform
        criterion = SoftTargetCrossEntropy()
    elif args.smoothing > 0.:
        criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
    else:
        criterion = torch.nn.CrossEntropyLoss()

    print("criterion = %s" % str(criterion))

    utils.auto_load_model(args=args,
                          model=model,
                          model_without_ddp=model_without_ddp,
                          optimizer=optimizer,
                          loss_scaler=loss_scaler,
                          model_ema=model_ema)

    if args.eval:
        test_stats = evaluate(data_loader_val, model, device)
        print(
            f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%"
        )
        exit(0)

    print(f"Start training for {args.epochs} epochs")
    start_time = time.time()
    max_accuracy = 0.0
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            data_loader_train.sampler.set_epoch(epoch)
        if log_writer is not None:
            log_writer.set_step(epoch * num_training_steps_per_epoch *
                                args.update_freq)
        train_stats = train_one_epoch(
            model,
            criterion,
            data_loader_train,
            optimizer,
            device,
            epoch,
            loss_scaler,
            args.clip_grad,
            model_ema,
            mixup_fn,
            log_writer=log_writer,
            start_steps=epoch * num_training_steps_per_epoch,
            lr_schedule_values=lr_schedule_values,
            wd_schedule_values=wd_schedule_values,
            num_training_steps_per_epoch=num_training_steps_per_epoch,
            update_freq=args.update_freq,
        )
        if args.output_dir and args.save_ckpt:
            if (epoch +
                    1) % args.save_ckpt_freq == 0 or epoch + 1 == args.epochs:
                utils.save_model(args=args,
                                 model=model,
                                 model_without_ddp=model_without_ddp,
                                 optimizer=optimizer,
                                 loss_scaler=loss_scaler,
                                 epoch=epoch,
                                 model_ema=model_ema)
        if data_loader_val is not None:
            test_stats = evaluate(data_loader_val, model, device)
            print(
                f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%"
            )
            if max_accuracy < test_stats["acc1"]:
                max_accuracy = test_stats["acc1"]
                if args.output_dir and args.save_ckpt:
                    utils.save_model(args=args,
                                     model=model,
                                     model_without_ddp=model_without_ddp,
                                     optimizer=optimizer,
                                     loss_scaler=loss_scaler,
                                     epoch="best",
                                     model_ema=model_ema)

            print(f'Max accuracy: {max_accuracy:.2f}%')
            if log_writer is not None:
                log_writer.update(test_acc1=test_stats['acc1'],
                                  head="perf",
                                  step=epoch)
                log_writer.update(test_acc5=test_stats['acc5'],
                                  head="perf",
                                  step=epoch)
                log_writer.update(test_loss=test_stats['loss'],
                                  head="perf",
                                  step=epoch)

            log_stats = {
                **{f'train_{k}': v
                   for k, v in train_stats.items()},
                **{f'test_{k}': v
                   for k, v in test_stats.items()}, 'epoch': epoch,
                'n_parameters': n_parameters
            }
        else:
            log_stats = {
                **{f'train_{k}': v
                   for k, v in train_stats.items()},
                # **{f'test_{k}': v for k, v in test_stats.items()},
                'epoch': epoch,
                'n_parameters': n_parameters
            }

        if args.output_dir and utils.is_main_process():
            if log_writer is not None:
                log_writer.flush()
            with open(os.path.join(args.output_dir, "log.txt"),
                      mode="a",
                      encoding="utf-8") as f:
                f.write(json.dumps(log_stats) + "\n")

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))