import torch.nn as nn from utils.options import args import utils.common as utils import time from data import cifar10, imagenet_dali, imagenet from importlib import import_module device = torch.device( f"cuda:{args.gpus[0]}") if torch.cuda.is_available() else 'cpu' loss_func = nn.CrossEntropyLoss() # Data print('==> Preparing data..') if args.data_set == 'cifar10': testLoader = cifar10.Data(args).testLoader else: #imagenet if device != 'cpu': testLoader = imagenet_dali.get_imagenet_iter_dali( 'val', args.data_path, args.eval_batch_size, num_threads=4, crop=224, device_id=args.gpus[0], num_gpus=1) else: testLoader = imagenet.Data(args).testLoader def test(model, topk=(1, )):
logger = utils.get_logger(os.path.join(args.job_dir + 'logger.log')) loss_func = nn.CrossEntropyLoss() conv_num_cfg = { 'vgg16': 13, 'resnet56': 27, 'resnet110': 54, 'googlenet': 9, 'densenet': 36, } food_dimension = conv_num_cfg[args.cfg] # Data print('==> Loading Data..') if args.data_set == 'cifar10': loader = cifar10.Data(args) elif args.data_set == 'cifar100': loader = cifar100.Data(args) else: loader = imagenet.Data(args) # Model print('==> Loading Model..') if args.arch == 'vgg_cifar': origin_model = import_module(f'model.{args.arch}').VGG(args.cfg).to(device) elif args.arch == 'resnet_cifar': origin_model = import_module(f'model.{args.arch}').resnet( args.cfg).to(device) elif args.arch == 'googlenet': origin_model = import_module(f'model.{args.arch}').googlenet().to(device) elif args.arch == 'densenet':