Ejemplo n.º 1
0
def predict(image):
    if not is_model_exists():
        raise Exception('В начале обучите данные! Модель не найдена!')

    # Инициализируем наши параметры
    params = Params()

    # ПОДГОТОВКА ВХОДНЫХ ДАННЫХ
    # ----------
    # Считываем изображение, которое необходимо распознать
    # Ввиду того, что вход НС имеет вид [None, image_height, image_width, num_channels]
    # мы преобразуем наши данные к нужной форме
    images = prepare_image_for_predict(image, params.base_params.image_size)
    x_batch = images.reshape(1, params.base_params.image_height,
                             params.base_params.image_width,
                             params.base_params.num_channels)

    # ВОССТАНОВЛЕНИЕ МОДЕЛИ
    # ----------
    session = tf.Session()

    # Загружаем/восстанавливаем сохраненную обученную модель
    saver = tf.train.import_meta_graph(model_dir + model_name + '.meta')
    saver.restore(session, tf.train.latest_checkpoint(model_dir))

    graph = tf.get_default_graph()

    # Отвечает за предсказание сети
    y_pred = graph.get_tensor_by_name("y_pred:0")

    # Передаем данные сети на вход
    x = graph.get_tensor_by_name("x:0")
    y = graph.get_tensor_by_name("y:0")

    # Узнаем сколько/каких классов нужно нам распознать
    classes = next(os.walk(train_path))[1]

    y_test_images = np.zeros((1, len(classes)))

    name_of_classes = read_json(train_path + '/classes.json')

    # ПРЕДСКАЗАНИЕ
    # ----------
    feed_dict_test = {x: x_batch, y: y_test_images}
    result = session.run(y_pred, feed_dict=feed_dict_test)

    # Возвращаем какой класс и какая вертоятность что это он
    cls_id = classes[np.argmax(result)]
    return name_of_classes[cls_id] + ' (' + cls_id + ')', np.amax(result) * 100
Ejemplo n.º 2
0
    def test_from_dict_empty(self):
        """
        Asserts a Params class is correctly instantiated from
        an empty dict
        """

        params = Params.from_dict({})
        self.assertTrue(params,
                        msg='Params could not be instantiated from dict.')
        keys = Params.__dataclass_fields__.keys()
        default_params = Params()
        for k in keys:
            self.assertEqual(default_params.__getattribute__(k),
                             params.__getattribute__(k),
                             msg=f'Key {k} does not match in Params class.')
Ejemplo n.º 3
0
def console_train(path=train_path):
    # Инициализируем наши параметры
    params = Params()

    # ВХОДНЫЕ ДАННЫЕ
    # ----------
    # Подготавливаем входные данные
    # Классы и их количество, которые хотим в дальнейшмем будем распознавать (пример: 'Цветок', 'Машина')
    classes = next(os.walk(path))[1]
    num_classes = len(classes)

    # Подгружаем входные данные для тренировки сети
    data = read_train_sets(path,
                           params.base_params.image_size,
                           classes,
                           test_size=0.2)

    print("Тренировочные данные: {}".format(len(data.train.labels)))
    print("Проверочные данные: {}".format(len(data.test.labels)))

    train(num_classes, data, params)
Ejemplo n.º 4
0
    def test_from_dict_partial(self):
        """
        Asserts a Params class is correctly instantiated from
        a partial dict
        """

        params_dict = {'GEOHASH_PRECISION_GROUPING': 4}
        params = Params.from_dict(params_dict)
        self.assertTrue(params,
                        msg='Params could not be instantiated from dict.')
        keys = Params.__dataclass_fields__.keys()
        default_params = Params()
        for k in keys:
            if k in params_dict.keys():
                self.assertEqual(
                    params_dict[k],
                    params.__getattribute__(k),
                    msg=f'Key {k} does not match in Params class.')
            else:
                self.assertEqual(
                    default_params.__getattribute__(k),
                    params.__getattribute__(k),
                    msg=f'Key {k} does not match in Params class.')
Ejemplo n.º 5
0
 def create_params_from_attributes(attrs):
     params = Params()
     # Performance variables
     for variable, value in attrs.items():
         setattr(params, variable, value)
     return params
Ejemplo n.º 6
0
def get_params() -> Params:
    """Method to obtain a set of Params"""

    return Params()