コード例 #1
0
ファイル: main.py プロジェクト: volgachen/deit
def main(args):
    utils.init_distributed_mode(args)

    print(args)

    if args.distillation_type != 'none' and args.finetune and not args.eval:
        raise NotImplementedError(
            "Finetuning with distillation not yet supported")

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    # random.seed(seed)

    cudnn.benchmark = True

    dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
    dataset_val, _ = build_dataset(is_train=False, args=args)

    if True:  # args.distributed:
        num_tasks = utils.get_world_size()
        global_rank = utils.get_rank()
        if args.repeated_aug:
            sampler_train = RASampler(dataset_train,
                                      num_replicas=num_tasks,
                                      rank=global_rank,
                                      shuffle=True)
        else:
            sampler_train = torch.utils.data.DistributedSampler(
                dataset_train,
                num_replicas=num_tasks,
                rank=global_rank,
                shuffle=True)
        if args.dist_eval:
            if len(dataset_val) % num_tasks != 0:
                print(
                    'Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
                    'This will slightly alter validation results as extra duplicate entries are added to achieve '
                    'equal num of samples per-process.')
            sampler_val = torch.utils.data.DistributedSampler(
                dataset_val,
                num_replicas=num_tasks,
                rank=global_rank,
                shuffle=False)
        else:
            sampler_val = torch.utils.data.SequentialSampler(dataset_val)
    else:
        sampler_train = torch.utils.data.RandomSampler(dataset_train)
        sampler_val = torch.utils.data.SequentialSampler(dataset_val)

    data_loader_train = torch.utils.data.DataLoader(
        dataset_train,
        sampler=sampler_train,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=True,
    )

    data_loader_val = torch.utils.data.DataLoader(dataset_val,
                                                  sampler=sampler_val,
                                                  batch_size=int(
                                                      1.5 * args.batch_size),
                                                  num_workers=args.num_workers,
                                                  pin_memory=args.pin_mem,
                                                  drop_last=False)

    mixup_fn = None
    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
    if mixup_active:
        mixup_fn = Mixup(mixup_alpha=args.mixup,
                         cutmix_alpha=args.cutmix,
                         cutmix_minmax=args.cutmix_minmax,
                         prob=args.mixup_prob,
                         switch_prob=args.mixup_switch_prob,
                         mode=args.mixup_mode,
                         label_smoothing=args.smoothing,
                         num_classes=args.nb_classes)

    print(f"Creating model: {args.model}")
    model = create_model(
        args.model,
        pretrained=False,
        num_classes=args.nb_classes,
        drop_rate=args.drop,
        drop_path_rate=args.drop_path,
        drop_block_rate=None,
    )

    if args.finetune:
        if args.finetune.startswith('https'):
            checkpoint = torch.hub.load_state_dict_from_url(args.finetune,
                                                            map_location='cpu',
                                                            check_hash=True)
        else:
            checkpoint = torch.load(args.finetune, map_location='cpu')

        checkpoint_model = checkpoint['model']
        state_dict = model.state_dict()
        for k in [
                'head.weight', 'head.bias', 'head_dist.weight',
                'head_dist.bias'
        ]:
            if k in checkpoint_model and checkpoint_model[
                    k].shape != state_dict[k].shape:
                print(f"Removing key {k} from pretrained checkpoint")
                del checkpoint_model[k]

        # interpolate position embedding
        pos_embed_checkpoint = checkpoint_model['pos_embed']
        embedding_size = pos_embed_checkpoint.shape[-1]
        num_patches = model.patch_embed.num_patches
        num_extra_tokens = model.pos_embed.shape[-2] - num_patches
        # height (== width) for the checkpoint position embedding
        orig_size = int(
            (pos_embed_checkpoint.shape[-2] - num_extra_tokens)**0.5)
        # height (== width) for the new position embedding
        new_size = int(num_patches**0.5)
        # class_token and dist_token are kept unchanged
        extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
        # only the position tokens are interpolated
        pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
        pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size,
                                        embedding_size).permute(0, 3, 1, 2)
        pos_tokens = torch.nn.functional.interpolate(pos_tokens,
                                                     size=(new_size, new_size),
                                                     mode='bicubic',
                                                     align_corners=False)
        pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
        new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
        checkpoint_model['pos_embed'] = new_pos_embed

        model.load_state_dict(checkpoint_model, strict=False)

    model.to(device)

    model_ema = None
    if args.model_ema:
        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
        model_ema = ModelEma(model,
                             decay=args.model_ema_decay,
                             device='cpu' if args.model_ema_force_cpu else '',
                             resume='')

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpu])
        model_without_ddp = model.module
    n_parameters = sum(p.numel() for p in model.parameters()
                       if p.requires_grad)
    print('number of params:', n_parameters)

    linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size(
    ) / 512.0
    args.lr = linear_scaled_lr
    optimizer = create_optimizer(args, model_without_ddp)
    loss_scaler = NativeScaler()

    lr_scheduler, _ = create_scheduler(args, optimizer)

    criterion = LabelSmoothingCrossEntropy()

    if args.mixup > 0.:
        # smoothing is handled with mixup label transform
        criterion = SoftTargetCrossEntropy()
    elif args.smoothing:
        criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
    else:
        criterion = torch.nn.CrossEntropyLoss()

    teacher_model = None
    if args.distillation_type != 'none':
        assert args.teacher_path, 'need to specify teacher-path when using distillation'
        print(f"Creating teacher model: {args.teacher_model}")
        teacher_model = create_model(
            args.teacher_model,
            pretrained=False,
            num_classes=args.nb_classes,
            global_pool='avg',
        )
        if args.teacher_path.startswith('https'):
            checkpoint = torch.hub.load_state_dict_from_url(args.teacher_path,
                                                            map_location='cpu',
                                                            check_hash=True)
        else:
            checkpoint = torch.load(args.teacher_path, map_location='cpu')
        teacher_model.load_state_dict(checkpoint['model'])
        teacher_model.to(device)
        teacher_model.eval()

    # wrap the criterion in our custom DistillationLoss, which
    # just dispatches to the original criterion if args.distillation_type is 'none'
    criterion = DistillationLoss(criterion, teacher_model,
                                 args.distillation_type,
                                 args.distillation_alpha,
                                 args.distillation_tau)

    output_dir = Path(args.output_dir)
    if args.resume:
        if args.resume.startswith('https'):
            checkpoint = torch.hub.load_state_dict_from_url(args.resume,
                                                            map_location='cpu',
                                                            check_hash=True)
        else:
            checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer'])
            lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
            args.start_epoch = checkpoint['epoch'] + 1
            if args.model_ema:
                utils._load_checkpoint_for_ema(model_ema,
                                               checkpoint['model_ema'])
            if 'scaler' in checkpoint:
                loss_scaler.load_state_dict(checkpoint['scaler'])

    if args.eval:
        test_stats = evaluate(data_loader_val, model, device)
        print(
            f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%"
        )
        return

    print(f"Start training for {args.epochs} epochs")
    start_time = time.time()
    max_accuracy = 0.0
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            data_loader_train.sampler.set_epoch(epoch)

        train_stats = train_one_epoch(
            model,
            criterion,
            data_loader_train,
            optimizer,
            device,
            epoch,
            loss_scaler,
            args.clip_grad,
            model_ema,
            mixup_fn,
            set_training_mode=args.finetune ==
            ''  # keep in eval mode during finetuning
        )

        lr_scheduler.step(epoch)
        if args.output_dir:
            checkpoint_paths = [output_dir / ('checkpoint_%04d.pth' % (epoch))]
            for checkpoint_path in checkpoint_paths:
                utils.save_on_master(
                    {
                        'model': model_without_ddp.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'lr_scheduler': lr_scheduler.state_dict(),
                        'epoch': epoch,
                        'model_ema': get_state_dict(model_ema),
                        'scaler': loss_scaler.state_dict(),
                        'args': args,
                    }, checkpoint_path)

        if not args.train_without_eval:
            test_stats = evaluate(data_loader_val, model, device)
            print(
                f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%"
            )
            max_accuracy = max(max_accuracy, test_stats["acc1"])
            print(f'Max accuracy: {max_accuracy:.2f}%')

            log_stats = {
                **{f'train_{k}': v
                   for k, v in train_stats.items()},
                **{f'test_{k}': v
                   for k, v in test_stats.items()}, 'epoch': epoch,
                'n_parameters': n_parameters
            }
        else:
            log_stats = {
                **{f'train_{k}': v
                   for k, v in train_stats.items()}, 'epoch': epoch,
                'n_parameters': n_parameters
            }
        if args.output_dir and utils.is_main_process():
            with (output_dir / "log.txt").open("a") as f:
                f.write(json.dumps(log_stats) + "\n")

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
コード例 #2
0
def main(args):
    utils.init_distributed_mode(args)

    # disable any harsh augmentation in case of Self-supervise training
    if args.training_mode == 'SSL':
        print("NOTE: Smoothing, Mixup, CutMix, and AutoAugment will be disabled in case of Self-supervise training")
        args.smoothing = args.reprob = args.reprob = args.recount = args.mixup = args.cutmix = 0.0
        args.aa = ''

        if args.SiT_LinearEvaluation == 1:
            print("Warning: Linear Evaluation should be set to 0 during SSL training - changing SiT_LinearEvaluation to 0")
            args.SiT_LinearEvaluation = 0
        
    utils.print_args(args)

    device = torch.device(args.device)
    seed = args.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    cudnn.benchmark = True

    print("Loading dataset ....")
    dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)   
    dataset_val, _ = build_dataset(is_train=False, args=args)
    

    num_tasks = utils.get_world_size()
    global_rank = utils.get_rank()
    if args.repeated_aug:
        sampler_train = RASampler(dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True)
    else:
        sampler_train = torch.utils.data.DistributedSampler(dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True)
    
    sampler_val = torch.utils.data.SequentialSampler(dataset_val)


    data_loader_train = torch.utils.data.DataLoader(dataset_train, sampler=sampler_train,
        batch_size=args.batch_size, num_workers=args.num_workers,
        pin_memory=args.pin_mem, drop_last=True, collate_fn=collate_fn)

    data_loader_val = torch.utils.data.DataLoader(dataset_val, sampler=sampler_val,
        batch_size=int(1.5 * args.batch_size), num_workers=args.num_workers,
        pin_memory=args.pin_mem, drop_last=False, collate_fn=collate_fn)

    mixup_fn = None
    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
    if mixup_active:
        mixup_fn = Mixup(
            mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
            prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
            label_smoothing=args.smoothing, num_classes=args.nb_classes)

    print(f"Creating model: {args.model}")
    model = create_model(
        args.model, pretrained=False, num_classes=args.nb_classes,
        drop_rate=args.drop, drop_path_rate=args.drop_path, representation_size=args.representation_size,
        drop_block_rate=None, training_mode=args.training_mode)

    if args.finetune:
        checkpoint = torch.load(args.finetune, map_location='cpu')

        checkpoint_model = checkpoint['model']
        state_dict = model.state_dict()
        for k in ['rot_head.weight', 'rot_head.bias', 'contrastive_head.weight', 'contrastive_head.bias']:
            if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
                print(f"Removing key {k} from pretrained checkpoint")
                del checkpoint_model[k]

        # interpolate position embedding
        pos_embed_checkpoint = checkpoint_model['pos_embed']
        embedding_size = pos_embed_checkpoint.shape[-1]
        num_patches = model.patch_embed.num_patches
        num_extra_tokens = model.pos_embed.shape[-2] - num_patches
        orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
        new_size = int(num_patches ** 0.5)
        extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
        pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
        pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
        pos_tokens = torch.nn.functional.interpolate(
            pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
        pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
        new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
        checkpoint_model['pos_embed'] = new_pos_embed

        model.load_state_dict(checkpoint_model, strict=False)

    model.to(device)

    # Freeze the backbone in case of linear evaluation
    if args.SiT_LinearEvaluation == 1:
        requires_grad(model, False)
        
        model.rot_head.weight.requires_grad = True
        model.rot_head.bias.requires_grad = True
        
        model.contrastive_head.weight.requires_grad = True
        model.contrastive_head.bias.requires_grad = True
        
        if args.representation_size is not None:
            model.pre_logits_rot.fc.weight.requires_grad = True
            model.pre_logits_rot.fc.bias.requires_grad = True
            
            model.pre_logits_contrastive.fc.weight.requires_grad = True
            model.pre_logits_contrastive.fc.bias.requires_grad = True            


    model_ema = None
    if args.model_ema:
        model_ema = ModelEma(model, decay=args.model_ema_decay,
            device='cpu' if args.model_ema_force_cpu else '', resume='')

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module
        
    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print('number of params:', n_parameters)

    linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0
    args.lr = linear_scaled_lr
    optimizer = create_optimizer(args, model_without_ddp)
    loss_scaler = NativeScaler()

    lr_scheduler, _ = create_scheduler(args, optimizer)

    if args.training_mode == 'SSL':
        criterion = MTL_loss(args.device, args.batch_size)
    elif args.training_mode == 'finetune' and args.mixup > 0.:
        criterion = SoftTargetCrossEntropy()
    else:
        criterion = torch.nn.CrossEntropyLoss()



    output_dir = Path(args.output_dir)
    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer'])
            lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
            args.start_epoch = checkpoint['epoch'] + 1
            if args.model_ema:
                utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema'])
            if 'scaler' in checkpoint:
                loss_scaler.load_state_dict(checkpoint['scaler'])

    if args.eval:
        test_stats = evaluate_SSL(data_loader_val, model, device)
        print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
        return

    print(f"Start training for {args.epochs} epochs")
    start_time = time.time()
    max_accuracy = 0.0
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            data_loader_train.sampler.set_epoch(epoch)

        if args.training_mode == 'SSL':
            train_stats = train_SSL(
                model, criterion, data_loader_train, optimizer, device, epoch, loss_scaler,
                args.clip_grad, model_ema, mixup_fn)
        else:
            train_stats = train_finetune(
                model, criterion, data_loader_train, optimizer, device, epoch, loss_scaler,
                args.clip_grad, model_ema, mixup_fn)
            
        lr_scheduler.step(epoch)
            
        if epoch%args.validate_every == 0:
            if args.output_dir:
                checkpoint_paths = [output_dir / 'checkpoint.pth']
                for checkpoint_path in checkpoint_paths:
                    utils.save_on_master({
                        'model': model_without_ddp.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'lr_scheduler': lr_scheduler.state_dict(),
                        'epoch': epoch,
                        'model_ema': get_state_dict(model_ema),
                        'scaler': loss_scaler.state_dict(),
                        'args': args,
                    }, checkpoint_path)
    
            if args.training_mode == 'SSL':
                test_stats = evaluate_SSL(data_loader_val, model, device, epoch, args.output_dir)
            else:
                test_stats = evaluate_finetune(data_loader_val, model, device)

                print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
                max_accuracy = max(max_accuracy, test_stats["acc1"])
                print(f'Max accuracy: {max_accuracy:.2f}%')

        log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
                     **{f'test_{k}': v for k, v in test_stats.items()},
                     'epoch': epoch,
                     'n_parameters': n_parameters}

        if args.output_dir and utils.is_main_process():
            with (output_dir / "log.txt").open("a") as f:
                f.write(json.dumps(log_stats) + "\n")

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
コード例 #3
0
ファイル: supernet_train.py プロジェクト: ICCV2021/Autoformer
def main(args):

    utils.init_distributed_mode(args)
    update_config_from_file(args.cfg)

    print(args)
    args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    # random.seed(seed)
    cudnn.benchmark = True

    dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
    dataset_val, _ = build_dataset(is_train=False, args=args)

    if args.distributed:
        num_tasks = utils.get_world_size()
        global_rank = utils.get_rank()
        if args.repeated_aug:
            sampler_train = RASampler(dataset_train,
                                      num_replicas=num_tasks,
                                      rank=global_rank,
                                      shuffle=True)
        else:
            sampler_train = torch.utils.data.DistributedSampler(
                dataset_train,
                num_replicas=num_tasks,
                rank=global_rank,
                shuffle=True)
        if args.dist_eval:
            if len(dataset_val) % num_tasks != 0:
                print(
                    'Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
                    'This will slightly alter validation results as extra duplicate entries are added to achieve '
                    'equal num of samples per-process.')
            sampler_val = torch.utils.data.DistributedSampler(
                dataset_val,
                num_replicas=num_tasks,
                rank=global_rank,
                shuffle=False)
        else:
            sampler_val = torch.utils.data.SequentialSampler(dataset_val)
    else:
        sampler_val = torch.utils.data.SequentialSampler(dataset_val)
        sampler_train = torch.utils.data.RandomSampler(dataset_train)

    data_loader_train = torch.utils.data.DataLoader(
        dataset_train,
        sampler=sampler_train,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=True,
    )

    data_loader_val = torch.utils.data.DataLoader(dataset_val,
                                                  batch_size=int(
                                                      2 * args.batch_size),
                                                  sampler=sampler_val,
                                                  num_workers=args.num_workers,
                                                  pin_memory=args.pin_mem,
                                                  drop_last=False)

    mixup_fn = None
    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
    if mixup_active:
        mixup_fn = Mixup(mixup_alpha=args.mixup,
                         cutmix_alpha=args.cutmix,
                         cutmix_minmax=args.cutmix_minmax,
                         prob=args.mixup_prob,
                         switch_prob=args.mixup_switch_prob,
                         mode=args.mixup_mode,
                         label_smoothing=args.smoothing,
                         num_classes=args.nb_classes)

    print(f"Creating SuperVisionTransformer")
    print(cfg)
    model = Vision_TransformerSuper(
        img_size=args.input_size,
        patch_size=args.patch_size,
        embed_dim=cfg.SUPERNET.EMBED_DIM,
        depth=cfg.SUPERNET.DEPTH,
        num_heads=cfg.SUPERNET.NUM_HEADS,
        mlp_ratio=cfg.SUPERNET.MLP_RATIO,
        qkv_bias=True,
        drop_rate=args.drop,
        drop_path_rate=args.drop_path,
        gp=args.gp,
        num_classes=args.nb_classes,
        max_relative_position=args.max_relative_position,
        relative_position=args.relative_position,
        change_qkv=args.change_qkv,
        abs_pos=not args.no_abs_pos)

    choices = {
        'num_heads': cfg.SEARCH_SPACE.NUM_HEADS,
        'mlp_ratio': cfg.SEARCH_SPACE.MLP_RATIO,
        'embed_dim': cfg.SEARCH_SPACE.EMBED_DIM,
        'depth': cfg.SEARCH_SPACE.DEPTH
    }

    model.to(device)
    if args.teacher_model:
        teacher_model = create_model(
            args.teacher_model,
            pretrained=True,
            num_classes=args.nb_classes,
        )
        teacher_model.to(device)
        teacher_loss = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
    else:
        teacher_model = None
        teacher_loss = None

    model_ema = None

    model_without_ddp = model
    if args.distributed:

        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpu], find_unused_parameters=True)
        model_without_ddp = model.module

    n_parameters = sum(p.numel() for p in model.parameters()
                       if p.requires_grad)
    print('number of params:', n_parameters)

    linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size(
    ) / 512.0
    args.lr = linear_scaled_lr
    optimizer = create_optimizer(args, model_without_ddp)
    loss_scaler = NativeScaler()
    lr_scheduler, _ = create_scheduler(args, optimizer)

    # criterion = LabelSmoothingCrossEntropy()

    if args.mixup > 0.:
        # smoothing is handled with mixup label transform
        criterion = SoftTargetCrossEntropy()
    elif args.smoothing:
        criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
    else:
        criterion = torch.nn.CrossEntropyLoss()

    output_dir = Path(args.output_dir)

    if not output_dir.exists():
        output_dir.mkdir(parents=True)
    # save config for later experiments
    with open(output_dir / "config.yaml", 'w') as f:
        f.write(args_text)
    if args.resume:
        if args.resume.startswith('https'):
            checkpoint = torch.hub.load_state_dict_from_url(args.resume,
                                                            map_location='cpu',
                                                            check_hash=True)
        else:
            checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer'])
            lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
            args.start_epoch = checkpoint['epoch'] + 1
            if 'scaler' in checkpoint:
                loss_scaler.load_state_dict(checkpoint['scaler'])
            if args.model_ema:
                utils._load_checkpoint_for_ema(model_ema,
                                               checkpoint['model_ema'])

    retrain_config = None
    if args.mode == 'retrain' and "RETRAIN" in cfg:
        retrain_config = {
            'layer_num': cfg.RETRAIN.DEPTH,
            'embed_dim': [cfg.RETRAIN.EMBED_DIM] * cfg.RETRAIN.DEPTH,
            'num_heads': cfg.RETRAIN.NUM_HEADS,
            'mlp_ratio': cfg.RETRAIN.MLP_RATIO
        }
    if args.eval:
        print(retrain_config)
        test_stats = evaluate(data_loader_val,
                              model,
                              device,
                              mode=args.mode,
                              retrain_config=retrain_config)
        print(
            f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%"
        )
        return

    print("Start training")
    start_time = time.time()
    max_accuracy = 0.0

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            data_loader_train.sampler.set_epoch(epoch)

        train_stats = train_one_epoch(
            model,
            criterion,
            data_loader_train,
            optimizer,
            device,
            epoch,
            loss_scaler,
            args.clip_grad,
            model_ema,
            mixup_fn,
            amp=args.amp,
            teacher_model=teacher_model,
            teach_loss=teacher_loss,
            choices=choices,
            mode=args.mode,
            retrain_config=retrain_config,
        )

        lr_scheduler.step(epoch)
        if args.output_dir:
            checkpoint_paths = [output_dir / 'checkpoint.pth']
            for checkpoint_path in checkpoint_paths:
                utils.save_on_master(
                    {
                        'model': model_without_ddp.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'lr_scheduler': lr_scheduler.state_dict(),
                        'epoch': epoch,
                        # 'model_ema': get_state_dict(model_ema),
                        'scaler': loss_scaler.state_dict(),
                        'args': args,
                    },
                    checkpoint_path)

        test_stats = evaluate(data_loader_val,
                              model,
                              device,
                              amp=args.amp,
                              choices=choices,
                              mode=args.mode,
                              retrain_config=retrain_config)
        print(
            f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%"
        )
        max_accuracy = max(max_accuracy, test_stats["acc1"])
        print(f'Max accuracy: {max_accuracy:.2f}%')

        log_stats = {
            **{f'train_{k}': v
               for k, v in train_stats.items()},
            **{f'test_{k}': v
               for k, v in test_stats.items()}, 'epoch': epoch,
            'n_parameters': n_parameters
        }

        if args.output_dir and utils.is_main_process():
            with (output_dir / "log.txt").open("a") as f:
                f.write(json.dumps(log_stats) + "\n")

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
コード例 #4
0
ファイル: train.py プロジェクト: rstrudel/segmenter
def main(
    log_dir,
    dataset,
    im_size,
    crop_size,
    window_size,
    window_stride,
    backbone,
    decoder,
    optimizer,
    scheduler,
    weight_decay,
    dropout,
    drop_path,
    batch_size,
    epochs,
    learning_rate,
    normalization,
    eval_freq,
    amp,
    resume,
):
    # start distributed mode
    ptu.set_gpu_mode(True)
    distributed.init_process()

    # set up configuration
    cfg = config.load_config()
    model_cfg = cfg["model"][backbone]
    dataset_cfg = cfg["dataset"][dataset]
    if "mask_transformer" in decoder:
        decoder_cfg = cfg["decoder"]["mask_transformer"]
    else:
        decoder_cfg = cfg["decoder"][decoder]

    # model config
    if not im_size:
        im_size = dataset_cfg["im_size"]
    if not crop_size:
        crop_size = dataset_cfg.get("crop_size", im_size)
    if not window_size:
        window_size = dataset_cfg.get("window_size", im_size)
    if not window_stride:
        window_stride = dataset_cfg.get("window_stride", im_size)

    model_cfg["image_size"] = (crop_size, crop_size)
    model_cfg["backbone"] = backbone
    model_cfg["dropout"] = dropout
    model_cfg["drop_path_rate"] = drop_path
    decoder_cfg["name"] = decoder
    model_cfg["decoder"] = decoder_cfg

    # dataset config
    world_batch_size = dataset_cfg["batch_size"]
    num_epochs = dataset_cfg["epochs"]
    lr = dataset_cfg["learning_rate"]
    if batch_size:
        world_batch_size = batch_size
    if epochs:
        num_epochs = epochs
    if learning_rate:
        lr = learning_rate
    if eval_freq is None:
        eval_freq = dataset_cfg.get("eval_freq", 1)

    if normalization:
        model_cfg["normalization"] = normalization

    # experiment config
    batch_size = world_batch_size // ptu.world_size
    variant = dict(
        world_batch_size=world_batch_size,
        version="normal",
        resume=resume,
        dataset_kwargs=dict(
            dataset=dataset,
            image_size=im_size,
            crop_size=crop_size,
            batch_size=batch_size,
            normalization=model_cfg["normalization"],
            split="train",
            num_workers=10,
        ),
        algorithm_kwargs=dict(
            batch_size=batch_size,
            start_epoch=0,
            num_epochs=num_epochs,
            eval_freq=eval_freq,
        ),
        optimizer_kwargs=dict(
            opt=optimizer,
            lr=lr,
            weight_decay=weight_decay,
            momentum=0.9,
            clip_grad=None,
            sched=scheduler,
            epochs=num_epochs,
            min_lr=1e-5,
            poly_power=0.9,
            poly_step_size=1,
        ),
        net_kwargs=model_cfg,
        amp=amp,
        log_dir=log_dir,
        inference_kwargs=dict(
            im_size=im_size,
            window_size=window_size,
            window_stride=window_stride,
        ),
    )

    log_dir = Path(log_dir)
    log_dir.mkdir(parents=True, exist_ok=True)
    checkpoint_path = log_dir / "checkpoint.pth"

    # dataset
    dataset_kwargs = variant["dataset_kwargs"]

    train_loader = create_dataset(dataset_kwargs)
    val_kwargs = dataset_kwargs.copy()
    val_kwargs["split"] = "val"
    val_kwargs["batch_size"] = 1
    val_kwargs["crop"] = False
    val_loader = create_dataset(val_kwargs)
    n_cls = train_loader.unwrapped.n_cls

    # model
    net_kwargs = variant["net_kwargs"]
    net_kwargs["n_cls"] = n_cls
    model = create_segmenter(net_kwargs)
    model.to(ptu.device)

    # optimizer
    optimizer_kwargs = variant["optimizer_kwargs"]
    optimizer_kwargs["iter_max"] = len(train_loader) * optimizer_kwargs["epochs"]
    optimizer_kwargs["iter_warmup"] = 0.0
    opt_args = argparse.Namespace()
    opt_vars = vars(opt_args)
    for k, v in optimizer_kwargs.items():
        opt_vars[k] = v
    optimizer = create_optimizer(opt_args, model)
    lr_scheduler = create_scheduler(opt_args, optimizer)
    num_iterations = 0
    amp_autocast = suppress
    loss_scaler = None
    if amp:
        amp_autocast = torch.cuda.amp.autocast
        loss_scaler = NativeScaler()

    # resume
    if resume and checkpoint_path.exists():
        print(f"Resuming training from checkpoint: {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path, map_location="cpu")
        model.load_state_dict(checkpoint["model"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        if loss_scaler and "loss_scaler" in checkpoint:
            loss_scaler.load_state_dict(checkpoint["loss_scaler"])
        lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
        variant["algorithm_kwargs"]["start_epoch"] = checkpoint["epoch"] + 1
    else:
        sync_model(log_dir, model)

    if ptu.distributed:
        model = DDP(model, device_ids=[ptu.device], find_unused_parameters=True)

    # save config
    variant_str = yaml.dump(variant)
    print(f"Configuration:\n{variant_str}")
    variant["net_kwargs"] = net_kwargs
    variant["dataset_kwargs"] = dataset_kwargs
    log_dir.mkdir(parents=True, exist_ok=True)
    with open(log_dir / "variant.yml", "w") as f:
        f.write(variant_str)

    # train
    start_epoch = variant["algorithm_kwargs"]["start_epoch"]
    num_epochs = variant["algorithm_kwargs"]["num_epochs"]
    eval_freq = variant["algorithm_kwargs"]["eval_freq"]

    model_without_ddp = model
    if hasattr(model, "module"):
        model_without_ddp = model.module

    val_seg_gt = val_loader.dataset.get_gt_seg_maps()

    print(f"Train dataset length: {len(train_loader.dataset)}")
    print(f"Val dataset length: {len(val_loader.dataset)}")
    print(f"Encoder parameters: {num_params(model_without_ddp.encoder)}")
    print(f"Decoder parameters: {num_params(model_without_ddp.decoder)}")

    for epoch in range(start_epoch, num_epochs):
        # train for one epoch
        train_logger = train_one_epoch(
            model,
            train_loader,
            optimizer,
            lr_scheduler,
            epoch,
            amp_autocast,
            loss_scaler,
        )

        # save checkpoint
        if ptu.dist_rank == 0:
            snapshot = dict(
                model=model_without_ddp.state_dict(),
                optimizer=optimizer.state_dict(),
                n_cls=model_without_ddp.n_cls,
                lr_scheduler=lr_scheduler.state_dict(),
            )
            if loss_scaler is not None:
                snapshot["loss_scaler"] = loss_scaler.state_dict()
            snapshot["epoch"] = epoch
            torch.save(snapshot, checkpoint_path)

        # evaluate
        eval_epoch = epoch % eval_freq == 0 or epoch == num_epochs - 1
        if eval_epoch:
            eval_logger = evaluate(
                model,
                val_loader,
                val_seg_gt,
                window_size,
                window_stride,
                amp_autocast,
            )
            print(f"Stats [{epoch}]:", eval_logger, flush=True)
            print("")

        # log stats
        if ptu.dist_rank == 0:
            train_stats = {
                k: meter.global_avg for k, meter in train_logger.meters.items()
            }
            val_stats = {}
            if eval_epoch:
                val_stats = {
                    k: meter.global_avg for k, meter in eval_logger.meters.items()
                }

            log_stats = {
                **{f"train_{k}": v for k, v in train_stats.items()},
                **{f"val_{k}": v for k, v in val_stats.items()},
                "epoch": epoch,
                "num_updates": (epoch + 1) * len(train_loader),
            }

            with open(log_dir / "log.txt", "a") as f:
                f.write(json.dumps(log_stats) + "\n")

    distributed.barrier()
    distributed.destroy_process()
    sys.exit(1)
コード例 #5
0
class TorchImageClassificationEstimator(BaseEstimator):
    """Torch Estimator implementation for Image Classification.

    Parameters
    ----------
    config : dict
        Config in nested dict.
    logger : logging.Logger
        Optional logger for this estimator, can be `None` when default setting is used.
    reporter : callable
        The reporter for metric checkpointing.
    net : torch.nn.Module
        The custom network. If defined, the model name in config will be ignored so your
        custom network will be used for training rather than pulling it from model zoo.
    """
    def __init__(self, config, logger=None, reporter=None, net=None, optimizer=None, problem_type=None):
        super().__init__(config, logger=logger, reporter=reporter, name=None)
        if problem_type is None:
            problem_type = MULTICLASS
        self._problem_type = problem_type
        self._feature_net = None
        self._custom_net = False

        self._img_cls_cfg = self._cfg.img_cls
        self._data_cfg = self._cfg.data
        self._optimizer_cfg = self._cfg.optimizer
        self._train_cfg = self._cfg.train
        self._augmentation_cfg = self._cfg.augmentation
        self._model_ema_cfg = self._cfg.model_ema
        self._misc_cfg = self._cfg.misc

        # resolve AMP arguments based on PyTorch / Apex availability
        self.use_amp = None
        if self._misc_cfg.amp:
            # `amp` chooses native amp before apex (APEX ver not actively maintained)
            if self._misc_cfg.native_amp and has_native_amp:
                self.use_amp = 'native'
            elif self._misc_cfg.apex_amp and has_apex:
                self.use_amp = 'apex'
            elif self._misc_cfg.apex_amp or self._misc_cfg.native_amp:
                self._logger.warning(f'Neither APEX or native Torch AMP is available, using float32. \
                                       Install NVIDA apex or upgrade to PyTorch 1.6')
        # FIXME: will provided model conflict with config provided?
        if net is not None:
            assert isinstance(net, nn.Module), f"given custom network {type(net)}, `torch.nn` expected"
            try:
                net.to('cpu')
                self._custom_net = True
            except ValueError:
                pass
        self.net = net
        if optimizer is not None:
            self._logger.warning('Custom optimizer object not supported. Will follow the config instead.')
        self._optimizer = None

    def _fit(self, train_data, val_data, time_limit=math.inf):
        tic = time.time()
        self._cp_name = ''
        self._best_acc = -float('inf')
        self.epochs = self._train_cfg.epochs
        self.epoch = 0
        self.start_epoch = self._train_cfg.start_epoch
        self._time_elapsed = 0
        if max(self.start_epoch, self.epoch) >= self.epochs:
            return {'time', self._time_elapsed}
        self._init_trainer()
        self._init_model_ema()
        self._time_elapsed += time.time() - tic
        return self._resume_fit(train_data, val_data, time_limit=time_limit)

    def _resume_fit(self, train_data, val_data, time_limit=math.inf):
        tic = time.time()
        # TODO: regression not implemented
        if self._problem_type != REGRESSION and (not self.classes or not self.num_class):
            raise ValueError('This is a classification problem and we are not able to determine classes of dataset')

        if max(self.start_epoch, self.epoch) >= self.epochs:
            return {'time': self._time_elapsed}

        # wrap DP if possible
        if self.found_gpu:
            self.net = torch.nn.DataParallel(self.net, device_ids=[int(i) for i in self.valid_gpus])
        self.net = self.net.to(self.ctx[0])

        # prepare dataset
        train_dataset = train_data.to_torch()
        val_dataset = val_data.to_torch()

        # setup mixup / cutmix
        self._collate_fn = None
        self._mixup_fn = None
        self.mixup_active = self._augmentation_cfg.mixup > 0 or self._augmentation_cfg.cutmix > 0. or self._augmentation_cfg.cutmix_minmax is not None
        if self.mixup_active:
            mixup_args = dict(
                mixup_alpha=self._augmentation_cfg.mixup, cutmix_alpha=self._augmentation_cfg.cutmix,
                cutmix_minmax=self._augmentation_cfg.cutmix_minmax, prob=self._augmentation_cfg.mixup_prob,
                switch_prob=self._augmentation_cfg.mixup_switch_prob, mode=self._augmentation_cfg.mixup_mode,
                label_smoothing=self._augmentation_cfg.smoothing, num_classes=self.num_class)
            if self._misc_cfg.prefetcher:
                self._collate_fn = FastCollateMixup(**mixup_args)
            else:
                self._mixup_fn = Mixup(**mixup_args)

        # create data loaders w/ augmentation pipeiine
        train_interpolation = self._augmentation_cfg.train_interpolation
        if self._augmentation_cfg.no_aug or not train_interpolation:
            train_interpolation = self._data_cfg.interpolation
        train_loader = create_loader(
            train_dataset,
            input_size=self._data_cfg.input_size,
            batch_size=self._train_cfg.batch_size,
            is_training=True,
            use_prefetcher=self._misc_cfg.prefetcher,
            no_aug=self._augmentation_cfg.no_aug,
            scale=self._augmentation_cfg.scale,
            ratio=self._augmentation_cfg.ratio,
            hflip=self._augmentation_cfg.hflip,
            vflip=self._augmentation_cfg.vflip,
            color_jitter=self._augmentation_cfg.color_jitter,
            auto_augment=self._augmentation_cfg.auto_augment,
            interpolation=train_interpolation,
            mean=self._data_cfg.mean,
            std=self._data_cfg.std,
            num_workers=self._misc_cfg.num_workers,
            distributed=False,
            collate_fn=self._collate_fn,
            pin_memory=self._misc_cfg.pin_mem,
            use_multi_epochs_loader=self._misc_cfg.use_multi_epochs_loader
        )

        val_loader = create_loader(
            val_dataset,
            input_size=self._data_cfg.input_size,
            batch_size=self._data_cfg.validation_batch_size_multiplier * self._train_cfg.batch_size,
            is_training=False,
            use_prefetcher=self._misc_cfg.prefetcher,
            interpolation=self._data_cfg.interpolation,
            mean=self._data_cfg.mean,
            std=self._data_cfg.std,
            num_workers=self._misc_cfg.num_workers,
            distributed=False,
            crop_pct=self._data_cfg.crop_pct,
            pin_memory=self._misc_cfg.pin_mem,
        )

        self._time_elapsed += time.time() - tic
        return self._train_loop(train_loader, val_loader, time_limit=time_limit)

    def _train_loop(self, train_loader, val_loader, time_limit=math.inf):
        start_tic = time.time()
        # setup loss function
        if self.mixup_active:
            # smoothing is handled with mixup target transform
            train_loss_fn = SoftTargetCrossEntropy()
        elif self._augmentation_cfg.smoothing:
            train_loss_fn = LabelSmoothingCrossEntropy(smoothing=self._augmentation_cfg.smoothing)
        else:
            train_loss_fn = nn.CrossEntropyLoss()
        validate_loss_fn = nn.CrossEntropyLoss()
        train_loss_fn = train_loss_fn.to(self.ctx[0])
        validate_loss_fn = validate_loss_fn.to(self.ctx[0])
        eval_metric = self._misc_cfg.eval_metric
        if self._problem_type == REGRESSION:
            train_loss_fn = nn.MSELoss()
            validate_loss_fn = nn.MSELoss()
            eval_metric = 'rmse'
        early_stopper = EarlyStopperOnPlateau(
            patience=self._train_cfg.early_stop_patience,
            min_delta=self._train_cfg.early_stop_min_delta,
            baseline_value=self._train_cfg.early_stop_baseline,
            max_value=self._train_cfg.early_stop_max_value)

        self._logger.info('Start training from [Epoch %d]', max(self._train_cfg.start_epoch, self.epoch))

        self._time_elapsed += time.time() - start_tic
        for self.epoch in range(max(self.start_epoch, self.epoch), self.epochs):
            epoch = self.epoch
            if self._best_acc >= 1.0:
                self._logger.info('[Epoch {}] Early stopping as acc is reaching 1.0'.format(epoch))
                break
            should_stop, stop_message = early_stopper.get_early_stop_advice()
            if should_stop:
                self._logger.info('[Epoch {}] '.format(epoch) + stop_message)
                break
            train_metrics = self.train_one_epoch(
                epoch, self.net, train_loader, self._optimizer, train_loss_fn,
                lr_scheduler=self._lr_scheduler, output_dir=self._logdir,
                amp_autocast=self._amp_autocast, loss_scaler=self._loss_scaler, model_ema=self._model_ema, mixup_fn=self._mixup_fn, time_limit=time_limit)
            # reaching time limit, exit early
            if train_metrics['time_limit']:
                self._logger.warning(f'`time_limit={time_limit}` reached, exit early...')
                return {'train_acc': train_metrics['train_acc'], 'valid_acc': self._best_acc,
                        'time': self._time_elapsed, 'checkpoint': self._cp_name}
            post_tic = time.time()

            eval_metrics = self.validate(self.net, val_loader, validate_loss_fn, amp_autocast=self._amp_autocast)

            if self._model_ema is not None and not self._model_ema_cfg.model_ema_force_cpu:
                ema_eval_metrics = self.validate(
                    self._model_ema.module, val_loader, validate_loss_fn, amp_autocast=self._amp_autocast)
                eval_metrics = ema_eval_metrics

            if self._problem_type == REGRESSION:
                val_acc = eval_metrics['rmse']
                if self._reporter:
                    self._reporter(epoch=epoch, acc_reward=-val_acc)
                early_stopper.update(-val_acc)

                if -val_acc > self._best_acc:
                    self._cp_name = os.path.join(self._logdir, _BEST_CHECKPOINT_FILE)
                    self._logger.info('[Epoch %d] Current best rmse: %f vs previous %f, saved to %s',
                                      self.epoch, val_acc, -self._best_acc, self._cp_name)
                    self.save(self._cp_name)
                    self._best_acc = -val_acc
            else:
                val_acc = eval_metrics['top1']
                if self._reporter:
                    self._reporter(epoch=epoch, acc_reward=val_acc)
                early_stopper.update(val_acc)

                if val_acc > self._best_acc:
                    self._cp_name = os.path.join(self._logdir, _BEST_CHECKPOINT_FILE)
                    self._logger.info('[Epoch %d] Current best top-1: %f vs previous %f, saved to %s',
                                      self.epoch, val_acc, self._best_acc, self._cp_name)
                    self.save(self._cp_name)
                    self._best_acc = val_acc

            if self._lr_scheduler is not None:
                # step LR for next epoch
                self._lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])

            self._time_elapsed += time.time() - post_tic

        if 'accuracy' in train_metrics:
            return {'train_acc': train_metrics['accuracy'], 'valid_acc': self._best_acc,
                    'time': self._time_elapsed, 'checkpoint': self._cp_name}
        # rmse
        else:
            if self._problem_type == REGRESSION:
                return {'train_score': train_metrics['rmse'], 'valid_score': -self._best_acc,
                        'time': self._time_elapsed, 'checkpoint': self._cp_name}
            # mixup
            else:
                return {'train_score': train_metrics['rmse'], 'valid_acc': self._best_acc,
                        'time': self._time_elapsed, 'checkpoint': self._cp_name}

    def train_one_epoch(
            self, epoch, net, loader, optimizer, loss_fn,
            lr_scheduler=None, output_dir=None, amp_autocast=suppress,
            loss_scaler=None, model_ema=None, mixup_fn=None, time_limit=math.inf):
        start_tic = time.time()
        if self._augmentation_cfg.mixup_off_epoch and epoch >= self._augmentation_cfg.mixup_off_epoch:
            if self._misc_cfg.prefetcher and loader.mixup_enabled:
                loader.mixup_enabled = False
            elif mixup_fn is not None:
                mixup_fn.mixup_enabled = False

        second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
        losses_m = AverageMeter()
        train_metric_score_m = AverageMeter()

        net.train()

        num_updates = epoch * len(loader)
        self._time_elapsed += time.time() - start_tic
        tic = time.time()
        last_tic = time.time()
        train_metric_name = 'accuracy'
        batch_idx = 0
        for batch_idx, (input, target) in enumerate(loader):
            b_tic = time.time()
            if self._time_elapsed > time_limit:
                return {'train_acc': train_metric_score_m.avg, 'train_loss': losses_m.avg, 'time_limit': True}
            if self._problem_type == REGRESSION:
                target = target.to(torch.float32)
            if not self._misc_cfg.prefetcher:
                # prefetcher would move data to cuda by default
                input, target = input.to(self.ctx[0]), target.to(self.ctx[0])
                if mixup_fn is not None:
                    input, target = mixup_fn(input, target)

            with amp_autocast():
                output = net(input)
                if self._problem_type == REGRESSION:
                    output = output.flatten()
                loss = loss_fn(output, target)
            if self._problem_type == REGRESSION:
                train_metric_name = 'rmse'
                train_metric_score = rmse(output, target)
            else:
                if output.shape == target.shape:
                    train_metric_name = 'rmse'
                    train_metric_score = rmse(output, target)
                else:
                    train_metric_score = accuracy(output, target)[0] / 100

            losses_m.update(loss.item(), input.size(0))
            train_metric_score_m.update(train_metric_score.item(), output.size(0))

            optimizer.zero_grad()
            if loss_scaler is not None:
                loss_scaler(
                    loss, optimizer,
                    clip_grad=self._optimizer_cfg.clip_grad, clip_mode=self._optimizer_cfg.clip_mode,
                    parameters=model_parameters(net, exclude_head='agc' in self._optimizer_cfg.clip_mode),
                    create_graph=second_order)
            else:
                loss.backward(create_graph=second_order)
                if self._optimizer_cfg.clip_grad is not None:
                    dispatch_clip_grad(
                        model_parameters(net, exclude_head='agc' in self._optimizer_cfg.clip_mode),
                        value=self._optimizer_cfg.clip_grad, mode=self._optimizer_cfg.clip_mode)
                optimizer.step()

            if model_ema is not None:
                model_ema.update(net)

            if self.found_gpu:
                torch.cuda.synchronize()

            num_updates += 1
            if (batch_idx+1) % self._misc_cfg.log_interval == 0:
                lrl = [param_group['lr'] for param_group in optimizer.param_groups]
                lr = sum(lrl) / len(lrl)
                self._logger.info('Epoch[%d] Batch [%d]\tSpeed: %f samples/sec\t%s=%f\tlr=%f',
                                  epoch, batch_idx,
                                  self._train_cfg.batch_size*self._misc_cfg.log_interval/(time.time()-last_tic),
                                  train_metric_name, train_metric_score_m.avg, lr)
                last_tic = time.time()

                if self._misc_cfg.save_images and output_dir:
                    torchvision.utils.save_image(
                        input,
                        os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
                        padding=0,
                        normalize=True)

            if lr_scheduler is not None:
                lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)

            self._time_elapsed += time.time() - b_tic

        throughput = int(self._train_cfg.batch_size * batch_idx / (time.time() - tic))
        self._logger.info('[Epoch %d] training: %s=%f', epoch, train_metric_name, train_metric_score_m.avg)
        self._logger.info('[Epoch %d] speed: %d samples/sec\ttime cost: %f', epoch, throughput, time.time()-tic)

        end_time = time.time()
        if hasattr(optimizer, 'sync_lookahead'):
            optimizer.sync_lookahead()

        self._time_elapsed += time.time() - end_time

        return {train_metric_name: train_metric_score_m.avg, 'train_loss': losses_m.avg, 'time_limit': False}

    def validate(self, net, loader, loss_fn, amp_autocast=suppress, metric_name=None):
        losses_m = AverageMeter()
        top1_m = AverageMeter()
        top5_m = AverageMeter()
        rmse_m = AverageMeter()

        net.eval()

        with torch.no_grad():
            for batch_idx, (input, target) in enumerate(loader):
                if not self._misc_cfg.prefetcher:
                    input = input.to(self.ctx[0])
                    target = target.to(self.ctx[0])

                with amp_autocast():
                    output = net(input)
                    if self._problem_type == REGRESSION:
                        output = output.flatten()
                if isinstance(output, (tuple, list)):
                    output = output[0]

                if self._problem_type == REGRESSION:
                    if metric_name:
                        assert metric_name == 'rmse', f'{metric_name} metric not supported for regression.'
                    val_metric_score = rmse(output, target)
                else:
                    val_metric_score = accuracy(output, target, topk=(1, min(5, self.num_class)))

                # augmentation reduction
                reduce_factor = self._misc_cfg.tta
                if self._problem_type != REGRESSION and reduce_factor > 1:
                    output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)
                    target = target[0:target.size(0):reduce_factor]

                loss = loss_fn(output, target)
                reduced_loss = loss.data

                if self.found_gpu:
                    torch.cuda.synchronize()

                losses_m.update(reduced_loss.item(), input.size(0))
                if self._problem_type == REGRESSION:
                    rmse_score = val_metric_score
                    rmse_m.update(rmse_score.item(), output.size(0))
                else:
                    acc1, acc5 = val_metric_score
                    acc1 /= 100
                    acc5 /= 100
                    top1_m.update(acc1.item(), output.size(0))
                    top5_m.update(acc5.item(), output.size(0))

        if self._problem_type == REGRESSION:
            self._logger.info('[Epoch %d] validation: rmse=%f', self.epoch, rmse_m.avg)
            return {'loss': losses_m.avg, 'rmse': rmse_m.avg}
        else:
            self._logger.info('[Epoch %d] validation: top1=%f top5=%f', self.epoch, top1_m.avg, top5_m.avg)
            return {'loss': losses_m.avg, 'top1': top1_m.avg, 'top5': top5_m.avg}

    def _init_network(self, **kwargs):
        load_only = kwargs.get('load_only', False)
        if not self.num_class and self._problem_type != REGRESSION:
            raise ValueError('This is a classification problem and we are not able to create network when `num_class` is unknown. \
                It should be inferred from dataset or resumed from saved states.')
        assert len(self.classes) == self.num_class

        # Disable syncBatchNorm as it's only supported on DDP
        if self._train_cfg.sync_bn:
            self._logger.info(
                'Disable Sync batch norm as it is not supported for now.')
            update_cfg(self._cfg, {'train': {'sync_bn': False}})

        # ctx
        self.found_gpu = False
        valid_gpus = []
        if self._cfg.gpus:
            valid_gpus = self._torch_validate_gpus(self._cfg.gpus)
            self.found_gpu = True
            if not valid_gpus:
                self.found_gpu = False
                self._logger.warning(
                    'No gpu detected, fallback to cpu. You can ignore this warning if this is intended.')
            elif len(valid_gpus) != len(self._cfg.gpus):
                self._logger.warning(
                    f'Loaded on gpu({valid_gpus}), different from gpu({self._cfg.gpus}).')
        self.ctx = [torch.device(f'cuda:{gid}') for gid in valid_gpus] if self.found_gpu else [torch.device('cpu')]
        self.valid_gpus = valid_gpus

        if not self.found_gpu and self.use_amp:
            self.use_amp = None
            self._logger.warning('Training on cpu. AMP disabled.')
            update_cfg(self._cfg, {'misc': {'amp': False, 'apex_amp': False, 'native_amp': False}})

        if not self.found_gpu and self._misc_cfg.prefetcher:
            self._logger.warning(
                'Training on cpu. Prefetcher disabled.')
            update_cfg(self._cfg, {'misc': {'prefetcher': False}})
            self._logger.warning(
                'Training on cpu. SyncBatchNorm disabled.')
            update_cfg(self._cfg, {'train': {'sync_bn': False}})

        random_seed(self._misc_cfg.seed)

        if not self.net:
            self.net = create_model(
                self._img_cls_cfg.model,
                pretrained=self._img_cls_cfg.pretrained and not load_only,
                num_classes=max(self.num_class, 1),
                global_pool=self._img_cls_cfg.global_pool_type,
                drop_rate=self._augmentation_cfg.drop,
                drop_path_rate=self._augmentation_cfg.drop_path,
                drop_block_rate=self._augmentation_cfg.drop_block,
                bn_momentum=self._train_cfg.bn_momentum,
                bn_eps=self._train_cfg.bn_eps,
                scriptable=self._misc_cfg.torchscript
            )

            self._logger.info(f'Model {safe_model_name(self._img_cls_cfg.model)} created, param count: \
                                        {sum([m.numel() for m in self.net.parameters()])}')
        else:
            self._logger.info(f'Use user provided model. Neglect model in config.')
            out_features = list(self.net.children())[-1].out_features
            if self._problem_type != REGRESSION:
                assert out_features == self.num_class, f'Custom model out_feature {out_features} != num_class {self.num_class}.'
            else:
                assert out_features == 1, f'Regression problem expects num_out_feature == 1, got {out_features} instead.'

        resolve_data_config(self._cfg, model=self.net)

        self.net = self.net.to(self.ctx[0])

        # setup synchronized BatchNorm
        if self._train_cfg.sync_bn:
            if has_apex and self.use_amp != 'native':
                # Apex SyncBN preferred unless native amp is activated
                self.net = convert_syncbn_model(self.net)
            else:
                self.net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.net)
            self._logger.info(
                'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
                'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')

        if self._misc_cfg.torchscript:
            assert not self.use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
            assert not self._train_cfg.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
            self.net = torch.jit.script(self.net)

    def _init_trainer(self):
        if self._optimizer is None:
            if self._img_cls_cfg.pretrained and not self._custom_net \
                and (self._train_cfg.transfer_lr_mult != 1 or self._train_cfg.output_lr_mult != 1):
                # adjust feature/last_fc learning rate multiplier in optimizer
                self._logger.debug(f'Reduce network lr multiplier to {self._train_cfg.transfer_lr_mult}, while keep ' +
                                   f'last FC layer lr_mult to {self._train_cfg.output_lr_mult}')
                optim_kwargs = optimizer_kwargs(cfg=self._cfg)
                optim_kwargs['feature_lr_mult'] = self._cfg.train.transfer_lr_mult
                optim_kwargs['fc_lr_mult'] = self._cfg.train.output_lr_mult
                self._optimizer = create_optimizer_v2(self.net, **optimizer_kwargs(cfg=self._cfg))
            else:
                self._optimizer = create_optimizer_v2(self.net, **optimizer_kwargs(cfg=self._cfg))
        self._init_loss_scaler()
        self._lr_scheduler, self.epochs = create_scheduler(self._cfg, self._optimizer)
        self._lr_scheduler.step(self.start_epoch, self.epoch)

    def _init_loss_scaler(self):
        # setup automatic mixed-precision (AMP) loss scaling and op casting
        self._amp_autocast = suppress  # do nothing
        self._loss_scaler = None
        if self.use_amp == 'apex':
            self.net, self._optimizer = amp.initialize(self.net, self._optimizer, opt_level='O1')
            self._loss_scaler = ApexScaler()
            self._logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
        elif self.use_amp == 'native':
            self._amp_autocast = torch.cuda.amp.autocast
            self._loss_scaler = NativeScaler()
            self._logger.info('Using native Torch AMP. Training in mixed precision.')
        else:
            self._logger.info('AMP not enabled. Training in float32.')

    def _init_model_ema(self):
        # Disable for now
        if self._model_ema_cfg.model_ema:
            self._logger.info('Disable EMA as it is not supported for now.')
            update_cfg(self._cfg, {'model_ema': {'model_ema': False}})
        # setup exponential moving average of model weights, SWA could be used here too
        self._model_ema = None
        if self._model_ema_cfg.model_ema:
            # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
            self._model_ema = ModelEmaV2(
                self.net, decay=self._model_ema_cfg.model_ema_decay, device='cpu' if self._model_ema_cfg.model_ema_force_cpu else None)


    def evaluate(self, val_data, metric_name=None):
        return self._evaluate(val_data, metric_name)

    def _evaluate(self, val_data, metric_name=None):
        if self._problem_type == REGRESSION:
            validate_loss_fn = nn.MSELoss()
        else:
            validate_loss_fn = nn.CrossEntropyLoss()
        validate_loss_fn = validate_loss_fn.to(self.ctx[0])
        val_data = val_data.to_torch()
        val_loader = create_loader(
            val_data,
            input_size=self._data_cfg.input_size,
            batch_size=self._data_cfg.validation_batch_size_multiplier * self._train_cfg.batch_size,
            is_training=False,
            use_prefetcher=self._misc_cfg.prefetcher,
            interpolation=self._data_cfg.interpolation,
            mean=self._data_cfg.mean,
            std=self._data_cfg.std,
            num_workers=self._misc_cfg.num_workers,
            distributed=False,
            crop_pct=self._data_cfg.crop_pct,
            pin_memory=self._misc_cfg.pin_mem,
        )
        return self.validate(self.net, val_loader, validate_loss_fn, amp_autocast=self._amp_autocast, metric_name=metric_name)

    def _predict(self, x, **kwargs):
        with_proba = kwargs.get('with_proba', False)
        if with_proba and self._problem_type not in [MULTICLASS, BINARY]:
            raise AssertionError('with_proba is only supported for classification problems. Please use predict instead.')
        if isinstance(x, str):
            return self._predict((x,), **kwargs).drop(columns=['image'], errors='ignore')
        elif isinstance(x, pd.DataFrame):
            assert 'image' in x.columns, "Expect column `image` for input images"
            df = self._predict(tuple(x['image']), **kwargs)
            return df.reset_index(drop=True)
        elif isinstance(x, (list, tuple)):
            loader = create_loader(
                ImageListDataset(x),
                input_size=self._data_cfg.input_size,
                batch_size=self._train_cfg.batch_size,
                use_prefetcher=self._misc_cfg.prefetcher,
                interpolation=self._data_cfg.interpolation,
                mean=self._data_cfg.mean,
                std=self._data_cfg.std,
                num_workers=self._misc_cfg.num_workers,
                crop_pct=self._data_cfg.crop_pct
            )

            self.net.eval()

            topk = min(5, self.num_class)
            results = []
            idx = 0
            with torch.no_grad():
                for input, _ in loader:
                    input = input.to(self.ctx[0])
                    labels = self.net(input)
                    for l in labels:
                        if self._problem_type in [MULTICLASS, BINARY]:
                            probs = nn.functional.softmax(l, dim=0).cpu().numpy().flatten()
                            if with_proba:
                                results.append({'image_proba': probs.tolist(), 'image': x[idx]})
                            else:
                                topk_inds = l.topk(topk)[1].cpu().numpy().flatten()
                                results.extend([{'class': self.classes[topk_inds[k]],
                                                 'score': probs[topk_inds[k]],
                                                 'id': topk_inds[k],
                                                 'image': x[idx]}
                                                for k in range(topk)])
                        else:
                            results.append({'prediction': l.cpu().numpy().flatten(), 'image': x[idx]})
                        idx += 1
            return pd.DataFrame(results)
        elif not isinstance(x, torch.Tensor):
            raise ValueError('Input is not supported: {}'.format(type(x)))
        assert len(x.shape) == 4 and x.shape[1] == 3, f"Expect input to be (n, 3, h, w), given {x.shape}"
        with torch.no_grad():
            input = x.to(self.ctx[0])
            label = self.net(input)
            if self._problem_type in [MULTICLASS, BINARY]:
                topk = min(5, self.num_class)
                probs = nn.functional.softmax(label, dim=1).cpu().numpy().flatten()
                topk_inds = label.topk(topk)[1].cpu().numpy().flatten()
                if with_proba:
                    df = pd.DataFrame([{'image_proba': probs.tolist()}])
                else:
                    df = pd.DataFrame([{'class': self.classes[topk_inds[k]],
                                        'score': probs[topk_inds[k]],
                                        'id': topk_inds[k]}
                                       for k in range(topk)])
            else:
                df = pd.DataFrame([{'prediction': label.cpu().numpy().flatten()}])
        return df


    def _predict_feature(self, x, **kwargs):
        if isinstance(x, str):
            return self._predict_feature((x,))
        elif isinstance(x, pd.DataFrame):
            assert 'image' in x.columns, "Expect column `image` for input images"
            df = self._predict_feature(tuple(x['image']))
            df = df.set_index(x.index)
            df['image'] = x['image']
            return df
        elif isinstance(x, (list, tuple)):
            assert isinstance(x[0], str), "expect image paths in list/tuple input"
            loader = create_loader(
                ImageListDataset(x),
                input_size=self._data_cfg.input_size,
                batch_size=self._train_cfg.batch_size,
                use_prefetcher=self._misc_cfg.prefetcher,
                interpolation=self._data_cfg.interpolation,
                mean=self._data_cfg.mean,
                std=self._data_cfg.std,
                num_workers=self._misc_cfg.num_workers,
                crop_pct=self._data_cfg.crop_pct
            )

            self.net.eval()

            results = []
            with torch.no_grad():
                for input, _ in loader:
                    input = input.to(self.ctx[0])
                    try:
                        features = self.net.forward_features(input)
                    except AttributeError:
                        features = self.net.module.forward_features(input)
                    for f in features:
                        f = f.cpu().numpy().flatten()
                        results.append({'image_feature': f})
            df = pd.DataFrame(results)
            df['image'] = x
            return df
        elif not isinstance(x, torch.Tensor):
            raise ValueError('Input is not supported: {}'.format(type(x)))
        with torch.no_grad():
            input = x.to(self.ctx[0])
            feature = self.net.forward_features(input)
            result = [{'image_feature': feature}]
        df = pd.DataFrame(result)
        return df

    def _reconstruct_state_dict(self, state_dict):
        new_state_dict = {}
        for k, v in state_dict.items():
            name = k[7:] if k.startswith('module') else k
            new_state_dict[name] = v
        return new_state_dict

    def save(self, filename):
        d = dict()
        current_states = self.__dict__.copy()
        if self.net:
            if not self._custom_net:
                if isinstance(self.net, torch.nn.DataParallel):
                    d['model_state_dict'] = get_state_dict(self.net.module, unwrap_model)
                else:
                    d['model_state_dict'] = get_state_dict(self.net, unwrap_model)
            else:
                net_pickle = pickle.dumps(self.net)
                d['net_pickle'] = net_pickle
            self.net = None
        if self._optimizer:
            d['optimizer_state_dict'] = self._optimizer.state_dict()
            self._optimizer = None
        if hasattr(self, '_loss_scaler') and self._loss_scaler:
            d[self._loss_scaler.state_dict_key] = self._loss_scaler.state_dict()
            d['_loss_scaler_state_dict_key'] = self._loss_scaler.state_dict_key
        if self._model_ema:
            d['ema_state_dict'] = get_state_dict(self._model_ema, unwrap_model)
            self._model_ema = None
        self._logger = None
        self._reporter = None
        d['estimator'] = self
        torch.save(d, filename)
        self.__dict__.update(current_states)

    @classmethod
    def load(cls, filename, ctx='auto'):
        d = torch.load(filename, map_location=torch.device('cpu'))
        est = d.pop('estimator')
        # logger
        est._logger = logging.getLogger(cls.__name__)
        est._logger.setLevel(logging.ERROR)
        try:
            fh = logging.FileHandler(est._log_file)
            est._logger.addHandler(fh)
        #pylint: disable=bare-except
        except:
            pass
        model_state_dict = d.get('model_state_dict', None)
        net_pickle = d.get('net_pickle', None)
        if model_state_dict:
            est._init_network(load_only=True)
            net_state_dict = est._reconstruct_state_dict(model_state_dict)
            if isinstance(est.net, torch.nn.DataParallel):
                est.net.module.load_state_dict(net_state_dict)
            else:
                est.net.load_state_dict(net_state_dict)
        elif net_pickle:
            est.net = pickle.loads(net_pickle)
        optimizer_state_dict = d.get('optimizer_state_dict', None)
        if optimizer_state_dict:
            est._init_trainer()
            est._optimizer.load_state_dict(optimizer_state_dict)
        if hasattr(est, '_loss_scaler') and est._loss_scaler:
            loss_scaler_state_dict_key = d.get('loss_scaler_state_dict')
            loss_scaler_dict = d.get(loss_scaler_state_dict_key, None)
            if loss_scaler_dict:
                est._loss_scaler.load_state_dict(loss_scaler_dict)
        ema_state_dict = d.get('ema_state_dict', None)
        est._init_model_ema()
        if ema_state_dict:
            ema_state_dict = est._reconstruct_state_dict(ema_state_dict)
            if isinstance(est.net, torch.nn.DataParallel):
                est._model_ema.module.module.load_state_dict(ema_state_dict)
            else:
                est._model_ema.module.load_state_dict(ema_state_dict)
        new_ctx = _suggest_load_context(est.net, ctx, est.ctx)
        est.reset_ctx(new_ctx)
        est._logger.setLevel(logging.INFO)
        return est

    # pylint: disable=redefined-outer-name, reimported
    def __getstate__(self):
        d = self.__dict__.copy()
        try:
            import torch
            net = d.pop('net', None)
            model_ema = d.pop('_model_ema', None)
            optimizer = d.pop('_optimizer', None)
            loss_scaler = d.pop('_loss_scaler', None)
            save_state = {}
            if net is not None:
                if not self._custom_net:
                    if isinstance(net, torch.nn.DataParallel):
                        save_state['state_dict'] = get_state_dict(net.module, unwrap_model)
                    else:
                        save_state['state_dict'] = get_state_dict(net, unwrap_model)
                else:
                    net_pickle = pickle.dumps(net)
                    save_state['net_pickle'] = net_pickle
            if optimizer is not None:
                save_state['optimizer'] = optimizer.state_dict()
            if loss_scaler is not None:
                save_state[loss_scaler.state_dict_key] = loss_scaler.state_dict()
            if model_ema is not None:
                save_state['state_dict_ema'] = get_state_dict(model_ema, unwrap_model)
        except ImportError:
            pass
        d['save_state'] = save_state
        d['_logger'] = None
        d['_reporter'] = None
        return d

    def __setstate__(self, state):
        save_state = state.pop('save_state', None)
        self.__dict__.update(state)
        # logger
        self._logger = logging.getLogger(state.get('_name', self.__class__.__name__))
        self._logger.setLevel(logging.ERROR)
        try:
            fh = logging.FileHandler(self._log_file)
            self._logger.addHandler(fh)
        #pylint: disable=bare-except
        except:
            pass
        if not save_state:
            self.net = None
            self._optimizer = None
            self._logger.setLevel(logging.INFO)
            return
        try:
            import torch
            self.net = None
            self._optimizer = None
            if self._custom_net:
                if save_state.get('net_pickle', None):
                    self.net = pickle.loads(save_state['net_pickle'])
            else:
                if save_state.get('state_dict', None):
                    self._init_network(load_only=True)
                    net_state_dict = self._reconstruct_state_dict(save_state['state_dict'])
                    if isinstance(self.net, torch.nn.DataParallel):
                        self.net.module.load_state_dict(net_state_dict)
                    else:
                        self.net.load_state_dict(net_state_dict)
            if save_state.get('optimizer', None):
                self._init_trainer()
                self._optimizer.load_state_dict(save_state['optimizer'])
            if hasattr(self, '_loss_scaler') and self._loss_scaler and self._loss_scaler.state_dict_key in save_state:
                loss_scaler_dict = save_state[self._loss_scaler.state_dict_key]
                self._loss_scaler.load_state_dict(loss_scaler_dict)
            if save_state.get('state_dict_ema', None):
                self._init_model_ema()
                model_ema_dict = save_state.get('state_dict_ema')
                model_ema_dict = self._reconstruct_state_dict(model_ema_dict)
                if isinstance(self.net, torch.nn.DataParallel):
                    self._model_ema.module.module.load_state_dict(model_ema_dict)
                else:
                    self._model_ema.module.load_state_dict(model_ema_dict)
        except ImportError:
            pass
        self._logger.setLevel(logging.INFO)
コード例 #6
0
def main(args):
    utils.init_distributed_mode(args)

    print(args)

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    # random.seed(seed)

    cudnn.benchmark = True

    dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
    dataset_val, _ = build_dataset(is_train=False, args=args)

    if True:  # args.distributed:
        num_tasks = utils.get_world_size()
        global_rank = utils.get_rank()
        sampler_train = torch.utils.data.DistributedSampler(
            dataset_train,
            num_replicas=num_tasks,
            rank=global_rank,
            shuffle=True)
        if args.dist_eval:
            if len(dataset_val) % num_tasks != 0:
                print(
                    'Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
                    'This will slightly alter validation results as extra duplicate entries are added to achieve '
                    'equal num of samples per-process.')
            sampler_val = torch.utils.data.DistributedSampler(
                dataset_val,
                num_replicas=num_tasks,
                rank=global_rank,
                shuffle=False)
        else:
            sampler_val = torch.utils.data.SequentialSampler(dataset_val)
    else:
        sampler_train = torch.utils.data.RandomSampler(dataset_train)
        sampler_val = torch.utils.data.SequentialSampler(dataset_val)

    data_loader_train = torch.utils.data.DataLoader(
        dataset_train,
        sampler=sampler_train,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=True,
    )

    data_loader_val = torch.utils.data.DataLoader(dataset_val,
                                                  sampler=sampler_val,
                                                  batch_size=int(
                                                      1.0 * args.batch_size),
                                                  num_workers=args.num_workers,
                                                  pin_memory=args.pin_mem,
                                                  drop_last=False)

    mixup_fn = None
    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
    if mixup_active:
        mixup_fn = Mixup(mixup_alpha=args.mixup,
                         cutmix_alpha=args.cutmix,
                         cutmix_minmax=args.cutmix_minmax,
                         prob=args.mixup_prob,
                         switch_prob=args.mixup_switch_prob,
                         mode=args.mixup_mode,
                         label_smoothing=args.smoothing,
                         num_classes=args.nb_classes)

    print(f"Creating model: {args.model}")
    # model = create_model(
    #     args.model,
    #     pretrained=False,
    #     num_classes=args.nb_classes,
    #     drop_rate=args.drop,
    #     drop_path_rate=args.drop_path,
    #     drop_block_rate=None,
    # )
    model = getattr(SwinTransformer, args.model)(num_classes=args.nb_classes,
                                                 drop_path_rate=args.drop_path)
    model.to(device)

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpu])
        model_without_ddp = model.module
    n_parameters = sum(p.numel() for p in model.parameters()
                       if p.requires_grad)
    print('number of params:', n_parameters)

    linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size(
    ) / 512.0
    args.lr = linear_scaled_lr

    linear_scaled_warmup_lr = args.warmup_lr * args.batch_size * utils.get_world_size(
    ) / 512.0
    args.warmup_lr = linear_scaled_warmup_lr

    optimizer = create_optimizer(args, model_without_ddp)
    loss_scaler = NativeScaler()

    lr_scheduler, _ = create_scheduler(args, optimizer)

    # criterion = LabelSmoothingCrossEntropy()

    if args.mixup > 0.:
        # smoothing is handled with mixup label transform
        criterion = SoftTargetCrossEntropy()
    elif args.smoothing:
        criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
    else:
        criterion = torch.nn.CrossEntropyLoss()

    output_dir = Path(args.output_dir)
    if args.resume:
        if args.resume.startswith('https'):
            checkpoint = torch.hub.load_state_dict_from_url(args.resume,
                                                            map_location='cpu',
                                                            check_hash=True)
        else:
            checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer'])
            lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
            args.start_epoch = checkpoint['epoch'] + 1
            if 'scaler' in checkpoint:
                loss_scaler.load_state_dict(checkpoint['scaler'])

    if args.eval:
        test_stats = evaluate(data_loader_val, model, device)
        print(
            f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%"
        )
        return

    print(f"Start training for {args.epochs} epochs")
    start_time = time.time()
    max_accuracy = 0.0
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            data_loader_train.sampler.set_epoch(epoch)

        lr_scheduler.step(epoch + 1)
        train_stats = train_one_epoch(
            model,
            criterion,
            data_loader_train,
            optimizer,
            device,
            epoch,
            loss_scaler,
            args.clip_grad,
            mixup_fn,
            set_training_mode=True  # keep in eval mode during finetuning
        )

        if args.output_dir:
            checkpoint_paths = [output_dir / 'checkpoint.pth']
            for checkpoint_path in checkpoint_paths:
                utils.save_on_master(
                    {
                        'model': model_without_ddp.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'lr_scheduler': lr_scheduler.state_dict(),
                        'epoch': epoch,
                        'scaler': loss_scaler.state_dict(),
                        'args': args,
                    }, checkpoint_path)

        test_stats = evaluate(data_loader_val, model, device)
        print(
            f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%"
        )
        max_accuracy = max(max_accuracy, test_stats["acc1"])
        print(f'Max accuracy: {max_accuracy:.2f}%')

        log_stats = {
            **{f'train_{k}': v
               for k, v in train_stats.items()},
            **{f'test_{k}': v
               for k, v in test_stats.items()}, 'epoch': epoch,
            'n_parameters': n_parameters
        }

        if args.output_dir and utils.is_main_process():
            with (output_dir / "log.txt").open("a") as f:
                f.write(json.dumps(log_stats) + "\n")

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
コード例 #7
0
ファイル: supernet_train.py プロジェクト: penghouwen/nni
def main(args):

    utils.init_distributed_mode(args)
    update_config_from_file(args.cfg)

    print(args)
    args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    cudnn.benchmark = True

    dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
    dataset_val, _ = build_dataset(is_train=False, args=args)

    if args.distributed:
        num_tasks = utils.get_world_size()
        global_rank = utils.get_rank()
        if args.repeated_aug:
            sampler_train = RASampler(
                dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
            )
        else:
            sampler_train = torch.utils.data.DistributedSampler(
                dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
            )
        if args.dist_eval:
            if len(dataset_val) % num_tasks != 0:
                print(
                    'Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
                    'This will slightly alter validation results as extra duplicate entries are added to achieve '
                    'equal num of samples per-process.')
            sampler_val = torch.utils.data.DistributedSampler(
                dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False)
        else:
            sampler_val = torch.utils.data.SequentialSampler(dataset_val)
    else:
        sampler_val = torch.utils.data.SequentialSampler(dataset_val)
        sampler_train = torch.utils.data.RandomSampler(dataset_train)

    data_loader_train = torch.utils.data.DataLoader(
        dataset_train, sampler=sampler_train,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=True,
    )

    data_loader_val = torch.utils.data.DataLoader(
        dataset_val, batch_size=args.batch_size // 2,
        sampler=sampler_val, num_workers=args.num_workers,
        pin_memory=args.pin_mem, drop_last=False
    )

    mixup_fn = None
    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
    if mixup_active:
        mixup_fn = Mixup(
            mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
            prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
            label_smoothing=args.smoothing, num_classes=args.nb_classes)
    print("Creating SuperVisionTransformer")
    print(cfg)
    model = Vision_TransformerSuper(img_size=args.input_size,
                                    patch_size=args.patch_size,
                                    embed_dim=cfg.SUPERNET.EMBED_DIM, depth=cfg.SUPERNET.DEPTH,
                                    num_heads=cfg.SUPERNET.NUM_HEADS,mlp_ratio=cfg.SUPERNET.MLP_RATIO,
                                    qkv_bias=True, drop_rate=args.drop,
                                    drop_path_rate=args.drop_path,
                                    gp=args.gp,
                                    num_classes=args.nb_classes,
                                    max_relative_position=args.max_relative_position,
                                    relative_position=args.relative_position,
                                    change_qkv=args.change_qkv, abs_pos=not args.no_abs_pos)

    choices = {'num_heads': cfg.SEARCH_SPACE.NUM_HEADS, 'mlp_ratio': cfg.SEARCH_SPACE.MLP_RATIO,
               'embed_dim': cfg.SEARCH_SPACE.EMBED_DIM , 'depth': cfg.SEARCH_SPACE.DEPTH}

    model.to(device)
    if args.teacher_model:
        teacher_model = create_model(
            args.teacher_model,
            pretrained=True,
            num_classes=args.nb_classes,
        )
        teacher_model.to(device)
        teacher_loss = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
    else:
        teacher_model = None
        teacher_loss = None

    model_ema = None

    model_without_ddp = model
    if args.distributed:

        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
        model_without_ddp = model.module

    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print('number of params:', n_parameters)

    linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0
    args.lr = linear_scaled_lr
    optimizer = create_optimizer(args, model_without_ddp)
    loss_scaler = NativeScaler()
    lr_scheduler, _ = create_scheduler(args, optimizer)

    if args.mixup > 0.:
        # smoothing is handled with mixup label transform
        criterion = SoftTargetCrossEntropy()
    elif args.smoothing:
        criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
    else:
        criterion = torch.nn.CrossEntropyLoss()

    output_dir = Path(args.output_dir)

    if not output_dir.exists():
        output_dir.mkdir(parents=True)
    # save config for later experiments
    with open(file=output_dir / "config.yaml", mode='w') as f:
        f.write(args_text)
    if args.resume:
        if args.resume.startswith('https'):
            checkpoint = torch.hub.load_state_dict_from_url(
                args.resume, map_location='cpu', check_hash=True)
        else:
            checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer'])
            lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
            args.start_epoch = checkpoint['epoch'] + 1
            if 'scaler' in checkpoint:
                loss_scaler.load_state_dict(checkpoint['scaler'])
            if args.model_ema:
                utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema'])

    retrain_config = None
    if args.mode == 'retrain' and "RETRAIN" in cfg:
        retrain_config = {'layer_num': cfg.RETRAIN.DEPTH, 'embed_dim': [cfg.RETRAIN.EMBED_DIM]*cfg.RETRAIN.DEPTH,
                          'num_heads': cfg.RETRAIN.NUM_HEADS,'mlp_ratio': cfg.RETRAIN.MLP_RATIO}

    trainer = AFSupernetTrainer(
        model, criterion, data_loader_train, data_loader_val,
        optimizer, device, args.epochs, loss_scaler,
        args.clip_grad, model_ema, mixup_fn,
        args.amp, teacher_model, teacher_loss,choices, args.mode, retrain_config, 0., output_dir, lr_scheduler,
    )
    if args.eval:
        trainer._validate_one_epoch(-1)
        return
    trainer.fit()
コード例 #8
0
ファイル: train_finetune.py プロジェクト: zengwang430521/PVT
def main(args):
    utils.my_init_distributed_mode(args)
    print(args)
    # if args.distillation_type != 'none' and args.finetune and not args.eval:
    #     raise NotImplementedError("Finetuning with distillation not yet supported")

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    # random.seed(seed)

    cudnn.benchmark = True

    dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
    dataset_val, _ = build_dataset(is_train=False, args=args)

    # if True:  # args.distributed:
    #     num_tasks = utils.get_world_size()
    #     global_rank = utils.get_rank()
    #     if args.repeated_aug:
    #         sampler_train = RASampler(
    #             dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
    #         )
    #     else:
    #         sampler_train = torch.utils.data.DistributedSampler(
    #             dataset_train,
    #             # num_replicas=num_tasks,
    #             num_replicas=0,
    #             rank=global_rank, shuffle=True
    #         )
    #     if args.dist_eval:
    #         if len(dataset_val) % num_tasks != 0:
    #             print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
    #                   'This will slightly alter validation results as extra duplicate entries are added to achieve '
    #                   'equal num of samples per-process.')
    #         sampler_val = torch.utils.data.DistributedSampler(
    #             dataset_val,
    #             # num_replicas=num_tasks,
    #             num_replicas=0,
    #             rank=global_rank, shuffle=False)
    #     else:
    #         sampler_val = torch.utils.data.SequentialSampler(dataset_val)
    # else:
    #     sampler_train = torch.utils.data.RandomSampler(dataset_train)
    #     sampler_val = torch.utils.data.SequentialSampler(dataset_val)
    #
    # data_loader_train = torch.utils.data.DataLoader(
    #     dataset_train, sampler=sampler_train,
    #     batch_size=args.batch_size,
    #     num_workers=args.num_workers,
    #     pin_memory=args.pin_mem,
    #     drop_last=True,
    # )
    #
    # data_loader_val = torch.utils.data.DataLoader(
    #     dataset_val, sampler=sampler_val,
    #     batch_size=int(1.5 * args.batch_size),
    #     num_workers=args.num_workers,
    #     pin_memory=args.pin_mem,
    #     drop_last=False
    # )
    #

    if args.distributed:
        if args.cache_mode:
            sampler_train = samplers.NodeDistributedSampler(dataset_train)
            sampler_val = samplers.NodeDistributedSampler(dataset_val,
                                                          shuffle=False)
        else:
            sampler_train = samplers.DistributedSampler(dataset_train)
            sampler_val = samplers.DistributedSampler(dataset_val,
                                                      shuffle=False)
    else:
        sampler_train = torch.utils.data.RandomSampler(dataset_train)
        sampler_val = torch.utils.data.SequentialSampler(dataset_val)

    batch_sampler_train = torch.utils.data.BatchSampler(sampler_train,
                                                        args.batch_size,
                                                        drop_last=True)

    data_loader_train = DataLoader(dataset_train,
                                   batch_sampler=batch_sampler_train,
                                   num_workers=args.num_workers,
                                   pin_memory=True)
    data_loader_val = DataLoader(dataset_val,
                                 args.batch_size,
                                 sampler=sampler_val,
                                 drop_last=False,
                                 num_workers=args.num_workers,
                                 pin_memory=True)

    mixup_fn = None
    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
    if mixup_active:
        mixup_fn = Mixup(mixup_alpha=args.mixup,
                         cutmix_alpha=args.cutmix,
                         cutmix_minmax=args.cutmix_minmax,
                         prob=args.mixup_prob,
                         switch_prob=args.mixup_switch_prob,
                         mode=args.mixup_mode,
                         label_smoothing=args.smoothing,
                         num_classes=args.nb_classes)

    print(f"Creating model: {args.model}")
    model = create_model(
        args.model,
        pretrained=False,
        num_classes=args.nb_classes,
        drop_rate=args.drop,
        drop_path_rate=args.drop_path,
        drop_block_rate=None,
    )
    model_without_ddp = model

    # # there are bugs
    # if args.finetune:
    #     if args.finetune.startswith('https'):
    #         checkpoint = torch.hub.load_state_dict_from_url(
    #             args.finetune, map_location='cpu', check_hash=True)
    #     else:
    #         checkpoint = torch.load(args.finetune, map_location='cpu')
    #
    #     checkpoint_model = checkpoint['model']
    #
    #     state_dict = model.state_dict()
    #     for k in ['head.weight', 'head.bias', 'head_dist.weight', 'head_dist.bias']:
    #         if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
    #             print(f"Removing key {k} from pretrained checkpoint")
    #             del checkpoint_model[k]
    #
    #     _ = model.load_state_dict(checkpoint_model, strict=False)

    model.to(device)

    model_ema = None
    # if args.model_ema:
    #     # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
    #     model_ema = ModelEma(
    #         model,
    #         decay=args.model_ema_decay,
    #         device='cpu' if args.model_ema_force_cpu else '',
    #         resume='')

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpu])
        model_without_ddp = model.module
    n_parameters = sum(p.numel() for p in model.parameters()
                       if p.requires_grad)
    print('number of params:', n_parameters)

    linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size(
    ) / 512.0
    args.lr = linear_scaled_lr
    optimizer = create_optimizer(args, model_without_ddp)
    loss_scaler = NativeScaler()
    lr_scheduler, _ = create_scheduler(args, optimizer)

    criterion = LabelSmoothingCrossEntropy()

    if args.mixup > 0.:
        # smoothing is handled with mixup label transform
        criterion = SoftTargetCrossEntropy()
    elif args.smoothing:
        criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
    else:
        criterion = torch.nn.CrossEntropyLoss()

    criterion = DistillationLoss(criterion, None, 'none', 0, 0)

    output_dir = Path(args.output_dir)

    # for finetune
    if args.finetune:
        if args.finetune.startswith('https'):
            checkpoint = torch.hub.load_state_dict_from_url(args.finetune,
                                                            map_location='cpu',
                                                            check_hash=True)
        else:
            if not os.path.exists(args.finetune):
                checkpoint = None
                print('NOTICE:' + args.finetune + ' does not exist!')
            else:
                checkpoint = torch.load(args.finetune, map_location='cpu')

        if checkpoint is not None:
            if 'model' in checkpoint:
                check_model = checkpoint['model']
            else:
                check_model = checkpoint

            missing_keys = model_without_ddp.load_state_dict(
                check_model, strict=False).missing_keys
            skip_keys = model_without_ddp.no_weight_decay()
            # create optimizer manually
            param_dicts = [
                {
                    "params": [
                        p for n, p in model_without_ddp.named_parameters()
                        if n in missing_keys and n not in skip_keys
                    ],
                    "lr":
                    args.lr,
                    'weight_decay':
                    args.weight_decay,
                },
                {
                    "params": [
                        p for n, p in model_without_ddp.named_parameters()
                        if n in missing_keys and n in skip_keys
                    ],
                    "lr":
                    args.lr,
                    'weight_decay':
                    0,
                },
                {
                    "params": [
                        p for n, p in model_without_ddp.named_parameters()
                        if n not in missing_keys and n not in skip_keys
                    ],
                    "lr":
                    args.lr * args.fine_factor,
                    'weight_decay':
                    args.weight_decay,
                },
                {
                    "params": [
                        p for n, p in model_without_ddp.named_parameters()
                        if n not in missing_keys and n in skip_keys
                    ],
                    "lr":
                    args.lr * args.fine_factor,
                    'weight_decay':
                    0,
                },
            ]

            optimizer = torch.optim.AdamW(param_dicts,
                                          lr=args.lr,
                                          weight_decay=args.weight_decay)
            loss_scaler = NativeScaler()
            lr_scheduler, _ = create_scheduler(args, optimizer)

            # if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
            #     optimizer.load_state_dict(checkpoint['optimizer'])
            #     lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
            #     args.start_epoch = checkpoint['epoch'] + 1
            #     # if args.model_ema:
            #     #     utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema'])
            #     if 'scaler' in checkpoint:
            #         loss_scaler.load_state_dict(checkpoint['scaler'])

            print('finetune from' + args.finetune)

            # for debug
            # lr_scheduler.step(10)
            # lr_scheduler.step(100)
            # lr_scheduler.step(200)

    if args.resume:
        if args.resume.startswith('https'):
            checkpoint = torch.hub.load_state_dict_from_url(args.resume,
                                                            map_location='cpu',
                                                            check_hash=True)
        else:
            if not os.path.exists(args.resume):
                checkpoint = None
                print('NOTICE:' + args.resume + ' does not exist!')
            else:
                checkpoint = torch.load(args.resume, map_location='cpu')

        if checkpoint is not None:
            if 'model' in checkpoint:
                model_without_ddp.load_state_dict(checkpoint['model'])
            else:
                model_without_ddp.load_state_dict(checkpoint)

            if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
                optimizer.load_state_dict(checkpoint['optimizer'])
                lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
                args.start_epoch = checkpoint['epoch'] + 1
                # if args.model_ema:
                #     utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema'])
                if 'scaler' in checkpoint:
                    loss_scaler.load_state_dict(checkpoint['scaler'])

            print('resume from' + args.resume)

    if args.eval:
        test_stats = evaluate(data_loader_val, model, device)
        print(
            f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%"
        )
        return

    print(f"Start training for {args.epochs} epochs")
    start_time = time.time()
    max_accuracy = 0.0
    max_epoch_dp_warm_up = 100
    if 'pvt_tiny' in args.model or 'pvt_small' in args.model:
        max_epoch_dp_warm_up = 0
    if args.start_epoch < max_epoch_dp_warm_up:
        model_without_ddp.reset_drop_path(0.0)
    for epoch in range(args.start_epoch, args.epochs):
        if args.fp32_resume and epoch > args.start_epoch + 1:
            args.fp32_resume = False
        loss_scaler._scaler = torch.cuda.amp.GradScaler(
            enabled=not args.fp32_resume)

        if epoch == max_epoch_dp_warm_up:
            model_without_ddp.reset_drop_path(args.drop_path)

        if epoch < args.warmup_epochs:
            optimizer.param_groups[2]['lr'] = 0
            optimizer.param_groups[3]['lr'] = 0

        if args.distributed:
            # data_loader_train.sampler.set_epoch(epoch)
            sampler_train.set_epoch(epoch)

        train_stats = my_train_one_epoch(
            model,
            criterion,
            data_loader_train,
            optimizer,
            device,
            epoch,
            loss_scaler,
            args.clip_grad,
            model_ema,
            mixup_fn,
            # set_training_mode=args.finetune == '',  # keep in eval mode during finetuning
            fp32=args.fp32_resume)

        lr_scheduler.step(epoch)
        if args.output_dir:
            checkpoint_paths = [output_dir / 'checkpoint.pth']
            for checkpoint_path in checkpoint_paths:
                utils.save_on_master(
                    {
                        'model': model_without_ddp.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'lr_scheduler': lr_scheduler.state_dict(),
                        'epoch': epoch,
                        # 'model_ema': get_state_dict(model_ema),
                        'scaler': loss_scaler.state_dict(),
                        'args': args,
                    },
                    checkpoint_path)

        test_stats = evaluate(data_loader_val, model, device)
        print(
            f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%"
        )
        max_accuracy = max(max_accuracy, test_stats["acc1"])
        print(f'Max accuracy: {max_accuracy:.2f}%')

        log_stats = {
            **{f'train_{k}': v
               for k, v in train_stats.items()},
            **{f'test_{k}': v
               for k, v in test_stats.items()}, 'epoch': epoch,
            'n_parameters': n_parameters
        }

        if args.output_dir and utils.is_main_process():
            with (output_dir / "log.txt").open("a") as f:
                f.write(json.dumps(log_stats) + "\n")

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))