예제 #1
0
def evaluate_ofa_random_sample(path,
                               data_loader,
                               batch_size=100,
                               device='cuda:0',
                               ensemble=False):
    net_acc = []
    for i, id in enumerate(net_id):
        acc = ""
        for j in range(2, len(id)):
            if id[j] == '.':
                acc = id[j - 2] + id[j - 1] + id[j] + id[j + 1]
        net_acc.append(acc)
    id = np.argsort(np.array(net_acc))
    new_net_id = copy.deepcopy(net_id)
    for i, sortid in enumerate(id):
        new_net_id[i] = net_id[sortid]
    print('new_net_id', new_net_id)
    n = len(net_id)
    best_acc = 0
    acc_list = []
    space = []
    best_team = []
    for k in range(20):
        nets = []
        team = []
        i = random.randint(0, n - 1)
        j = (i + random.randint(1, n - 1)) % n
        print('i:{} j:{}'.format(i, j))
        team.append(j)
        team.append(i)
        net, image_size = ofa_specialized(net_id=new_net_id[j],
                                          pretrained=True)
        nets.append(net)
        net, image_size = ofa_specialized(net_id=new_net_id[i],
                                          pretrained=True)
        nets.append(net)
        acc = ensemble_validate(nets, path, image_size, data_loader,
                                batch_size, device)
        print('net i:{} netj:{} acc:{}'.format(new_net_id[i], new_net_id[j],
                                               acc))
        acc_list.append(acc)
        if acc > best_acc:
            best_acc = acc
            best_team = team
    avg_acc = np.mean(acc_list)
    std_acc = np.std(acc_list, ddof=1)
    var_acc = np.var(acc_list)
    print("avg{} var{} std{}".format(avg_acc, std_acc, var_acc))
    print('best_random_team best_acc{}'.format(best_team, best_acc))
    space.append(best_acc)
    print('space:{}'.format(space))
    return new_net_id[best_team[0]], new_net_id[best_team[1]]
예제 #2
0
    def __init__(self, data_dir=None):
        self.__net_id = "flops@[email protected]_finetune@75"
        __cwd = os.getcwd()
        if data_dir is not None:
            _p = os.path.join(data_dir, self.__class__.__module__, self.get_version())
            _p = os.path.abspath(_p)
            os.makedirs(_p, exist_ok=True)
            os.chdir(_p)
        try:
            # NOTE: hard-coded download dir of ofa_specialized is under CWD.
            self.__model, self.__image_size = ofa_specialized(self.__net_id, pretrained=True)
        finally:
            os.chdir(__cwd)
        mobi = self.__model
        assert isinstance(mobi, MobileNetV3)

        def _forward__fv(x):
            """Patch of bound method MobileNetV3.forward to extract the feature vectors without classification layer"""
            x = mobi.first_conv(x)
            for block in mobi.blocks:
                x = block(x)
            x = mobi.final_expand_layer(x)
            x = mobi.global_avg_pool(x)  # global average pooling
            x = mobi.feature_mix_layer(x)
            x = x.view(x.size(0), -1)
            # x = mobi.classifier(x)
            return x

        self.__model.forward = _forward__fv
        self.__model.eval()
        self.__dim = self.__model.feature_mix_layer.conv.out_channels

        # to normalize according to model trained on imagenet
        self.__mean = [0.485, 0.456, 0.406]
        self.__std = [0.229, 0.224, 0.225]
예제 #3
0
def evaluate_ofa_best_acc_team(path,
                               data_loader,
                               batch_size=100,
                               device='cuda:0',
                               ensemble=False):
    net_acc = []
    for i, id in enumerate(net_id):
        acc = ""
        for j in range(2, len(id)):
            if id[j] == '.':
                acc = id[j - 2] + id[j - 1] + id[j] + id[j + 1]
        net_acc.append(acc)
    id = np.argsort(np.array(net_acc))
    new_net_id = copy.deepcopy(net_id)
    for i, sortid in enumerate(id):
        new_net_id[i] = net_id[sortid]
    print('new_net_id', new_net_id)
    n = len(net_id)
    best_acc = 0
    space = []
    best_team = []
    i = n - 1
    for j in range(18, n):
        nets = []
        team = []
        team.append(j)
        team.append(i)
        net, image_size = ofa_specialized(net_id=new_net_id[j],
                                          pretrained=True)
        nets.append(net)
        net, image_size = ofa_specialized(net_id=new_net_id[i],
                                          pretrained=True)
        nets.append(net)
        acc = ensemble_validate(nets, path, image_size, data_loader,
                                batch_size, device)
        print('net i:{} netj:{} acc:{}'.format(new_net_id[i], new_net_id[j],
                                               acc))
        if acc > best_acc:
            best_acc = acc
            best_team = team

    print('space {} best_acc{}'.format(i + 1, best_acc))
    space.append(best_acc)
    print('space:{}'.format(space))
    return new_net_id[best_team[0]], new_net_id[best_team[1]]
예제 #4
0
def main(args):
    save_folder = args.save_folder

    model_folder = os.path.join(args.model_root, save_folder)

    makedirs(model_folder)

    setattr(args, 'model_folder', model_folder)

    logger = create_logger(model_folder, 'train', 'info')
    print_args(args, logger)

    # seed
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    # ResNet
    if "ResNet" in args.model:
        depth_ = args.model.split('-')[1]

        # when it is dst, resnet block is special
        if args.prune_method != 'dst': p_type = None

        res_dict = {
            '18':
            resnet18(pretrained=args.pretrained,
                     progress=True,
                     prune_type=p_type),
            '34':
            resnet34(pretrained=args.pretrained,
                     progress=True,
                     prune_type=p_type),
            '50':
            resnet50(pretrained=args.pretrained,
                     progress=True,
                     prune_type=p_type),
            '101':
            resnet101(pretrained=args.pretrained,
                      progress=True,
                      prune_type=p_type)
        }

        net = res_dict[depth_]

    #elif 'efficientnet' in args.model:
    #    net = EfficientNet.from_pretrained(args.model)
    elif args.model == 'efficientnet_b0':
        print('efficientnet-b0 load...')
        net = efficientnet_b0(pretrained=args.pretrained)

    # MobileNet
    elif args.model == "mobilenetv3-large-1.0":
        print('mobilenetv3-large-1.0')
        net = mobilenetv3_large_100(pretrained=args.pretrained)

    elif args.model == 'once-mobilenetv3-large-1.0':
        print('once-mobilenetv3-large-1.0')
        net, image_size = ofa_specialized(
            'note8_lat@[email protected]_finetune@25', pretrained=args.pretrained)

    elif args.model == 'mobilenetv2-120d':
        print('mobilenetv2-120d load...')
        net = mobilenetv2_120d(pretrained=args.pretrained)

    # conv1 trainable
    if args.conv1_not_train:
        print('conv1 weight not train')
        if args.model == "mobilenetv3-large-1.0":
            for param in net.conv_stem.parameters():
                param.requires_grad = False
        elif "ResNet" in args.model:
            for param in net.conv1.parameters():
                param.requires_grad = False

        else:
            assert (False, 'not ready')

    # custom pretrain path
    if args.pretrain_path:
        print('load custom pretrain weight...')
        net.load_state_dict(torch.load(args.pretrain_path))

    net2 = copy.deepcopy(net)  # for save removed_models
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net = nn.DataParallel(net)
    net.to(device)

    # KD
    if args.KD:
        print('knowledge distillation model load!')
        teacher_net, image_size = ofa_specialized(
            'flops@[email protected]_finetune@75', pretrained=True)  # 79.6%
        teacher_net = nn.DataParallel(teacher_net)
        teacher_net.to(device)

    # set trainer
    if args.KD:
        trainer = Trainer_KD(args, logger)
    else:
        trainer = Trainer(args, logger)

    # loss
    loss = nn.CrossEntropyLoss()

    # dataloader
    if args.model != 'once-mobilenetv3-large-1': image_size = 224
    if args.dataset == 'imagenet':
        train_loader = torch.utils.data.DataLoader(
            datasets.ImageNet(
                '/data/imagenet/',
                split='train',
                download=False,
                transform=transforms.Compose([
                    transforms.RandomSizedCrop(image_size),
                    transforms.RandomHorizontalFlip(),  #ImageNetPolicy(),
                    transforms.ToTensor(),
                    transforms.Normalize((0.485, 0.456, 0.406),
                                         (0.229, 0.224, 0.225))
                ])),
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.num_worker,
            pin_memory=True)
        #
        test_loader = torch.utils.data.DataLoader(datasets.ImageNet(
            '/data/imagenet/',
            split='val',
            download=False,
            transform=transforms.Compose([
                transforms.Resize(int(image_size / 0.875)),
                transforms.CenterCrop(image_size),
                transforms.ToTensor(),
                transforms.Normalize((0.485, 0.456, 0.406),
                                     (0.229, 0.224, 0.225))
            ])),
                                                  batch_size=args.batch_size,
                                                  shuffle=False,
                                                  num_workers=args.num_worker,
                                                  pin_memory=True)

    # optimizer & scheduler
    optimizer = torch.optim.SGD(net.parameters(),
                                lr=args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    if args.scheduler == 'multistep':
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer,
            milestones=eval(args.multi_step_epoch),
            gamma=args.multi_step_gamma)
    elif args.scheduler == 'plateau':
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer=optimizer,
            mode='max',
            patience=3,
            verbose=True,
            factor=0.3,
            threshold=1e-4,
            min_lr=1e-6)

    # pruning
    if args.prune_method == 'global':
        if args.prune_type == 'group_filter':
            tmps = []
            for n, conv in enumerate(net.modules()):
                if isinstance(conv, nn.Conv2d):
                    if conv.weight.shape[1] <= 3:
                        continue
                    tmp_pruned = conv.weight.data.clone()
                    original_size = tmp_pruned.size()  # (out, ch, h, w)
                    tmp_pruned = tmp_pruned.view(original_size[0],
                                                 -1)  # (out, inp)
                    #append_size = 4 - tmp_pruned.shape[1] % 4
                    #tmp_pruned = torch.cat((tmp_pruned, tmp_pruned[:, 0:append_size]), 1)
                    tmp_pruned = tmp_pruned.view(tmp_pruned.shape[0], -1,
                                                 args.block_size)  # out, -1, 4
                    tmp_pruned = tmp_pruned.abs().mean(2, keepdim=True).expand(
                        tmp_pruned.shape)  # out, -1, 4
                    tmp = tmp_pruned.flatten()
                    tmps.append(tmp)

            tmps = torch.cat(tmps)
            num = tmps.shape[0] * (1 - args.sparsity)  #sparsity 0.2
            top_k = torch.topk(tmps, int(num), sorted=True)
            threshold = top_k.values[-1]

            for n, conv in enumerate(net.modules()):
                if isinstance(conv, nn.Conv2d):
                    if conv.weight.shape[1] <= 3:
                        continue
                    tmp_pruned = conv.weight.data.clone()
                    original_size = tmp_pruned.size()
                    tmp_pruned = tmp_pruned.view(original_size[0], -1)
                    #append_size = 4 - tmp_pruned.shape[1] % 4
                    #tmp_pruned = torch.cat((tmp_pruned, tmp_pruned[:, 0:append_size]), 1)
                    tmp_pruned = tmp_pruned.view(tmp_pruned.shape[0], -1,
                                                 args.block_size)  # out, -1, 4
                    tmp_pruned = tmp_pruned.abs().mean(2, keepdim=True).expand(
                        tmp_pruned.shape)  # out,-1, 4
                    tmp_pruned = tmp_pruned.ge(threshold)
                    tmp_pruned = tmp_pruned.view(original_size[0],
                                                 -1)  # out, inp
                    #tmp_pruned = tmp_pruned[:, 0: conv.weight.data[0].nelement()]
                    tmp_pruned = tmp_pruned.contiguous().view(
                        original_size)  # out, ch, h, w

                    prune.custom_from_mask(conv,
                                           name='weight',
                                           mask=tmp_pruned)

        elif args.prune_type == 'group_channel':
            tmps = []
            for n, conv in enumerate(net.modules()):
                if isinstance(conv, nn.Conv2d):
                    if conv.weight.shape[1] <= 3:
                        continue
                    tmp_pruned = conv.weight.data.clone()
                    original_size = tmp_pruned.size()  # (out, ch, h, w)
                    tmp_pruned = tmp_pruned.view(original_size[0],
                                                 -1)  # (out, inp)
                    #append_size = 4 - tmp_pruned.shape[1] % 4
                    #tmp_pruned = torch.cat((tmp_pruned, tmp_pruned[:, 0:append_size]), 1)
                    tmp_pruned = tmp_pruned.view(
                        -1, args.block_size, tmp_pruned.shape[1])  # out, -1, 4
                    tmp_pruned = tmp_pruned.abs().mean(1, keepdim=True).expand(
                        tmp_pruned.shape)  # out, -1, 4
                    tmp = tmp_pruned.flatten()
                    tmps.append(tmp)

            tmps = torch.cat(tmps)
            num = tmps.shape[0] * (1 - args.sparsity)  #sparsity 0.2
            top_k = torch.topk(tmps, int(num), sorted=True)
            threshold = top_k.values[-1]

            for n, conv in enumerate(net.modules()):
                if isinstance(conv, nn.Conv2d):
                    if conv.weight.shape[1] <= 3:
                        continue
                    tmp_pruned = conv.weight.data.clone()
                    original_size = tmp_pruned.size()
                    tmp_pruned = tmp_pruned.view(original_size[0], -1)
                    #append_size = 4 - tmp_pruned.shape[1] % 4
                    #tmp_pruned = torch.cat((tmp_pruned, tmp_pruned[:, 0:append_size]), 1)
                    tmp_pruned = tmp_pruned.view(
                        -1, args.block_size, tmp_pruned.shape[1])  # out, -1, 4
                    tmp_pruned = tmp_pruned.abs().mean(1, keepdim=True).expand(
                        tmp_pruned.shape)  # out,-1, 4
                    tmp_pruned = tmp_pruned.ge(threshold)
                    #tmp_pruned = tmp_pruned.view(original_size[0], -1) # out, inp
                    #tmp_pruned = tmp_pruned[:, 0: conv.weight.data[0].nelement()]
                    tmp_pruned = tmp_pruned.contiguous().view(
                        original_size)  # out, ch, h, w

                    prune.custom_from_mask(conv,
                                           name='weight',
                                           mask=tmp_pruned)

        elif args.prune_type == 'filter':
            tmps = []
            for n, conv in enumerate(net.modules()):
                if isinstance(conv, nn.Conv2d):
                    if conv.weight.shape[1] <= 3:
                        continue
                    tmp_pruned = conv.weight.data.clone()
                    original_size = tmp_pruned.size()
                    tmp_pruned = tmp_pruned.view(original_size[0], -1)
                    tmp_pruned = tmp_pruned.abs().mean(1, keepdim=True).expand(
                        tmp_pruned.shape)
                    tmp = tmp_pruned.flatten()
                    tmps.append(tmp)

            tmps = torch.cat(tmps)
            num = tmps.shape[0] * (1 - args.sparsity)  #sparsity 0.5
            top_k = torch.topk(tmps, int(num), sorted=True)
            threshold = top_k.values[-1]

            for n, conv in enumerate(net.modules()):
                if isinstance(conv, nn.Conv2d):
                    if conv.weight.shape[1] <= 3:
                        continue
                    tmp_pruned = conv.weight.data.clone()
                    original_size = tmp_pruned.size()
                    tmp_pruned = tmp_pruned.view(original_size[0], -1)
                    tmp_pruned = tmp_pruned.abs().mean(1, keepdim=True).expand(
                        tmp_pruned.shape)
                    tmp = tmp_pruned.flatten()
                    tmp_pruned = tmp_pruned.ge(threshold)
                    tmp_pruned = tmp_pruned.view(original_size[0], -1)
                    tmp_pruned = tmp_pruned[:,
                                            0:conv.weight.data[0].nelement()]
                    tmp_pruned = tmp_pruned.contiguous().view(original_size)

                    prune.custom_from_mask(conv,
                                           name='weight',
                                           mask=tmp_pruned)

        elif args.prune_type == 'channel':
            tmps = []
            for n, conv in enumerate(net.modules()):
                if isinstance(conv, nn.Conv2d):
                    if conv.weight.shape[1] <= 3:
                        continue
                    tmp_pruned = conv.weight.data.clone()
                    original_size = tmp_pruned.size()
                    tmp_pruned = tmp_pruned.view(original_size[0], -1)
                    tmp_pruned = tmp_pruned.abs().mean(0, keepdim=True).expand(
                        tmp_pruned.shape)
                    tmp = tmp_pruned.flatten()
                    tmps.append(tmp)

            tmps = torch.cat(tmps)
            num = tmps.shape[0] * (1 - args.sparsity)  #sparsity 0.5
            top_k = torch.topk(tmps, int(num), sorted=True)
            threshold = top_k.values[-1]

            for n, conv in enumerate(net.modules()):
                if isinstance(conv, nn.Conv2d):
                    if conv.weight.shape[1] <= 3:
                        continue
                    tmp_pruned = conv.weight.data.clone()
                    original_size = tmp_pruned.size()
                    tmp_pruned = tmp_pruned.view(original_size[0], -1)
                    tmp_pruned = tmp_pruned.abs().mean(0, keepdim=True).expand(
                        tmp_pruned.shape)
                    tmp = tmp_pruned.flatten()
                    tmp_pruned = tmp_pruned.ge(threshold)
                    tmp_pruned = tmp_pruned.view(original_size[0], -1)
                    tmp_pruned = tmp_pruned[:,
                                            0:conv.weight.data[0].nelement()]
                    tmp_pruned = tmp_pruned.contiguous().view(original_size)

                    prune.custom_from_mask(conv,
                                           name='weight',
                                           mask=tmp_pruned)
        print(
            f'model pruned!!(sparsity : {args.sparsity : .2f}, prune_method : {args.prune_method}, prune_type : {args.prune_type}-level pruning'
        )

    elif args.prune_method == 'uniform':
        assert False, 'uniform code is not ready'

    elif args.prune_method == 'dst':
        print(
            f'model pruned!!(prune_method : {args.prune_method}, prune_type : {args.prune_type}-level pruning'
        )

    elif args.prune_method == None:
        print('Not pruned model training started!')

    # Training
    if args.KD:
        trainer.train(net,
                      teacher_net,
                      loss,
                      device,
                      train_loader,
                      test_loader,
                      optimizer=optimizer,
                      scheduler=scheduler)
    else:
        trainer.train(net,
                      loss,
                      device,
                      train_loader,
                      test_loader,
                      optimizer=optimizer,
                      scheduler=scheduler)

    # save removed models
    filename = os.path.join(args.model_folder, 'pruned_models.pth')
    temp = torch.load(filename)
    temp_dict = OrderedDict()
    for i in temp:
        if ('orig' in i):
            value = temp[i] * temp[i.split('_orig')[0] + '_mask']
            temp_dict[i.split('module.')[1].split('_orig')[0]] = value
        elif 'mask' not in i:
            temp_dict[i.split('module.')[1]] = temp[i]
    net2.load_state_dict(temp_dict)
    save_model(net2, os.path.join(args.model_folder, 'removed_models.pth'))
    print('saved removed models')
예제 #5
0
    metavar='NET',
    default='pixel1_lat@[email protected]_finetune@75',
    choices=specialized_network_list,
    help='OFA specialized networks: ' +
    ' | '.join(specialized_network_list) +
    ' (default: pixel1_lat@[email protected]_finetune@75)')

args = parser.parse_args()
if args.gpu == 'all':
    device_list = range(torch.cuda.device_count())
    args.gpu = ','.join(str(_) for _ in device_list)
else:
    device_list = [int(_) for _ in args.gpu.split(',')]
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

net, image_size = ofa_specialized(net_id=args.net, pretrained=True)
args.batch_size = args.batch_size * max(len(device_list), 1)

data_loader = torch.utils.data.DataLoader(
    datasets.ImageFolder(
        osp.join(
            args.path,
            'val'),
        transforms.Compose(
            [
                transforms.Resize(int(math.ceil(image_size / 0.875))),
                transforms.CenterCrop(image_size),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[
                        0.485,
예제 #6
0
def evaluate_ofa_specialized(path,
                             data_loader,
                             batch_size=100,
                             device='cuda:0'):
    def select_platform_name():
        valid_platform_name = [
            'pixel1', 'pixel2', 'note10', 'note8', 's7edge', 'lg-g8', '1080ti',
            'v100', 'tx2', 'cpu', 'flops'
        ]

        print(
            "Please select a hardware platform from ('pixel1', 'pixel2', 'note10', 'note8', 's7edge', 'lg-g8', '1080ti', 'v100', 'tx2', 'cpu', 'flops')!\n"
        )

        while True:
            platform_name = input()
            platform_name = platform_name.lower()
            if platform_name in valid_platform_name:
                return platform_name
            print(
                "Platform name is invalid! Please select in ('pixel1', 'pixel2', 'note10', 'note8', 's7edge', 'lg-g8', '1080ti', 'v100', 'tx2', 'cpu', 'flops')!\n"
            )

    def select_netid(platform_name):
        platform_efficiency_map = {
            'pixel1': {
                143: 'pixel1_lat@[email protected]_finetune@75',
                132: 'pixel1_lat@[email protected]_finetune@75',
                79: 'pixel1_lat@[email protected]_finetune@75',
                58: 'pixel1_lat@[email protected]_finetune@75',
                40: 'pixel1_lat@[email protected]_finetune@25',
                28: 'pixel1_lat@[email protected]_finetune@25',
                20: 'pixel1_lat@[email protected]_finetune@25',
            },
            'pixel2': {
                62: 'pixel2_lat@[email protected]_finetune@25',
                50: 'pixel2_lat@[email protected]_finetune@25',
                35: 'pixel2_lat@[email protected]_finetune@25',
                25: 'pixel2_lat@[email protected]_finetune@25',
            },
            'note10': {
                64: 'note10_lat@[email protected]_finetune@75',
                50: 'note10_lat@[email protected]_finetune@75',
                41: 'note10_lat@[email protected]_finetune@75',
                30: 'note10_lat@[email protected]_finetune@75',
                22: 'note10_lat@[email protected]_finetune@25',
                16: 'note10_lat@[email protected]_finetune@25',
                11: 'note10_lat@[email protected]_finetune@25',
                8: 'note10_lat@[email protected]_finetune@25',
            },
            'note8': {
                65: 'note8_lat@[email protected]_finetune@25',
                49: 'note8_lat@[email protected]_finetune@25',
                31: 'note8_lat@[email protected]_finetune@25',
                22: 'note8_lat@[email protected]_finetune@25',
            },
            's7edge': {
                88: 's7edge_lat@[email protected]_finetune@25',
                58: 's7edge_lat@[email protected]_finetune@25',
                41: 's7edge_lat@[email protected]_finetune@25',
                29: 's7edge_lat@[email protected]_finetune@25',
            },
            'lg-g8': {
                24: 'LG-G8_lat@[email protected]_finetune@25',
                16: 'LG-G8_lat@[email protected]_finetune@25',
                11: 'LG-G8_lat@[email protected]_finetune@25',
                8: 'LG-G8_lat@[email protected]_finetune@25',
            },
            '1080ti': {
                27: '1080ti_gpu64@[email protected]_finetune@25',
                22: '1080ti_gpu64@[email protected]_finetune@25',
                15: '1080ti_gpu64@[email protected]_finetune@25',
                12: '1080ti_gpu64@[email protected]_finetune@25',
            },
            'v100': {
                11: 'v100_gpu64@[email protected]_finetune@25',
                9: 'v100_gpu64@[email protected]_finetune@25',
                6: 'v100_gpu64@[email protected]_finetune@25',
                5: 'v100_gpu64@[email protected]_finetune@25',
            },
            'tx2': {
                96: 'tx2_gpu16@[email protected]_finetune@25',
                80: 'tx2_gpu16@[email protected]_finetune@25',
                47: 'tx2_gpu16@[email protected]_finetune@25',
                35: 'tx2_gpu16@[email protected]_finetune@25',
            },
            'cpu': {
                17: 'cpu_lat@[email protected]_finetune@25',
                15: 'cpu_lat@[email protected]_finetune@25',
                11: 'cpu_lat@[email protected]_finetune@25',
                10: 'cpu_lat@[email protected]_finetune@25',
            },
            'flops': {
                595: 'flops@[email protected]_finetune@75',
                482: 'flops@[email protected]_finetune@75',
                389: 'flops@[email protected]_finetune@75',
            }
        }

        sub_efficiency_map = platform_efficiency_map[platform_name]
        if not platform_name == 'flops':
            print(
                "Now, please specify a latency constraint for model specialization among",
                sorted(list(sub_efficiency_map.keys())),
                'ms. (Please just input the number.) \n')
        else:
            print(
                "Now, please specify a FLOPs constraint for model specialization among",
                sorted(list(sub_efficiency_map.keys())),
                'MFLOPs. (Please just input the number.) \n')

        while True:
            efficiency_constraint = input()
            if not efficiency_constraint.isdigit():
                print('Sorry, please input an integer! \n')
                continue
            efficiency_constraint = int(efficiency_constraint)
            if not efficiency_constraint in sub_efficiency_map.keys():
                print('Sorry, please choose a value from: ',
                      sorted(list(sub_efficiency_map.keys())), '.\n')
                continue
            return sub_efficiency_map[efficiency_constraint]

    # platform_name = select_platform_name()
    # net_id = select_netid(platform_name)
    net_id = 'flops@[email protected]_finetune@75'
    print(net_id)
    net, image_size, net_config = ofa_specialized(net_id=net_id,
                                                  pretrained=True)

    # validate(net, path, image_size, data_loader, batch_size, device)

    return net_id, net_config