def load_model_by_name(self, model_name, logits=False, input_range_type=1, pre_filter=lambda x: x): """ :params logits: return logits(input of softmax layer) if True; return softmax output otherwise. :params input_range_type: {1: [0,1], 2:[-0.5, 0.5], 3:[-1, 1]...} """ if model_name not in [ "cnn2", 'cnn2_adv_trained', 'cnn1', 'densenet', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', 'lenet', 'distillation' ]: raise NotImplementedError("Undefined model [%s] for %s." % (model_name, self.dataset_name)) self.model_name = model_name model_weights_fpath = "%s_%s.keras_weights.h5" % (self.dataset_name, model_name) model_weights_fpath = os.path.join('downloads/trained_models', model_weights_fpath) if model_name in ["cnn2", 'cnn2_adv_trained']: model = cnn2_cifar10_model(logits=logits, input_range_type=input_range_type, pre_filter=pre_filter) elif model_name == "cnn1": model = cnn1_cifar10_model(logits=logits, input_range_type=input_range_type, pre_filter=pre_filter) elif model_name == "densenet": model = densenet_cifar10_model(logits=logits, input_range_type=input_range_type, pre_filter=pre_filter) model_weights_fpath = get_densenet_weights_path(self.dataset_name) elif model_name == "lenet": model = lenet_cifar10_model(logits=logits, input_range_type=input_range_type, pre_filter=pre_filter) elif model_name in ['resnet20']: model = resnet20_cifar10_model() elif model_name in ['resnet32']: model = resnet32_cifar10_model() elif model_name in ['resnet44']: model = resnet44_cifar10_model() elif model_name in ['resnet56']: model = resnet56_cifar10_model() elif model_name in ['resnet110']: model = resnet110_cifar10_model() elif model_name in ['distillation']: model = distillation_cifar10_model() print("\n===Defined TensorFlow model graph.") #if model_name not in ['distillation']: model.load_weights(model_weights_fpath) print("---Loaded CIFAR-10-%s model.\n" % model_name) return model
def load_model_by_name(self, model_name, logits=False, input_range_type=1, pre_filter=lambda x: x): """ :params logits: return logits(input of softmax layer) if True; return softmax output otherwise. :params input_range_type: {1: [0,1], 2:[-0.5, 0.5], 3:[-1, 1]...} """ if model_name not in [ "cleverhans", 'cleverhans_adv_trained', 'carlini', 'densenet' ]: raise NotImplementedError("Undefined model [%s] for %s." % (model_name, self.dataset_name)) self.model_name = model_name model_weights_fpath = "%s_%s.keras_weights.h5" % (self.dataset_name, model_name) model_weights_fpath = os.path.join(PARENT_DIR, 'downloads/trained_models', model_weights_fpath) if model_name in ["cleverhans", 'cleverhans_adv_trained']: model = cleverhans_cifar10_model(logits=logits, input_range_type=input_range_type, pre_filter=pre_filter) elif model_name == "carlini": model = carlini_cifar10_model(logits=logits, input_range_type=input_range_type, pre_filter=pre_filter) elif model_name == "densenet": from models.densenet_models import densenet_cifar10_model model = densenet_cifar10_model(logits=logits, input_range_type=input_range_type, pre_filter=pre_filter) print("\n===Defined TensorFlow model graph.") model.load_weights(model_weights_fpath) print("---Loaded CIFAR-10-%s model.\n" % model_name) return model