Exemple #1
0
def get_imagenet_test_data_loader(args):
    valdir = os.path.join(args.data_path, "val")
    resize_size, crop_size, interpolation = get_transform_params(args.arch)
    cache_path = _get_cache_path(valdir)
    if args.cache_dataset and os.path.exists(cache_path):
        # Attention, as the transforms are also cached!
        print("Loading dataset_test from {}".format(cache_path))
        dataset_test, _ = torch.load(cache_path)
    else:
        dataset_test = torchvision.datasets.ImageFolder(
            valdir,
            presets.ClassificationPresetEval(crop_size=crop_size, resize_size=resize_size, interpolation=interpolation)
        )
        if args.cache_dataset:
            print("Saving dataset_test to {}".format(cache_path))
            utils.mkdir(os.path.dirname(cache_path))
            utils.save_on_master((dataset_test, valdir), cache_path)

    print("Creating data loaders")
    test_sampler = torch.utils.data.SequentialSampler(dataset_test)
    
    imagenet_test_data_loader = torch.utils.data.DataLoader(
        dataset_test, batch_size=args.batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True
    )
    return imagenet_test_data_loader
Exemple #2
0
def load_data(traindir, valdir, cache_dataset, distributed):
    # Data loading code
    print("Loading data")
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    print("Loading training data")
    st = time.time()
    cache_path = _get_cache_path(traindir)
    if cache_dataset and os.path.exists(cache_path):
        # Attention, as the transforms are also cached!
        print("Loading dataset_train from {}".format(cache_path))
        dataset, _ = torch.load(cache_path)
    else:
        dataset = torchvision.datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))
        if cache_dataset:
            print("Saving dataset_train to {}".format(cache_path))
            utils.mkdir(os.path.dirname(cache_path))
            utils.save_on_master((dataset, traindir), cache_path)
    print("Took", time.time() - st)

    print("Loading validation data")
    cache_path = _get_cache_path(valdir)
    if cache_dataset and os.path.exists(cache_path):
        # Attention, as the transforms are also cached!
        print("Loading dataset_test from {}".format(cache_path))
        dataset_test, _ = torch.load(cache_path)
    else:
        dataset_test = torchvision.datasets.ImageFolder(
            valdir,
            transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ]))
        if cache_dataset:
            print("Saving dataset_test to {}".format(cache_path))
            utils.mkdir(os.path.dirname(cache_path))
            utils.save_on_master((dataset_test, valdir), cache_path)

    print("Creating data loaders")
    if distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            dataset)
        test_sampler = torch.utils.data.distributed.DistributedSampler(
            dataset_test)
    else:
        train_sampler = torch.utils.data.RandomSampler(dataset)
        test_sampler = torch.utils.data.SequentialSampler(dataset_test)

    return dataset, dataset_test, train_sampler, test_sampler
Exemple #3
0
def main(args):
    '''data_loader &dataset'''
    dataset, num_classes = get_dataset(args.dataset, "train",
                                       get_transform(train=True),
                                       args.data_path)

    train_sampler = torch.utils.data.RandomSampler(dataset)
    train_batch_sampler = torch.utils.data.BatchSampler(train_sampler,
                                                        args.batch_size,
                                                        drop_last=True)

    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_sampler=train_batch_sampler,
        num_workers=args.workers,
        collate_fn=utils.collate_fn)

    print("Creating model")
    model = torchvision.models.detection.__dict__[args.model](
        num_classes=num_classes,
        pretrained=args.pretrained,
    )
    print(model)

    device = torch.device(args.device)
    model.to(device)
    '''optimizer&&lr_scheduler'''
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.SGD(params,
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=args.lr_steps, gamma=args.lr_gamma)

    # TO DO:resume &distributed

    print("Start training")
    start_time = time.time()
    for epoch in range(args.epochs):
        train_one_epoch(model, optimizer, data_loader, device, epoch,
                        args.print_freq)
        lr_scheduler.step()
        if args.output_dir:
            utils.save_on_master(
                {
                    'model': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'lr_scheduler': lr_scheduler.state_dict(),
                    'args': args
                }, os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
def load_data(traindir, valdir, args):
    # Data loading code
    print("Loading data")
    resize_size, crop_size = (342,
                              299) if args.model == 'inception_v3' else (256,
                                                                         224)

    print("Loading training data")
    st = time.time()
    cache_path = _get_cache_path(traindir)
    if args.cache_dataset and os.path.exists(cache_path):
        # Attention, as the transforms are also cached!
        print("Loading dataset_train from {}".format(cache_path))
        dataset, _ = torch.load(cache_path)
    else:
        auto_augment_policy = getattr(args, "auto_augment", None)
        random_erase_prob = getattr(args, "random_erase", 0.0)
        dataset = torchvision.datasets.ImageFolder(
            traindir,
            presets.ClassificationPresetTrain(
                crop_size=crop_size,
                auto_augment_policy=auto_augment_policy,
                random_erase_prob=random_erase_prob))
        if args.cache_dataset:
            print("Saving dataset_train to {}".format(cache_path))
            utils.mkdir(os.path.dirname(cache_path))
            utils.save_on_master((dataset, traindir), cache_path)
    print("Took", time.time() - st)

    print("Loading validation data")
    cache_path = _get_cache_path(valdir)
    if args.cache_dataset and os.path.exists(cache_path):
        # Attention, as the transforms are also cached!
        print("Loading dataset_test from {}".format(cache_path))
        dataset_test, _ = torch.load(cache_path)
    else:
        dataset_test = torchvision.datasets.ImageFolder(
            valdir,
            presets.ClassificationPresetEval(crop_size=crop_size,
                                             resize_size=resize_size))
        if args.cache_dataset:
            print("Saving dataset_test to {}".format(cache_path))
            utils.mkdir(os.path.dirname(cache_path))
            utils.save_on_master((dataset_test, valdir), cache_path)

    print("Creating data loaders")
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            dataset)
        test_sampler = torch.utils.data.distributed.DistributedSampler(
            dataset_test)
    else:
        train_sampler = torch.utils.data.RandomSampler(dataset)
        test_sampler = torch.utils.data.SequentialSampler(dataset_test)

    return dataset, dataset_test, train_sampler, test_sampler
def save_chk_point(model_without_ddp, optimizer, lr_scheduler, epoch, acc5):
    global args
    global best_acc5
    if acc5 > best_acc5:
        best_acc5 = acc5
        checkpoint = {
            'model': model_without_ddp.state_dict(),
            'optimizer': optimizer.state_dict(),
            'lr_scheduler': lr_scheduler.state_dict(),
            'epoch': epoch,
            'args': args
        }
        utils.save_on_master(
            checkpoint, os.path.join(args.output_dir,
                                     'ckp_{}.pth'.format(acc5)))
Exemple #6
0
def main(args):
    if args.output_dir:
        utils.mkdir(args.output_dir)

    utils.init_distributed_mode(args)
    print(args)

    device = torch.device(args.device)

    # dataset, num_classes = get_dataset(args.data_path, args.dataset, "train", get_transform(train=True))
    # dataset_test, _ = get_dataset(args.data_path, args.dataset, "val", get_transform(train=False))

    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]

    scale_min = 0.5
    scale_max = 1.75
    rotate_min = -1
    rotate_max = 1
    train_h = 512
    train_w = 1024
    ignore_label = 255

    train_transform = T.Compose([
        T.RandScale([scale_min, scale_max]),
        T.RandRotate([rotate_min, rotate_max], padding=mean, ignore_label=ignore_label),
        T.RandomGaussianBlur(),
        T.RandomHorizontalFlip(),
        T.Crop([train_h, train_w], crop_type='rand', padding=mean, ignore_label=ignore_label),
        T.ToTensor(),
        T.Normalize(mean=mean, std=std)
    ])

    dataset_train = dataset.CityscapesData(split='train', data_root=args.data_root, data_list=args.train_list, transform=train_transform)

    val_transform = T.Compose([
        T.Crop([train_h, train_w], crop_type='center', padding=mean, ignore_label=ignore_label),
        T.ToTensor(),
        T.Normalize(mean=mean, std=std)
    ])
    dataset_test = dataset.CityscapesData(split='val', data_root=args.data_root, data_list=args.val_list, transform=val_transform)

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(dataset_train)
        test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test)
    else:
        train_sampler = torch.utils.data.RandomSampler(dataset_train)
        test_sampler = torch.utils.data.SequentialSampler(dataset_test)

    data_loader = torch.utils.data.DataLoader(
        dataset_train, batch_size=args.batch_size,
        sampler=train_sampler, num_workers=args.workers,
        collate_fn=utils.collate_fn, drop_last=True)

    data_loader_test = torch.utils.data.DataLoader(
        dataset_test, batch_size=1,
        sampler=test_sampler, num_workers=args.workers,
        collate_fn=utils.collate_fn)
    num_classes = 19
    # model = torchvision.models.segmentation.__dict__[args.model](num_classes=num_classes,
    #                                                              aux_loss=args.aux_loss,
    #                                                              pretrained=args.pretrained)
    if args.pretrained:
        supernet = OFAMobileNetV3(
            n_classes=1000,
            dropout_rate=0,
            width_mult_list=1.2,
            ks_list=[3, 5, 7],
            expand_ratio_list=[3, 4, 6],
            depth_list=[2, 3, 4],
        )
        arch = OFAArchitecture.from_legency_string(args.arch)
        supernet.set_active_subnet(ks=arch.ks, e=arch.ratios, d=arch.depths)
        model = supernet.get_active_subnet()

        s = torch.load("model_best.pth.tar", map_location="cpu")
        model.load_state_dict(s["state_dict_ema"])
        model = convert2segmentation(model=model, begin_index_index=17)
        print("load pretrained model.")
    else:
        supernet = SPOSMobileNetV3Segmentation(width_mult=1.2)
        model = supernet.get_subnet(OFAArchitecture.from_legency_string(args.arch))

    model.to(device)
    if args.distributed:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module

    if args.pretrained:
        params_to_optimize = [
            {"params": [p for p in model_without_ddp.backbone.parameters() if p.requires_grad]},
            {"params": [p for p in model_without_ddp.stem.parameters() if p.requires_grad]},
        ]
        if args.aux_loss:
            params = [p for p in model_without_ddp.classifier.parameters() if p.requires_grad]
            params_to_optimize.append({"params": params, "lr": args.lr * 10})
        optimizer = torch.optim.SGD(
            params_to_optimize,
            lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    else:
        params_to_optimize = [
            {"params": [p for p in model_without_ddp.first_conv.parameters() if p.requires_grad]},
            {"params": [p for p in model_without_ddp.blocks.parameters() if p.requires_grad]},
            {"params": [p for p in model_without_ddp.remain_block.parameters() if p.requires_grad]},
            {"params": [p for p in model_without_ddp.head.parameters() if p.requires_grad]},
        ]
        if args.aux_loss:
            params = [p for p in model_without_ddp.aux_head.parameters() if p.requires_grad]
            params_to_optimize.append({"params": params, "lr": args.lr})
        optimizer = torch.optim.SGD(
            params_to_optimize,
            lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

    lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer,
        lambda x: (1 - x / (len(data_loader) * args.epochs)) ** 0.9)

    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'], strict=not args.test_only)
        if not args.test_only:
            optimizer.load_state_dict(checkpoint['optimizer'])
            lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
            args.start_epoch = checkpoint['epoch'] + 1

    if args.test_only:
        confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes)
        print(confmat)
        return

    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, args.print_freq)
        confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes)
        print(confmat)
        utils.save_on_master(
            {
                'model': model_without_ddp.state_dict(),
                'optimizer': optimizer.state_dict(),
                'lr_scheduler': lr_scheduler.state_dict(),
                'epoch': epoch,
                'args': args
            },
            os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
Exemple #7
0
def main(args):
    if args.output_dir:
        utils.mkdir(args.output_dir)

    utils.init_distributed_mode(args)
    print(args)

    device = torch.device(args.device)

    dataset, num_classes = get_dataset(args.data_path, args.dataset, "train",
                                       get_transform(train=True))
    dataset_test, _ = get_dataset(args.data_path, args.dataset, "val",
                                  get_transform(train=False))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            dataset)
        test_sampler = torch.utils.data.distributed.DistributedSampler(
            dataset_test)
    else:
        train_sampler = torch.utils.data.RandomSampler(dataset)
        test_sampler = torch.utils.data.SequentialSampler(dataset_test)

    data_loader = torch.utils.data.DataLoader(dataset,
                                              batch_size=args.batch_size,
                                              sampler=train_sampler,
                                              num_workers=args.workers,
                                              collate_fn=utils.collate_fn,
                                              drop_last=True)

    data_loader_test = torch.utils.data.DataLoader(dataset_test,
                                                   batch_size=1,
                                                   sampler=test_sampler,
                                                   num_workers=args.workers,
                                                   collate_fn=utils.collate_fn)

    model = torchvision.models.segmentation.__dict__[args.model](
        num_classes=num_classes,
        aux_loss=args.aux_loss,
        pretrained=args.pretrained)
    model.to(device)
    if args.distributed:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpu])
        model_without_ddp = model.module

    params_to_optimize = [
        {
            "params": [
                p for p in model_without_ddp.backbone.parameters()
                if p.requires_grad
            ]
        },
        {
            "params": [
                p for p in model_without_ddp.classifier.parameters()
                if p.requires_grad
            ]
        },
    ]
    if args.aux_loss:
        params = [
            p for p in model_without_ddp.aux_classifier.parameters()
            if p.requires_grad
        ]
        params_to_optimize.append({"params": params, "lr": args.lr * 10})
    optimizer = torch.optim.SGD(params_to_optimize,
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer, lambda x: (1 - x / (len(data_loader) * args.epochs))**0.9)

    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'],
                                          strict=not args.test_only)
        if not args.test_only:
            optimizer.load_state_dict(checkpoint['optimizer'])
            lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
            args.start_epoch = checkpoint['epoch'] + 1

    if args.test_only:
        confmat = evaluate(model,
                           data_loader_test,
                           device=device,
                           num_classes=num_classes)
        print(confmat)
        return

    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler,
                        device, epoch, args.print_freq)
        confmat = evaluate(model,
                           data_loader_test,
                           device=device,
                           num_classes=num_classes)
        print(confmat)
        utils.save_on_master(
            {
                'model': model_without_ddp.state_dict(),
                'optimizer': optimizer.state_dict(),
                'lr_scheduler': lr_scheduler.state_dict(),
                'epoch': epoch,
                'args': args
            }, os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
Exemple #8
0
def main(args):
    utils.init_distributed_mode(args)

    print(args)

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    # random.seed(seed)

    cudnn.benchmark = True

    dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
    dataset_val, _ = build_dataset(is_train=False, args=args)

    if True:  # args.distributed:
        num_tasks = utils.get_world_size()
        global_rank = utils.get_rank()
        if args.repeated_aug:
            sampler_train = RASampler(dataset_train,
                                      num_replicas=num_tasks,
                                      rank=global_rank,
                                      shuffle=True)
        else:
            sampler_train = torch.utils.data.DistributedSampler(
                dataset_train,
                num_replicas=num_tasks,
                rank=global_rank,
                shuffle=True)
    else:
        sampler_train = torch.utils.data.RandomSampler(dataset_train)

    data_loader_train = torch.utils.data.DataLoader(
        dataset_train,
        sampler=sampler_train,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=True,
    )

    data_loader_val = torch.utils.data.DataLoader(dataset_val,
                                                  batch_size=int(
                                                      1.5 * args.batch_size),
                                                  shuffle=False,
                                                  num_workers=args.num_workers,
                                                  pin_memory=args.pin_mem,
                                                  drop_last=False)

    mixup_fn = None
    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
    if mixup_active:
        mixup_fn = Mixup(mixup_alpha=args.mixup,
                         cutmix_alpha=args.cutmix,
                         cutmix_minmax=args.cutmix_minmax,
                         prob=args.mixup_prob,
                         switch_prob=args.mixup_switch_prob,
                         mode=args.mixup_mode,
                         label_smoothing=args.smoothing,
                         num_classes=args.nb_classes)

    print(f"Creating model: {args.model}")
    model = create_model(
        args.model,
        pretrained=False,
        num_classes=args.nb_classes,
        drop_rate=args.drop,
        drop_path_rate=args.drop_path,
        drop_block_rate=args.drop_block,
    )

    # TODO: finetuning

    model.to(device)

    model_ema = None
    if args.model_ema:
        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
        model_ema = ModelEma(model,
                             decay=args.model_ema_decay,
                             device='cpu' if args.model_ema_force_cpu else '',
                             resume='')

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpu])
        model_without_ddp = model.module
    n_parameters = sum(p.numel() for p in model.parameters()
                       if p.requires_grad)
    print('number of params:', n_parameters)

    linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size(
    ) / 512.0
    args.lr = linear_scaled_lr
    optimizer = create_optimizer(args, model)
    loss_scaler = NativeScaler()

    lr_scheduler, _ = create_scheduler(args, optimizer)

    criterion = LabelSmoothingCrossEntropy()

    if args.mixup > 0.:
        # smoothing is handled with mixup label transform
        criterion = SoftTargetCrossEntropy()
    elif args.smoothing:
        criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
    else:
        criterion = torch.nn.CrossEntropyLoss()

    output_dir = Path(args.output_dir)
    if args.resume:
        if args.resume.startswith('https'):
            checkpoint = torch.hub.load_state_dict_from_url(args.resume,
                                                            map_location='cpu',
                                                            check_hash=True)
        else:
            checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer'])
            lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
            args.start_epoch = checkpoint['epoch'] + 1
            if args.model_ema:
                utils._load_checkpoint_for_ema(model_ema,
                                               checkpoint['model_ema'])

    if args.eval:
        test_stats = evaluate(data_loader_val, model, device)
        print(
            f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%"
        )
        return

    print("Start training")
    start_time = time.time()
    max_accuracy = 0.0
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            data_loader_train.sampler.set_epoch(epoch)

        train_stats = train_one_epoch(model, criterion, data_loader_train,
                                      optimizer, device, epoch, loss_scaler,
                                      args.clip_grad, model_ema, mixup_fn)

        lr_scheduler.step(epoch)
        if args.output_dir:
            checkpoint_paths = [output_dir / 'checkpoint.pth']
            for checkpoint_path in checkpoint_paths:
                utils.save_on_master(
                    {
                        'model': model_without_ddp.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'lr_scheduler': lr_scheduler.state_dict(),
                        'epoch': epoch,
                        'model_ema': get_state_dict(model_ema),
                        'args': args,
                    }, checkpoint_path)

        test_stats = evaluate(data_loader_val, model, device)
        print(
            f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%"
        )
        max_accuracy = max(max_accuracy, test_stats["acc1"])
        print(f'Max accuracy: {max_accuracy:.2f}%')

        log_stats = {
            **{f'train_{k}': v
               for k, v in train_stats.items()},
            **{f'test_{k}': v
               for k, v in test_stats.items()}, 'epoch': epoch,
            'n_parameters': n_parameters
        }

        if args.output_dir and utils.is_main_process():
            with (output_dir / "log.txt").open("a") as f:
                f.write(json.dumps(log_stats) + "\n")

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
Exemple #9
0
def main(args):
    if args.output_dir:
        utils.mkdir(args.output_dir)

    utils.init_distributed_mode(args)
    print(args)

    device = torch.device(args.device)

    if args.use_deterministic_algorithms:
        torch.use_deterministic_algorithms(True)

    # Data loading code
    print("Loading data")

    dataset, num_classes = get_dataset(args.dataset, "train",
                                       get_transform(True, args),
                                       args.data_path)
    dataset_test, _ = get_dataset(args.dataset, "val",
                                  get_transform(False, args), args.data_path)

    print("Creating data loaders")
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            dataset)
        test_sampler = torch.utils.data.distributed.DistributedSampler(
            dataset_test, shuffle=False)
    else:
        train_sampler = torch.utils.data.RandomSampler(dataset)
        test_sampler = torch.utils.data.SequentialSampler(dataset_test)

    if args.aspect_ratio_group_factor >= 0:
        group_ids = create_aspect_ratio_groups(
            dataset, k=args.aspect_ratio_group_factor)
        train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids,
                                                  args.batch_size)
    else:
        train_batch_sampler = torch.utils.data.BatchSampler(train_sampler,
                                                            args.batch_size,
                                                            drop_last=True)

    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_sampler=train_batch_sampler,
        num_workers=args.workers,
        collate_fn=utils.collate_fn)

    data_loader_test = torch.utils.data.DataLoader(dataset_test,
                                                   batch_size=1,
                                                   sampler=test_sampler,
                                                   num_workers=args.workers,
                                                   collate_fn=utils.collate_fn)

    print("Creating model")
    kwargs = {"trainable_backbone_layers": args.trainable_backbone_layers}
    if args.data_augmentation in ["multiscale", "lsj"]:
        kwargs["_skip_resize"] = True
    if "rcnn" in args.model:
        if args.rpn_score_thresh is not None:
            kwargs["rpn_score_thresh"] = args.rpn_score_thresh
    model = torchvision.models.detection.__dict__[args.model](
        weights=args.weights,
        weights_backbone=args.weights_backbone,
        num_classes=num_classes,
        **kwargs)
    model.to(device)
    if args.distributed and args.sync_bn:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpu])
        model_without_ddp = model.module

    if args.norm_weight_decay is None:
        parameters = [p for p in model.parameters() if p.requires_grad]
    else:
        param_groups = torchvision.ops._utils.split_normalization_params(model)
        wd_groups = [args.norm_weight_decay, args.weight_decay]
        parameters = [{
            "params": p,
            "weight_decay": w
        } for p, w in zip(param_groups, wd_groups) if p]

    opt_name = args.opt.lower()
    if opt_name.startswith("sgd"):
        optimizer = torch.optim.SGD(
            parameters,
            lr=args.lr,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
            nesterov="nesterov" in opt_name,
        )
    elif opt_name == "adamw":
        optimizer = torch.optim.AdamW(parameters,
                                      lr=args.lr,
                                      weight_decay=args.weight_decay)
    else:
        raise RuntimeError(
            f"Invalid optimizer {args.opt}. Only SGD and AdamW are supported.")

    scaler = torch.cuda.amp.GradScaler() if args.amp else None

    args.lr_scheduler = args.lr_scheduler.lower()
    if args.lr_scheduler == "multisteplr":
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=args.lr_steps, gamma=args.lr_gamma)
    elif args.lr_scheduler == "cosineannealinglr":
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=args.epochs)
    else:
        raise RuntimeError(
            f"Invalid lr scheduler '{args.lr_scheduler}'. Only MultiStepLR and CosineAnnealingLR are supported."
        )

    if args.resume:
        checkpoint = torch.load(args.resume, map_location="cpu")
        model_without_ddp.load_state_dict(checkpoint["model"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
        args.start_epoch = checkpoint["epoch"] + 1
        if args.amp:
            scaler.load_state_dict(checkpoint["scaler"])

    if args.test_only:
        torch.backends.cudnn.deterministic = True
        evaluate(model, data_loader_test, device=device)
        return

    print("Start training")
    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        train_one_epoch(model, optimizer, data_loader, device, epoch,
                        args.print_freq, scaler)
        lr_scheduler.step()
        if args.output_dir:
            checkpoint = {
                "model": model_without_ddp.state_dict(),
                "optimizer": optimizer.state_dict(),
                "lr_scheduler": lr_scheduler.state_dict(),
                "args": args,
                "epoch": epoch,
            }
            if args.amp:
                checkpoint["scaler"] = scaler.state_dict()
            utils.save_on_master(
                checkpoint, os.path.join(args.output_dir,
                                         f"model_{epoch}.pth"))
            utils.save_on_master(
                checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))

        # evaluate after every epoch
        evaluate(model, data_loader_test, device=device)

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print(f"Training time {total_time_str}")
Exemple #10
0
def train_dino(args):
    utils.init_distributed_mode(args)
    utils.fix_random_seeds(args.seed)
    print("git:\n  {}\n".format(utils.get_sha()))
    print("\n".join("%s: %s" % (k, str(v))
                    for k, v in sorted(dict(vars(args)).items())))
    cudnn.benchmark = True

    # ============ preparing data ... ============
    transform = DataAugmentationDINO(
        args.global_crops_scale,
        args.local_crops_scale,
        args.local_crops_number,
    )
    #dataset = datasets.ImageFolder(args.data_path, transform=transform)
    from sen12ms import get_transform
    dataset = AllSen12MSDataset(args.data_path,
                                "train",
                                transform=transform,
                                tansform_coord=None,
                                classes=None,
                                seasons=None,
                                split_by_region=True,
                                download=False)

    sampler = torch.utils.data.DistributedSampler(dataset, shuffle=True)
    data_loader = torch.utils.data.DataLoader(
        dataset,
        sampler=sampler,
        batch_size=args.batch_size_per_gpu,
        num_workers=args.num_workers,
        pin_memory=True,
        drop_last=True,
    )
    print(f"Data loaded: there are {len(dataset)} images.")

    # ============ building student and teacher networks ... ============
    # if the network is a vision transformer (i.e. deit_tiny, deit_small, vit_base)
    if args.arch in vits.__dict__.keys():
        student = vits.__dict__[args.arch](
            patch_size=args.patch_size,
            drop_path_rate=0.1,  # stochastic depth
        )
        teacher = vits.__dict__[args.arch](patch_size=args.patch_size)
        embed_dim = student.embed_dim

        student = utils.replace_input_layer(student, inchannels=13)
        teacher = utils.replace_input_layer(teacher, inchannels=13)

    # otherwise, we check if the architecture is in torchvision models
    elif args.arch in torchvision_models.__dict__.keys():
        student = torchvision_models.__dict__[args.arch]()
        teacher = torchvision_models.__dict__[args.arch]()
        embed_dim = student.fc.weight.shape[1]
    else:
        print(f"Unknow architecture: {args.arch}")

    # multi-crop wrapper handles forward with inputs of different resolutions
    student = utils.MultiCropWrapper(
        student,
        DINOHead(
            embed_dim,
            args.out_dim,
            use_bn=args.use_bn_in_head,
            norm_last_layer=args.norm_last_layer,
        ))
    teacher = utils.MultiCropWrapper(
        teacher,
        DINOHead(embed_dim, args.out_dim, args.use_bn_in_head),
    )
    # move networks to gpu
    student, teacher = student.cuda(), teacher.cuda()
    # synchronize batch norms (if any)
    if utils.has_batchnorms(student):
        student = nn.SyncBatchNorm.convert_sync_batchnorm(student)
        teacher = nn.SyncBatchNorm.convert_sync_batchnorm(teacher)

        # we need DDP wrapper to have synchro batch norms working...
        teacher = nn.parallel.DistributedDataParallel(teacher,
                                                      device_ids=[args.gpu])
        teacher_without_ddp = teacher.module
    else:
        # teacher_without_ddp and teacher are the same thing
        teacher_without_ddp = teacher
    student = nn.parallel.DistributedDataParallel(student,
                                                  device_ids=[args.gpu])
    # teacher and student start with the same weights
    teacher_without_ddp.load_state_dict(student.module.state_dict())
    # there is no backpropagation through the teacher, so no need for gradients
    for p in teacher.parameters():
        p.requires_grad = False
    print(f"Student and Teacher are built: they are both {args.arch} network.")

    # ============ preparing loss ... ============
    dino_loss = DINOLoss(
        args.out_dim,
        args.local_crops_number +
        2,  # total number of crops = 2 global crops + local_crops_number
        args.warmup_teacher_temp,
        args.teacher_temp,
        args.warmup_teacher_temp_epochs,
        args.epochs,
    ).cuda()

    # ============ preparing optimizer ... ============
    params_groups = utils.get_params_groups(student)
    if args.optimizer == "adamw":
        optimizer = torch.optim.AdamW(params_groups)  # to use with ViTs
    elif args.optimizer == "sgd":
        optimizer = torch.optim.SGD(params_groups, lr=0,
                                    momentum=0.9)  # lr is set by scheduler
    elif args.optimizer == "lars":
        optimizer = utils.LARS(
            params_groups)  # to use with convnet and large batches
    # for mixed precision training
    fp16_scaler = None
    if args.use_fp16:
        fp16_scaler = torch.cuda.amp.GradScaler()

    # ============ init schedulers ... ============
    lr_schedule = utils.cosine_scheduler(
        args.lr * (args.batch_size_per_gpu * utils.get_world_size()) /
        256.,  # linear scaling rule
        args.min_lr,
        args.epochs,
        len(data_loader),
        warmup_epochs=args.warmup_epochs,
    )
    wd_schedule = utils.cosine_scheduler(
        args.weight_decay,
        args.weight_decay_end,
        args.epochs,
        len(data_loader),
    )
    # momentum parameter is increased to 1. during training with a cosine schedule
    momentum_schedule = utils.cosine_scheduler(args.momentum_teacher, 1,
                                               args.epochs, len(data_loader))
    print(f"Loss, optimizer and schedulers ready.")

    # ============ optionally resume training ... ============
    to_restore = {"epoch": 0}
    utils.restart_from_checkpoint(
        os.path.join(args.output_dir, "checkpoint.pth"),
        run_variables=to_restore,
        student=student,
        teacher=teacher,
        optimizer=optimizer,
        fp16_scaler=fp16_scaler,
        dino_loss=dino_loss,
    )
    start_epoch = to_restore["epoch"]

    start_time = time.time()
    print("Starting DINO training !")
    for epoch in range(start_epoch, args.epochs):
        data_loader.sampler.set_epoch(epoch)

        # ============ training one epoch of DINO ... ============
        train_stats = train_one_epoch(student, teacher, teacher_without_ddp,
                                      dino_loss, data_loader, optimizer,
                                      lr_schedule, wd_schedule,
                                      momentum_schedule, epoch, fp16_scaler,
                                      args)

        # ============ writing logs ... ============
        save_dict = {
            'student': student.state_dict(),
            'teacher': teacher.state_dict(),
            'optimizer': optimizer.state_dict(),
            'epoch': epoch + 1,
            'args': args,
            'dino_loss': dino_loss.state_dict(),
        }
        if fp16_scaler is not None:
            save_dict['fp16_scaler'] = fp16_scaler.state_dict()
        utils.save_on_master(save_dict,
                             os.path.join(args.output_dir, 'checkpoint.pth'))
        if args.saveckp_freq and epoch % args.saveckp_freq == 0:
            utils.save_on_master(
                save_dict,
                os.path.join(args.output_dir, f'checkpoint{epoch:04}.pth'))
        log_stats = {
            **{f'train_{k}': v
               for k, v in train_stats.items()}, 'epoch': epoch
        }
        if utils.is_main_process():
            with (Path(args.output_dir) / "log.txt").open("a") as f:
                f.write(json.dumps(log_stats) + "\n")
    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
Exemple #11
0
def main(args):
    if args.prototype and prototype is None:
        raise ImportError(
            "The prototype module couldn't be found. Please install the latest torchvision nightly."
        )
    if not args.prototype and args.weights:
        raise ValueError(
            "The weights parameter works only in prototype mode. Please pass the --prototype argument."
        )
    if args.output_dir:
        utils.mkdir(args.output_dir)

    utils.init_distributed_mode(args)
    print(args)

    device = torch.device(args.device)

    # Data loading code
    print("Loading data")

    dataset, num_classes = get_dataset(args.dataset, "train",
                                       get_transform(True, args),
                                       args.data_path)
    dataset_test, _ = get_dataset(args.dataset, "val",
                                  get_transform(False, args), args.data_path)

    print("Creating data loaders")
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            dataset)
        test_sampler = torch.utils.data.distributed.DistributedSampler(
            dataset_test)
    else:
        train_sampler = torch.utils.data.RandomSampler(dataset)
        test_sampler = torch.utils.data.SequentialSampler(dataset_test)

    if args.aspect_ratio_group_factor >= 0:
        group_ids = create_aspect_ratio_groups(
            dataset, k=args.aspect_ratio_group_factor)
        train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids,
                                                  args.batch_size)
    else:
        train_batch_sampler = torch.utils.data.BatchSampler(train_sampler,
                                                            args.batch_size,
                                                            drop_last=True)

    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_sampler=train_batch_sampler,
        num_workers=args.workers,
        collate_fn=utils.collate_fn)

    data_loader_test = torch.utils.data.DataLoader(dataset_test,
                                                   batch_size=1,
                                                   sampler=test_sampler,
                                                   num_workers=args.workers,
                                                   collate_fn=utils.collate_fn)

    print("Creating model")
    kwargs = {"trainable_backbone_layers": args.trainable_backbone_layers}
    if "rcnn" in args.model:
        if args.rpn_score_thresh is not None:
            kwargs["rpn_score_thresh"] = args.rpn_score_thresh
    if not args.prototype:
        model = torchvision.models.detection.__dict__[args.model](
            pretrained=args.pretrained, num_classes=num_classes, **kwargs)
    else:
        model = prototype.models.detection.__dict__[args.model](
            weights=args.weights, num_classes=num_classes, **kwargs)
    model.to(device)
    if args.distributed and args.sync_bn:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpu])
        model_without_ddp = model.module

    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.SGD(params,
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    scaler = torch.cuda.amp.GradScaler() if args.amp else None

    args.lr_scheduler = args.lr_scheduler.lower()
    if args.lr_scheduler == "multisteplr":
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=args.lr_steps, gamma=args.lr_gamma)
    elif args.lr_scheduler == "cosineannealinglr":
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=args.epochs)
    else:
        raise RuntimeError(
            f"Invalid lr scheduler '{args.lr_scheduler}'. Only MultiStepLR and CosineAnnealingLR are supported."
        )

    if args.resume:
        checkpoint = torch.load(args.resume, map_location="cpu")
        model_without_ddp.load_state_dict(checkpoint["model"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
        args.start_epoch = checkpoint["epoch"] + 1
        if args.amp:
            scaler.load_state_dict(checkpoint["scaler"])

    if args.test_only:
        evaluate(model, data_loader_test, device=device)
        return

    print("Start training")
    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        train_one_epoch(model, optimizer, data_loader, device, epoch,
                        args.print_freq, scaler)
        lr_scheduler.step()
        if args.output_dir:
            checkpoint = {
                "model": model_without_ddp.state_dict(),
                "optimizer": optimizer.state_dict(),
                "lr_scheduler": lr_scheduler.state_dict(),
                "args": args,
                "epoch": epoch,
            }
            if args.amp:
                checkpoint["scaler"] = scaler.state_dict()
            utils.save_on_master(
                checkpoint, os.path.join(args.output_dir,
                                         f"model_{epoch}.pth"))
            utils.save_on_master(
                checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))

        # evaluate after every epoch
        evaluate(model, data_loader_test, device=device)

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print(f"Training time {total_time_str}")
Exemple #12
0
def main():
    args = get_args()
    if args.output_dir:
        utils.mkdir(args.output_dir)
    utils.init_distributed_mode(args)

    # Data loading
    print("Loading data")
    dataset, num_classes = get_dataset(args.dataset, "train",
                                       get_transform(train=True))
    dataset_test, _ = get_dataset(args.dataset, "val",
                                  get_transform(train=False))

    print("Creating data loaders")
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            dataset)
        test_sampler = torch.utils.data.distributed.DistributedSampler(
            dataset_test)
    else:
        train_sampler = torch.utils.data.RandomSampler(dataset)
        test_sampler = torch.utils.data.SequentialSampler(dataset_test)

    if args.aspect_ratio_group_factor >= 0:
        group_ids = create_aspect_ratio_groups(
            dataset, k=args.aspect_ratio_group_factor)
        train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids,
                                                  args.b)
    else:
        train_batch_sampler = torch.utils.data.BatchSampler(train_sampler,
                                                            args.b,
                                                            drop_last=True)

    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_sampler=train_batch_sampler,
        num_workers=args.workers,
        collate_fn=utils.collate_fn)

    data_loader_test = torch.utils.data.DataLoader(dataset_test,
                                                   batch_size=args.b,
                                                   sampler=test_sampler,
                                                   num_workers=args.workers,
                                                   collate_fn=utils.collate_fn)

    # Model creating
    print("Creating model")
    # model = models.__dict__[args.model](num_classes=num_classes, pretrained=args.pretrained)
    model = torchvision.models.detection.__dict__[args.model](
        num_classes=num_classes, pretrained=args.pretrained)

    device = torch.device(args.device)
    model.to(device)

    # Distribute
    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpu])
        model_without_ddp = model.module

    # Parallel
    if args.parallel:
        print('Training parallel')
        model = torch.nn.DataParallel(model, device_ids=[args.gpu]).cuda()
        model_without_ddp = model.module

    # Optimizer
    params = [p for p in model.parameters() if p.requires_grad]

    optimizer = torch.optim.SGD(params,
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=args.lr_steps, gamma=args.lr_gamma)

    # Resume training
    if args.resume:
        print('Resume training')
        checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])

    if args.test_only:
        evaluate(model, data_loader_test, device=device)
        return

    # Training
    print('Start training')
    start_time = time.time()
    for epoch in range(args.epochs):
        train_one_epoch(model, optimizer, data_loader, device, epoch,
                        args.print_freq)
        lr_scheduler.step()
        if args.output_dir:
            utils.save_on_master(
                {
                    'model': model_without_ddp.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'lr_scheduler': lr_scheduler.state_dict(),
                    'args': args
                }, os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))

        # evaluate after every epoch
        evaluate(model, data_loader_test, device=device)

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
Exemple #13
0
def train_one_epoch(
        model,
        arch,
        optimizer,
        lr_scheduler,
        data_loader,
        device,
        epoch,
        print_freq,
        ngpus_per_node,
        model_without_ddp,
        args
    ):
    model.train()
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
    # header = "Epoch: [{}]".format(epoch)

    for images, targets in metric_logger.log_every(
            iterable=data_loader,
            print_freq=print_freq,
            # header=header,
            iter_num=args.iter_num
        ):

        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        """
        [{"boxes": tensor([], device="cuda:0"), "labels": tensor([], device="cuda:0", dtype=torch.int64), "masks": tensor([], device="cuda:0", dtype=torch.uint8), "iscrowd": tensor([], device="cuda:0", dtype=torch.int64)}]
        """

        try:
            loss_dict = model(images, targets) 
            losses = sum(loss for loss in loss_dict.values())

            # reduce losses over all GPUs for logging purposes
            loss_dict_reduced = utils.reduce_dict(loss_dict)
            losses_reduced = sum(loss for loss in loss_dict_reduced.values())
            loss_value = losses_reduced.item()

            if not math.isfinite(loss_value):
                logger.fatal("Loss is {}, stopping training".format(loss_value))
                logger.fatal(loss_dict_reduced)
                sys.exit(1)

            optimizer.zero_grad()
            losses.backward()
            optimizer.step()
            lr_scheduler.step()

            metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
            metric_logger.update(lr=optimizer.param_groups[0]["lr"])

        except Exception as e:
            logger.warning(e, exc_info=True)
            # logger.info("print target for debug")
            # print(targets)

        args.iter_num += 1

        # save checkpoint here
        if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0):
            if args.iter_num % 1000 == 0:
                utils.save_on_master({
                        "model": model_without_ddp.state_dict(),
                        "optimizer": optimizer.state_dict(),
                        "lr_scheduler": lr_scheduler.state_dict(),
                        "epoch": epoch,
                        "iter_num": args.iter_num,
                        "args": args,
                    },
                    "{}/{}_{}.pth".format(checkpoint_dir, arch.__name__, args.iter_num)
                )

                os.makedirs("{}/debug_image/".format(checkpoint_dir), exist_ok=True)

                if args.iter_num < 5000:
                    continue

                model.eval()

                from barez import overlay_ann	
                debug_image = None
                debug_image_list = []
                cnt = 0
                for image_path in glob.glob("./table_test/*"):
                    cnt += 1
                    image_name = os.path.basename(image_path)
                    # print(image_name)
                    image = cv2.imread(image_path)
                    rat = 1300 / image.shape[0]
                    image = cv2.resize(image, None, fx=rat, fy=rat)

                    transform = transforms.Compose([transforms.ToTensor()])
                    image = transform(image)

                    # put the model in evaluation mode
                    with torch.no_grad():
                        tensor = [image.to(device)]
                        prediction = model(tensor)
                        
                    image = torch.squeeze(image, 0).permute(1, 2, 0).mul(255).numpy().astype(np.uint8)

                    for pred in prediction:
                        for idx, mask in enumerate(pred['masks']):
                            if pred['scores'][idx].item() < 0.5:
                                continue
                        
                            m =  mask[0].mul(255).byte().cpu().numpy()
                            box = list(map(int, pred["boxes"][idx].tolist())) 
                            score = pred["scores"][idx].item()
                            image = overlay_ann(image, m, box, "", score)

                    if debug_image is None:
                        debug_image = image
                    else:
                        debug_image = np.concatenate((debug_image, image), axis=1)

                    if cnt == 10:
                        cnt = 0
                        debug_image_list.append(debug_image)
                        debug_image = None
                    
                avg_length = np.mean([i.shape[1] for i in debug_image_list])

                
                di = None

                
                for debug_image in debug_image_list:
                    rat = avg_length / debug_image.shape[1]
                    debug_image = cv2.resize(debug_image, None, fx=rat, fy=rat)

                    if di is None:
                        di = debug_image
                    else:
                        
                        di = np.concatenate((di, debug_image), axis=0)
            

                di = cv2.resize(di, None, fx=0.4, fy=0.4)
                cv2.imwrite("{}/debug_image/{}.jpg".format(checkpoint_dir, args.iter_num), di)

                model.train()

        # hard stop
        if args.iter_num == 50000:
            logger.info("ITER NUM == 50k, training successfully!")
            raise SystemExit
Exemple #14
0
def main(args):
    utils.init_distributed_mode(args)

    # disable any harsh augmentation in case of Self-supervise training
    if args.training_mode == 'SSL':
        print("NOTE: Smoothing, Mixup, CutMix, and AutoAugment will be disabled in case of Self-supervise training")
        args.smoothing = args.reprob = args.reprob = args.recount = args.mixup = args.cutmix = 0.0
        args.aa = ''

        if args.SiT_LinearEvaluation == 1:
            print("Warning: Linear Evaluation should be set to 0 during SSL training - changing SiT_LinearEvaluation to 0")
            args.SiT_LinearEvaluation = 0
        
    utils.print_args(args)

    device = torch.device(args.device)
    seed = args.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    cudnn.benchmark = True

    print("Loading dataset ....")
    dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)   
    dataset_val, _ = build_dataset(is_train=False, args=args)
    

    num_tasks = utils.get_world_size()
    global_rank = utils.get_rank()
    if args.repeated_aug:
        sampler_train = RASampler(dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True)
    else:
        sampler_train = torch.utils.data.DistributedSampler(dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True)
    
    sampler_val = torch.utils.data.SequentialSampler(dataset_val)


    data_loader_train = torch.utils.data.DataLoader(dataset_train, sampler=sampler_train,
        batch_size=args.batch_size, num_workers=args.num_workers,
        pin_memory=args.pin_mem, drop_last=True, collate_fn=collate_fn)

    data_loader_val = torch.utils.data.DataLoader(dataset_val, sampler=sampler_val,
        batch_size=int(1.5 * args.batch_size), num_workers=args.num_workers,
        pin_memory=args.pin_mem, drop_last=False, collate_fn=collate_fn)

    mixup_fn = None
    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
    if mixup_active:
        mixup_fn = Mixup(
            mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
            prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
            label_smoothing=args.smoothing, num_classes=args.nb_classes)

    print(f"Creating model: {args.model}")
    model = create_model(
        args.model, pretrained=False, num_classes=args.nb_classes,
        drop_rate=args.drop, drop_path_rate=args.drop_path, representation_size=args.representation_size,
        drop_block_rate=None, training_mode=args.training_mode)

    if args.finetune:
        checkpoint = torch.load(args.finetune, map_location='cpu')

        checkpoint_model = checkpoint['model']
        state_dict = model.state_dict()
        for k in ['rot_head.weight', 'rot_head.bias', 'contrastive_head.weight', 'contrastive_head.bias']:
            if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
                print(f"Removing key {k} from pretrained checkpoint")
                del checkpoint_model[k]

        # interpolate position embedding
        pos_embed_checkpoint = checkpoint_model['pos_embed']
        embedding_size = pos_embed_checkpoint.shape[-1]
        num_patches = model.patch_embed.num_patches
        num_extra_tokens = model.pos_embed.shape[-2] - num_patches
        orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
        new_size = int(num_patches ** 0.5)
        extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
        pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
        pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
        pos_tokens = torch.nn.functional.interpolate(
            pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
        pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
        new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
        checkpoint_model['pos_embed'] = new_pos_embed

        model.load_state_dict(checkpoint_model, strict=False)

    model.to(device)

    # Freeze the backbone in case of linear evaluation
    if args.SiT_LinearEvaluation == 1:
        requires_grad(model, False)
        
        model.rot_head.weight.requires_grad = True
        model.rot_head.bias.requires_grad = True
        
        model.contrastive_head.weight.requires_grad = True
        model.contrastive_head.bias.requires_grad = True
        
        if args.representation_size is not None:
            model.pre_logits_rot.fc.weight.requires_grad = True
            model.pre_logits_rot.fc.bias.requires_grad = True
            
            model.pre_logits_contrastive.fc.weight.requires_grad = True
            model.pre_logits_contrastive.fc.bias.requires_grad = True            


    model_ema = None
    if args.model_ema:
        model_ema = ModelEma(model, decay=args.model_ema_decay,
            device='cpu' if args.model_ema_force_cpu else '', resume='')

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module
        
    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print('number of params:', n_parameters)

    linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0
    args.lr = linear_scaled_lr
    optimizer = create_optimizer(args, model_without_ddp)
    loss_scaler = NativeScaler()

    lr_scheduler, _ = create_scheduler(args, optimizer)

    if args.training_mode == 'SSL':
        criterion = MTL_loss(args.device, args.batch_size)
    elif args.training_mode == 'finetune' and args.mixup > 0.:
        criterion = SoftTargetCrossEntropy()
    else:
        criterion = torch.nn.CrossEntropyLoss()



    output_dir = Path(args.output_dir)
    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer'])
            lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
            args.start_epoch = checkpoint['epoch'] + 1
            if args.model_ema:
                utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema'])
            if 'scaler' in checkpoint:
                loss_scaler.load_state_dict(checkpoint['scaler'])

    if args.eval:
        test_stats = evaluate_SSL(data_loader_val, model, device)
        print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
        return

    print(f"Start training for {args.epochs} epochs")
    start_time = time.time()
    max_accuracy = 0.0
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            data_loader_train.sampler.set_epoch(epoch)

        if args.training_mode == 'SSL':
            train_stats = train_SSL(
                model, criterion, data_loader_train, optimizer, device, epoch, loss_scaler,
                args.clip_grad, model_ema, mixup_fn)
        else:
            train_stats = train_finetune(
                model, criterion, data_loader_train, optimizer, device, epoch, loss_scaler,
                args.clip_grad, model_ema, mixup_fn)
            
        lr_scheduler.step(epoch)
            
        if epoch%args.validate_every == 0:
            if args.output_dir:
                checkpoint_paths = [output_dir / 'checkpoint.pth']
                for checkpoint_path in checkpoint_paths:
                    utils.save_on_master({
                        'model': model_without_ddp.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'lr_scheduler': lr_scheduler.state_dict(),
                        'epoch': epoch,
                        'model_ema': get_state_dict(model_ema),
                        'scaler': loss_scaler.state_dict(),
                        'args': args,
                    }, checkpoint_path)
    
            if args.training_mode == 'SSL':
                test_stats = evaluate_SSL(data_loader_val, model, device, epoch, args.output_dir)
            else:
                test_stats = evaluate_finetune(data_loader_val, model, device)

                print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
                max_accuracy = max(max_accuracy, test_stats["acc1"])
                print(f'Max accuracy: {max_accuracy:.2f}%')

        log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
                     **{f'test_{k}': v for k, v in test_stats.items()},
                     'epoch': epoch,
                     'n_parameters': n_parameters}

        if args.output_dir and utils.is_main_process():
            with (output_dir / "log.txt").open("a") as f:
                f.write(json.dumps(log_stats) + "\n")

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
Exemple #15
0
def main(args):
    if args.prototype and prototype is None:
        raise ImportError(
            "The prototype module couldn't be found. Please install the latest torchvision nightly."
        )
    if not args.prototype and args.weights:
        raise ValueError(
            "The weights parameter works only in prototype mode. Please pass the --prototype argument."
        )
    if args.output_dir:
        utils.mkdir(args.output_dir)

    utils.init_distributed_mode(args)
    print(args)
    print("torch version: ", torch.__version__)
    print("torchvision version: ", torchvision.__version__)

    device = torch.device(args.device)

    torch.backends.cudnn.benchmark = True

    # Data loading code
    print("Loading data")
    traindir = os.path.join(args.data_path, args.train_dir)
    valdir = os.path.join(args.data_path, args.val_dir)

    print("Loading training data")
    st = time.time()
    cache_path = _get_cache_path(traindir)
    transform_train = presets.VideoClassificationPresetTrain((128, 171),
                                                             (112, 112))

    if args.cache_dataset and os.path.exists(cache_path):
        print(f"Loading dataset_train from {cache_path}")
        dataset, _ = torch.load(cache_path)
        dataset.transform = transform_train
    else:
        if args.distributed:
            print(
                "It is recommended to pre-compute the dataset cache on a single-gpu first, as it will be faster"
            )
        dataset = torchvision.datasets.Kinetics400(
            traindir,
            frames_per_clip=args.clip_len,
            step_between_clips=1,
            transform=transform_train,
            frame_rate=15,
            extensions=(
                "avi",
                "mp4",
            ),
        )
        if args.cache_dataset:
            print(f"Saving dataset_train to {cache_path}")
            utils.mkdir(os.path.dirname(cache_path))
            utils.save_on_master((dataset, traindir), cache_path)

    print("Took", time.time() - st)

    print("Loading validation data")
    cache_path = _get_cache_path(valdir)

    if not args.prototype:
        transform_test = presets.VideoClassificationPresetEval(
            resize_size=(128, 171), crop_size=(112, 112))
    else:
        if args.weights:
            weights = prototype.models.get_weight(args.weights)
            transform_test = weights.transforms()
        else:
            transform_test = prototype.transforms.Kinect400Eval(
                crop_size=(112, 112), resize_size=(128, 171))

    if args.cache_dataset and os.path.exists(cache_path):
        print(f"Loading dataset_test from {cache_path}")
        dataset_test, _ = torch.load(cache_path)
        dataset_test.transform = transform_test
    else:
        if args.distributed:
            print(
                "It is recommended to pre-compute the dataset cache on a single-gpu first, as it will be faster"
            )
        dataset_test = torchvision.datasets.Kinetics400(
            valdir,
            frames_per_clip=args.clip_len,
            step_between_clips=1,
            transform=transform_test,
            frame_rate=15,
            extensions=(
                "avi",
                "mp4",
            ),
        )
        if args.cache_dataset:
            print(f"Saving dataset_test to {cache_path}")
            utils.mkdir(os.path.dirname(cache_path))
            utils.save_on_master((dataset_test, valdir), cache_path)

    print("Creating data loaders")
    train_sampler = RandomClipSampler(dataset.video_clips,
                                      args.clips_per_video)
    test_sampler = UniformClipSampler(dataset_test.video_clips,
                                      args.clips_per_video)
    if args.distributed:
        train_sampler = DistributedSampler(train_sampler)
        test_sampler = DistributedSampler(test_sampler)

    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.batch_size,
        sampler=train_sampler,
        num_workers=args.workers,
        pin_memory=True,
        collate_fn=collate_fn,
    )

    data_loader_test = torch.utils.data.DataLoader(
        dataset_test,
        batch_size=args.batch_size,
        sampler=test_sampler,
        num_workers=args.workers,
        pin_memory=True,
        collate_fn=collate_fn,
    )

    print("Creating model")
    if not args.prototype:
        model = torchvision.models.video.__dict__[args.model](
            pretrained=args.pretrained)
    else:
        model = prototype.models.video.__dict__[args.model](
            weights=args.weights)
    model.to(device)
    if args.distributed and args.sync_bn:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    criterion = nn.CrossEntropyLoss()

    lr = args.lr * args.world_size
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    scaler = torch.cuda.amp.GradScaler() if args.amp else None

    # convert scheduler to be per iteration, not per epoch, for warmup that lasts
    # between different epochs
    iters_per_epoch = len(data_loader)
    lr_milestones = [
        iters_per_epoch * (m - args.lr_warmup_epochs)
        for m in args.lr_milestones
    ]
    main_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=lr_milestones, gamma=args.lr_gamma)

    if args.lr_warmup_epochs > 0:
        warmup_iters = iters_per_epoch * args.lr_warmup_epochs
        args.lr_warmup_method = args.lr_warmup_method.lower()
        if args.lr_warmup_method == "linear":
            warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
                optimizer,
                start_factor=args.lr_warmup_decay,
                total_iters=warmup_iters)
        elif args.lr_warmup_method == "constant":
            warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(
                optimizer,
                factor=args.lr_warmup_decay,
                total_iters=warmup_iters)
        else:
            raise RuntimeError(
                f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant are supported."
            )

        lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
            optimizer,
            schedulers=[warmup_lr_scheduler, main_lr_scheduler],
            milestones=[warmup_iters])
    else:
        lr_scheduler = main_lr_scheduler

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpu])
        model_without_ddp = model.module

    if args.resume:
        checkpoint = torch.load(args.resume, map_location="cpu")
        model_without_ddp.load_state_dict(checkpoint["model"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
        args.start_epoch = checkpoint["epoch"] + 1
        if args.amp:
            scaler.load_state_dict(checkpoint["scaler"])

    if args.test_only:
        evaluate(model, criterion, data_loader_test, device=device)
        return

    print("Start training")
    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader,
                        device, epoch, args.print_freq, scaler)
        evaluate(model, criterion, data_loader_test, device=device)
        if args.output_dir:
            checkpoint = {
                "model": model_without_ddp.state_dict(),
                "optimizer": optimizer.state_dict(),
                "lr_scheduler": lr_scheduler.state_dict(),
                "epoch": epoch,
                "args": args,
            }
            if args.amp:
                checkpoint["scaler"] = scaler.state_dict()
            utils.save_on_master(
                checkpoint, os.path.join(args.output_dir,
                                         f"model_{epoch}.pth"))
            utils.save_on_master(
                checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print(f"Training time {total_time_str}")
def main(args):

    if args.output_dir:
        utils.mkdir(args.output_dir)

    print(args)
    print("torch version: ", torch.__version__)
    print("torchvision version: ", torchvision.__version__)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    device = torch.device('cuda')

    # Data loading code
    print("Loading data")

    st = time.time()

    dataset = MSRAction3D(root=args.data_path,
                          frames_per_clip=args.clip_len,
                          step_between_clips=1,
                          num_points=args.num_points,
                          train=True)

    dataset_test = MSRAction3D(root=args.data_path,
                               frames_per_clip=args.clip_len,
                               step_between_clips=1,
                               num_points=args.num_points,
                               train=False)

    print("Creating data loaders")

    data_loader = torch.utils.data.DataLoader(dataset,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              num_workers=args.workers,
                                              pin_memory=True)

    data_loader_test = torch.utils.data.DataLoader(dataset_test,
                                                   batch_size=args.batch_size,
                                                   num_workers=args.workers,
                                                   pin_memory=True)

    print("Creating model")
    Model = getattr(Models, args.model)
    model = Model(radius=args.radius,
                  nsamples=args.nsamples,
                  num_classes=dataset.num_classes)
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
    model.to(device)

    criterion = nn.CrossEntropyLoss()

    lr = args.lr
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # convert scheduler to be per iteration, not per epoch, for warmup that lasts
    # between different epochs
    warmup_iters = args.lr_warmup_epochs * len(data_loader)
    lr_milestones = [len(data_loader) * m for m in args.lr_milestones]
    lr_scheduler = utils.WarmupMultiStepLR(optimizer,
                                           milestones=lr_milestones,
                                           gamma=args.lr_gamma,
                                           warmup_iters=warmup_iters,
                                           warmup_factor=1e-5)

    model_without_ddp = model

    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        args.start_epoch = checkpoint['epoch'] + 1

    print("Start training")
    start_time = time.time()
    acc = 0
    for epoch in range(args.start_epoch, args.epochs):
        train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader,
                        device, epoch, args.print_freq)

        acc = max(acc,
                  evaluate(model, criterion, data_loader_test, device=device))

        if args.output_dir:
            checkpoint = {
                'model': model_without_ddp.state_dict(),
                'optimizer': optimizer.state_dict(),
                'lr_scheduler': lr_scheduler.state_dict(),
                'epoch': epoch,
                'args': args
            }
            utils.save_on_master(
                checkpoint,
                os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
            utils.save_on_master(
                checkpoint, os.path.join(args.output_dir, 'checkpoint.pth'))

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
    print('Accuracy {}'.format(acc))
Exemple #17
0
def main(args):
    if args.apex and amp is None:
        raise RuntimeError(
            "Failed to import apex. Please install apex from https://www.github.com/nvidia/apex "
            "to enable mixed-precision training.")

    if args.output_dir:
        utils.mkdir(args.output_dir)

    utils.init_distributed_mode(args)
    print(args)

    device = torch.device(args.device)

    torch.backends.cudnn.benchmark = True

    train_dir = os.path.join(args.data_path, 'train')
    val_dir = os.path.join(args.data_path, 'val')
    dataset, dataset_test, train_sampler, test_sampler = load_data(
        train_dir, val_dir, args)

    collate_fn = None
    num_classes = len(dataset.classes)
    mixup_transforms = []
    if args.mixup_alpha > 0.0:
        mixup_transforms.append(
            transforms.RandomMixup(num_classes, p=1.0, alpha=args.mixup_alpha))
    if args.cutmix_alpha > 0.0:
        mixup_transforms.append(
            transforms.RandomCutmix(num_classes,
                                    p=1.0,
                                    alpha=args.cutmix_alpha))
    if mixup_transforms:
        mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms)
        collate_fn = lambda batch: mixupcutmix(*default_collate(batch)
                                               )  # noqa: E731
    data_loader = torch.utils.data.DataLoader(dataset,
                                              batch_size=args.batch_size,
                                              sampler=train_sampler,
                                              num_workers=args.workers,
                                              pin_memory=True,
                                              collate_fn=collate_fn)
    data_loader_test = torch.utils.data.DataLoader(dataset_test,
                                                   batch_size=args.batch_size,
                                                   sampler=test_sampler,
                                                   num_workers=args.workers,
                                                   pin_memory=True)

    print("Creating model")
    model = torchvision.models.__dict__[args.model](pretrained=args.pretrained,
                                                    num_classes=num_classes)
    model.to(device)

    if args.distributed and args.sync_bn:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)

    opt_name = args.opt.lower()
    if opt_name.startswith("sgd"):
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov="nesterov" in opt_name)
    elif opt_name == 'rmsprop':
        optimizer = torch.optim.RMSprop(model.parameters(),
                                        lr=args.lr,
                                        momentum=args.momentum,
                                        weight_decay=args.weight_decay,
                                        eps=0.0316,
                                        alpha=0.9)
    else:
        raise RuntimeError(
            "Invalid optimizer {}. Only SGD and RMSprop are supported.".format(
                args.opt))

    if args.apex:
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.apex_opt_level)

    args.lr_scheduler = args.lr_scheduler.lower()
    if args.lr_scheduler == 'steplr':
        main_lr_scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
    elif args.lr_scheduler == 'cosineannealinglr':
        main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=args.epochs - args.lr_warmup_epochs)
    elif args.lr_scheduler == 'exponentiallr':
        main_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
            optimizer, gamma=args.lr_gamma)
    else:
        raise RuntimeError(
            "Invalid lr scheduler '{}'. Only StepLR, CosineAnnealingLR and ExponentialLR "
            "are supported.".format(args.lr_scheduler))

    if args.lr_warmup_epochs > 0:
        if args.lr_warmup_method == 'linear':
            warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
                optimizer,
                start_factor=args.lr_warmup_decay,
                total_iters=args.lr_warmup_epochs)
        elif args.lr_warmup_method == 'constant':
            warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(
                optimizer,
                factor=args.lr_warmup_decay,
                total_iters=args.lr_warmup_epochs)
        else:
            raise RuntimeError(
                f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant "
                "are supported.")
        lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
            optimizer,
            schedulers=[warmup_lr_scheduler, main_lr_scheduler],
            milestones=[args.lr_warmup_epochs])
    else:
        lr_scheduler = main_lr_scheduler

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpu])
        model_without_ddp = model.module

    model_ema = None
    if args.model_ema:
        model_ema = utils.ExponentialMovingAverage(model_without_ddp,
                                                   device=device,
                                                   decay=args.model_ema_decay)

    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        args.start_epoch = checkpoint['epoch'] + 1
        if model_ema:
            model_ema.load_state_dict(checkpoint['model_ema'])

    if args.test_only:
        evaluate(model, criterion, data_loader_test, device=device)
        return

    print("Start training")
    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        train_one_epoch(model, criterion, optimizer, data_loader, device,
                        epoch, args.print_freq, args.apex, model_ema)
        lr_scheduler.step()
        evaluate(model, criterion, data_loader_test, device=device)
        if model_ema:
            evaluate(model_ema,
                     criterion,
                     data_loader_test,
                     device=device,
                     log_suffix='EMA')
        if args.output_dir:
            checkpoint = {
                'model': model_without_ddp.state_dict(),
                'optimizer': optimizer.state_dict(),
                'lr_scheduler': lr_scheduler.state_dict(),
                'epoch': epoch,
                'args': args
            }
            if model_ema:
                checkpoint['model_ema'] = model_ema.state_dict()
            utils.save_on_master(
                checkpoint,
                os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
            utils.save_on_master(
                checkpoint, os.path.join(args.output_dir, 'checkpoint.pth'))

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
Exemple #18
0
def load_data(traindir, valdir, args):
    # Data loading code
    print("Loading data")
    resize_size, crop_size = 256, 224
    interpolation = InterpolationMode.BILINEAR
    if args.model == 'inception_v3':
        resize_size, crop_size = 342, 299
    elif args.model.startswith('efficientnet_'):
        sizes = {
            'b0': (256, 224),
            'b1': (256, 240),
            'b2': (288, 288),
            'b3': (320, 300),
            'b4': (384, 380),
            'b5': (456, 456),
            'b6': (528, 528),
            'b7': (600, 600),
        }
        e_type = args.model.replace('efficientnet_', '')
        resize_size, crop_size = sizes[e_type]
        interpolation = InterpolationMode.BICUBIC

    print("Loading training data")
    st = time.time()
    cache_path = _get_cache_path(traindir)
    if args.cache_dataset and os.path.exists(cache_path):
        # Attention, as the transforms are also cached!
        print("Loading dataset_train from {}".format(cache_path))
        dataset, _ = torch.load(cache_path)
    else:
        auto_augment_policy = getattr(args, "auto_augment", None)
        random_erase_prob = getattr(args, "random_erase", 0.0)
        dataset = torchvision.datasets.ImageFolder(
            traindir,
            presets.ClassificationPresetTrain(
                crop_size=crop_size,
                auto_augment_policy=auto_augment_policy,
                random_erase_prob=random_erase_prob))
        if args.cache_dataset:
            print("Saving dataset_train to {}".format(cache_path))
            utils.mkdir(os.path.dirname(cache_path))
            utils.save_on_master((dataset, traindir), cache_path)
    print("Took", time.time() - st)

    print("Loading validation data")
    cache_path = _get_cache_path(valdir)
    if args.cache_dataset and os.path.exists(cache_path):
        # Attention, as the transforms are also cached!
        print("Loading dataset_test from {}".format(cache_path))
        dataset_test, _ = torch.load(cache_path)
    else:
        dataset_test = torchvision.datasets.ImageFolder(
            valdir,
            presets.ClassificationPresetEval(crop_size=crop_size,
                                             resize_size=resize_size,
                                             interpolation=interpolation))
        if args.cache_dataset:
            print("Saving dataset_test to {}".format(cache_path))
            utils.mkdir(os.path.dirname(cache_path))
            utils.save_on_master((dataset_test, valdir), cache_path)

    print("Creating data loaders")
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            dataset)
        test_sampler = torch.utils.data.distributed.DistributedSampler(
            dataset_test)
    else:
        train_sampler = torch.utils.data.RandomSampler(dataset)
        test_sampler = torch.utils.data.SequentialSampler(dataset_test)

    return dataset, dataset_test, train_sampler, test_sampler
def main(args, layer_train_para, layer_names, layer_kernel_inc, pattern):
    if args.apex:
        if sys.version_info < (3, 0):
            raise RuntimeError(
                "Apex currently only supports Python 3. Aborting.")
        if amp is None:
            raise RuntimeError(
                "Failed to import apex. Please install apex from https://www.github.com/nvidia/apex "
                "to enable mixed-precision training.")

    if args.output_dir:
        utils.mkdir(args.output_dir)

    utils.init_distributed_mode(args)
    print(args)

    device = torch.device(args.device)

    torch.backends.cudnn.benchmark = True

    train_dir = os.path.join(args.data_path, 'train')
    val_dir = os.path.join(args.data_path, 'val')
    dataset, dataset_test, train_sampler, test_sampler = load_data(
        train_dir, val_dir, args.cache_dataset, args.distributed)
    data_loader = torch.utils.data.DataLoader(dataset,
                                              batch_size=args.batch_size,
                                              sampler=train_sampler,
                                              num_workers=args.workers,
                                              pin_memory=True)

    data_loader_test = torch.utils.data.DataLoader(dataset_test,
                                                   batch_size=args.batch_size,
                                                   sampler=test_sampler,
                                                   num_workers=args.workers,
                                                   pin_memory=True)

    print("Creating model")
    model = torchvision.models.__dict__[args.model](pretrained=args.pretrained)

    # layer_train_para = [
    #     "layer1.0.conv1.weight",
    #     "layer1.0.bn1.weight",
    #     "layer1.0.bn1.bias",
    #     "layer1.0.conv2.weight",
    #     "layer1.0.bn2.weight",
    #     "layer1.0.bn2.bias",
    #     "layer1.1.conv1.weight",
    #     "layer1.1.bn1.weight",
    #     "layer1.1.bn1.bias",
    #     "layer1.1.conv2.weight",
    #     "layer1.1.bn2.weight",
    #     "layer1.1.bn2.bias",
    #     "layer2.0.conv2.weight",
    #     "layer2.0.bn2.weight",
    #     "layer2.0.bn2.bias",
    #     "layer2.0.conv1.weight",
    #     "layer2.0.bn1.weight",
    #     "layer2.0.bn1.bias",
    #     "layer2.0.downsample.0.weight",
    #     "layer2.0.downsample.1.weight",
    #     "layer2.0.downsample.1.bias"]
    #
    # layer_names = [
    #     "layer1.0.conv1",
    #     "layer1.0.conv2",
    #     "layer1.1.conv1",
    #     "layer1.1.conv2",
    #     "layer2.0.conv2",
    #     "layer2.1.conv1",
    #     "layer2.1.conv2"
    # ]
    #
    # layer_kernel_inc = [
    #     # "layer2.0.conv1",
    #     # "layer2.0.downsample.0"
    # ]
    #
    # pattern = {}
    # pattern[0] = torch.tensor([[0, 0, 0],
    #                            [1, 1, 1],
    #                            [1, 1, 1]], dtype=torch.float32)
    #
    # pattern[1] = torch.tensor([[1, 1, 1],
    #                            [1, 1, 1],
    #                            [0, 0, 0]], dtype=torch.float32)
    #
    # pattern[2] = torch.tensor([[1, 1, 0],
    #                            [1, 1, 0],
    #                            [1, 1, 0]], dtype=torch.float32)
    #
    # pattern[3] = torch.tensor([[0, 1, 1],
    #                            [0, 1, 1],
    #                            [0, 1, 1]], dtype=torch.float32)

    layers = {}
    ki_layers = {}
    # for layer_name, layer in model.named_modules():
    for layer_name, layer in model.named_modules():
        if isinstance(layer, nn.Conv2d):
            # if is_same(layer.kernel_size) == 3 and layer.in_channels == 512:
            # if is_same(layer.kernel_size) == 3:
            if layer_name in layer_names:
                # layer_names.append(layer_name)
                layers[layer_name] = layer
            if layer_name in layer_kernel_inc:
                ki_layers[layer_name] = layer

        # print(layer_name)
        # if is_same(layer.kernel_size) == 3 and layer.in_channels==512:
        #     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        #     mask = torch.tensor([[1, 1, 1], [1, 1, 0], [1, 0, 0]], dtype=torch.float32, device=device)
        #     ztNAS_add_kernel_mask(model, layer, layer_name, mask=mask)

    #model = modify_model(model)

    # for name, param in model.named_parameters():
    #     names = [n + "." for n in name.split(".")[:-1]]
    #     if "".join(names)[:-1] not in layer_names:
    #         param.requires_grad = False
    #     else:
    #         break

    for name, param in model.named_parameters():
        if name in layer_train_para:
            param.requires_grad = True
        else:
            param.requires_grad = False

    # for name, param in model.named_parameters():
    #     print(name, param.requires_grad, param.data.shape)

    # print(model)

    model.to(device)
    if args.distributed and args.sync_bn:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    criterion = nn.CrossEntropyLoss()

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    admm_optimizer = torch.optim.Adam(model.parameters(),
                                      lr=args.lr,
                                      eps=args.adam_epsilon)

    admm_re_train_optimizer = PruneAdam(model.named_parameters(),
                                        lr=args.lr,
                                        eps=args.adam_epsilon)

    if args.apex:
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.apex_opt_level)

    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                   step_size=args.lr_step_size,
                                                   gamma=args.lr_gamma)

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpu])
        model_without_ddp = model.module

    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        args.start_epoch = checkpoint['epoch'] + 1

    if args.test_only:
        # for name, param in model.named_parameters():
        #     print(name)git oull
        #     print(param)

        layer_pattern = utils.get_layers_pattern(model, layer_names, pattern,
                                                 device)
        utils.print_prune(model, layer_names, layer_pattern)

        for layer_name in layer_names:
            ztNAS_add_kernel_mask(model,
                                  layers[layer_name],
                                  layer_name,
                                  is_pattern=True,
                                  pattern=layer_pattern[layer_name].to(device))

        # print(model)
        model.to(device)
        evaluate(model, criterion, data_loader_test, device=device)

        # evaluate(model, criterion, data_loader_test, device=device)
        return

    if args.retrain_only:
        epoch = 999
        print("Start re-training")
        start_time = time.time()
        print("=" * 10, "Applying pruning model")
        layer_pattern = utils.get_layers_pattern(model, layer_names, pattern,
                                                 device)
        # utils.print_prune(model, layer_names, layer_pattern)

        for layer_name in layer_names:
            ztNAS_add_kernel_mask(model,
                                  layers[layer_name],
                                  layer_name,
                                  is_pattern=True,
                                  pattern=layer_pattern[layer_name].to(device))

        for layer_name in layer_kernel_inc:
            ztNAS_modify_kernel_shape(model, ki_layers[layer_name], layer_name,
                                      2)

        # print(model)
        model.to(device)
        # evaluate(model, criterion, data_loader_test, device=device)

        print("=" * 10, "Retrain")

        re_train_one_epoch(model, criterion, admm_re_train_optimizer,
                           data_loader, device, epoch, args.print_freq,
                           layer_names, layer_pattern, data_loader_test,
                           args.exploration, args.apex)

        acc1, acc5 = evaluate(model,
                              criterion,
                              data_loader_test,
                              device=device,
                              exploration=args.exploration)

        total_time = time.time() - start_time
        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
        print('Training time {}'.format(total_time_str))

        return acc1, acc5

    print("Start training")
    start_time = time.time()

    Z, U = utils.initialize_Z_and_U(model, layer_names)
    rho = args.rho
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)

        Z, U = train_one_epoch(model, criterion, admm_optimizer, data_loader,
                               device, epoch, args.print_freq, layer_names,
                               percent, pattern, Z, U, rho, args.apex)

        rho = rho * 10
        lr_scheduler.step()

        if args.output_dir:
            checkpoint = {
                'model': model_without_ddp.state_dict(),
                'optimizer': optimizer.state_dict(),
                'lr_scheduler': lr_scheduler.state_dict(),
                'epoch': epoch,
                'args': args
            }
            utils.save_on_master(
                checkpoint,
                os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
            utils.save_on_master(
                checkpoint, os.path.join(args.output_dir, 'checkpoint.pth'))

        evaluate(model, criterion, data_loader_test, device=device)

    print("=" * 10, "Applying pruning model")
    layer_pattern = utils.get_layers_pattern(model, layer_names, pattern,
                                             device)
    # utils.print_prune(model, layer_names, layer_pattern)

    for layer_name in layer_names:
        ztNAS_add_kernel_mask(model,
                              layers[layer_name],
                              layer_name,
                              is_pattern=True,
                              pattern=layer_pattern[layer_name].to(device))

    # print(model)
    model.to(device)
    # evaluate(model, criterion, data_loader_test, device=device)

    print("=" * 10, "Retrain")

    re_train_one_epoch(model, criterion, admm_re_train_optimizer, data_loader,
                       device, epoch, args.print_freq, layer_names,
                       layer_pattern, data_loader_test, args.exploration,
                       args.apex)

    evaluate(model, criterion, data_loader_test, device=device)

    if args.output_dir:
        checkpoint = {
            'model': model_without_ddp.state_dict(),
            'optimizer': optimizer.state_dict(),
            'lr_scheduler': lr_scheduler.state_dict(),
            'epoch': epoch + 1,
            'args': args
        }
        utils.save_on_master(
            checkpoint,
            os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
        utils.save_on_master(checkpoint,
                             os.path.join(args.output_dir, 'checkpoint.pth'))

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
Exemple #20
0
def main(args):

    utils.init_distributed_mode(args)
    update_config_from_file(args.cfg)

    print(args)
    args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    # random.seed(seed)
    cudnn.benchmark = True

    dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
    dataset_val, _ = build_dataset(is_train=False, args=args)

    if args.distributed:
        num_tasks = utils.get_world_size()
        global_rank = utils.get_rank()
        if args.repeated_aug:
            sampler_train = RASampler(dataset_train,
                                      num_replicas=num_tasks,
                                      rank=global_rank,
                                      shuffle=True)
        else:
            sampler_train = torch.utils.data.DistributedSampler(
                dataset_train,
                num_replicas=num_tasks,
                rank=global_rank,
                shuffle=True)
        if args.dist_eval:
            if len(dataset_val) % num_tasks != 0:
                print(
                    'Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
                    'This will slightly alter validation results as extra duplicate entries are added to achieve '
                    'equal num of samples per-process.')
            sampler_val = torch.utils.data.DistributedSampler(
                dataset_val,
                num_replicas=num_tasks,
                rank=global_rank,
                shuffle=False)
        else:
            sampler_val = torch.utils.data.SequentialSampler(dataset_val)
    else:
        sampler_val = torch.utils.data.SequentialSampler(dataset_val)
        sampler_train = torch.utils.data.RandomSampler(dataset_train)

    data_loader_train = torch.utils.data.DataLoader(
        dataset_train,
        sampler=sampler_train,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=True,
    )

    data_loader_val = torch.utils.data.DataLoader(dataset_val,
                                                  batch_size=int(
                                                      2 * args.batch_size),
                                                  sampler=sampler_val,
                                                  num_workers=args.num_workers,
                                                  pin_memory=args.pin_mem,
                                                  drop_last=False)

    mixup_fn = None
    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
    if mixup_active:
        mixup_fn = Mixup(mixup_alpha=args.mixup,
                         cutmix_alpha=args.cutmix,
                         cutmix_minmax=args.cutmix_minmax,
                         prob=args.mixup_prob,
                         switch_prob=args.mixup_switch_prob,
                         mode=args.mixup_mode,
                         label_smoothing=args.smoothing,
                         num_classes=args.nb_classes)

    print(f"Creating SuperVisionTransformer")
    print(cfg)
    model = Vision_TransformerSuper(
        img_size=args.input_size,
        patch_size=args.patch_size,
        embed_dim=cfg.SUPERNET.EMBED_DIM,
        depth=cfg.SUPERNET.DEPTH,
        num_heads=cfg.SUPERNET.NUM_HEADS,
        mlp_ratio=cfg.SUPERNET.MLP_RATIO,
        qkv_bias=True,
        drop_rate=args.drop,
        drop_path_rate=args.drop_path,
        gp=args.gp,
        num_classes=args.nb_classes,
        max_relative_position=args.max_relative_position,
        relative_position=args.relative_position,
        change_qkv=args.change_qkv,
        abs_pos=not args.no_abs_pos)

    choices = {
        'num_heads': cfg.SEARCH_SPACE.NUM_HEADS,
        'mlp_ratio': cfg.SEARCH_SPACE.MLP_RATIO,
        'embed_dim': cfg.SEARCH_SPACE.EMBED_DIM,
        'depth': cfg.SEARCH_SPACE.DEPTH
    }

    model.to(device)
    if args.teacher_model:
        teacher_model = create_model(
            args.teacher_model,
            pretrained=True,
            num_classes=args.nb_classes,
        )
        teacher_model.to(device)
        teacher_loss = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
    else:
        teacher_model = None
        teacher_loss = None

    model_ema = None

    model_without_ddp = model
    if args.distributed:

        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpu], find_unused_parameters=True)
        model_without_ddp = model.module

    n_parameters = sum(p.numel() for p in model.parameters()
                       if p.requires_grad)
    print('number of params:', n_parameters)

    linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size(
    ) / 512.0
    args.lr = linear_scaled_lr
    optimizer = create_optimizer(args, model_without_ddp)
    loss_scaler = NativeScaler()
    lr_scheduler, _ = create_scheduler(args, optimizer)

    # criterion = LabelSmoothingCrossEntropy()

    if args.mixup > 0.:
        # smoothing is handled with mixup label transform
        criterion = SoftTargetCrossEntropy()
    elif args.smoothing:
        criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
    else:
        criterion = torch.nn.CrossEntropyLoss()

    output_dir = Path(args.output_dir)

    if not output_dir.exists():
        output_dir.mkdir(parents=True)
    # save config for later experiments
    with open(output_dir / "config.yaml", 'w') as f:
        f.write(args_text)
    if args.resume:
        if args.resume.startswith('https'):
            checkpoint = torch.hub.load_state_dict_from_url(args.resume,
                                                            map_location='cpu',
                                                            check_hash=True)
        else:
            checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer'])
            lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
            args.start_epoch = checkpoint['epoch'] + 1
            if 'scaler' in checkpoint:
                loss_scaler.load_state_dict(checkpoint['scaler'])
            if args.model_ema:
                utils._load_checkpoint_for_ema(model_ema,
                                               checkpoint['model_ema'])

    retrain_config = None
    if args.mode == 'retrain' and "RETRAIN" in cfg:
        retrain_config = {
            'layer_num': cfg.RETRAIN.DEPTH,
            'embed_dim': [cfg.RETRAIN.EMBED_DIM] * cfg.RETRAIN.DEPTH,
            'num_heads': cfg.RETRAIN.NUM_HEADS,
            'mlp_ratio': cfg.RETRAIN.MLP_RATIO
        }
    if args.eval:
        print(retrain_config)
        test_stats = evaluate(data_loader_val,
                              model,
                              device,
                              mode=args.mode,
                              retrain_config=retrain_config)
        print(
            f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%"
        )
        return

    print("Start training")
    start_time = time.time()
    max_accuracy = 0.0

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            data_loader_train.sampler.set_epoch(epoch)

        train_stats = train_one_epoch(
            model,
            criterion,
            data_loader_train,
            optimizer,
            device,
            epoch,
            loss_scaler,
            args.clip_grad,
            model_ema,
            mixup_fn,
            amp=args.amp,
            teacher_model=teacher_model,
            teach_loss=teacher_loss,
            choices=choices,
            mode=args.mode,
            retrain_config=retrain_config,
        )

        lr_scheduler.step(epoch)
        if args.output_dir:
            checkpoint_paths = [output_dir / 'checkpoint.pth']
            for checkpoint_path in checkpoint_paths:
                utils.save_on_master(
                    {
                        'model': model_without_ddp.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'lr_scheduler': lr_scheduler.state_dict(),
                        'epoch': epoch,
                        # 'model_ema': get_state_dict(model_ema),
                        'scaler': loss_scaler.state_dict(),
                        'args': args,
                    },
                    checkpoint_path)

        test_stats = evaluate(data_loader_val,
                              model,
                              device,
                              amp=args.amp,
                              choices=choices,
                              mode=args.mode,
                              retrain_config=retrain_config)
        print(
            f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%"
        )
        max_accuracy = max(max_accuracy, test_stats["acc1"])
        print(f'Max accuracy: {max_accuracy:.2f}%')

        log_stats = {
            **{f'train_{k}': v
               for k, v in train_stats.items()},
            **{f'test_{k}': v
               for k, v in test_stats.items()}, 'epoch': epoch,
            'n_parameters': n_parameters
        }

        if args.output_dir and utils.is_main_process():
            with (output_dir / "log.txt").open("a") as f:
                f.write(json.dumps(log_stats) + "\n")

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
def main(args):
    if args.apex:
        if sys.version_info < (3, 0):
            raise RuntimeError(
                "Apex currently only supports Python 3. Aborting.")
        if amp is None:
            raise RuntimeError(
                "Failed to import apex. Please install apex from https://www.github.com/nvidia/apex "
                "to enable mixed-precision training.")

    if args.output_dir:
        utils.mkdir(args.output_dir)

    utils.init_distributed_mode(args)
    print(args)

    device = torch.device(args.device)

    torch.backends.cudnn.benchmark = True

    train_dir = os.path.join(args.data_path, 'train')
    val_dir = os.path.join(args.data_path, 'val')
    dataset, dataset_test, train_sampler, test_sampler = load_data(
        train_dir, val_dir, args.cache_dataset, args.distributed)
    data_loader = torch.utils.data.DataLoader(dataset,
                                              batch_size=args.batch_size,
                                              sampler=train_sampler,
                                              num_workers=args.workers,
                                              pin_memory=True)

    data_loader_test = torch.utils.data.DataLoader(dataset_test,
                                                   batch_size=args.batch_size,
                                                   sampler=test_sampler,
                                                   num_workers=args.workers,
                                                   pin_memory=True)

    print("Creating model")
    model = torchvision.models.__dict__[args.model](pretrained=args.pretrained)
    model.to(device)
    if args.distributed and args.sync_bn:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    criterion = nn.CrossEntropyLoss()

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    if args.apex:
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.apex_opt_level)

    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                   step_size=args.lr_step_size,
                                                   gamma=args.lr_gamma)

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpu])
        model_without_ddp = model.module

    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        args.start_epoch = checkpoint['epoch'] + 1

    if args.test_only:
        evaluate(model, criterion, data_loader_test, device=device)
        return

    print("Start training")
    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        train_one_epoch(model, criterion, optimizer, data_loader, device,
                        epoch, args.print_freq, args.apex)
        lr_scheduler.step()
        evaluate(model, criterion, data_loader_test, device=device)
        if args.output_dir:
            checkpoint = {
                'model': model_without_ddp.state_dict(),
                'optimizer': optimizer.state_dict(),
                'lr_scheduler': lr_scheduler.state_dict(),
                'epoch': epoch,
                'args': args
            }
            utils.save_on_master(
                checkpoint,
                os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
            utils.save_on_master(
                checkpoint, os.path.join(args.output_dir, 'checkpoint.pth'))

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
Exemple #22
0
def main(args):
    if args.is_hmp:
        from hmp import hmp
        hmp.convert(opt_level=args.hmp_opt_level,
                    bf16_file_path=args.hmp_bf16,
                    fp32_file_path=args.hmp_fp32,
                    isVerbose=args.hmp_verbose)

    if args.apex:
        if sys.version_info < (3, 0):
            raise RuntimeError(
                "Apex currently only supports Python 3. Aborting.")
        if amp is None:
            raise RuntimeError(
                "Failed to import apex. Please install apex from https://www.github.com/nvidia/apex "
                "to enable mixed-precision training.")

    if args.output_dir:
        utils.mkdir(args.output_dir)

    utils.init_distributed_mode(args)
    print(args)

    if args.device == 'habana':
        sys.path.append(
            os.path.realpath(
                os.path.join(os.path.dirname(os.path.realpath(__file__)),
                             "../../../common")))
        from library_loader import load_habana_module
        load_habana_module()

    torch.manual_seed(args.seed)

    if args.deterministic:
        seed = args.seed
        if args.device == 'cuda':
            torch.cuda.manual_seed(seed)
    else:
        seed = None

    device = torch.device(args.device)

    torch.backends.cudnn.benchmark = True

    # Limit the test(eval) phase batch size to a lower value to reduce overall device memory pressure
    test_batch_size = args.batch_size
    if args.batch_size > 32:
        test_batch_size = 32

    train_dir = os.path.join(args.data_path, 'train')
    val_dir = os.path.join(args.data_path, 'val')
    dataset, dataset_test, train_sampler, test_sampler = load_data(
        train_dir, val_dir, args.cache_dataset, args.distributed)
    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.batch_size,
        sampler=train_sampler,
        num_workers=args.workers,
        worker_init_fn=dl_worker_init_fn(seed),
        pin_memory=True,
        drop_last=True)

    data_loader_test = torch.utils.data.DataLoader(
        dataset_test,
        batch_size=test_batch_size,
        sampler=test_sampler,
        num_workers=args.workers,
        worker_init_fn=dl_worker_init_fn(seed),
        pin_memory=True,
        drop_last=True)

    print("Creating model")
    #model = torchvision.models.__dict__[args.model](pretrained=args.pretrained)
    #Instead of importing resnet model from the standard torchvision package,
    #import from a local copy. A local copy of resnet model file is used so that
    #modifications can be done to the resnet model if necessary.
    model = resnet_models.__dict__[args.model](pretrained=args.pretrained)

    model.to(device)
    if args.channels_last:
        if (device == torch.device('cuda')):
            print('Converting model to channels_last format on CUDA')
            model.to(memory_format=torch.channels_last)
        elif (args.device == 'habana'):
            print('Converting model params to channels_last format on Habana')
            #TODO:
            #model.to(device).to(memory_format=torch.channels_last)
            #The above model conversion doesn't change the model params
            #to channels_last for many components - e.g. convolution.
            #So we are forced to rearrange such tensors ourselves.

    if (args.device == 'habana'):
        permute_params(model, True)

    if args.distributed and args.sync_bn:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    criterion = nn.CrossEntropyLoss()

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    if args.apex:
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.apex_opt_level)

    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                   step_size=args.lr_step_size,
                                                   gamma=args.lr_gamma)

    model_for_eval = model
    if args.run_trace_mode:
        sample_trace_tensor = enable_tracing(device)

        if args.channels_last:
            sample_trace_tensor = sample_trace_tensor.contiguous(
                memory_format=torch.channels_last)
        # Create traced model for eval
        model.eval()
        model_for_eval = torch.jit.trace(model,
                                         sample_trace_tensor,
                                         check_trace=False)
        # Create traced model for train
        model.train()
        model = torch.jit.trace(model, sample_trace_tensor, check_trace=False)
        model_for_train = model

    # TBD: pass the right module for ddp
    model_without_ddp = model

    if args.distributed:
        if args.device == 'habana':
            model = torch.nn.parallel.DistributedDataParallel(
                model, bucket_cap_mb=100, broadcast_buffers=False)
        else:
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[args.gpu])
        model_without_ddp = model.module

    model_for_train = model

    if args.resume:
        if (args.device == 'habana'):
            permute_params(model_without_ddp, False)
        checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        if (args.device == 'habana'):
            permute_momentum(optimizer, True)
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        args.start_epoch = checkpoint['epoch'] + 1
        if (args.device == 'habana'):
            permute_params(model_without_ddp, True)

    if args.test_only:
        evaluate(model_for_eval,
                 criterion,
                 data_loader_test,
                 device=device,
                 print_freq=args.print_freq)
        return

    print("Start training")
    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)

        train_one_epoch(model_for_train,
                        criterion,
                        optimizer,
                        data_loader,
                        device,
                        epoch,
                        print_freq=args.print_freq,
                        apex=args.apex)
        lr_scheduler.step()
        evaluate(model_for_eval,
                 criterion,
                 data_loader_test,
                 device=device,
                 print_freq=args.print_freq)

        if (args.output_dir and args.save_checkpoint):
            if args.device == 'habana':
                permute_params(model_without_ddp, False)
                #Use this model only to copy the state_dict of the actual model
                copy_model = resnet_models.__dict__[args.model](
                    pretrained=args.pretrained)

                copy_model.load_state_dict(model_without_ddp.state_dict())
                permute_momentum(optimizer, False)
                for state in optimizer.state.values():
                    for k, v in state.items():
                        if isinstance(v, torch.Tensor):
                            state[k] = v.to('cpu')

                checkpoint = {
                    'model': copy_model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'lr_scheduler': lr_scheduler.state_dict(),
                    'epoch': epoch,
                    'args': args
                }
                utils.save_on_master(
                    checkpoint,
                    os.path.join(args.output_dir,
                                 'model_{}.pth'.format(epoch)))
                utils.save_on_master(
                    checkpoint, os.path.join(args.output_dir,
                                             'checkpoint.pth'))

                for state in optimizer.state.values():
                    for k, v in state.items():
                        if isinstance(v, torch.Tensor):
                            state[k] = v.to('habana')
                permute_params(model_without_ddp, True)

            else:
                checkpoint = {
                    'model': model_without_ddp.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'lr_scheduler': lr_scheduler.state_dict(),
                    'epoch': epoch,
                    'args': args
                }
                utils.save_on_master(
                    checkpoint,
                    os.path.join(args.output_dir,
                                 'model_{}.pth'.format(epoch)))
                utils.save_on_master(
                    checkpoint, os.path.join(args.output_dir,
                                             'checkpoint.pth'))

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
Exemple #23
0
def main(args):

    args.log_dir = save_path_formatter(args)
    if args.deconv:
        args.deconv = partial(FastDeconv,
                              bias=args.bias,
                              eps=args.eps,
                              n_iter=args.deconv_iter,
                              block=args.block,
                              sampling_stride=args.stride)

    if args.tensorboard:
        from torch.utils.tensorboard import SummaryWriter
        args.writer = SummaryWriter(args.log_dir, flush_secs=30)

    if args.output_dir:
        utils.mkdir(args.output_dir)

    utils.init_distributed_mode(args)
    print(args)

    device = torch.device(args.device)
    transform = get_transform(mode='train', base_size=args.base_size)
    dataset, num_classes = get_dataset(args.dataset,
                                       "train",
                                       transform=transform)

    transform = get_transform(mode='test', base_size=args.base_size)
    dataset_test, _ = get_dataset(args.dataset, "val", transform=transform)

    if args.dataset == 'cityscapes':
        args.colormap = np.asarray([
            dataset.classes[i].color
            for i in range(max(dataset.new_classes) + 1)
        ])
    else:
        args.colormap = create_mapillary_vistas_label_colormap()

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            dataset)
        test_sampler = torch.utils.data.distributed.DistributedSampler(
            dataset_test)
    else:
        train_sampler = torch.utils.data.RandomSampler(dataset)
        test_sampler = torch.utils.data.SequentialSampler(dataset_test)

    data_loader = torch.utils.data.DataLoader(dataset,
                                              batch_size=args.batch_size,
                                              sampler=train_sampler,
                                              num_workers=args.workers,
                                              collate_fn=utils.collate_fn,
                                              drop_last=True)

    data_loader_test = torch.utils.data.DataLoader(dataset_test,
                                                   batch_size=args.batch_size,
                                                   sampler=test_sampler,
                                                   num_workers=args.workers,
                                                   collate_fn=utils.collate_fn)

    #model = torchvision.models.segmentation.__dict__[args.model](num_classes=num_classes,aux_loss=args.aux_loss,pretrained=args.pretrained)
    model = models.segmentation.__dict__[args.model](
        num_classes=num_classes,
        aux_loss=args.aux_loss,
        pretrained=args.pretrained,
        deconv=args.deconv,
        pretrained_backbone=args.pretrained_backbone)

    model.to(device)

    if args.distributed:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        model.load_state_dict(checkpoint['model'])
        args.start_epoch = checkpoint['epoch']
        print("=> loaded checkpoint '{}' (epoch {})".format(
            args.resume, checkpoint['epoch']))
        del checkpoint

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpu])
        model_without_ddp = model.module
    elif args.device == 'cuda':
        model = torch.nn.DataParallel(model).cuda()

    if args.test_only:
        confmat = evaluate(model,
                           data_loader_test,
                           device=device,
                           num_classes=num_classes)
        print(confmat)
        return

    params_to_optimize = [
        {
            "params": [
                p for p in model_without_ddp.backbone.parameters()
                if p.requires_grad
            ]
        },
        {
            "params": [
                p for p in model_without_ddp.classifier.parameters()
                if p.requires_grad
            ]
        },
    ]
    if args.aux_loss:
        params = [
            p for p in model_without_ddp.aux_classifier.parameters()
            if p.requires_grad
        ]
        params_to_optimize.append({"params": params, "lr": args.lr * 10})
    optimizer = torch.optim.SGD(params_to_optimize,
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer, lambda x: (1 - x / (len(data_loader) * args.epochs))**0.9)

    if args.resume:
        total_steps = len(data_loader) * args.start_epoch
        global n_iter
        for i in range(total_steps):
            n_iter = n_iter + 1
            lr_scheduler.step()

    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler,
                        device, epoch, args.print_freq)

        if epoch == 0 or (epoch + 1) % args.eval_freq == 0:
            confmat = evaluate(model,
                               data_loader_test,
                               device=device,
                               num_classes=num_classes)

            utils.save_on_master(
                {
                    'model': model_without_ddp.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'epoch': epoch,
                    #'args': args
                },
                #os.path.join(args.log_dir, 'model_{}.pth'.format(epoch)))
                os.path.join(args.log_dir, 'model.pth'))

            print(confmat)

            acc_global, acc, iu = confmat.compute()
            acc_global = acc_global.item() * 100
            iu = iu.mean().item() * 100

            if args.tensorboard:
                args.writer.add_scalar('Acc/Test', acc_global, epoch + 1)
                args.writer.add_scalar('IOU/Test', iu, epoch + 1)

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))

    args.writer.close()
Exemple #24
0
def main(args):
    utils.init_distributed_mode(args)
    print(args)

    device = torch.device(args.device)

    # Data loading code
    print("Loading data")

    dataset, num_classes = get_dataset(args.dataset, "train", get_transform(train=True), args.data_path)
    dataset_test, _ = get_dataset(args.dataset, "val", get_transform(train=False), args.data_path)

    print("Creating data loaders")
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
        test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test)
    else:
        train_sampler = torch.utils.data.RandomSampler(dataset)
        test_sampler = torch.utils.data.SequentialSampler(dataset_test)

    if args.aspect_ratio_group_factor >= 0:
        group_ids = create_aspect_ratio_groups(dataset, k=args.aspect_ratio_group_factor)
        train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids, args.batch_size)
    else:
        train_batch_sampler = torch.utils.data.BatchSampler(
            train_sampler, args.batch_size, drop_last=True)

    data_loader = torch.utils.data.DataLoader(
        dataset, batch_sampler=train_batch_sampler, num_workers=args.workers,
        collate_fn=utils.collate_fn)

    data_loader_test = torch.utils.data.DataLoader(
        dataset_test, batch_size=1,
        sampler=test_sampler, num_workers=args.workers,
        collate_fn=utils.collate_fn)

    print("Creating model")
    kwargs = {}
    if "keypoint" in args.model:
        kwargs["num_keypoints"] = 6
    # if "rcnn" in args.model:
    #     kwargs["rpn_score_thresh"] = 0.0
    model = torchvision.models.detection.__dict__[args.model](num_classes=num_classes, pretrained=args.pretrained,
                                                              **kwargs)
    model.to(device)

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module

    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.SGD(
        params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

    # lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_steps, gamma=args.lr_gamma)

    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        args.start_epoch = checkpoint['epoch'] + 1

    if args.test_only:
        evaluate(model, data_loader_test, device=device)
        return

    print("Start training")
    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        train_one_epoch(model, optimizer, data_loader, device, epoch, args.print_freq)
        lr_scheduler.step()
        if args.output_dir:
            utils.save_on_master({
                'model': model_without_ddp.state_dict(),
                'optimizer': optimizer.state_dict(),
                'lr_scheduler': lr_scheduler.state_dict(),
                'args': args,
                'epoch': epoch},
                os.path.join(args.output_dir, 'model77.pth'))

        # evaluate after every epoch
        evaluate(model, data_loader_test, device=device)

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
Exemple #25
0
def main(args):
    if args.apex:
        if sys.version_info < (3, 0):
            raise RuntimeError(
                "Apex currently only supports Python 3. Aborting.")
        if amp is None:
            raise RuntimeError(
                "Failed to import apex. Please install apex from https://www.github.com/nvidia/apex "
                "to enable mixed-precision training.")

    if args.output_dir:
        utils.mkdir(args.output_dir)

    vis = utils.Visualize(args)

    utils.init_distributed_mode(args)
    print(args)
    print("torch version: ", torch.__version__)
    print("torchvision version: ", torchvision.__version__)

    device = torch.device(args.device)

    torch.backends.cudnn.benchmark = True

    # Data loading code
    print("Loading data")
    traindir = os.path.join(
        args.data_path, 'train_256' if not args.fast_test else 'val_256_bob')
    valdir = os.path.join(args.data_path, 'val_256_bob')
    normalize = T.Normalize(mean=[0.43216, 0.394666, 0.37645],
                            std=[0.22803, 0.22145, 0.216989])

    print("Loading training data")
    st = time.time()
    cache_path = _get_cache_path(traindir)

    frame_transform_train = utils.make_frame_transform(args.frame_transforms)

    transform_train = torchvision.transforms.Compose([
        #         torchvision.transforms.RandomGrayscale(p=1),
        frame_transform_train,
        T.ToFloatTensorInZeroOne(),
        T.Resize((256, 256)),
        # T.Resize((128, 171)),
        # T.RandomHorizontalFlip(),
        # T.GaussianBlurTransform(),
        normalize,
        # T.RandomCrop((112, 112))
    ])

    def make_dataset(is_train):
        _transform = transform_train if is_train else transform_test

        if 'kinetics' in args.data_path.lower():
            return Kinetics400(traindir if is_train else valdir,
                               frames_per_clip=args.clip_len,
                               step_between_clips=1,
                               transform=transform_train,
                               extensions=('mp4'),
                               frame_rate=args.frame_skip)
        else:
            return VideoList(
                args,
                is_train,
                frame_gap=args.frame_skip,
                transform=_transform,
                # frame_transform=_frame_transform
            )

    if args.cache_dataset and os.path.exists(cache_path):
        print("Loading dataset_train from {}".format(cache_path))
        dataset, _ = torch.load(cache_path)
        dataset.transform = transform_train
    else:
        if args.distributed:
            print("It is recommended to pre-compute the dataset cache "
                  "on a single-gpu first, as it will be faster")
        dataset = make_dataset(is_train=True)

        if args.cache_dataset:
            print("Saving dataset_train to {}".format(cache_path))
            utils.mkdir(os.path.dirname(cache_path))
            utils.save_on_master((dataset, traindir), cache_path)

    if hasattr(dataset, 'video_clips'):
        dataset.video_clips.compute_clips(args.clip_len, 1, frame_rate=15)

    print("Took", time.time() - st)

    print("Loading validation data")
    cache_path = _get_cache_path(valdir)

    transform_test = torchvision.transforms.Compose([
        T.ToFloatTensorInZeroOne(),
        # T.Resize((128, 171)),
        # normalize,
        # T.CenterCrop((112, 112))
        T.Resize((256, 256)),
        normalize
    ])

    if args.cache_dataset and os.path.exists(cache_path):
        print("Loading dataset_test from {}".format(cache_path))
        dataset_test, _ = torch.load(cache_path)
        dataset_test.transform = transform_test
    else:
        if args.distributed:
            print("It is recommended to pre-compute the dataset cache "
                  "on a single-gpu first, as it will be faster")
        # dataset_test = Kinetics400(
        #     valdir,
        #     frames_per_clip=args.clip_len,
        #     step_between_clips=1,
        #     transform=transform_test,
        #     extensions=('mp4')
        # )
        dataset_test = make_dataset(is_train=False)

        if args.cache_dataset:
            print("Saving dataset_test to {}".format(cache_path))
            utils.mkdir(os.path.dirname(cache_path))
            utils.save_on_master((dataset_test, valdir), cache_path)

    if hasattr(dataset, 'video_clips'):
        dataset_test.video_clips.compute_clips(args.clip_len, 1, frame_rate=15)

    def make_data_sampler(is_train, dataset):
        if hasattr(dataset, 'video_clips'):
            _sampler = RandomClipSampler if is_train else UniformClipSampler
            return _sampler(dataset.video_clips, args.clips_per_video)
        else:
            return torch.utils.data.sampler.RandomSampler(
                dataset) if is_train else None

    print("Creating data loaders")
    train_sampler, test_sampler = make_data_sampler(True, dataset), \
                                    make_data_sampler(False, dataset_test)
    # train_sampler = train_sampler(dataset.video_clips, args.clips_per_video)
    # test_sampler = test_sampler(dataset_test.video_clips, args.clips_per_video)

    if args.distributed:
        train_sampler = DistributedSampler(train_sampler)
        test_sampler = DistributedSampler(test_sampler)

    data_loader = torch.utils.data.DataLoader(dataset,
                                              batch_size=args.batch_size,
                                              sampler=train_sampler,
                                              num_workers=args.workers,
                                              pin_memory=True,
                                              collate_fn=collate_fn)

    data_loader_test = torch.utils.data.DataLoader(dataset_test,
                                                   batch_size=args.batch_size,
                                                   sampler=test_sampler,
                                                   num_workers=args.workers,
                                                   pin_memory=True,
                                                   collate_fn=collate_fn)

    print("Creating model")
    import resnet
    import timecycle as tc
    # model = resnet.__dict__[args.model](pretrained=args.pretrained)
    model = tc.TimeCycle(args)

    # utils.compute_RF_numerical(model.resnet, torch.ones(1, 3, 1, 112, 112).numpy())
    # import pdb; pdb.set_trace()
    # print(utils.compute_RF_numerical(model,img_np))

    model.to(device)

    if args.distributed and args.sync_bn:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    criterion = nn.CrossEntropyLoss()

    lr = args.lr * args.world_size
    # optimizer = torch.optim.SGD(
    #     model.parameters(), lr=lr, momentum=args.momentum, weight_decay=args.weight_decay)
    optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
    if args.apex:
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.apex_opt_level)

    # convert scheduler to be per iteration, not per epoch, for warmup that lasts
    # between different epochs
    warmup_iters = args.lr_warmup_epochs * len(data_loader)
    lr_milestones = [len(data_loader) * m for m in args.lr_milestones]
    lr_scheduler = WarmupMultiStepLR(optimizer,
                                     milestones=lr_milestones,
                                     gamma=args.lr_gamma,
                                     warmup_iters=warmup_iters,
                                     warmup_factor=1e-5)

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpu])
        model_without_ddp = model.module

    if args.data_parallel:
        model = torch.nn.parallel.DataParallel(model)
        model_without_ddp = model.module

    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        args.start_epoch = checkpoint['epoch'] + 1

    if args.test_only:
        evaluate(model, criterion, data_loader_test, device=device)
        return

    print("Start training")
    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        train_one_epoch(model,
                        criterion,
                        optimizer,
                        lr_scheduler,
                        data_loader,
                        device,
                        epoch,
                        args.print_freq,
                        args.apex,
                        vis=vis)
        # evaluate(model, criterion, data_loader_test, device=device)
        if args.output_dir:
            checkpoint = {
                'model': model_without_ddp.state_dict(),
                'optimizer': optimizer.state_dict(),
                'lr_scheduler': lr_scheduler.state_dict(),
                'epoch': epoch,
                'args': args
            }
            utils.save_on_master(
                checkpoint,
                os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
            utils.save_on_master(
                checkpoint, os.path.join(args.output_dir, 'checkpoint.pth'))

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
Exemple #26
0
def main(args):
    if args.apex and amp is None:
        raise RuntimeError(
            "Failed to import apex. Please install apex from https://www.github.com/nvidia/apex "
            "to enable mixed-precision training.")

    if args.output_dir:
        utils.mkdir(args.output_dir)

    utils.init_distributed_mode(args)
    print(args)

    device = torch.device(args.device)

    torch.backends.cudnn.benchmark = True

    train_dir = os.path.join(args.data_path, 'train')
    val_dir = os.path.join(args.data_path, 'val')
    dataset, dataset_test, train_sampler, test_sampler = load_data(
        train_dir, val_dir, args.cache_dataset, args.distributed)
    dataset.samples = [dataset.samples[idx] for idx in range(1024)]
    dataset.targets = [dataset.targets[idx] for idx in range(1024)]
    dataset_test.samples = [dataset.samples[idx] for idx in range(1024)]
    dataset_test.targets = [dataset.targets[idx] for idx in range(1024)]

    data_loader = torch.utils.data.DataLoader(dataset,
                                              batch_size=args.batch_size,
                                              sampler=train_sampler,
                                              num_workers=args.workers,
                                              pin_memory=True)

    data_loader_test = torch.utils.data.DataLoader(dataset_test,
                                                   batch_size=args.batch_size,
                                                   sampler=test_sampler,
                                                   num_workers=args.workers,
                                                   pin_memory=True)

    print("Creating model")
    model = torchvision.models.__dict__[args.model](pretrained=args.pretrained)
    model.to(device)

    # Mehrdad: Fuse
    from DG_Prune.FuseHook import Fuse_Hook, get_modules_to_fuse
    # from torch.quantization.fuse_modules import fuse_modules

    # model_fused = fuse_modules(model, h.modules_to_fuse, inplace=False, fuser_func=modified_fuse_known_modules )
    # Mehrdad: Prune
    from DG_Prune import DG_Pruner, TaylorImportance, MagnitudeImportance, RigLImportance, PrunableConv2d
    dgPruner = None
    if args.prune:
        dgPruner = DG_Pruner()
        model = dgPruner.swap_prunable_modules(model)
        # dgPruner.dump_sparsity_stat(model, output_dir, 0)
        pruners = dgPruner.pruners_from_file('DG_Prune/rigl_resnet50.json')
        hooks = dgPruner.add_custom_pruning(model, RigLImportance)
    #
    fuse_type_list = [[PrunableConv2d, nn.BatchNorm2d]]
    for image, _ in data_loader_test:
        sample_image = image[0].unsqueeze(0)
        break
    modules_to_fuse = get_modules_to_fuse(model, fuse_type_list, sample_image)

    dgPruner.attach_bn_to_prunables(model, modules_to_fuse)

    if args.distributed and args.sync_bn:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    criterion = nn.CrossEntropyLoss()

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    if args.apex:
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.apex_opt_level)

    # lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
    lf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (
        1 - args.lrf) + args.lrf  # cosine
    lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpu])
        model_without_ddp = model.module

    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        args.start_epoch = checkpoint['epoch'] + 1

    if args.test_only:
        evaluate(model,
                 criterion,
                 data_loader_test,
                 device=device,
                 print_freq=args.print_freq,
                 dgPruner=dgPruner,
                 output_dir=args.output_dir)
        return

    print("Start training")
    start_time = time.time()
    for lth_stage in range(0, dgPruner.num_stages() + 1):
        if (lth_stage != 0):
            checkpoint = dgPruner.rewind_masked_checkpoint()
            model_without_ddp.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
            args.start_epoch = checkpoint['epoch'] + 1
            dgPruner.dump_sparsity_stat(model, args.output_dir,
                                        lth_stage * 10000)

        for epoch in range(args.start_epoch, args.epochs):
            if args.distributed:
                train_sampler.set_epoch(epoch)
            train_metrics = train_one_epoch(model,
                                            criterion,
                                            optimizer,
                                            data_loader,
                                            device,
                                            epoch,
                                            args.print_freq,
                                            args.apex,
                                            dgPruner=dgPruner,
                                            output_dir=args.output_dir)
            lr_scheduler.step()
            eval_metrics = evaluate(model,
                                    criterion,
                                    data_loader_test,
                                    device=device,
                                    print_freq=args.print_freq,
                                    dgPruner=dgPruner,
                                    output_dir=args.output_dir)

            if args.output_dir:
                checkpoint = {
                    'model': model_without_ddp.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'lr_scheduler': lr_scheduler.state_dict(),
                    'epoch': epoch,
                    'args': args
                }
                utils.save_on_master(
                    checkpoint,
                    os.path.join(args.output_dir,
                                 'model_{}.pth'.format(epoch)))
                utils.save_on_master(
                    checkpoint, os.path.join(args.output_dir,
                                             'checkpoint.pth'))
            # Mehrdad: LTH, pruning in the end
            if (args.prune):

                if (epoch == args.epochs - 1):
                    dgPruner.prune_n_reset(epoch)
                    dgPruner.dump_sparsity_stat(model, args.output_dir, epoch)
                    dgPruner.apply_mask_to_weight()
                checkpoint = {
                    'model': model_without_ddp.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'lr_scheduler': lr_scheduler.state_dict(),
                    'epoch': epoch,
                    'args': args
                }

                # Save checkpoints
                if (lth_stage == 0) and (epoch == dgPruner.rewind_epoch(
                        args.epochs)):
                    dgPruner.save_rewind_checkpoint(checkpoint)
                if (epoch == args.epochs - 1):
                    dgPruner.save_final_checkpoint(checkpoint)

            update_summary(epoch, train_metrics, eval_metrics,
                           os.path.join(args.output_dir, 'summary.csv'))

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
Exemple #27
0
def main(args):
    utils.init_distributed_mode(args)

    print(args)

    if args.distillation_type != 'none' and args.finetune and not args.eval:
        raise NotImplementedError(
            "Finetuning with distillation not yet supported")

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    # random.seed(seed)

    cudnn.benchmark = True

    dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
    dataset_val, _ = build_dataset(is_train=False, args=args)

    if True:  # args.distributed:
        num_tasks = utils.get_world_size()
        global_rank = utils.get_rank()
        if args.repeated_aug:
            sampler_train = RASampler(dataset_train,
                                      num_replicas=num_tasks,
                                      rank=global_rank,
                                      shuffle=True)
        else:
            sampler_train = torch.utils.data.DistributedSampler(
                dataset_train,
                num_replicas=num_tasks,
                rank=global_rank,
                shuffle=True)
        if args.dist_eval:
            if len(dataset_val) % num_tasks != 0:
                print(
                    'Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
                    'This will slightly alter validation results as extra duplicate entries are added to achieve '
                    'equal num of samples per-process.')
            sampler_val = torch.utils.data.DistributedSampler(
                dataset_val,
                num_replicas=num_tasks,
                rank=global_rank,
                shuffle=False)
        else:
            sampler_val = torch.utils.data.SequentialSampler(dataset_val)
    else:
        sampler_train = torch.utils.data.RandomSampler(dataset_train)
        sampler_val = torch.utils.data.SequentialSampler(dataset_val)

    data_loader_train = torch.utils.data.DataLoader(
        dataset_train,
        sampler=sampler_train,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=True,
    )

    data_loader_val = torch.utils.data.DataLoader(dataset_val,
                                                  sampler=sampler_val,
                                                  batch_size=int(
                                                      1.5 * args.batch_size),
                                                  num_workers=args.num_workers,
                                                  pin_memory=args.pin_mem,
                                                  drop_last=False)

    mixup_fn = None
    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
    if mixup_active:
        mixup_fn = Mixup(mixup_alpha=args.mixup,
                         cutmix_alpha=args.cutmix,
                         cutmix_minmax=args.cutmix_minmax,
                         prob=args.mixup_prob,
                         switch_prob=args.mixup_switch_prob,
                         mode=args.mixup_mode,
                         label_smoothing=args.smoothing,
                         num_classes=args.nb_classes)

    print(f"Creating model: {args.model}")
    model = create_model(
        args.model,
        pretrained=False,
        num_classes=args.nb_classes,
        drop_rate=args.drop,
        drop_path_rate=args.drop_path,
        drop_block_rate=None,
    )

    if args.finetune:
        if args.finetune.startswith('https'):
            checkpoint = torch.hub.load_state_dict_from_url(args.finetune,
                                                            map_location='cpu',
                                                            check_hash=True)
        else:
            checkpoint = torch.load(args.finetune, map_location='cpu')

        checkpoint_model = checkpoint['model']
        state_dict = model.state_dict()
        for k in [
                'head.weight', 'head.bias', 'head_dist.weight',
                'head_dist.bias'
        ]:
            if k in checkpoint_model and checkpoint_model[
                    k].shape != state_dict[k].shape:
                print(f"Removing key {k} from pretrained checkpoint")
                del checkpoint_model[k]

        # interpolate position embedding
        pos_embed_checkpoint = checkpoint_model['pos_embed']
        embedding_size = pos_embed_checkpoint.shape[-1]
        num_patches = model.patch_embed.num_patches
        num_extra_tokens = model.pos_embed.shape[-2] - num_patches
        # height (== width) for the checkpoint position embedding
        orig_size = int(
            (pos_embed_checkpoint.shape[-2] - num_extra_tokens)**0.5)
        # height (== width) for the new position embedding
        new_size = int(num_patches**0.5)
        # class_token and dist_token are kept unchanged
        extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
        # only the position tokens are interpolated
        pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
        pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size,
                                        embedding_size).permute(0, 3, 1, 2)
        pos_tokens = torch.nn.functional.interpolate(pos_tokens,
                                                     size=(new_size, new_size),
                                                     mode='bicubic',
                                                     align_corners=False)
        pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
        new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
        checkpoint_model['pos_embed'] = new_pos_embed

        model.load_state_dict(checkpoint_model, strict=False)

    model.to(device)

    model_ema = None
    if args.model_ema:
        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
        model_ema = ModelEma(model,
                             decay=args.model_ema_decay,
                             device='cpu' if args.model_ema_force_cpu else '',
                             resume='')

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpu])
        model_without_ddp = model.module
    n_parameters = sum(p.numel() for p in model.parameters()
                       if p.requires_grad)
    print('number of params:', n_parameters)

    linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size(
    ) / 512.0
    args.lr = linear_scaled_lr
    optimizer = create_optimizer(args, model_without_ddp)
    loss_scaler = NativeScaler()

    lr_scheduler, _ = create_scheduler(args, optimizer)

    criterion = LabelSmoothingCrossEntropy()

    if args.mixup > 0.:
        # smoothing is handled with mixup label transform
        criterion = SoftTargetCrossEntropy()
    elif args.smoothing:
        criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
    else:
        criterion = torch.nn.CrossEntropyLoss()

    teacher_model = None
    if args.distillation_type != 'none':
        assert args.teacher_path, 'need to specify teacher-path when using distillation'
        print(f"Creating teacher model: {args.teacher_model}")
        teacher_model = create_model(
            args.teacher_model,
            pretrained=False,
            num_classes=args.nb_classes,
            global_pool='avg',
        )
        if args.teacher_path.startswith('https'):
            checkpoint = torch.hub.load_state_dict_from_url(args.teacher_path,
                                                            map_location='cpu',
                                                            check_hash=True)
        else:
            checkpoint = torch.load(args.teacher_path, map_location='cpu')
        teacher_model.load_state_dict(checkpoint['model'])
        teacher_model.to(device)
        teacher_model.eval()

    # wrap the criterion in our custom DistillationLoss, which
    # just dispatches to the original criterion if args.distillation_type is 'none'
    criterion = DistillationLoss(criterion, teacher_model,
                                 args.distillation_type,
                                 args.distillation_alpha,
                                 args.distillation_tau)

    output_dir = Path(args.output_dir)
    if args.resume:
        if args.resume.startswith('https'):
            checkpoint = torch.hub.load_state_dict_from_url(args.resume,
                                                            map_location='cpu',
                                                            check_hash=True)
        else:
            checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer'])
            lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
            args.start_epoch = checkpoint['epoch'] + 1
            if args.model_ema:
                utils._load_checkpoint_for_ema(model_ema,
                                               checkpoint['model_ema'])
            if 'scaler' in checkpoint:
                loss_scaler.load_state_dict(checkpoint['scaler'])

    if args.eval:
        test_stats = evaluate(data_loader_val, model, device)
        print(
            f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%"
        )
        return

    print(f"Start training for {args.epochs} epochs")
    start_time = time.time()
    max_accuracy = 0.0
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            data_loader_train.sampler.set_epoch(epoch)

        train_stats = train_one_epoch(
            model,
            criterion,
            data_loader_train,
            optimizer,
            device,
            epoch,
            loss_scaler,
            args.clip_grad,
            model_ema,
            mixup_fn,
            set_training_mode=args.finetune ==
            ''  # keep in eval mode during finetuning
        )

        lr_scheduler.step(epoch)
        if args.output_dir:
            checkpoint_paths = [output_dir / ('checkpoint_%04d.pth' % (epoch))]
            for checkpoint_path in checkpoint_paths:
                utils.save_on_master(
                    {
                        'model': model_without_ddp.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'lr_scheduler': lr_scheduler.state_dict(),
                        'epoch': epoch,
                        'model_ema': get_state_dict(model_ema),
                        'scaler': loss_scaler.state_dict(),
                        'args': args,
                    }, checkpoint_path)

        if not args.train_without_eval:
            test_stats = evaluate(data_loader_val, model, device)
            print(
                f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%"
            )
            max_accuracy = max(max_accuracy, test_stats["acc1"])
            print(f'Max accuracy: {max_accuracy:.2f}%')

            log_stats = {
                **{f'train_{k}': v
                   for k, v in train_stats.items()},
                **{f'test_{k}': v
                   for k, v in test_stats.items()}, 'epoch': epoch,
                'n_parameters': n_parameters
            }
        else:
            log_stats = {
                **{f'train_{k}': v
                   for k, v in train_stats.items()}, 'epoch': epoch,
                'n_parameters': n_parameters
            }
        if args.output_dir and utils.is_main_process():
            with (output_dir / "log.txt").open("a") as f:
                f.write(json.dumps(log_stats) + "\n")

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
def main(args):

    if args.output_dir:
        utils.mkdir(args.output_dir)

    utils.init_distributed_mode(args)
    print(args)

    if args.post_training_quantize and args.distributed:
        raise RuntimeError("Post training quantization example should not be performed "
                           "on distributed mode")

    # Set backend engine to ensure that quantized model runs on the correct kernels
    if args.backend not in torch.backends.quantized.supported_engines:
        raise RuntimeError("Quantized backend not supported: " + str(args.backend))
    torch.backends.quantized.engine = args.backend

    device = torch.device(args.device)
    torch.backends.cudnn.benchmark = True

    # Data loading code
    print("Loading data")
    train_dir = os.path.join(args.data_path, 'train')
    val_dir = os.path.join(args.data_path, 'val')

    dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir,
                                                                   args.cache_dataset, args.distributed)
    data_loader = torch.utils.data.DataLoader(
        dataset, batch_size=args.batch_size,
        sampler=train_sampler, num_workers=args.workers, pin_memory=True)

    data_loader_test = torch.utils.data.DataLoader(
        dataset_test, batch_size=args.eval_batch_size,
        sampler=test_sampler, num_workers=args.workers, pin_memory=True)

    print("Creating model", args.model)
    # when training quantized models, we always start from a pre-trained fp32 reference model
    model = torchvision.models.quantization.__dict__[args.model](pretrained=True, quantize=args.test_only)
    model.to(device)

    if not (args.test_only or args.post_training_quantize):
        model.fuse_model()
        model.qconfig = torch.quantization.get_default_qat_qconfig(args.backend)
        torch.quantization.prepare_qat(model, inplace=True)

        if args.distributed and args.sync_bn:
            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

        optimizer = torch.optim.SGD(
            model.parameters(), lr=args.lr, momentum=args.momentum,
            weight_decay=args.weight_decay)

        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                       step_size=args.lr_step_size,
                                                       gamma=args.lr_gamma)

    criterion = nn.CrossEntropyLoss()
    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module

    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        args.start_epoch = checkpoint['epoch'] + 1

    if args.post_training_quantize:
        # perform calibration on a subset of the training dataset
        # for that, create a subset of the training dataset
        ds = torch.utils.data.Subset(
            dataset,
            indices=list(range(args.batch_size * args.num_calibration_batches)))
        data_loader_calibration = torch.utils.data.DataLoader(
            ds, batch_size=args.batch_size, shuffle=False, num_workers=args.workers,
            pin_memory=True)
        model.eval()
        model.fuse_model()
        model.qconfig = torch.quantization.get_default_qconfig(args.backend)
        torch.quantization.prepare(model, inplace=True)
        # Calibrate first
        print("Calibrating")
        evaluate(model, criterion, data_loader_calibration, device=device, print_freq=1)
        torch.quantization.convert(model, inplace=True)
        if args.output_dir:
            print('Saving quantized model')
            if utils.is_main_process():
                torch.save(model.state_dict(), os.path.join(args.output_dir,
                           'quantized_post_train_model.pth'))
        print("Evaluating post-training quantized model")
        evaluate(model, criterion, data_loader_test, device=device)
        return

    if args.test_only:
        evaluate(model, criterion, data_loader_test, device=device)
        return

    model.apply(torch.quantization.enable_observer)
    model.apply(torch.quantization.enable_fake_quant)
    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        print('Starting training for epoch', epoch)
        train_one_epoch(model, criterion, optimizer, data_loader, device, epoch,
                        args.print_freq)
        lr_scheduler.step()
        with torch.no_grad():
            if epoch >= args.num_observer_update_epochs:
                print('Disabling observer for subseq epochs, epoch = ', epoch)
                model.apply(torch.quantization.disable_observer)
            if epoch >= args.num_batch_norm_update_epochs:
                print('Freezing BN for subseq epochs, epoch = ', epoch)
                model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
            print('Evaluate QAT model')

            evaluate(model, criterion, data_loader_test, device=device)
            quantized_eval_model = copy.deepcopy(model_without_ddp)
            quantized_eval_model.eval()
            quantized_eval_model.to(torch.device('cpu'))
            torch.quantization.convert(quantized_eval_model, inplace=True)

            print('Evaluate Quantized model')
            evaluate(quantized_eval_model, criterion, data_loader_test,
                     device=torch.device('cpu'))

        model.train()

        if args.output_dir:
            checkpoint = {
                'model': model_without_ddp.state_dict(),
                'eval_model': quantized_eval_model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'lr_scheduler': lr_scheduler.state_dict(),
                'epoch': epoch,
                'args': args}
            utils.save_on_master(
                checkpoint,
                os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
            utils.save_on_master(
                checkpoint,
                os.path.join(args.output_dir, 'checkpoint.pth'))
        print('Saving models after epoch ', epoch)

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
Exemple #29
0
def main(args):
    if args.output_dir:
        utils.mkdir(args.output_dir)

    utils.init_distributed_mode(args)
    print(args)

    device = torch.device(args.device)

    torch.backends.cudnn.benchmark = True

    # Data loading code
    print("Loading data")
    traindir = os.path.join(args.data_path, 'train')
    valdir = os.path.join(args.data_path, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    print("Loading training data")
    st = time.time()
    cache_path = _get_cache_path(traindir)
    if args.cache_dataset and os.path.exists(cache_path):
        # Attention, as the transforms are also cached!
        print("Loading dataset_train from {}".format(cache_path))
        dataset, _ = torch.load(cache_path)
    else:
        dataset = torchvision.datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))
        if args.cache_dataset:
            print("Saving dataset_train to {}".format(cache_path))
            utils.mkdir(os.path.dirname(cache_path))
            utils.save_on_master((dataset, traindir), cache_path)
    print("Took", time.time() - st)

    print("Loading validation data")
    cache_path = _get_cache_path(valdir)
    if args.cache_dataset and os.path.exists(cache_path):
        # Attention, as the transforms are also cached!
        print("Loading dataset_test from {}".format(cache_path))
        dataset_test, _ = torch.load(cache_path)
    else:
        dataset_test = torchvision.datasets.ImageFolder(
            valdir,
            transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ]))
        if args.cache_dataset:
            print("Saving dataset_test to {}".format(cache_path))
            utils.mkdir(os.path.dirname(cache_path))
            utils.save_on_master((dataset_test, valdir), cache_path)

    print("Creating data loaders")
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            dataset)
        test_sampler = torch.utils.data.distributed.DistributedSampler(
            dataset_test)
    else:
        train_sampler = torch.utils.data.RandomSampler(dataset)
        test_sampler = torch.utils.data.SequentialSampler(dataset_test)

    data_loader = torch.utils.data.DataLoader(dataset,
                                              batch_size=args.batch_size,
                                              sampler=train_sampler,
                                              num_workers=args.workers,
                                              pin_memory=True)

    data_loader_test = torch.utils.data.DataLoader(dataset_test,
                                                   batch_size=args.batch_size,
                                                   sampler=test_sampler,
                                                   num_workers=args.workers,
                                                   pin_memory=True)

    print("Creating model")
    model = torchvision.models.__dict__[args.model](pretrained=args.pretrained)
    model.to(device)
    if args.distributed and args.sync_bn:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpu])
        model_without_ddp = model.module

    criterion = nn.CrossEntropyLoss()

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                   step_size=args.lr_step_size,
                                                   gamma=args.lr_gamma)

    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        args.start_epoch = checkpoint['epoch'] + 1

    if args.test_only:
        evaluate(model, criterion, data_loader_test, device=device)
        return

    print("Start training")
    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        train_one_epoch(model, criterion, optimizer, data_loader, device,
                        epoch, args.print_freq)
        lr_scheduler.step()
        evaluate(model, criterion, data_loader_test, device=device)
        if args.output_dir:
            checkpoint = {
                'model': model_without_ddp.state_dict(),
                'optimizer': optimizer.state_dict(),
                'lr_scheduler': lr_scheduler.state_dict(),
                'epoch': epoch,
                'args': args
            }
            utils.save_on_master(
                checkpoint,
                os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
            utils.save_on_master(
                checkpoint, os.path.join(args.output_dir, 'checkpoint.pth'))

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
def main(args):
    utils.init_distributed_mode(args)
    print(args)

    device = torch.device(args.device)

    # Data loading code
    print("Loading data")

    # 支持加载自定义Pascal格式数据集 参数dataset设置为custom_voc
    if args.dataset == 'custom_voc':
        # dataset, num_classes = get_custom_voc(args.train_data_path,get_transform(train=True))
        # dataset_test, _ = get_custom_voc(args.test_data_path,get_transform(train=False))

        # 如果是自定义Pascal数据集,不需要传入image_set参数,因此这里设置为None
        dataset, num_classes = get_dataset(args.dataset, None,
                                           get_transform(train=True),
                                           args.train_data_path)
        dataset_test, _ = get_dataset(args.dataset, None,
                                      get_transform(train=False),
                                      args.test_data_path)
    else:
        dataset, num_classes = get_dataset(
            args.dataset, "train" if args.dataset == 'coco' else 'trainval',
            get_transform(train=True), args.data_path)
        dataset_test, _ = get_dataset(
            args.dataset, "test" if args.dataset == 'coco' else 'val',
            get_transform(train=False), args.data_path)

    print("Creating data loaders")
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            dataset)
        test_sampler = torch.utils.data.distributed.DistributedSampler(
            dataset_test)
    else:
        train_sampler = torch.utils.data.RandomSampler(dataset)
        test_sampler = torch.utils.data.SequentialSampler(dataset_test)

    if args.aspect_ratio_group_factor >= 0:
        group_ids = create_aspect_ratio_groups(
            dataset, k=args.aspect_ratio_group_factor)
        train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids,
                                                  args.batch_size)
    else:
        train_batch_sampler = torch.utils.data.BatchSampler(train_sampler,
                                                            args.batch_size,
                                                            drop_last=True)

    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_sampler=train_batch_sampler,
        num_workers=args.workers,
        collate_fn=utils.collate_fn)

    data_loader_test = torch.utils.data.DataLoader(dataset_test,
                                                   batch_size=1,
                                                   sampler=test_sampler,
                                                   num_workers=args.workers,
                                                   collate_fn=utils.collate_fn)

    print("Creating model")
    # model = torchvision.models.detection.fasterrcnn_resnet50_fpn()
    model = torchvision.models.detection.__dict__[args.model](
        num_classes=num_classes, pretrained=args.pretrained)
    model.to(device)

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpu])
        model_without_ddp = model.module

    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.SGD(params,
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=args.lr_steps, gamma=args.lr_gamma)

    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(
            checkpoint['optimizer'])  # 用于恢复训练,处理模型还需要优化器和学习率规则
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])

    # 如果只进行模型测试,注意这里传入的参数是--resume, 原作者只提到了--resume用于恢复训练,根据官方文档可知也是可以用于模型推理的
    # 参考官方文档https://pytorch.org/tutorials/beginner/saving_loading_models.html
    if args.test_only:
        if not args.resume:
            raise Exception('需要checkpoints模型用于推理!')
        else:
            checkpoint = torch.load(args.resume, map_location='cpu')
            model_without_ddp.load_state_dict(checkpoint['model'])

            if 'coco' == args.dataset:
                coco_evaluate(model_without_ddp,
                              data_loader_test,
                              device=device)
            elif 'voc' == args.dataset:
                voc_evaluate(model_without_ddp,
                             data_loader_test,
                             device=device)
            elif 'custom_voc' == args.dataset:
                custom_voc_evaluate(model_without_ddp,
                                    data_loader_test,
                                    device=device)
            else:
                print(
                    f'No evaluation method available for the dataset {args.dataset}'
                )
            # evaluate(model, data_loader_test, device=device)
            return

    print("Start training")
    start_time = time.time()
    for epoch in range(args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        train_one_epoch(model, optimizer, data_loader, device, epoch,
                        args.print_freq)
        lr_scheduler.step()
        if args.output_dir:
            # model.save('./checkpoints/model_{}_{}.pth'.format(args.dataset, epoch))
            utils.save_on_master(
                {
                    'model': model_without_ddp.state_dict(),  # 存储网络参数(不存储网络骨架)
                    # 'model': model_without_ddp, # 存储整个网络
                    'optimizer': optimizer.state_dict(),
                    'lr_scheduler': lr_scheduler.state_dict(),
                    'args': args
                },
                os.path.join(args.output_dir,
                             'model_{}_{}.pth'.format(args.dataset, epoch)))

        # evaluate after every epoch
        if args.dataset == 'coco':
            coco_evaluate(model, data_loader_test, device=device)
        elif 'voc' == args.dataset:
            voc_evaluate(model, data_loader_test, device=device)
        elif 'custom_voc' == args.dataset:
            custom_voc_evaluate(model, data_loader_test, device=device)
        else:
            print(
                f'No evaluation method available for the dataset {args.dataset}'
            )

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))