Exemple #1
0
def main():
    # Scale learning rate based on global batch size.
    if not args.no_scale_lr:
        scale = float(args.batch_size * args.world_size) / 128.0
        args.learning_rate = scale * args.learning_rate

    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)
    np.random.seed(args.seed)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)
    logging.info('gpu device = %d' % args.gpu)
    logging.info('args = %s', args)

    # Get data loaders.
    traindir = os.path.join(args.data, 'train')
    validdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.4,
                               contrast=0.4,
                               saturation=0.4,
                               hue=0.2),
        transforms.ToTensor(),
        normalize,
    ])
    val_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ])

    if 'lmdb' in args.data:
        train_data = imagenet_lmdb_dataset(traindir, transform=train_transform)
        valid_data = imagenet_lmdb_dataset(validdir, transform=val_transform)
    else:
        train_data = dset.ImageFolder(traindir, transform=train_transform)
        valid_data = dset.ImageFolder(validdir, transform=val_transform)

    train_sampler = torch.utils.data.distributed.DistributedSampler(train_data)
    train_queue = torch.utils.data.DataLoader(train_data,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              pin_memory=True,
                                              num_workers=8,
                                              sampler=train_sampler)

    valid_queue = torch.utils.data.DataLoader(valid_data,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              pin_memory=True,
                                              num_workers=8)

    # Set up the network.
    if os.path.isfile(args.genotype):
        logging.info('Loading genotype from: %s' % args.genotype)
        genotype = torch.load(args.genotype, map_location='cpu')
    else:
        logging.info('Loading genotype: %s' % args.genotype)
        genotype = eval('genotypes.%s' % args.genotype)
    if not isinstance(genotype, list):
        genotype = [genotype]

    # If num channels not provided, find the max under 600M MAdds.
    if args.init_channels < 0:
        if args.local_rank == 0:
            flops, num_params, init_channels = find_max_channels(
                genotype, args.layers, args.max_M_flops * 1e6)
            logging.info('Num flops = %.2fM', flops / 1e6)
            logging.info('Num params = %.2fM', num_params / 1e6)
        else:
            init_channels = 0
        # All reduce with world_size 1 is sum.
        init_channels = torch.Tensor([init_channels]).cuda()
        init_channels = utils.reduce_tensor(init_channels, 1)
        args.init_channels = int(init_channels.item())
    logging.info('Num channels = %d', args.init_channels)

    # Create model and loss.
    model = Network(args.init_channels, CLASSES, args.layers, args.auxiliary,
                    genotype)
    model = model.cuda()
    model = DDP(model, delay_allreduce=True)

    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()
    criterion_smooth = CrossEntropyLabelSmooth(CLASSES, args.label_smooth)
    criterion_smooth = criterion_smooth.cuda()
    logging.info('param size = %fM', utils.count_parameters_in_M(model))

    # Set up network weights optimizer.
    optimizer = torch.optim.SGD(model.parameters(),
                                args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, float(args.epochs), eta_min=args.min_learning_rate)

    # Train.
    global_step = 0
    best_acc_top1 = 0
    for epoch in range(args.epochs):
        # Shuffle the sampler, update lrs.
        train_queue.sampler.set_epoch(epoch + args.seed)
        # Change lr.
        if epoch >= args.warmup_epochs:
            scheduler.step()
        model.module.drop_path_prob = args.drop_path_prob * epoch / args.epochs

        # Training.
        train_acc, train_obj, global_step = train(train_queue, model,
                                                  criterion_smooth, optimizer,
                                                  epoch, args.learning_rate,
                                                  args.warmup_epochs,
                                                  global_step)
        logging.info('train_acc %f', train_acc)
        writer.add_scalar('train/acc', train_acc, global_step)

        # Validation.
        valid_acc_top1, valid_acc_top5, valid_obj = infer(
            valid_queue, model, criterion)
        logging.info('valid_acc_top1 %f', valid_acc_top1)
        logging.info('valid_acc_top5 %f', valid_acc_top5)
        writer.add_scalar('val/acc_top1', valid_acc_top1, global_step)
        writer.add_scalar('val/acc_top5', valid_acc_top5, global_step)
        writer.add_scalar('val/loss', valid_obj, global_step)

        is_best = False
        if valid_acc_top1 > best_acc_top1:
            best_acc_top1 = valid_acc_top1
            is_best = True

        if args.local_rank == 0:
            utils.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'best_acc_top1': best_acc_top1,
                    'optimizer': optimizer.state_dict(),
                }, is_best, args.save)
def main():
    start = time.time()
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    torch.cuda.set_device(config.local_rank % len(config.gpus))
    torch.distributed.init_process_group(backend='nccl', init_method='env://')
    config.world_size = torch.distributed.get_world_size()
    config.total_batch = config.world_size * config.batch_size

    np.random.seed(config.seed)
    torch.manual_seed(config.seed)
    torch.cuda.manual_seed_all(config.seed)
    torch.backends.cudnn.benchmark = True

    CLASSES = 1000
    channels = [32, 16, 24, 40, 80, 96, 192, 320, 1280]
    steps = [1, 1, 2, 3, 4, 3, 3, 1, 1]
    strides = [2, 1, 2, 2, 1, 2, 1, 1, 1]

    criterion = nn.CrossEntropyLoss()
    criterion_latency = LatencyLoss(channels[2:9], steps[2:8], strides[2:8])
    criterion = criterion.cuda(config.gpus)
    criterion_latency = criterion_latency.cuda(config.gpus)
    model = Network(channels, steps, strides, CLASSES, criterion)
    model = model.to(device)
    #model = DDP(model, delay_allreduce=True)
    # For solve the custome loss can`t use model.parameters() in apex warpped model via https://github.com/NVIDIA/apex/issues/457 and
    model = torch.nn.parallel.DistributedDataParallel(
        model, device_ids=[config.local_rank], output_device=config.local_rank)
    logger.info("param size = %fMB", utils.count_parameters_in_MB(model))

    optimizer = torch.optim.SGD(model.parameters(),
                                config.w_lr,
                                momentum=config.w_momentum,
                                weight_decay=config.w_weight_decay)

    train_data = get_imagenet_iter_torch(
        type='train',
        # image_dir="/googol/atlas/public/cv/ILSVRC/Data/"
        # use soft link `mkdir ./data/imagenet && ln -s /googol/atlas/public/cv/ILSVRC/Data/CLS-LOC/* ./data/imagenet/`
        image_dir=config.data_path + config.dataset.lower(),
        batch_size=config.batch_size,
        num_threads=config.workers,
        world_size=config.world_size,
        local_rank=config.local_rank,
        crop=224,
        device_id=config.local_rank,
        num_gpus=config.gpus,
        portion=config.train_portion)
    valid_data = get_imagenet_iter_torch(
        type='train',
        # image_dir="/googol/atlas/public/cv/ILSVRC/Data/"
        # use soft link `mkdir ./data/imagenet && ln -s /googol/atlas/public/cv/ILSVRC/Data/CLS-LOC/* ./data/imagenet/`
        image_dir=config.data_path + "/" + config.dataset.lower(),
        batch_size=config.batch_size,
        num_threads=config.workers,
        world_size=config.world_size,
        local_rank=config.local_rank,
        crop=224,
        device_id=config.local_rank,
        num_gpus=config.gpus,
        portion=config.val_portion)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, float(config.epochs), eta_min=config.w_lr_min)

    if len(config.gpus) > 1:
        architect = Architect(model.module, config)
    else:
        architect = Architect(module, config)

    best_top1 = 0.
    for epoch in range(config.epochs):
        scheduler.step()
        lr = scheduler.get_lr()[0]
        logger.info('epoch %d lr %e', epoch, lr)

        #print(F.softmax(model.alphas_normal, dim=-1))
        #print(F.softmax(model.alphas_reduce, dim=-1))

        # training
        train_top1, train_loss = train(train_data, valid_data, model,
                                       architect, criterion, criterion_latency,
                                       optimizer, lr, epoch, writer)
        logger.info('Train top1 %f', train_top1)

        # validation
        top1 = 0
        if config.epochs - epoch <= 1:
            top1, loss = infer(valid_data, model, epoch, criterion, writer)
            logger.info('valid top1 %f', top1)

        if len(config.gpus) > 1:
            genotype = model.module.genotype()
        else:
            genotype = model.genotype()
        logger.info("genotype = {}".format(genotype))

        # genotype as a image
        plot_path = os.path.join(config.plot_path,
                                 "EP{:02d}".format(epoch + 1))
        caption = "Epoch {}".format(epoch + 1)
        plot(genotype.normal, plot_path + "-normal")
        plot(genotype.reduce, plot_path + "-reduce")
        # save
        if best_top1 < top1:
            best_top1 = top1
            best_genotype = genotype
            is_best = True
        else:
            is_best = False
        utils.save_checkpoint(model, config.path, is_best)
        print("")

    utils.time(time.time() - start)
    logger.info("Final best Prec@1 = {:.4%}".format(best_top1))
    logger.info("Best Genotype = {}".format(best_genotype))
Exemple #3
0
def main():
    start = time.time()
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    torch.cuda.set_device(config.local_rank % len(config.gpus))
    torch.distributed.init_process_group(backend='nccl', init_method='env://')
    config.world_size = torch.distributed.get_world_size()
    config.total_batch = config.world_size * config.batch_size

    np.random.seed(config.seed)
    torch.manual_seed(config.seed)
    torch.cuda.manual_seed_all(config.seed)
    torch.backends.cudnn.benchmark = True

    CLASSES = 1000
    num_training_samples = 1281167
    num_batches = num_training_samples // config.batch_size
    model_name = config.arch

    if config.epoch_start_cs != -1:
        config.use_all_channels = True

    ### Model
    if model_name == 'ShuffleNas_fixArch':
        architecture = [
            0, 0, 3, 1, 1, 1, 0, 0, 2, 0, 2, 1, 1, 0, 2, 0, 2, 1, 3, 2
        ]
        scale_ids = [
            6, 5, 3, 5, 2, 6, 3, 4, 2, 5, 7, 5, 4, 6, 7, 4, 4, 5, 4, 3
        ]
        # we rescale the channels uniformly to adjust the FLOPs.
        model = get_shufflenas_oneshot(
            architecture,
            scale_ids,
            use_se=config.use_se,
            n_class=CLASSES,
            last_conv_after_pooling=config.last_conv_after_pooling,
            channels_layout=config.channels_layout)
    elif model_name == 'ShuffleNas':
        model = get_shufflenas_oneshot(
            use_all_blocks=config.use_all_blocks,
            use_se=config.use_se,
            n_class=CLASSES,
            last_conv_after_pooling=config.last_conv_after_pooling,
            channels_layout=config.channels_layout)
    else:
        raise NotImplementedError

    model = model.to(device)
    #model.apply(utils.weights_init)
    #model = DDP(model, delay_allreduce=True)
    # For solve the custome loss can`t use model.parameters() in apex warpped model via https://github.com/NVIDIA/apex/issues/457 and https://github.com/NVIDIA/apex/issues/107
    model = torch.nn.parallel.DistributedDataParallel(
        model, device_ids=[config.local_rank], output_device=config.local_rank)
    if model_name == 'ShuffleNas_fixArch':
        logger.info("param size = %fMB", utils.count_parameters_in_MB(model))
    else:
        logger.info("Train Supernet")
    # Loss
    if config.label_smoothing:
        criterion = CrossEntropyLabelSmooth(CLASSES, config.label_smooth)
    else:
        criterion = nn.CrossEntropyLoss()

    weight = model.parameters()
    # Optimizer
    w_optimizer = torch.optim.SGD(weight,
                                  config.w_lr,
                                  momentum=config.w_momentum,
                                  weight_decay=config.w_weight_decay)

    w_schedule = utils.Schedule(w_optimizer)

    train_data = get_imagenet_iter_torch(
        type='train',
        # image_dir="/googol/atlas/public/cv/ILSVRC/Data/"
        # use soft link `mkdir ./data/imagenet && ln -s /googol/atlas/public/cv/ILSVRC/Data/CLS-LOC/* ./data/imagenet/`
        image_dir=config.data_path + config.dataset.lower(),
        batch_size=config.batch_size,
        num_threads=config.workers,
        world_size=config.world_size,
        local_rank=config.local_rank,
        crop=224,
        device_id=config.local_rank,
        num_gpus=config.gpus,
        portion=config.train_portion)
    valid_data = get_imagenet_iter_torch(
        type='train',
        # image_dir="/googol/atlas/public/cv/ILSVRC/Data/"
        # use soft link `mkdir ./data/imagenet && ln -s /googol/atlas/public/cv/ILSVRC/Data/CLS-LOC/* ./data/imagenet/`
        image_dir=config.data_path + "/" + config.dataset.lower(),
        batch_size=config.batch_size,
        num_threads=config.workers,
        world_size=config.world_size,
        local_rank=config.local_rank,
        crop=224,
        device_id=config.local_rank,
        num_gpus=config.gpus,
        portion=config.val_portion)

    best_top1 = 0.
    for epoch in range(config.epochs):
        if epoch < config.warmup_epochs:
            lr = w_schedule.update_schedule_linear(epoch, config.w_lr,
                                                   config.w_weight_decay,
                                                   config.batch_size)
        else:
            w_scheduler = w_schedule.get_schedule_cosine(
                config.w_lr_min, config.epochs)
            w_scheduler.step()
            lr = w_scheduler.get_lr()[0]
        logger.info('epoch %d lr %e', epoch, lr)

        if epoch > config.epoch_start_cs:
            config.use_all_channels = False
        # training
        train_top1, train_loss = train(train_data, valid_data, model,
                                       criterion, w_optimizer, lr, epoch,
                                       writer, model_name)
        logger.info('Train top1 %f', train_top1)

        # validation
        top1 = 0
        if epoch % 10 == 0:
            top1, loss = infer(valid_data, model, epoch, criterion, writer,
                               model_name)
            logger.info('valid top1 %f', top1)

        # save
        if best_top1 < top1:
            best_top1 = top1
            is_best = True
        else:
            is_best = False
        utils.save_checkpoint(model, config.path, is_best)
        print("")

    utils.time(time.time() - start)
    logger.info("Final best Prec@1 = {:.4%}".format(best_top1))
Exemple #4
0
def main():
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    # set seeds
    np.random.seed(args.seed)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)
    logging.info('args = %s', args)

    # Get data loaders.
    traindir = os.path.join(args.data, 'train')
    validdir = os.path.join(args.data, 'val')

    # data augmentation
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.4,
                               contrast=0.4,
                               saturation=0.4,
                               hue=0.2),
        transforms.ToTensor(),
        normalize,
    ])
    val_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ])

    train_data = dset.ImageFolder(traindir, transform=train_transform)
    valid_data = dset.ImageFolder(validdir, transform=val_transform)

    # dataset split
    valid_data, test_data = utils.dataset_split(valid_data, len(valid_data))

    train_sampler = torch.utils.data.distributed.DistributedSampler(train_data)
    train_queue = torch.utils.data.DataLoader(train_data,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              pin_memory=True,
                                              num_workers=8,
                                              sampler=train_sampler)

    valid_queue = torch.utils.data.DataLoader(valid_data,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              pin_memory=True,
                                              num_workers=8)

    test_queue = torch.utils.data.DataLoader(test_data,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=8)

    # Create model and loss.
    torch.hub.set_dir('/tmp/hub_cache_%d' % args.local_rank)
    model = torch.hub.load('pytorch/vision:v0.4.2',
                           'resnet50',
                           pretrained=False)
    model = model.cuda()
    model = DDP(model, delay_allreduce=True)

    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()
    criterion_smooth = CrossEntropyLabelSmooth(CLASSES, args.label_smooth)
    criterion_smooth = criterion_smooth.cuda()

    # Set up network weights optimizer.
    if args.optimizer == 'SGD':
        optimizer = torch.optim.SGD(model.parameters(),
                                    args.learning_rate,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
    elif args.optimizer == 'fromage':
        optimizer = Fromage(model.parameters(), args.learning_rate)
    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     args.learning_rate,
                                     weight_decay=args.weight_decay)
    else:
        raise NotImplementedError

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                gamma=0.1,
                                                step_size=30)

    # Train.
    global_step = 0
    best_acc_top1 = 0
    for epoch in range(args.epochs):
        # Shuffle the sampler, update lrs.
        train_queue.sampler.set_epoch(epoch + args.seed)

        # Training.
        train_acc_top1, train_acc_top5, train_obj, global_step = train(
            train_queue, model, criterion_smooth, optimizer, global_step)
        logging.info('epoch %d train_acc %f', epoch, train_acc_top1)
        writer.add_scalar('train/loss', train_obj, global_step)
        writer.add_scalar('train/acc_top1', train_acc_top1, global_step)
        writer.add_scalar('train/acc_top5', train_acc_top5, global_step)
        writer.add_scalar('train/lr',
                          optimizer.state_dict()['param_groups'][0]['lr'],
                          global_step)

        # Validation.
        valid_acc_top1, valid_acc_top5, valid_obj = infer(
            valid_queue, model, criterion)
        logging.info('valid_acc_top1 %f', valid_acc_top1)
        logging.info('valid_acc_top5 %f', valid_acc_top5)
        writer.add_scalar('val/acc_top1', valid_acc_top1, global_step)
        writer.add_scalar('val/acc_top5', valid_acc_top5, global_step)
        writer.add_scalar('val/loss', valid_obj, global_step)

        # Test
        test_acc_top1, test_acc_top5, test_obj = infer(test_queue, model,
                                                       criterion)
        logging.info('test_acc_top1 %f', test_acc_top1)
        logging.info('test_acc_top5 %f', test_acc_top5)
        writer.add_scalar('test/acc_top1', test_acc_top1, global_step)
        writer.add_scalar('test/acc_top5', test_acc_top5, global_step)
        writer.add_scalar('test/loss', test_obj, global_step)

        is_best = False
        if valid_acc_top1 > best_acc_top1:
            best_acc_top1 = valid_acc_top1
            is_best = True

        if args.local_rank == 0:
            utils.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'best_acc_top1': best_acc_top1,
                    'optimizer': optimizer.state_dict(),
                }, is_best, args.save)

        # Update LR.
        scheduler.step()

    writer.flush()
Exemple #5
0
def main():

    cudnn.benchmark = True
    cudnn.enabled = True
    logger.info("args = %s", args)

    if args.compress_rate:
        import re
        cprate_str = args.compress_rate
        cprate_str_list = cprate_str.split('+')
        pat_cprate = re.compile(r'\d+\.\d*')
        pat_num = re.compile(r'\*\d+')
        cprate = []
        for x in cprate_str_list:
            num = 1
            find_num = re.findall(pat_num, x)
            if find_num:
                assert len(find_num) == 1
                num = int(find_num[0].replace('*', ''))
            find_cprate = re.findall(pat_cprate, x)
            assert len(find_cprate) == 1
            cprate += [float(find_cprate[0])] * num

        compress_rate = cprate

    # load model
    logger.info('compress_rate:' + str(compress_rate))
    logger.info('==> Building model..')
    model = eval(args.arch)(compress_rate=compress_rate).cuda()
    logger.info(model)

    #calculate model size
    input_image_size = 32
    input_image = torch.randn(1, 3, input_image_size, input_image_size).cuda()
    flops, params = profile(model, inputs=(input_image, ))
    logger.info('Params: %.2f' % (params))
    logger.info('Flops: %.2f' % (flops))

    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    lr_decay_step = list(map(int, args.lr_decay_step.split(',')))
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=lr_decay_step,
                                                     gamma=0.1)

    start_epoch = 0
    best_top1_acc = 0

    # load the checkpoint if it exists
    checkpoint_tar = os.path.join(args.job_dir, 'checkpoint.pth.tar')
    if os.path.exists(checkpoint_tar):
        logger.info('loading checkpoint {} ..........'.format(checkpoint_tar))
        checkpoint = torch.load(checkpoint_tar)
        start_epoch = checkpoint['epoch']
        best_top1_acc = checkpoint['best_top1_acc']
        model.load_state_dict(checkpoint['state_dict'])
        logger.info("loaded checkpoint {} epoch = {}".format(
            checkpoint_tar, checkpoint['epoch']))  #'''
    else:
        if args.use_pretrain:
            logger.info('resuming from pretrain model')
            origin_model = eval(args.arch)(compress_rate=[0.] * 100).cuda()
            ckpt = torch.load(args.pretrain_dir, map_location='cuda:0')

            #if args.arch=='resnet_56':
            #    origin_model.load_state_dict(ckpt['state_dict'],strict=False)
            if args.arch == 'densenet_40' or args.arch == 'resnet_110':
                new_state_dict = OrderedDict()
                for k, v in ckpt['state_dict'].items():
                    new_state_dict[k.replace('module.', '')] = v
                origin_model.load_state_dict(new_state_dict)
            else:
                origin_model.load_state_dict(ckpt['state_dict'])

            oristate_dict = origin_model.state_dict()

            if args.arch == 'googlenet':
                load_google_model(model, oristate_dict, args.random_rule)
            elif args.arch == 'vgg_16_bn':
                load_vgg_model(model, oristate_dict, args.random_rule)
            elif args.arch == 'resnet_56':
                load_resnet_model(model, oristate_dict, args.random_rule, 56)
            elif args.arch == 'resnet_110':
                load_resnet_model(model, oristate_dict, args.random_rule, 110)
            elif args.arch == 'densenet_40':
                load_densenet_model(model, oristate_dict, args.random_rule)
            else:
                raise
        else:
            logger('training from scratch')

    # adjust the learning rate according to the checkpoint
    for epoch in range(start_epoch):
        scheduler.step()

    # load training data
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    trainset = torchvision.datasets.CIFAR10(root=args.data_dir,
                                            train=True,
                                            download=True,
                                            transform=transform_train)
    train_loader = torch.utils.data.DataLoader(trainset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=2)
    testset = torchvision.datasets.CIFAR10(root=args.data_dir,
                                           train=False,
                                           download=True,
                                           transform=transform_test)
    val_loader = torch.utils.data.DataLoader(testset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=2)

    # train the model
    epoch = start_epoch
    while epoch < args.epochs:
        train_obj, train_top1_acc, train_top5_acc = train(
            epoch, train_loader, model, criterion, optimizer, scheduler)
        valid_obj, valid_top1_acc, valid_top5_acc = validate(
            epoch, val_loader, model, criterion, args)

        is_best = False
        if valid_top1_acc > best_top1_acc:
            best_top1_acc = valid_top1_acc
            is_best = True

        utils.save_checkpoint(
            {
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'best_top1_acc': best_top1_acc,
                'optimizer': optimizer.state_dict(),
            }, is_best, args.job_dir)

        epoch += 1
        logger.info("=>Best accuracy {:.3f}".format(best_top1_acc))  #
def main():
    if args.load_checkpoint:
        args.save = Path(args.load_checkpoint) / 'eval-imagenet-{}-{}'.format(args.save, time.strftime("%Y%m%d-%H%M%S"))
    else:
        args.save = Path('logs') / 'eval-imagenet-{}-{}'.format(args.save, time.strftime("%Y%m%d-%H%M%S"))
    utils.create_exp_dir(args.save, scripts_to_save=glob.glob('*.py'))

    log_format = '%(asctime)s %(message)s'
    logging.basicConfig(stream=sys.stdout, level=logging.INFO,
                        format=log_format, datefmt='%m/%d %I:%M:%S %p')
    fh = logging.FileHandler(args.save / 'log.txt')
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)

    np.random.seed(args.seed)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    logging.info('gpu device = %d' % args.gpu)
    logging.info("args = %s", args)

    model = eval(args.model)

    # flops, params = profile(model, input_size=(1, 3, 224, 224))
    # print("flops" + str(flops) + " params" + str(params))
    if args.load_checkpoint:
        dictionary = torch.load(args.load_checkpoint)
        start_epoch = dictionary['epoch'] if args.start_epoch == -1 else args.start_epoch
        model.load_state_dict(dictionary['state_dict'])
    else:
        start_epoch = 0 if args.start_epoch == -1 else args.start_epoch

    direct_model = model

    if args.gpu:
        model = nn.DataParallel(model)

    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

    optimizer = torch.optim.SGD(
        model.parameters(),
        args.learning_rate,
        momentum=args.momentum,
        weight_decay=args.weight_decay
    )

    # if args.load_checkpoint:
    #   optimizer.load_state_dict(dictionary['optimizer'])
    #   del dictionary

    traindir = os.path.join(args.data, 'train')
    validdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    train_data = dset.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(
                brightness=0.4,
                contrast=0.4,
                saturation=0.4,
                hue=0.2),
            transforms.ToTensor(),
            normalize,
        ]))
    valid_data = dset.ImageFolder(
        validdir,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ]))

    train_queue = torch.utils.data.DataLoader(
        train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=args.num_workers)

    valid_queue = torch.utils.data.DataLoader(
        valid_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=args.num_workers)

    if args.eval:
        direct_model.drop_path_prob = 0
        valid_acc_top1, valid_acc_top5, valid_obj = infer(valid_queue, model, args.gpu)
        logging.info('valid_acc_top1 %f', valid_acc_top1)
        logging.info('valid_acc_top5 %f', valid_acc_top5)
        return

    if args.period is not None:
        periods = args.period.split(',')
        periods = [int(p) for p in periods]
        totals = []
        total = 0
        for p in periods:
            total += p
            totals.append(total)
        scheduler = CosineAnnealingLR(optimizer, periods[0])
    else:
        periods = None
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.decay_period, gamma=args.gamma)

    best_acc_top1 = 0
    for epoch in range(start_epoch, args.epochs):
        if args.period is None:
            scheduler.step(epoch)
        else:
            assert len(periods) > 0
            index = bisect.bisect_left(totals, epoch)
            scheduler.T_max = periods[index]
            if index == 0:
                e = epoch
            else:
                e = epoch - totals[index - 1]
            scheduler.step(e % periods[index])
            logging.info("schedule epoch:" + str(e % periods[index]))
            logging.info("schedule period:" + str(periods[index]))
        logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0])
        direct_model.drop_path_prob = args.drop_path_prob * epoch / args.epochs

        train_acc, train_obj = train(train_queue, model, optimizer, args.gpu)
        logging.info('train_acc %f', train_acc)

        valid_acc_top1, valid_acc_top5, valid_obj = infer(valid_queue, model, args.gpu)
        logging.info('valid_acc_top1 %f', valid_acc_top1)
        logging.info('valid_acc_top5 %f', valid_acc_top5)

        is_best = False
        if valid_acc_top1 > best_acc_top1:
            best_acc_top1 = valid_acc_top1
            is_best = True

        utils.save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': model.module.state_dict(),
            'best_acc_top1': best_acc_top1,
            'optimizer': optimizer.state_dict(),
        }, is_best, args.save)
Exemple #7
0
def main():
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    if is_wandb_used:
        wandb.init(
            project="automl-gradient-based-nas",
            name="ImageNet:" + str(args.arch),
            config=args,
            entity="automl"
        )

    np.random.seed(args.seed)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)
    logging.info('gpu device = %d' % args.gpu)
    logging.info("args = %s", args)

    genotype = eval("genotypes.%s" % args.arch)
    model = Network(args.init_channels, CLASSES, args.layers, args.auxiliary, genotype)
    if args.parallel:
        model = nn.DataParallel(model).cuda()
    else:
        model = model.cuda()

    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()
    criterion_smooth = CrossEntropyLabelSmooth(CLASSES, args.label_smooth)
    criterion_smooth = criterion_smooth.cuda()

    optimizer = torch.optim.SGD(
        model.parameters(),
        args.learning_rate,
        momentum=args.momentum,
        weight_decay=args.weight_decay
    )

    traindir = os.path.join(args.data, 'train')
    validdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    train_data = dset.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(
                brightness=0.4,
                contrast=0.4,
                saturation=0.4,
                hue=0.2),
            transforms.ToTensor(),
            normalize,
        ]))
    valid_data = dset.ImageFolder(
        validdir,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ]))

    train_queue = torch.utils.data.DataLoader(
        train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=4)

    valid_queue = torch.utils.data.DataLoader(
        valid_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=4)

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.decay_period, gamma=args.gamma)

    best_acc_top1 = 0
    for epoch in range(args.epochs):
        scheduler.step()
        logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0])
        model.drop_path_prob = args.drop_path_prob * epoch / args.epochs

        train_acc, train_obj = train(train_queue, model, criterion_smooth, optimizer)
        logging.info('train_acc %f', train_acc)

        valid_acc_top1, valid_acc_top5, valid_obj = infer(valid_queue, model, criterion)
        logging.info('valid_acc_top1 %f', valid_acc_top1)
        logging.info('valid_acc_top5 %f', valid_acc_top5)

        is_best = False
        if valid_acc_top1 > best_acc_top1:
            best_acc_top1 = valid_acc_top1
            is_best = True

        utils.save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'best_acc_top1': best_acc_top1,
            'optimizer': optimizer.state_dict(),
        }, is_best, args.save)
def main():

    cudnn.benchmark = True
    cudnn.enabled=True
    logger.info("args = %s", args)

    if args.compress_rate:
        import re
        cprate_str = args.compress_rate
        cprate_str_list = cprate_str.split('+')
        pat_cprate = re.compile(r'\d+\.\d*')
        pat_num = re.compile(r'\*\d+')
        cprate = []
        for x in cprate_str_list:
            num = 1
            find_num = re.findall(pat_num, x)
            if find_num:
                assert len(find_num) == 1
                num = int(find_num[0].replace('*', ''))
            find_cprate = re.findall(pat_cprate, x)
            assert len(find_cprate) == 1
            cprate += [float(find_cprate[0])] * num

        compress_rate = cprate

    # load model
    logger.info('compress_rate:' + str(compress_rate))
    logger.info('==> Building model..')
    model = eval(args.arch)(compress_rate=compress_rate).cuda()
    logger.info(model)

    #calculate model size
    input_image_size=32
    input_image = torch.randn(1, 3, input_image_size, input_image_size).cuda()
    flops, params = profile(model, inputs=(input_image,))
    logger.info('Params: %.2f' % (params))
    logger.info('Flops: %.2f' % (flops))

    # load training data
    train_loader, val_loader = cifar10.load_data(args)

    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()

    if args.test_only:
        if os.path.isfile(args.test_model_dir):
            logger.info('loading checkpoint {} ..........'.format(args.test_model_dir))
            checkpoint = torch.load(args.test_model_dir)
            model.load_state_dict(checkpoint['state_dict'])
            valid_obj, valid_top1_acc, valid_top5_acc = validate(0, val_loader, model, criterion, args)
        else:
            logger.info('please specify a checkpoint file')
        return

    if len(args.gpu) > 1:
        device_id = []
        for i in range((len(args.gpu) + 1) // 2):
            device_id.append(i)
        model = nn.DataParallel(model, device_ids=device_id).cuda()

    optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
    lr_decay_step = list(map(int, args.lr_decay_step.split(',')))
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=lr_decay_step, gamma=0.1)

    start_epoch = 0
    best_top1_acc= 0

    # load the checkpoint if it exists
    checkpoint_dir = os.path.join(args.job_dir, 'checkpoint.pth.tar')
    if args.resume:
        logger.info('loading checkpoint {} ..........'.format(checkpoint_dir))
        checkpoint = torch.load(checkpoint_dir)
        start_epoch = checkpoint['epoch'] + 1
        best_top1_acc = checkpoint['best_top1_acc']

        # deal with the single-multi GPU problem
        new_state_dict = OrderedDict()
        tmp_ckpt = checkpoint['state_dict']
        if len(args.gpu) > 1:
            for k, v in tmp_ckpt.items():
                new_state_dict['module.' + k.replace('module.', '')] = v
        else:
            for k, v in tmp_ckpt.items():
                new_state_dict[k.replace('module.', '')] = v

        model.load_state_dict(new_state_dict)
        logger.info("loaded checkpoint {} epoch = {}".format(checkpoint_dir, checkpoint['epoch']))
    else:
        if args.use_pretrain:
            logger.info('resuming from pretrain model')
            origin_model = eval(args.arch)(compress_rate=[0.] * 100).cuda()
            ckpt = torch.load(args.pretrain_dir, map_location='cuda:0')

            #if args.arch=='resnet_56':
            #    origin_model.load_state_dict(ckpt['state_dict'],strict=False)
            if args.arch == 'densenet_40' or args.arch == 'resnet_110':
                new_state_dict = OrderedDict()
                for k, v in ckpt['state_dict'].items():
                    new_state_dict[k.replace('module.', '')] = v
                origin_model.load_state_dict(new_state_dict)
            else:
                origin_model.load_state_dict(ckpt['state_dict'])

            oristate_dict = origin_model.state_dict()

            if args.arch == 'googlenet':
                load_google_model(model, oristate_dict)
            elif args.arch == 'vgg_16_bn':
                load_vgg_model(model, oristate_dict)
            elif args.arch == 'resnet_56':
                load_resnet_model(model, oristate_dict, 56)
            elif args.arch == 'resnet_110':
                load_resnet_model(model, oristate_dict, 110)
            elif args.arch == 'densenet_40':
                load_densenet_model(model, oristate_dict)
            else:
                raise
        else:
            logger('training from scratch')

    # adjust the learning rate according to the checkpoint
    for epoch in range(start_epoch):
        scheduler.step()

    # train the model
    epoch = start_epoch
    while epoch < args.epochs:
        train_obj, train_top1_acc,  train_top5_acc = train(epoch,  train_loader, model, criterion, optimizer, scheduler)
        valid_obj, valid_top1_acc, valid_top5_acc = validate(epoch, val_loader, model, criterion, args)

        is_best = False
        if valid_top1_acc > best_top1_acc:
            best_top1_acc = valid_top1_acc
            is_best = True

        utils.save_checkpoint({
            'epoch': epoch,
            'state_dict': model.state_dict(),
            'best_top1_acc': best_top1_acc,
            'optimizer' : optimizer.state_dict(),
            }, is_best, args.job_dir)

        epoch += 1
        logger.info("=>Best accuracy {:.3f}".format(best_top1_acc))#
Exemple #9
0
    def train_epochs(epochs_to_train,
                     iteration,
                     args=args,
                     model=model_init,
                     optimizer=optimizer_init,
                     scheduler=scheduler_init,
                     train_queue=train_queue,
                     valid_queue=valid_queue,
                     train_transform=train_transform,
                     valid_transform=valid_transform,
                     architect=architect_init,
                     criterion=criterion,
                     primitives=primitives,
                     analyser=analyser_init,
                     la_tracker=la_tracker,
                     errors_dict=errors_dict,
                     start_epoch=-1):

        logging.info('STARTING ITERATION: %d', iteration)
        logging.info('EPOCHS TO TRAIN: %d', epochs_to_train - start_epoch - 1)

        la_tracker.stop_search = False

        if epochs_to_train - start_epoch - 1 <= 0:
            return model.genotype(), -1
        for epoch in range(start_epoch + 1, epochs_to_train):
            # set the epoch to the right one
            #epoch += args.epochs - epochs_to_train

            scheduler.step(epoch)
            lr = scheduler.get_lr()[0]
            if args.drop_path_prob != 0:
                model.drop_path_prob = args.drop_path_prob * epoch / (
                    args.epochs - 1)
                train_transform.transforms[
                    -1].cutout_prob = args.cutout_prob * epoch / (args.epochs -
                                                                  1)
                logging.info('epoch %d lr %e drop_prob %e cutout_prob %e',
                             epoch, lr, model.drop_path_prob,
                             train_transform.transforms[-1].cutout_prob)
            else:
                logging.info('epoch %d lr %e', epoch, lr)

            # training
            train_acc, train_obj = train(epoch, primitives, train_queue,
                                         valid_queue, model, architect,
                                         criterion, optimizer, lr, analyser,
                                         la_tracker, iteration)
            logging.info('train_acc %f', train_acc)

            # validation
            valid_acc, valid_obj = infer(valid_queue, model, criterion)
            logging.info('valid_acc %f', valid_acc)

            # update the errors dictionary
            errors_dict['train_acc'].append(100 - train_acc)
            errors_dict['train_loss'].append(train_obj)
            errors_dict['valid_acc'].append(100 - valid_acc)
            errors_dict['valid_loss'].append(valid_obj)

            genotype = model.genotype()

            logging.info('genotype = %s', genotype)

            print(F.softmax(model.alphas_normal, dim=-1))
            print(F.softmax(model.alphas_reduce, dim=-1))

            state = {
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'alphas_normal': model.alphas_normal.data,
                'alphas_reduce': model.alphas_reduce.data,
                'arch_optimizer': architect.optimizer.state_dict(),
                'lr': lr,
                'ev': la_tracker.ev,
                'ev_local_avg': la_tracker.ev_local_avg,
                'genotypes': la_tracker.genotypes,
                'la_epochs': la_tracker.la_epochs,
                'la_start_idx': la_tracker.la_start_idx,
                'la_end_idx': la_tracker.la_end_idx,
                #'scheduler': scheduler.state_dict(),
            }

            utils.save_checkpoint(state, False, args.save, epoch, args.task_id)

            if not args.compute_hessian:
                ev = -1
            else:
                ev = la_tracker.ev[-1]
            params = {
                'iteration': iteration,
                'epoch': epoch,
                'wd': args.weight_decay,
                'ev': ev,
            }

            schedule_of_params.append(params)

            # limit the number of iterations based on the maximum regularization
            # value predefined by the user
            final_iteration = round(
                np.log(args.max_weight_decay) / np.log(args.weight_decay),
                1) == 1.  ##lr decay到一定程度就停止

            if la_tracker.stop_search and not final_iteration:
                if args.early_stop == 1:
                    # set the following to the values they had at stop_epoch
                    errors_dict['valid_acc'] = errors_dict[
                        'valid_acc'][:la_tracker.stop_epoch + 1]
                    genotype = la_tracker.stop_genotype
                    valid_acc = 100 - errors_dict['valid_acc'][
                        la_tracker.stop_epoch]
                    logging.info(
                        'Decided to stop the search at epoch %d (Current epoch: %d)',
                        la_tracker.stop_epoch, epoch)
                    logging.info('Validation accuracy at stop epoch: %f',
                                 valid_acc)
                    logging.info('Genotype at stop epoch: %s', genotype)
                    break

                elif args.early_stop == 2:
                    # simulate early stopping and continue search afterwards
                    simulated_errors_dict = errors_dict[
                        'valid_acc'][:la_tracker.stop_epoch + 1]
                    simulated_genotype = la_tracker.stop_genotype
                    simulated_valid_acc = 100 - simulated_errors_dict[
                        la_tracker.stop_epoch]
                    logging.info(
                        '(SIM) Decided to stop the search at epoch %d (Current epoch: %d)',
                        la_tracker.stop_epoch, epoch)
                    logging.info('(SIM) Validation accuracy at stop epoch: %f',
                                 simulated_valid_acc)
                    logging.info('(SIM) Genotype at stop epoch: %s',
                                 simulated_genotype)

                    with open(
                            os.path.join(args.save,
                                         'arch_early_{}'.format(args.task_id)),
                            'w') as file:
                        file.write(str(simulated_genotype))

                    utils.write_yaml_results(args,
                                             'early_' + args.results_file_arch,
                                             str(simulated_genotype))
                    utils.write_yaml_results(args, 'early_stop_epochs',
                                             la_tracker.stop_epoch)

                    args.early_stop = 0

                elif args.early_stop == 3:
                    # adjust regularization
                    simulated_errors_dict = errors_dict[
                        'valid_acc'][:la_tracker.stop_epoch + 1]
                    simulated_genotype = la_tracker.stop_genotype
                    simulated_valid_acc = 100 - simulated_errors_dict[
                        la_tracker.stop_epoch]
                    stop_epoch = la_tracker.stop_epoch
                    start_again_epoch = stop_epoch - args.extra_rollback_epochs
                    logging.info(
                        '(ADA) Decided to increase regularization at epoch %d (Current epoch: %d)',
                        stop_epoch, epoch)
                    logging.info('(ADA) Rolling back to epoch %d',
                                 start_again_epoch)
                    logging.info(
                        '(ADA) Restoring model parameters and continuing for %d epochs',
                        epochs_to_train - start_again_epoch - 1)

                    if iteration == 1:
                        logging.info(
                            '(ADA) Saving the architecture at the early stop epoch and '
                            'continuing with the adaptive regularization strategy'
                        )
                        utils.write_yaml_results(
                            args, 'early_' + args.results_file_arch,
                            str(simulated_genotype))

                    del model
                    del architect
                    del optimizer
                    del scheduler
                    del analyser

                    model_new = Network(args.init_channels,
                                        args.n_classes,
                                        args.layers,
                                        criterion,
                                        primitives,
                                        steps=args.nodes)
                    model_new = model_new.cuda()

                    optimizer_new = torch.optim.SGD(
                        model_new.parameters(),
                        args.learning_rate,
                        momentum=args.momentum,
                        weight_decay=args.weight_decay)

                    architect_new = Architect(model_new, args)

                    analyser_new = Analyzer(args, model_new)

                    la_tracker = utils.EVLocalAvg(args.window,
                                                  args.report_freq_hessian,
                                                  args.epochs)

                    lr = utils.load_checkpoint(model_new, optimizer_new, None,
                                               architect_new, args.save,
                                               la_tracker, start_again_epoch,
                                               args.task_id)

                    args.weight_decay *= args.mul_factor
                    for param_group in optimizer_new.param_groups:
                        param_group['weight_decay'] = args.weight_decay

                    scheduler_new = CosineAnnealingLR(
                        optimizer_new,
                        float(args.epochs),
                        eta_min=args.learning_rate_min)

                    logging.info('(ADA) Validation accuracy at stop epoch: %f',
                                 simulated_valid_acc)
                    logging.info('(ADA) Genotype at stop epoch: %s',
                                 simulated_genotype)

                    logging.info(
                        '(ADA) Adjusting L2 regularization to the new value: %f',
                        args.weight_decay)

                    genotype, valid_acc = train_epochs(args.epochs,
                                                       iteration + 1,
                                                       model=model_new,
                                                       optimizer=optimizer_new,
                                                       architect=architect_new,
                                                       scheduler=scheduler_new,
                                                       analyser=analyser_new,
                                                       start_epoch=start_epoch)
                    args.early_stop = 0
                    break

        return genotype, valid_acc
def main():
    if not torch.cuda.is_available():
        logging.info('No GPU device available')
        sys.exit(1)
    np.random.seed(args.seed)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)
    logging.info("args = %s", args)
    logging.info("unparsed_args = %s", unparsed)
    num_gpus = torch.cuda.device_count()
    genotype = eval("genotypes.%s" % args.arch)
    print('---------Genotype---------')
    logging.info(genotype)
    print('--------------------------')
    model = Network(args.init_channels, CLASSES, args.layers, args.auxiliary,
                    genotype)
    if num_gpus > 1:
        model = nn.DataParallel(model)
        model = model.cuda()
    else:
        model = model.cuda()
    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()
    criterion_smooth = CrossEntropyLabelSmooth(CLASSES, args.label_smooth)
    criterion_smooth = criterion_smooth.cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    traindir = os.path.join(args.data, 'train')
    validdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_data = dset.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(brightness=0.4,
                                   contrast=0.4,
                                   saturation=0.4,
                                   hue=0.2),
            transforms.ToTensor(),
            normalize,
        ]))
    valid_data = dset.ImageFolder(
        validdir,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ]))

    train_queue = torch.utils.data.DataLoader(train_data,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              pin_memory=True,
                                              num_workers=args.workers)

    valid_queue = torch.utils.data.DataLoader(valid_data,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              pin_memory=True,
                                              num_workers=args.workers)

    #    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.decay_period, gamma=args.gamma)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, float(args.epochs))
    best_acc_top1 = 0
    best_acc_top5 = 0
    lr = args.learning_rate
    for epoch in range(args.epochs):
        if args.lr_scheduler == 'cosine':
            scheduler.step()
            current_lr = scheduler.get_lr()[0]
        elif args.lr_scheduler == 'linear':
            current_lr = adjust_lr(optimizer, epoch)
        else:
            print('Wrong lr type, exit')
            sys.exit(1)
        logging.info('Epoch: %d lr %e', epoch, current_lr)
        if epoch < 5 and args.batch_size > 256:
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr * (epoch + 1) / 5.0
            logging.info('Warming-up Epoch: %d, LR: %e', epoch,
                         lr * (epoch + 1) / 5.0)
        if num_gpus > 1:
            model.module.drop_path_prob = args.drop_path_prob * epoch / args.epochs
        else:
            model.drop_path_prob = args.drop_path_prob * epoch / args.epochs
        epoch_start = time.time()
        train_acc, train_obj = train(train_queue, model, criterion_smooth,
                                     optimizer)
        logging.info('Train_acc: %f', train_acc)

        valid_acc_top1, valid_acc_top5, valid_obj = infer(
            valid_queue, model, criterion)
        logging.info('Valid_acc_top1: %f', valid_acc_top1)
        logging.info('Valid_acc_top5: %f', valid_acc_top5)
        epoch_duration = time.time() - epoch_start
        logging.info('Epoch time: %ds.', epoch_duration)
        is_best = False
        if valid_acc_top5 > best_acc_top5:
            best_acc_top5 = valid_acc_top5
        if valid_acc_top1 > best_acc_top1:
            best_acc_top1 = valid_acc_top1
            is_best = True
        logging.info('Best_acc_top1: %f', best_acc_top1)
        logging.info('Best_acc_top5: %f', best_acc_top5)
        utils.save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_acc_top1': best_acc_top1,
                'optimizer': optimizer.state_dict(),
            }, is_best, args.save)
Exemple #11
0
def main():
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    # np.random.seed(args.seed)
    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)
    logging.info('gpu device = %d' % args.gpu)
    logging.info("args = %s", args)

    genotype = eval("genotypes.%s" % args.arch)

    model = Network(args.init_channels, CLASSES, args.layers, args.auxiliary,
                    genotype)
    model = model.cuda()

    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    traindir = os.path.join(args.data, 'train')
    validdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.4802, 0.4481, 0.3975],
                                     std=[0.2302, 0.2265, 0.2262])
    train_data = dset.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomCrop(64, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))
    valid_data = dset.ImageFolder(
        validdir, transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]))

    train_queue = torch.utils.data.DataLoader(train_data,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              pin_memory=True,
                                              num_workers=4)

    valid_queue = torch.utils.data.DataLoader(valid_data,
                                              batch_size=args.batch_size // 2,
                                              shuffle=False,
                                              pin_memory=True,
                                              num_workers=4)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, float(args.epochs))

    best_acc_top1 = 0
    for epoch in range(args.epochs):
        scheduler.step()
        logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0])
        model.drop_path_prob = args.drop_path_prob * epoch / args.epochs

        train_acc, train_obj = train(train_queue, model, criterion, optimizer)
        logging.info('train_acc %f', train_acc)

        valid_acc_top1, valid_acc_top5, valid_obj = infer(
            valid_queue, model, criterion)
        logging.info('valid_acc_top1 %f', valid_acc_top1)
        logging.info('valid_acc_top5 %f', valid_acc_top5)

        is_best = False
        if valid_acc_top1 > best_acc_top1:
            best_acc_top1 = valid_acc_top1
            is_best = True

        utils.save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_acc_top1': best_acc_top1,
                'optimizer': optimizer.state_dict(),
            }, is_best, args.save)
Exemple #12
0
def main():

    start_t = time.time()

    cudnn.benchmark = True
    cudnn.enabled = True
    logger.info("args = %s", args)

    if args.compress_rate:
        import re
        cprate_str = args.compress_rate
        cprate_str_list = cprate_str.split('+')
        pat_cprate = re.compile(r'\d+\.\d*')
        pat_num = re.compile(r'\*\d+')
        cprate = []
        for x in cprate_str_list:
            num = 1
            find_num = re.findall(pat_num, x)
            if find_num:
                assert len(find_num) == 1
                num = int(find_num[0].replace('*', ''))
            find_cprate = re.findall(pat_cprate, x)
            assert len(find_cprate) == 1
            cprate += [float(find_cprate[0])] * num

        compress_rate = cprate

    # load model
    logger.info('compress_rate:' + str(compress_rate))
    logger.info('==> Building model..')
    model = eval(args.arch)(compress_rate=compress_rate).cuda()
    logger.info(model)

    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()
    criterion_smooth = utils.CrossEntropyLabelSmooth(CLASSES,
                                                     args.label_smooth)
    criterion_smooth = criterion_smooth.cuda()

    # load training data
    print('==> Preparing data..')
    if args.use_dali:

        def get_data_set(type='train'):
            if type == 'train':
                return imagenet_dali.get_imagenet_iter_dali('train',
                                                            args.data_dir,
                                                            args.batch_size,
                                                            num_threads=4,
                                                            crop=224,
                                                            device_id=0,
                                                            num_gpus=1)
            else:
                return imagenet_dali.get_imagenet_iter_dali('val',
                                                            args.data_dir,
                                                            args.batch_size,
                                                            num_threads=4,
                                                            crop=224,
                                                            device_id=0,
                                                            num_gpus=1)

        train_loader = get_data_set('train')
        val_loader = get_data_set('val')
    else:
        data_tmp = imagenet.Data(args)
        train_loader = data_tmp.train_loader
        val_loader = data_tmp.test_loader

    # calculate model size
    input_image_size = 224
    input_image = torch.randn(1, 3, input_image_size, input_image_size).cuda()
    flops, params = profile(model, inputs=(input_image, ))
    logger.info('Params: %.2f' % (params))
    logger.info('Flops: %.2f' % (flops))

    if args.test_only:
        if os.path.isfile(args.test_model_dir):
            logger.info('loading checkpoint {} ..........'.format(
                args.test_model_dir))
            checkpoint = torch.load(args.test_model_dir)
            if 'state_dict' in checkpoint:
                tmp_ckpt = checkpoint['state_dict']
            else:
                tmp_ckpt = checkpoint

            new_state_dict = OrderedDict()
            for k, v in tmp_ckpt.items():
                new_state_dict[k.replace('module.', '')] = v
            model.load_state_dict(new_state_dict)

            valid_obj, valid_top1_acc, valid_top5_acc = validate(
                0, val_loader, model, criterion, args)
        else:
            logger.info('please specify a checkpoint file')

        return

    if len(args.gpu) > 1:
        device_id = []
        for i in range((len(args.gpu) + 1) // 2):
            device_id.append(i)
        model = nn.DataParallel(model, device_ids=device_id).cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    '''# define the learning rate scheduler
    #scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step : (1.0-step/args.epochs), last_epoch=-1)
    if args.lr_type=='multi_step':
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[args.epochs//4, args.epochs//2, args.epochs//4*3], gamma=0.1)
    elif args.lr_type=='cos':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=100, eta_min=0.0004)#'''
    start_epoch = 0
    best_top1_acc = 0
    best_top5_acc = 0

    # load the checkpoint if it exists
    checkpoint_dir = os.path.join(args.job_dir, 'checkpoint.pth.tar')
    if args.resume:
        logger.info('loading checkpoint {} ..........'.format(checkpoint_dir))
        checkpoint = torch.load(checkpoint_dir)
        start_epoch = checkpoint['epoch'] + 1
        best_top1_acc = checkpoint['best_top1_acc']
        if 'best_top5_acc' in checkpoint:
            best_top5_acc = checkpoint['best_top5_acc']

        #deal with the single-multi GPU problem
        new_state_dict = OrderedDict()
        tmp_ckpt = checkpoint['state_dict']
        if len(args.gpu) > 1:
            for k, v in tmp_ckpt.items():
                new_state_dict['module.' + k.replace('module.', '')] = v
        else:
            for k, v in tmp_ckpt.items():
                new_state_dict[k.replace('module.', '')] = v

        model.load_state_dict(new_state_dict)
        logger.info("loaded checkpoint {} epoch = {}".format(
            checkpoint_dir, checkpoint['epoch']))
    else:
        if args.use_pretrain:
            logger.info('resuming from pretrain model')
            origin_model = eval(args.arch)(compress_rate=[0.] * 100).cuda()
            ckpt = torch.load(args.pretrain_dir)
            if args.arch == 'mobilenet_v1':
                origin_model.load_state_dict(ckpt['state_dict'])
            else:
                origin_model.load_state_dict(ckpt)
            oristate_dict = origin_model.state_dict()
            if args.arch == 'resnet_50':
                load_resnet_model(model, oristate_dict)
            elif args.arch == 'mobilenet_v2':
                load_mobilenetv2_model(model, oristate_dict)
            elif args.arch == 'mobilenet_v1':
                load_mobilenetv1_model(model, oristate_dict)
            else:
                raise
        else:
            logger.info('training from scratch')

    # train the model
    epoch = start_epoch
    while epoch < args.epochs:

        train_obj, train_top1_acc, train_top5_acc = train(
            epoch, train_loader, model, criterion_smooth, optimizer)
        valid_obj, valid_top1_acc, valid_top5_acc = validate(
            epoch, val_loader, model, criterion, args)
        if args.use_dali:
            train_loader.reset()
            val_loader.reset()

        is_best = False
        if valid_top1_acc > best_top1_acc:
            best_top1_acc = valid_top1_acc
            best_top5_acc = valid_top5_acc
            is_best = True

        utils.save_checkpoint(
            {
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'best_top1_acc': best_top1_acc,
                'best_top5_acc': best_top5_acc,
                'optimizer': optimizer.state_dict(),
            }, is_best, args.job_dir)

        epoch += 1
        logger.info("=>Best accuracy Top1: {:.3f}, Top5: {:.3f}".format(
            best_top1_acc, best_top5_acc))

    training_time = (time.time() - start_t) / 36000
    logger.info('total training time = {} hours'.format(training_time))