def load_model(): map_location = None if FLAGS.ngpus > 0 else 'cpu' print('Getting VOneNet') model = get_model(map_location=map_location, model_arch=FLAGS.model_arch, pretrained=False, visual_degrees=FLAGS.visual_degrees, stride=FLAGS.stride, ksize=FLAGS.ksize, sf_corr=FLAGS.sf_corr, sf_max=FLAGS.sf_max, sf_min=FLAGS.sf_min, rand_param=FLAGS.rand_param, gabor_seed=FLAGS.gabor_seed, simple_channels=FLAGS.simple_channels, complex_channels=FLAGS.simple_channels, noise_mode=FLAGS.noise_mode, noise_scale=FLAGS.noise_scale, noise_level=FLAGS.noise_level, k_exc=FLAGS.k_exc) if FLAGS.ngpus > 0 and torch.cuda.device_count() > 1: print('We have multiple GPUs detected') model = model.to(device) elif FLAGS.ngpus > 0 and torch.cuda.device_count() is 1: print('We run on GPU') model = model.to(device) else: print('No GPU detected!') model = model.module return model
def voneresnet(model_name='resnet50'): from vonenet import get_model model = get_model(model_name) model = model.module from model_tools.activations.pytorch import load_preprocess_images preprocessing = functools.partial(load_preprocess_images, image_size=224, normalize_mean=(0.5, 0.5, 0.5), normalize_std=(0.5, 0.5, 0.5)) wrapper = PytorchWrapper(identifier='vone'+model_name, model=model, preprocessing=preprocessing) wrapper.image_size = 224 return wrapper
def vonecornet(model_name='cornets'): from vonenet import get_model model = get_model(model_name) model = model.module from model_tools.activations.pytorch import load_preprocess_images preprocessing = functools.partial(load_preprocess_images, image_size=224, normalize_mean=(0.5, 0.5, 0.5), normalize_std=(0.5, 0.5, 0.5)) from candidate_models.base_models.stochastic import StochasticTemporalPytorchWrapper wrapper = StochasticTemporalPytorchWrapper(identifier='vone' + model_name, model=model, preprocessing=preprocessing) wrapper.image_size = 224 return wrapper
def val(): model = get_model(model_arch=FLAGS.model_arch, pretrained=True) if FLAGS.ngpus == 0: print('Running on CPU') if FLAGS.ngpus > 0 and torch.cuda.device_count() > 1: print('Running on multiple GPUs') model = model.to(device) elif FLAGS.ngpus > 0 and torch.cuda.device_count() is 1: print('Running on single GPU') model = model.to(device) else: print('No GPU detected!') model = model.module validator = ImageNetVal(model) record = validator() print(record['top1']) print(record['top5']) return