示例#1
0
    def test_from_dict_empty(self):
        """
        Asserts a Params class is correctly instantiated from
        an empty dict
        """

        params = Params.from_dict({})
        self.assertTrue(params,
                        msg='Params could not be instantiated from dict.')
        keys = Params.__dataclass_fields__.keys()
        default_params = Params()
        for k in keys:
            self.assertEqual(default_params.__getattribute__(k),
                             params.__getattribute__(k),
                             msg=f'Key {k} does not match in Params class.')
示例#2
0
    def test_from_dict_partial(self):
        """
        Asserts a Params class is correctly instantiated from
        a partial dict
        """

        params_dict = {'GEOHASH_PRECISION_GROUPING': 4}
        params = Params.from_dict(params_dict)
        self.assertTrue(params,
                        msg='Params could not be instantiated from dict.')
        keys = Params.__dataclass_fields__.keys()
        default_params = Params()
        for k in keys:
            if k in params_dict.keys():
                self.assertEqual(
                    params_dict[k],
                    params.__getattribute__(k),
                    msg=f'Key {k} does not match in Params class.')
            else:
                self.assertEqual(
                    default_params.__getattribute__(k),
                    params.__getattribute__(k),
                    msg=f'Key {k} does not match in Params class.')
示例#3
0
def predict(image):
    if not is_model_exists():
        raise Exception('В начале обучите данные! Модель не найдена!')

    # Инициализируем наши параметры
    params = Params()

    # ПОДГОТОВКА ВХОДНЫХ ДАННЫХ
    # ----------
    # Считываем изображение, которое необходимо распознать
    # Ввиду того, что вход НС имеет вид [None, image_height, image_width, num_channels]
    # мы преобразуем наши данные к нужной форме
    images = prepare_image_for_predict(image, params.base_params.image_size)
    x_batch = images.reshape(1, params.base_params.image_height,
                             params.base_params.image_width,
                             params.base_params.num_channels)

    # ВОССТАНОВЛЕНИЕ МОДЕЛИ
    # ----------
    session = tf.Session()

    # Загружаем/восстанавливаем сохраненную обученную модель
    saver = tf.train.import_meta_graph(model_dir + model_name + '.meta')
    saver.restore(session, tf.train.latest_checkpoint(model_dir))

    graph = tf.get_default_graph()

    # Отвечает за предсказание сети
    y_pred = graph.get_tensor_by_name("y_pred:0")

    # Передаем данные сети на вход
    x = graph.get_tensor_by_name("x:0")
    y = graph.get_tensor_by_name("y:0")

    # Узнаем сколько/каких классов нужно нам распознать
    classes = next(os.walk(train_path))[1]

    y_test_images = np.zeros((1, len(classes)))

    name_of_classes = read_json(train_path + '/classes.json')

    # ПРЕДСКАЗАНИЕ
    # ----------
    feed_dict_test = {x: x_batch, y: y_test_images}
    result = session.run(y_pred, feed_dict=feed_dict_test)

    # Возвращаем какой класс и какая вертоятность что это он
    cls_id = classes[np.argmax(result)]
    return name_of_classes[cls_id] + ' (' + cls_id + ')', np.amax(result) * 100
示例#4
0
def read_entities(
    input_dir: str
) -> Tuple[Dict[str, Rider], Dict[str, Vehicle], Dict[str, Depot], Params]:
    """Method to parse the Riders, Vehicles and Depots from JSON to Dict"""

    riders_file = RIDERS_FILE.format(input_dir=input_dir)
    with open(riders_file) as f:
        logging.info(f'Read riders from {riders_file}.')
        riders_dicts = json.load(f)
    riders = {
        r_dict['rider_id']: Rider.from_dict(r_dict)
        for r_dict in riders_dicts
    }
    logging.info(f'Successfully parsed {len(riders)} riders.')

    vehicles_file = VEHICLES_FILE.format(input_dir=input_dir)
    with open(vehicles_file) as f:
        vehicles_dicts = json.load(f)
        logging.info(f'Read vehicles from {vehicles_file}.')
    vehicles = {
        v_dict['vehicle_id']: Vehicle.from_dict(v_dict)
        for v_dict in vehicles_dicts
    }
    logging.info(f'Successfully parsed {len(vehicles)} vehicles.')

    depots_file = DEPOTS_FILE.format(input_dir=input_dir)
    with open(depots_file) as f:
        depots_dicts = json.load(f)
        logging.info(f'Read depots from {depots_file}.')
    depots = {
        d_dict['depot_id']: Depot.from_dict(d_dict)
        for d_dict in depots_dicts
    }
    logging.info(f'Successfully parsed {len(depots)} depots.')

    params_file = PARAMS_FILE.format(input_dir=input_dir)
    with open(params_file) as f:
        logging.info(f'Read params from {params_file}.')
        params_dict = json.load(f)
    params = Params.from_dict(params_dict)
    logging.info(f'Successfully parsed {len(params_dict)} params.')

    return riders, vehicles, depots, params
示例#5
0
def console_train(path=train_path):
    # Инициализируем наши параметры
    params = Params()

    # ВХОДНЫЕ ДАННЫЕ
    # ----------
    # Подготавливаем входные данные
    # Классы и их количество, которые хотим в дальнейшмем будем распознавать (пример: 'Цветок', 'Машина')
    classes = next(os.walk(path))[1]
    num_classes = len(classes)

    # Подгружаем входные данные для тренировки сети
    data = read_train_sets(path,
                           params.base_params.image_size,
                           classes,
                           test_size=0.2)

    print("Тренировочные данные: {}".format(len(data.train.labels)))
    print("Проверочные данные: {}".format(len(data.test.labels)))

    train(num_classes, data, params)
示例#6
0
    def test_from_dict_complete(self):
        """
        Asserts a Params class is correctly instantiated from
        a complete dict
        """

        params_dict = {
            'GEOHASH_PRECISION_GROUPING': 8,
            'FIRST_SOLUTION_STRATEGY': 'AUTOMATIC',
            'SEARCH_METAHEURISTIC': 'AUTOMATIC',
            'SEARCH_TIME_LIMIT': 4,
            'SEARCH_SOLUTIONS_LIMIT': 3000,
        }
        params = Params.from_dict(params_dict)
        self.assertTrue(params,
                        msg='Params could not be instantiated from dict.')

        keys = Params.__dataclass_fields__.keys()
        for k in keys:
            self.assertEqual(params_dict[k],
                             params.__getattribute__(k),
                             msg=f'Key {k} does not match in Params class.')
示例#7
0
 def create_params_from_attributes(attrs):
     params = Params()
     # Performance variables
     for variable, value in attrs.items():
         setattr(params, variable, value)
     return params
示例#8
0
def get_params() -> Params:
    """Method to obtain a set of Params"""

    return Params()