示例#1
0
    def __init__(self):
        self.name = 'densenet'
        self.model_filename = 'networks/models/densenet.h5'
        self.growth_rate = 12
        self.depth = 100
        self.compression = 0.5
        self.num_classes = 10
        self.img_rows, self.img_cols = 32, 32
        self.img_channels = 3
        self.batch_size = 64  # 64 or 32 or other
        self.epochs = 250
        self.iterations = 782
        self.weight_decay = 0.0001
        self.log_filepath = r'networks/models/densenet/'

        self.acc = 0.9467  # Precalculated result for cifar10

        try:
            self._model = load_model(self.model_filename)
            self.param_count = self._model.count_params()
            print('Successfully loaded', self.name)
        except (ImportError, ValueError, OSError):
            print('Failed to load', self.name)
            print('Downloading model')
            try:
                download_model(self.name)
                self._model = load_model(self.model_filename)
                self.param_count = self._model.count_params()
                print('Successfully loaded', self.name)
            except (ImportError, ValueError, OSError):
                print('Failed to download model')
    def __init__(self):
        self.name = 'wide_resnet'
        self.model_filename = 'networks/models/wide_resnet.h5'

        self.depth = 16
        self.wide = 8
        self.num_classes = 10
        self.img_rows, self.img_cols = 32, 32
        self.img_channels = 3
        self.batch_size = 128
        self.epochs = 200
        self.iterations = 391
        self.weight_decay = 0.0005
        self.log_filepath = r'networks/models/wide_resnet/'

        self.acc = 0.9534  # Precalculated result for cifar10

        try:
            self._model = load_model(self.model_filename)
            self.param_count = self._model.count_params()
            print('Successfully loaded', self.name)
        except (ImportError, ValueError, OSError):
            print('Failed to load', self.name)
            print('Downloading model')
            try:
                download_model(self.name)
                self._model = load_model(self.model_filename)
                self.param_count = self._model.count_params()
                print('Successfully loaded', self.name)
            except (ImportError, ValueError, OSError):
                print('Failed to download model')
    def __init__(self):
        self.name               = 'lecun_net'
        self.model_filename     = 'networks/models/lecun_net.h5'
        self.num_classes        = 10
        self.input_shape        = 32, 32, 3
        self.batch_size         = 128
        self.epochs             = 200
        self.iterations         = 391
        self.weight_decay       = 0.0001
        self.log_filepath       = r'networks/models/lecun_net/'

        self.acc = 0.7488 # Precalculated result for cifar10

        try:
            self._model = load_model(self.model_filename)
            self.param_count = self._model.count_params()
            print('Successfully loaded', self.name)
        except (ImportError, ValueError, OSError):
            print('Failed to load', self.name)
            print('Downloading model')
            try:
                download_model(self.name)
                self._model = load_model(self.model_filename)
                self.param_count = self._model.count_params()
                print('Successfully loaded', self.name)
            except (ImportError, ValueError, OSError):
                print('Failed to download model')
示例#4
0
    def __init__(self, epochs=350, batch_size=128, load_weights=True):
        self.name               = 'pure_cnn'
        self.model_filename     = 'networks/models/pure_cnn.h5'
        self.num_classes        = 10
        self.input_shape        = 32, 32, 3
        self.batch_size         = batch_size
        self.epochs             = epochs
        self.learn_rate         = 1.0e-4
        self.log_filepath       = r'networks/models/pure_cnn/'

        if load_weights:
            try:
                self._model = load_model(self.model_filename)
                self.param_count = self._model.count_params()
                print('Successfully loaded', self.name)
            except (ImportError, ValueError, OSError):
                print('Failed to load', self.name)
                print('Downloading model')
                try:
                    download_model(self.name)
                    self._model = load_model(self.model_filename)
                    self.param_count = self._model.count_params()
                    print('Successfully loaded', self.name)
                except (ImportError, ValueError, OSError):
                    print('Failed to download model')
示例#5
0
    def __init__(self, epochs=200, batch_size=128, load_weights=True):
        self.name = 'lecun_net'
        self.model_filename = 'networks/models/lecun_net.h5'
        self.num_classes = 10
        self.input_shape = 32, 32, 3
        self.batch_size = batch_size
        self.epochs = epochs
        self.iterations = 391
        self.weight_decay = 0.0001
        self.log_filepath = r'networks/models/lecun_net/'

        if load_weights:
            try:
                self._model = load_model(self.model_filename)
                self.param_count = self._model.count_params()
                print('Successfully loaded', self.name)
            except (ImportError, ValueError, OSError):
                print('Failed to load', self.name)
                print('Downloading model')
                try:
                    download_model(self.name)
                    self._model = load_model(self.model_filename)
                    self.param_count = self._model.count_params()
                    print('Successfully loaded', self.name)
                except (ImportError, ValueError, OSError):
                    print('Failed to download model')
    def __init__(self):
        self.name               = 'capsnet'
        self.model_filename     = 'networks/models/capsnet.h5'
        self.num_classes        = 10
        self.input_shape        = 32, 32, 3
        self.num_routes         = 3
        self.batch_size         = 128

        self._model = CapsNetv1(input_shape=self.input_shape,
                        n_class=self.num_classes,
                        n_route=self.num_routes)
        try:
            self._model.load_weights(self.model_filename)
            self.param_count = self._model.count_params()
            print('Successfully loaded', self.name)
        except (ImportError, ValueError, OSError):
            print('Failed to load', self.name)
            print('Downloading model')
            try:
                download_model(self.name)
                self._model.load_weights(self.model_filename)
                self.param_count = self._model.count_params()
                print('Successfully loaded', self.name)
            except (ImportError, ValueError, OSError):
                print('Failed to download model')
示例#7
0
    def __init__(self, epochs=250, batch_size=64, load_weights=True):
        self.name = 'densenet'
        self.model_filename = 'networks/models/densenet.h5'
        self.growth_rate = 12
        self.depth = 100
        self.compression = 0.5
        self.num_classes = 10
        self.img_rows, self.img_cols = 32, 32
        self.img_channels = 3
        self.batch_size = batch_size
        self.epochs = epochs
        self.iterations = 782
        self.weight_decay = 0.0001
        self.log_filepath = r'networks/models/densenet/'

        if load_weights:
            try:
                self._model = load_model(self.model_filename)
                self.param_count = self._model.count_params()
                print('Successfully loaded', self.name)
            except (ImportError, ValueError, OSError):
                print('Failed to load', self.name)
                print('Downloading model')
                try:
                    download_model(self.name)
                    self._model = load_model(self.model_filename)
                    self.param_count = self._model.count_params()
                    print('Successfully loaded', self.name)
                except (ImportError, ValueError, OSError):
                    print('Failed to download model')