示例#1
0
    def load_model_by_name(self,
                           model_name,
                           logits=False,
                           input_range_type=1,
                           pre_filter=lambda x: x):
        """
        :params logits: return logits(input of softmax layer) if True; return softmax output otherwise.
        :params input_range_type: {1: [0,1], 2:[-0.5, 0.5], 3:[-1, 1]...}
        """
        if model_name not in [
                "cnn2", 'cnn2_adv_trained', 'cnn1', 'densenet', 'resnet20',
                'resnet32', 'resnet44', 'resnet56', 'resnet110', 'lenet',
                'distillation'
        ]:
            raise NotImplementedError("Undefined model [%s] for %s." %
                                      (model_name, self.dataset_name))
        self.model_name = model_name

        model_weights_fpath = "%s_%s.keras_weights.h5" % (self.dataset_name,
                                                          model_name)
        model_weights_fpath = os.path.join('downloads/trained_models',
                                           model_weights_fpath)

        if model_name in ["cnn2", 'cnn2_adv_trained']:
            model = cnn2_cifar10_model(logits=logits,
                                       input_range_type=input_range_type,
                                       pre_filter=pre_filter)
        elif model_name == "cnn1":
            model = cnn1_cifar10_model(logits=logits,
                                       input_range_type=input_range_type,
                                       pre_filter=pre_filter)
        elif model_name == "densenet":
            model = densenet_cifar10_model(logits=logits,
                                           input_range_type=input_range_type,
                                           pre_filter=pre_filter)
            model_weights_fpath = get_densenet_weights_path(self.dataset_name)
        elif model_name == "lenet":
            model = lenet_cifar10_model(logits=logits,
                                        input_range_type=input_range_type,
                                        pre_filter=pre_filter)
        elif model_name in ['resnet20']:
            model = resnet20_cifar10_model()
        elif model_name in ['resnet32']:
            model = resnet32_cifar10_model()
        elif model_name in ['resnet44']:
            model = resnet44_cifar10_model()
        elif model_name in ['resnet56']:
            model = resnet56_cifar10_model()
        elif model_name in ['resnet110']:
            model = resnet110_cifar10_model()
        elif model_name in ['distillation']:
            model = distillation_cifar10_model()
        print("\n===Defined TensorFlow model graph.")
        #if model_name not in ['distillation']:
        model.load_weights(model_weights_fpath)
        print("---Loaded CIFAR-10-%s model.\n" % model_name)
        return model
示例#2
0
    def load_model_by_name(self,
                           model_name,
                           logits=False,
                           input_range_type=1,
                           pre_filter=lambda x: x):
        """
        :params logits: return logits(input of softmax layer) if True; return softmax output otherwise.
        :params input_range_type: {1: [0,1], 2:[-0.5, 0.5], 3:[-1, 1]...}
        """
        if model_name not in [
                "cleverhans", 'cleverhans_adv_trained', 'carlini', 'densenet'
        ]:
            raise NotImplementedError("Undefined model [%s] for %s." %
                                      (model_name, self.dataset_name))
        self.model_name = model_name

        model_weights_fpath = "%s_%s.keras_weights.h5" % (self.dataset_name,
                                                          model_name)
        model_weights_fpath = os.path.join(PARENT_DIR,
                                           'downloads/trained_models',
                                           model_weights_fpath)

        if model_name in ["cleverhans", 'cleverhans_adv_trained']:
            model = cleverhans_cifar10_model(logits=logits,
                                             input_range_type=input_range_type,
                                             pre_filter=pre_filter)
        elif model_name == "carlini":
            model = carlini_cifar10_model(logits=logits,
                                          input_range_type=input_range_type,
                                          pre_filter=pre_filter)
        elif model_name == "densenet":
            from models.densenet_models import densenet_cifar10_model
            model = densenet_cifar10_model(logits=logits,
                                           input_range_type=input_range_type,
                                           pre_filter=pre_filter)
        print("\n===Defined TensorFlow model graph.")
        model.load_weights(model_weights_fpath)
        print("---Loaded CIFAR-10-%s model.\n" % model_name)
        return model