Пример #1
0
def _densenet(arch, growth_rate, block_config, num_init_features, pretrained=False,
              progress=True, imagenet_pretrained=False, num_classes=1, lin_features=512,
              dropout_prob=0.5, bn_final=False, concat_pool=True, **kwargs):

    # Model creation
    base_model = DenseNet(growth_rate, block_config, num_init_features, num_classes=num_classes, **kwargs)
    # Imagenet pretraining
    if imagenet_pretrained:
        if pretrained:
            raise ValueError('imagenet_pretrained cannot be set to True if pretrained=True')
        state_dict = load_state_dict_from_url(imagenet_urls[arch],
                                              progress=progress)
        state_dict = _update_state_dict(state_dict)
        # Remove FC params from dict
        for key in ('classifier.weight', 'classifier.bias'):
            state_dict.pop(key, None)
        missing, unexpected = base_model.load_state_dict(state_dict, strict=False)
        if any(unexpected) or any(not elt.startswith('classifier.') for elt in missing):
            raise KeyError(f"Missing parameters: {missing}\nUnexpected parameters: {unexpected}")

    # Cut at last conv layers
    model = cnn_model(base_model, model_cut, base_model.classifier.in_features, num_classes,
                      lin_features, dropout_prob, bn_final=bn_final, concat_pool=concat_pool)

    # Parameter loading
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        model.load_state_dict(state_dict)

    return model
Пример #2
0
def densenet121(pretrained=False, **kwargs):
    r"""Densenet-121 model from
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = Orig_DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), **kwargs)
    if pretrained:
        # '.'s are no longer allowed in module names, but pervious _DenseLayer
        # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
        # They are also in the checkpoints in model_urls. This pattern is used
        # to find such keys.
        pattern = re.compile(
            r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
        state_dict = model_zoo.load_url(model_urls['densenet121'])
        for key in list(state_dict.keys()):
            res = pattern.match(key)
            if res:
                new_key = res.group(1) + res.group(2)
                state_dict[new_key] = state_dict[key]
                del state_dict[key]
        model.load_state_dict(state_dict)
    return model
Пример #3
0
def main():
    global best_top1, best_top5

    start_epoch = args.start_epoch  # start from epoch 0 or last checkpoint epoch

    if not os.path.isdir(args.checkpoint):
        mkdir_p(args.checkpoint)

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    validdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ])

    val_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ])

    train_data = imagenet_lmdb_dataset(traindir, transform=train_transform)
    valid_data = imagenet_lmdb_dataset(validdir, transform=val_transform)

    train_sampler = torch.utils.data.distributed.DistributedSampler(train_data)

    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.train_batch,
                                               shuffle=(train_sampler is None),
                                               pin_memory=True,
                                               num_workers=8,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(valid_data,
                                             batch_size=args.test_batch,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=8)

    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True)
    elif args.arch.startswith('resnext'):
        model = models.__dict__[args.arch](
            baseWidth=args.base_width,
            cardinality=args.cardinality,
        )
    elif args.arch == 'densenet264':
        model = DenseNet(growth_rate=32,
                         block_config=(6, 12, 64, 48),
                         num_init_features=64,
                         bn_size=4,
                         drop_rate=0,
                         num_classes=1000,
                         memory_efficient=False)
    elif args.arch == 'resnet200':
        model = ResNet(block=Bottleneck,
                       layers=[3, 24, 36, 3],
                       num_classes=1000,
                       zero_init_residual=False,
                       groups=1,
                       width_per_group=64,
                       replace_stride_with_dilation=None,
                       norm_layer=None)
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch]()

    if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
        model.features = DDP(model.features)
        model.cuda()
    else:
        model = model.cuda()
        model = DDP(model, delay_allreduce=True)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()
    if args.optimizer.lower() == 'sgd':
        optimizer = optim.SGD(model.parameters(),
                              lr=args.lr,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)
    elif args.optimizer.lower() == 'adamw':
        optimizer = AdamW(model.parameters(),
                          lr=args.lr,
                          betas=(args.beta1, args.beta2),
                          weight_decay=args.weight_decay,
                          warmup=0)
    elif args.optimizer.lower() == 'radam':
        optimizer = RAdam(model.parameters(),
                          lr=args.lr,
                          betas=(args.beta1, args.beta2),
                          weight_decay=args.weight_decay)
    elif args.optimizer.lower() == 'lsadam':
        optimizer = LSAdamW(model.parameters(),
                            lr=args.lr * ((1. + 4. * args.sigma)**(0.25)),
                            betas=(args.beta1, args.beta2),
                            weight_decay=args.weight_decay,
                            sigma=args.sigma)
    elif args.optimizer.lower() == 'lsradam':
        sigma = 0.1
        optimizer = LSRAdam(model.parameters(),
                            lr=args.lr * ((1. + 4. * args.sigma)**(0.25)),
                            betas=(args.beta1, args.beta2),
                            weight_decay=args.weight_decay,
                            sigma=args.sigma)
    elif args.optimizer.lower() == 'srsgd':
        iter_count = 1
        optimizer = SGD_Adaptive(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay,
                                 iter_count=iter_count,
                                 restarting_iter=args.restart_schedule[0])
    elif args.optimizer.lower() == 'sradam':
        iter_count = 1
        optimizer = SRNAdam(model.parameters(),
                            lr=args.lr,
                            betas=(args.beta1, args.beta2),
                            iter_count=iter_count,
                            weight_decay=args.weight_decay,
                            restarting_iter=args.restart_schedule[0])
    elif args.optimizer.lower() == 'sradamw':
        iter_count = 1
        optimizer = SRAdamW(model.parameters(),
                            lr=args.lr,
                            betas=(args.beta1, args.beta2),
                            iter_count=iter_count,
                            weight_decay=args.weight_decay,
                            warmup=0,
                            restarting_iter=args.restart_schedule[0])
    elif args.optimizer.lower() == 'srradam':
        #NOTE: need to double-check this
        iter_count = 1
        optimizer = SRRAdam(model.parameters(),
                            lr=args.lr,
                            betas=(args.beta1, args.beta2),
                            iter_count=iter_count,
                            weight_decay=args.weight_decay,
                            warmup=0,
                            restarting_iter=args.restart_schedule[0])

    schedule_index = 1
    # Resume
    title = 'ImageNet-' + args.arch
    if args.resume:
        # Load checkpoint.
        print('==> Resuming from checkpoint..')
        assert os.path.isfile(
            args.resume), 'Error: no checkpoint directory found!'
        # args.checkpoint = os.path.dirname(args.resume)
        # checkpoint = torch.load(args.resume, map_location = lambda storage, loc: storage.cuda(args.local_rank))
        checkpoint = torch.load(args.resume, map_location=torch.device('cpu'))
        best_top1 = checkpoint['best_top1']
        best_top5 = checkpoint['best_top5']
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        if args.optimizer.lower() == 'srsgd' or args.optimizer.lower(
        ) == 'sradam' or args.optimizer.lower(
        ) == 'sradamw' or args.optimizer.lower() == 'srradam':
            iter_count = optimizer.param_groups[0]['iter_count']
        schedule_index = checkpoint['schedule_index']
        state['lr'] = optimizer.param_groups[0]['lr']
        if args.checkpoint == args.resume:
            logger = LoggerDistributed(os.path.join(args.checkpoint,
                                                    'log.txt'),
                                       rank=args.local_rank,
                                       title=title,
                                       resume=True)
        else:
            logger = LoggerDistributed(os.path.join(args.checkpoint,
                                                    'log.txt'),
                                       rank=args.local_rank,
                                       title=title)
            if args.local_rank == 0:
                logger.set_names([
                    'Learning Rate', 'Train Loss', 'Valid Loss', 'Train Top1',
                    'Valid Top1', 'Train Top5', 'Valid Top5'
                ])
    else:
        logger = LoggerDistributed(os.path.join(args.checkpoint, 'log.txt'),
                                   rank=args.local_rank,
                                   title=title)
        if args.local_rank == 0:
            logger.set_names([
                'Learning Rate', 'Train Loss', 'Valid Loss', 'Train Top1',
                'Valid Top1', 'Train Top5', 'Valid Top5'
            ])

    if args.local_rank == 0:
        logger.file.write('    Total params: %.2fM' %
                          (sum(p.numel()
                               for p in model.parameters()) / 1000000.0))

    if args.evaluate:
        if args.local_rank == 0:
            logger.file.write('\nEvaluation only')
        test_loss, test_top1, test_top5 = test(val_loader, model, criterion,
                                               start_epoch, use_cuda, logger)
        if args.local_rank == 0:
            logger.file.write(
                ' Test Loss:  %.8f, Test Top1:  %.2f, Test Top5: %.2f' %
                (test_loss, test_top1, test_top5))
        return

    # Train and val
    for epoch in range(start_epoch, args.epochs):
        # Shuffle the sampler.
        train_loader.sampler.set_epoch(epoch + args.manualSeed)

        if args.optimizer.lower() == 'srsgd':
            if epoch in args.schedule:
                optimizer = SGD_Adaptive(
                    model.parameters(),
                    lr=args.lr * (args.gamma**schedule_index),
                    weight_decay=args.weight_decay,
                    iter_count=iter_count,
                    restarting_iter=args.restart_schedule[schedule_index])
                schedule_index += 1

        elif args.optimizer.lower() == 'sradam':
            if epoch in args.schedule:
                optimizer = SRNAdam(
                    model.parameters(),
                    lr=args.lr * (args.gamma**schedule_index),
                    betas=(args.beta1, args.beta2),
                    iter_count=iter_count,
                    weight_decay=args.weight_decay,
                    restarting_iter=args.restart_schedule[schedule_index])
                schedule_index += 1

        elif args.optimizer.lower() == 'sradamw':
            if epoch in args.schedule:
                optimizer = SRAdamW(
                    model.parameters(),
                    lr=args.lr * (args.gamma**schedule_index),
                    betas=(args.beta1, args.beta2),
                    iter_count=iter_count,
                    weight_decay=args.weight_decay,
                    warmup=0,
                    restarting_iter=args.restart_schedule[schedule_index])
                schedule_index += 1

        elif args.optimizer.lower() == 'srradam':
            if epoch in args.schedule:
                optimizer = SRRAdam(
                    model.parameters(),
                    lr=args.lr * (args.gamma**schedule_index),
                    betas=(args.beta1, args.beta2),
                    iter_count=iter_count,
                    weight_decay=args.weight_decay,
                    warmup=0,
                    restarting_iter=args.restart_schedule[schedule_index])
                schedule_index += 1

        else:
            adjust_learning_rate(optimizer, epoch)

        if args.local_rank == 0:
            logger.file.write('\nEpoch: [%d | %d] LR: %f' %
                              (epoch + 1, args.epochs, state['lr']))

        if args.optimizer.lower() == 'srsgd' or args.optimizer.lower(
        ) == 'sradam' or args.optimizer.lower(
        ) == 'sradamw' or args.optimizer.lower() == 'srradam':
            train_loss, train_top1, train_top5, iter_count = train(
                train_loader, model, criterion, optimizer, epoch, use_cuda,
                logger)
        else:
            train_loss, train_top1, train_top5 = train(train_loader, model,
                                                       criterion, optimizer,
                                                       epoch, use_cuda, logger)

        test_loss, test_top1, test_top5 = test(val_loader, model, criterion,
                                               epoch, use_cuda, logger)

        # append logger file
        if args.local_rank == 0:
            logger.append([
                state['lr'], train_loss, test_loss, train_top1, test_top1,
                train_top5, test_top5
            ])
            writer.add_scalars('train_loss', {args.model_name: train_loss},
                               epoch)
            writer.add_scalars('test_loss', {args.model_name: test_loss},
                               epoch)
            writer.add_scalars('train_top1', {args.model_name: train_top1},
                               epoch)
            writer.add_scalars('test_top1', {args.model_name: test_top1},
                               epoch)
            writer.add_scalars('train_top5', {args.model_name: train_top5},
                               epoch)
            writer.add_scalars('test_top5', {args.model_name: test_top5},
                               epoch)

        # save model
        is_best = test_top1 > best_top1
        best_top1 = max(test_top1, best_top1)
        best_top5 = max(test_top5, best_top5)
        if args.local_rank == 0:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'schedule_index': schedule_index,
                    'state_dict': model.state_dict(),
                    'top1': test_top1,
                    'top5': test_top5,
                    'best_top1': best_top1,
                    'best_top5': best_top5,
                    'optimizer': optimizer.state_dict(),
                },
                is_best,
                epoch,
                checkpoint=args.checkpoint)

            if epoch == args.schedule[-1]:
                logger.file.write('Best top1: %f at epoch %i' %
                                  (best_top1, epoch))
                logger.file.write('Best top5: %f at epoch %i' %
                                  (best_top5, epoch))
                print('Best top1: %f at epoch %i' % (best_top1, epoch))
                print('Best top5: %f at epoch %i' % (best_top5, epoch))
                with open("./all_results_imagenet.txt", "a") as f:
                    fcntl.flock(f, fcntl.LOCK_EX)
                    f.write("%s\n" % args.checkpoint)
                    f.write("best_top1 %f, best_top5 %f at epoch %i\n\n" %
                            (best_top1, best_top5, epoch))
                    fcntl.flock(f, fcntl.LOCK_UN)

    if args.local_rank == 0:
        logger.file.write('Best top1: %f' % best_top1)
        logger.file.write('Best top5: %f' % best_top5)
        logger.close()
        logger.plot()
        savefig(os.path.join(args.checkpoint, 'log.eps'))
        print('Best top1: %f' % best_top1)
        print('Best top5: %f' % best_top5)
        with open("./all_results_imagenet.txt", "a") as f:
            fcntl.flock(f, fcntl.LOCK_EX)
            f.write("%s\n" % args.checkpoint)
            f.write("best_top1 %f, best_top5 %f\n\n" % (best_top1, best_top5))
            fcntl.flock(f, fcntl.LOCK_UN)