def get_model_config(model, dataset): """Map model name to model network configuration.""" if 'cifar10' == dataset.name: return get_cifar10_model_config(model) if model == 'vgg11': mc = vgg_model.Vgg11Model() elif model == 'vgg16': mc = vgg_model.Vgg16Model() elif model == 'vgg19': mc = vgg_model.Vgg19Model() elif model == 'lenet': mc = lenet_model.Lenet5Model() elif model == 'googlenet': mc = googlenet_model.GooglenetModel() elif model == 'overfeat': mc = overfeat_model.OverfeatModel() elif model == 'alexnet': mc = alexnet_model.AlexnetModel() elif model == 'trivial': mc = trivial_model.TrivialModel() elif model == 'inception3': mc = inception_model.Inceptionv3Model() elif model == 'inception4': mc = inception_model.Inceptionv4Model() elif model == 'resnet50' or model == 'resnet50_v2': mc = resnet_model.ResnetModel(model, (3, 4, 6, 3)) elif model == 'resnet101' or model == 'resnet101_v2': mc = resnet_model.ResnetModel(model, (3, 4, 23, 3)) elif model == 'resnet152' or model == 'resnet152_v2': mc = resnet_model.ResnetModel(model, (3, 8, 36, 3)) else: raise KeyError('Invalid model name \'%s\' for dataset \'%s\'' % (model, dataset.name)) return mc
def get_model_config(model): """Map model name to model network configuration.""" if model == 'deep_mnist': mc = deepmnist_model.DeepMNISTModel() elif model == 'eng_acoustic_model': mc = engacoustic_model.EngAcousticModel() elif model == 'sensor_net': mc = sensornet_model.SensorNetModel() elif model == 'vgg11': mc = vgg_model.Vgg11Model() elif model == 'vgg13': mc = vgg_model.Vgg13Model() elif model == 'vgg16': mc = vgg_model.Vgg16Model() elif model == 'vgg19': mc = vgg_model.Vgg19Model() elif model == 'lenet': mc = lenet_model.Lenet5Model() elif model == 'googlenet': mc = googlenet_model.GooglenetModel() elif model == 'overfeat': mc = overfeat_model.OverfeatModel() elif model == 'alexnet': mc = alexnet_model.AlexnetModel() elif model == 'trivial': mc = trivial_model.TrivialModel() elif model == 'inception3': mc = inception_model.Inceptionv3Model() elif model == 'inception4': mc = inception_model.Inceptionv4Model() elif model in ('resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnet200', 'resnet269'): mc = resnet_model.ResNet(model) else: raise KeyError('Invalid model name \'%s\'' % model) return mc