def test_model():
    """Evaluates the model."""

    # Build the model (before the loaders to speed up debugging)
    model = model_builder.build_model()
    log_model_info(model)

    # Compute precise time
    if cfg.PREC_TIME.ENABLED:
        logger.info("Computing precise time...")
        loss_fun = losses.get_loss_fun()
        bu.compute_precise_time(model, loss_fun)
        nu.reset_bn_stats(model)

    # Load model weights
    cu.load_checkpoint(cfg.TEST.WEIGHTS, model)
    logger.info("Loaded model weights from: {}".format(cfg.TEST.WEIGHTS))

    # Create data loaders
    test_loader = loader.construct_test_loader()

    # Create meters
    test_meter = TestMeter(len(test_loader))

    # Evaluate the model
    test_epoch(test_loader, model, test_meter, 0)
Example #2
0
def train_model():
    """Trains the model."""

    # Build the model (before the loaders to speed up debugging)
    model = model_builder.build_model()
    log_model_info(model)

    # Define the loss function
    loss_fun = losses.get_loss_fun()
    # Construct the optimizer
    optimizer = optim.construct_optimizer(model)

    # Load checkpoint or initial weights
    start_epoch = 0
    if cfg.TRAIN.AUTO_RESUME and cu.has_checkpoint():
        last_checkpoint = cu.get_last_checkpoint()
        checkpoint_epoch = cu.load_checkpoint(last_checkpoint, model,
                                              optimizer)
        logger.info("Loaded checkpoint from: {}".format(last_checkpoint))
        start_epoch = checkpoint_epoch + 1
    elif cfg.TRAIN.WEIGHTS:
        cu.load_checkpoint(cfg.TRAIN.WEIGHTS, model)
        logger.info("Loaded initial weights from: {}".format(
            cfg.TRAIN.WEIGHTS))

    # Compute precise time
    if start_epoch == 0 and cfg.PREC_TIME.ENABLED:
        logger.info("Computing precise time...")
        bu.compute_precise_time(model, loss_fun)
        nu.reset_bn_stats(model)

    # Create data loaders
    train_loader = loader.construct_train_loader()
    test_loader = loader.construct_test_loader()

    # Create meters
    train_meter = TrainMeter(len(train_loader))
    test_meter = TestMeter(len(test_loader))

    # Perform the training loop
    logger.info("Start epoch: {}".format(start_epoch + 1))

    for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH):
        # Train for one epoch
        train_epoch(train_loader, model, loss_fun, optimizer, train_meter,
                    cur_epoch)
        # Compute precise BN stats
        if cfg.BN.USE_PRECISE_STATS:
            nu.compute_precise_bn_stats(model, train_loader)
        # Save a checkpoint
        if cu.is_checkpoint_epoch(cur_epoch):
            checkpoint_file = cu.save_checkpoint(model, optimizer, cur_epoch)
            logger.info("Wrote checkpoint to: {}".format(checkpoint_file))
        # Evaluate the model
        if is_eval_epoch(cur_epoch):
            test_epoch(test_loader, model, test_meter, cur_epoch)
Example #3
0
def test_model():
    """Evaluates the model."""

    # Build the model (before the loaders to speed up debugging)
    model = model_builder.build_model()
    log_model_info(model)

    # Load model weights
    cu.load_checkpoint(cfg.TEST.WEIGHTS, model)
    logger.info("Loaded model weights from: {}".format(cfg.TEST.WEIGHTS))

    # Create data loaders
    test_loader = loader.construct_test_loader()

    # Create meters
    test_meter = TestMeter(len(test_loader))

    # Evaluate the model
    test_epoch(test_loader, model, test_meter, 0)
Example #4
0
def main():
    # load pretrained model
    checkpoint = torch.load(args.checkpoint_path)

    try:
        model_arch = checkpoint['model_name']
        patch_size = checkpoint['patch_size']
        prime_size = checkpoint['patch_size']
        flops = checkpoint['flops']
        model_flops = checkpoint['model_flops']
        policy_flops = checkpoint['policy_flops']
        fc_flops = checkpoint['fc_flops']
        anytime_classification = checkpoint['anytime_classification']
        budgeted_batch_classification = checkpoint[
            'budgeted_batch_classification']
        dynamic_threshold = checkpoint['dynamic_threshold']
        maximum_length = len(checkpoint['flops'])
    except:
        print(
            'Error: \n'
            'Please provide essential information'
            'for customized models (as we have done '
            'in pre-trained models)!\n'
            'At least the following information should be Given: \n'
            '--model_name: name of the backbone CNNs (e.g., resnet50, densenet121)\n'
            '--patch_size: size of image patches (i.e., H\' or W\' in the paper)\n'
            '--flops: a list containing the Multiply-Adds corresponding to each '
            'length of the input sequence during inference')

    model_configuration = model_configurations[model_arch]

    if args.eval_mode > 0:
        # create model
        if 'resnet' in model_arch:
            model = resnet.resnet50(pretrained=False)
            model_prime = resnet.resnet50(pretrained=False)

        elif 'densenet' in model_arch:
            model = eval('densenet.' + model_arch)(pretrained=False)
            model_prime = eval('densenet.' + model_arch)(pretrained=False)

        elif 'efficientnet' in model_arch:
            model = create_model(model_arch,
                                 pretrained=False,
                                 num_classes=1000,
                                 drop_rate=0.3,
                                 drop_connect_rate=0.2)
            model_prime = create_model(model_arch,
                                       pretrained=False,
                                       num_classes=1000,
                                       drop_rate=0.3,
                                       drop_connect_rate=0.2)

        elif 'mobilenetv3' in model_arch:
            model = create_model(model_arch,
                                 pretrained=False,
                                 num_classes=1000,
                                 drop_rate=0.2,
                                 drop_connect_rate=0.2)
            model_prime = create_model(model_arch,
                                       pretrained=False,
                                       num_classes=1000,
                                       drop_rate=0.2,
                                       drop_connect_rate=0.2)

        elif 'regnet' in model_arch:
            import pycls.core.model_builder as model_builder
            from pycls.core.config import cfg
            cfg.merge_from_file(model_configuration['cfg_file'])
            cfg.freeze()

            model = model_builder.build_model()
            model_prime = model_builder.build_model()

        traindir = args.data_url + 'train/'
        valdir = args.data_url + 'val/'

        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        train_set = datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(
                    model_configuration['image_size'],
                    interpolation=model_configuration['dataset_interpolation']
                ),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(), normalize
            ]))
        train_set_index = torch.randperm(len(train_set))
        train_loader = torch.utils.data.DataLoader(
            train_set,
            batch_size=256,
            num_workers=32,
            pin_memory=False,
            sampler=torch.utils.data.sampler.SubsetRandomSampler(
                train_set_index[-200000:]))

        val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
            valdir,
            transforms.Compose([
                transforms.Resize(
                    int(model_configuration['image_size'] /
                        model_configuration['crop_pct']),
                    interpolation=model_configuration['dataset_interpolation']
                ),
                transforms.CenterCrop(model_configuration['image_size']),
                transforms.ToTensor(), normalize
            ])),
                                                 batch_size=256,
                                                 shuffle=False,
                                                 num_workers=16,
                                                 pin_memory=False)

        state_dim = model_configuration['feature_map_channels'] * math.ceil(
            patch_size / 32) * math.ceil(patch_size / 32)

        memory = Memory()
        policy = ActorCritic(model_configuration['feature_map_channels'],
                             state_dim,
                             model_configuration['policy_hidden_dim'],
                             model_configuration['policy_conv'])
        fc = Full_layer(model_configuration['feature_num'],
                        model_configuration['fc_hidden_dim'],
                        model_configuration['fc_rnn'])

        model = nn.DataParallel(model.cuda())
        model_prime = nn.DataParallel(model_prime.cuda())
        policy = policy.cuda()
        fc = fc.cuda()

        model.load_state_dict(checkpoint['model_state_dict'])
        model_prime.load_state_dict(checkpoint['model_prime_state_dict'])
        fc.load_state_dict(checkpoint['fc'])
        policy.load_state_dict(checkpoint['policy'])

        budgeted_batch_flops_list = []
        budgeted_batch_acc_list = []

        print('generate logits on test samples...')
        test_logits, test_targets, anytime_classification = generate_logits(
            model_prime, model, fc, memory, policy, val_loader, maximum_length,
            prime_size, patch_size, model_arch)

        if args.eval_mode == 2:
            print('generate logits on training samples...')
            dynamic_threshold = torch.zeros([39, maximum_length])
            train_logits, train_targets, _ = generate_logits(
                model_prime, model, fc, memory, policy, train_loader,
                maximum_length, prime_size, patch_size, model_arch)

        for p in range(1, 40):

            print('inference: {}/40'.format(p))

            _p = torch.FloatTensor(1).fill_(p * 1.0 / 20)
            probs = torch.exp(torch.log(_p) * torch.range(1, maximum_length))
            probs /= probs.sum()

            if args.eval_mode == 2:
                dynamic_threshold[p - 1] = dynamic_find_threshold(
                    train_logits, train_targets, probs)

            acc_step, flops_step = dynamic_evaluate(test_logits, test_targets,
                                                    flops,
                                                    dynamic_threshold[p - 1])

            budgeted_batch_acc_list.append(acc_step)
            budgeted_batch_flops_list.append(flops_step)

        budgeted_batch_classification = [
            budgeted_batch_flops_list, budgeted_batch_acc_list
        ]

    print('model_arch :', model_arch)
    print('patch_size :', patch_size)
    print('flops :', flops)
    print('model_flops :', model_flops)
    print('policy_flops :', policy_flops)
    print('fc_flops :', fc_flops)
    print('anytime_classification :', anytime_classification)
    print('budgeted_batch_classification :', budgeted_batch_classification)
Example #5
0
def main():

    if not os.path.isdir(args.work_dirs):
        mkdir_p(args.work_dirs)

    record_path = args.work_dirs + '/GF-' + str(args.model_arch) \
                  + '_patch-size-' + str(args.patch_size) \
                  + '_T' + str(args.T) \
                  + '_train-stage' + str(args.train_stage)
    if not os.path.isdir(record_path):
        mkdir_p(record_path)
    record_file = record_path + '/record.txt'


    # *create model* #
    model_configuration = model_configurations[args.model_arch]
    if 'resnet' in args.model_arch:
        model_arch = 'resnet'
        model = resnet.resnet50(pretrained=False)
        model_prime = resnet.resnet50(pretrained=False)
    elif 'densenet' in args.model_arch:
        model_arch = 'densenet'
        model = eval('densenet.' + args.model_arch)(pretrained=False)
        model_prime = eval('densenet.' + args.model_arch)(pretrained=False)
    elif 'efficientnet' in args.model_arch:
        model_arch = 'efficientnet'
        model = create_model(args.model_arch, pretrained=False, num_classes=1000,
                             drop_rate=0.3, drop_connect_rate=0.2)
        model_prime = create_model(args.model_arch, pretrained=False, num_classes=1000,
                                   drop_rate=0.3, drop_connect_rate=0.2)
    elif 'mobilenetv3' in args.model_arch:
        model_arch = 'mobilenetv3'
        model = create_model(args.model_arch, pretrained=False, num_classes=1000,
                             drop_rate=0.2, drop_connect_rate=0.2)
        model_prime = create_model(args.model_arch, pretrained=False, num_classes=1000,
                                   drop_rate=0.2, drop_connect_rate=0.2)
    elif 'regnet' in args.model_arch:
        model_arch = 'regnet'
        import pycls.core.model_builder as model_builder
        from pycls.core.config import cfg
        cfg.merge_from_file(model_configuration['cfg_file'])
        cfg.freeze()
        model = model_builder.build_model()
        model_prime = model_builder.build_model()

    fc = Full_layer(model_configuration['feature_num'],
                    model_configuration['fc_hidden_dim'],
                    model_configuration['fc_rnn'])

    if args.train_stage == 1:
        model.load_state_dict(torch.load(args.model_path))
        model_prime.load_state_dict(torch.load(args.model_prime_path))
    else:
        checkpoint = torch.load(args.checkpoint_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        model_prime.load_state_dict(checkpoint['model_prime_state_dict'])
        fc.load_state_dict(checkpoint['fc'])

    train_configuration = train_configurations[model_arch]

    if args.train_stage != 2:
        if train_configuration['train_model_prime']:
            optimizer = torch.optim.SGD([{'params': model.parameters()},
                                         {'params': model_prime.parameters()},
                                         {'params': fc.parameters()}],
                                        lr=0,  # specify in adjust_learning_rate()
                                        momentum=train_configuration['momentum'],
                                        nesterov=train_configuration['Nesterov'],
                                        weight_decay=train_configuration['weight_decay'])
        else:
            optimizer = torch.optim.SGD([{'params': model.parameters()},
                                         {'params': fc.parameters()}],
                                        lr=0,  # specify in adjust_learning_rate()
                                        momentum=train_configuration['momentum'],
                                        nesterov=train_configuration['Nesterov'],
                                        weight_decay=train_configuration['weight_decay'])
        training_epoch_num = train_configuration['epoch_num']
    else:
        optimizer = None
        training_epoch_num = 15
    criterion = nn.CrossEntropyLoss().cuda()

    model = nn.DataParallel(model.cuda())
    model_prime = nn.DataParallel(model_prime.cuda())
    fc = fc.cuda()

    traindir = args.data_url + 'train/'
    valdir = args.data_url + 'val/'
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    train_set = datasets.ImageFolder(traindir, transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ]))
    train_set_index = torch.randperm(len(train_set))
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=256, num_workers=32, pin_memory=False,
                                               sampler=torch.utils.data.sampler.SubsetRandomSampler(
                                                   train_set_index[:]))

    val_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(valdir, transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize, ])),
        batch_size=train_configuration['batch_size'], shuffle=False, num_workers=32, pin_memory=False)

    if args.train_stage != 1:
        state_dim = model_configuration['feature_map_channels'] * math.ceil(args.patch_size / 32) * math.ceil(args.patch_size / 32)
        ppo = PPO(model_configuration['feature_map_channels'], state_dim,
                  model_configuration['policy_hidden_dim'], model_configuration['policy_conv'])

        if args.train_stage == 3:
            ppo.policy.load_state_dict(checkpoint['policy'])
            ppo.policy_old.load_state_dict(checkpoint['policy'])

    else:
        ppo = None
    memory = Memory()

    if args.resume:
        resume_ckp = torch.load(args.resume)

        start_epoch = resume_ckp['epoch']
        print('resume from epoch: {}'.format(start_epoch))

        model.module.load_state_dict(resume_ckp['model_state_dict'])
        model_prime.module.load_state_dict(resume_ckp['model_prime_state_dict'])
        fc.load_state_dict(resume_ckp['fc'])

        if optimizer:
            optimizer.load_state_dict(resume_ckp['optimizer'])

        if ppo:
            ppo.policy.load_state_dict(resume_ckp['policy'])
            ppo.policy_old.load_state_dict(resume_ckp['policy'])
            ppo.optimizer.load_state_dict(resume_ckp['ppo_optimizer'])

        best_acc = resume_ckp['best_acc']
    else:
        start_epoch = 0
        best_acc = 0

    for epoch in range(start_epoch, training_epoch_num):
        if args.train_stage != 2:
            print('Training Stage: {}, lr:'.format(args.train_stage))
            adjust_learning_rate(optimizer, train_configuration,
                                 epoch, training_epoch_num, args)
        else:
            print('Training Stage: {}, train ppo only'.format(args.train_stage))

        train(model_prime, model, fc, memory, ppo, optimizer, train_loader, criterion,
              args.print_freq, epoch, train_configuration['batch_size'], record_file, train_configuration, args)

        acc = validate(model_prime, model, fc, memory, ppo, optimizer, val_loader, criterion,
                       args.print_freq, epoch, train_configuration['batch_size'], record_file, train_configuration, args)

        if acc > best_acc:
            best_acc = acc
            is_best = True
        else:
            is_best = False

        save_checkpoint({
            'epoch': epoch + 1,
            'model_state_dict': model.module.state_dict(),
            'model_prime_state_dict': model_prime.module.state_dict(),
            'fc': fc.state_dict(),
            'acc': acc,
            'best_acc': best_acc,
            'optimizer': optimizer.state_dict() if optimizer else None,
            'ppo_optimizer': ppo.optimizer.state_dict() if ppo else None,
            'policy': ppo.policy.state_dict() if ppo else None,
        }, is_best, checkpoint=record_path)