예제 #1
0
def get_image_classese_train_eval(data_path):
    """get image list when data struct is
    {
    |-- data_url
        |-- train
            |-- class_1
                |-- a.jpg
                ...
            |-- class_2
                |-- b.jpg
                ...
            ...
        |-- eval
            |-- class_1
                |-- c.jpg
                ...
            |-- class_2
                |-- d.jpg
    }
    :param data_path: data store url
    Returns:
      train_data_list,
      eval_data_list,
    """
    image_label_name = {}
    image_list_train = []
    label_index = 0
    class_name = []
    # get all labeled train data
    image_list_set = file.list_directory(os.path.join(data_path, 'train'))
    assert not image_list_set == [], 'there is no file in data url'
    for i in image_list_set:
        if file.is_directory(os.path.join(data_path, 'train', i)):
            img_list = file.list_directory(os.path.join(data_path, 'train', i))
            for j in img_list:
                label = label_index
                class_name.append(i)
                if not '.xml' in j and not '.txt' in j:
                    image_list_train.append([os.path.join(data_path, 'train', i, j), label])
            image_label_name[i] = label_index
            label_index += 1

    # get all labeled eval data
    image_list_eval = []
    image_list_set = file.list_directory(os.path.join(data_path, 'eval'))
    assert not image_list_set == [], 'there is no file in data url'
    for i in image_list_set:
        if file.is_directory(os.path.join(data_path, 'eval', i)):
            img_list = file.list_directory(os.path.join(data_path, 'eval', i))
            for j in img_list:
                label = image_label_name[i]
                if not '.xml' in j and not '.txt' in j:
                    image_list_eval.append([os.path.join(data_path, 'eval', i, j), label])

    return image_list_train, image_list_eval, class_name
예제 #2
0
def get_image_classese_raw(data_path, split_spec):
    """get image list when data struct is
    {
    |-- data_url
        |-- class_1
            |-- a.jpg
            |-- b.jpg
        |-- class_2
            |-- c.jpg
            |-- d.jpg
            ...
        |-- label_map_dict (optional)
    }
    :param data_path: data store url
    Returns:
      train_data_list,
      eval_data_list,
    """
    image_set = []
    class_name = []
    # get all labeled train data
    image_list_set = file.list_directory(data_path)
    for i in image_list_set:
        if not file.is_directory(os.path.join(data_path, i)):
            image_list_set.remove(i)
    assert not image_list_set == [], 'there is no file in data url'
    label_index = 0
    for i in image_list_set:
        if file.is_directory(os.path.join(data_path, i)):
            img_list = file.list_directory(os.path.join(data_path, i))
            for j in img_list:
                label = label_index
                class_name.append(i)
                if not '.xml' in j and not '.txt' in j:
                    image_set.append([os.path.join(data_path, i, j), label])
            label_index += 1

    # split to train and eval
    image_list_train = []
    image_list_eval = []
    start_examples = 0
    for i in image_list_set:
        image_list_set = file.list_directory(os.path.join(data_path, i))
        num_examples = len(image_list_set)
        train_num = int(num_examples * split_spec)
        shuffle_list = list(range(start_examples, start_examples + num_examples))
        random.shuffle(shuffle_list)
        for idx, item in enumerate(shuffle_list):
            if idx < train_num:
                image_list_train.append(image_set[item])
            else:
                image_list_eval.append(image_set[item])
        start_examples += num_examples
    return image_list_train, image_list_eval, class_name
예제 #3
0
def get_image_list(data_path, split_spec):
    """get image list
    [[image_path, label_path]]
    :param data_path: data store url
    :param split_spec: split train percent if data doesn't have evaluation data
    Returns:
        train_data_list,
        eval_data_list,
    """
    image_list_train = []
    image_list_eval = []
    class_name = None
    file_list = file.list_directory(data_path)
    donot_have_directory = True
    if 'cache' in file_list:
        file_list.remove('cache')
    for i in file_list:
        if file.is_directory(os.path.join(data_path, i)):
            donot_have_directory = False
            break
    if 'Images' and 'Annotations' in file_list:
        image_list_train, image_list_eval, class_name = \
            get_image_images_annotation(data_path, split_spec)
    elif 'train' and 'eval' in file_list:
        file_list = file.list_directory(os.path.join(data_path, 'train'))
        is_raw = True
        if 'cache' in file_list:
            file_list.remove('cache')
        for i in file_list:
            if file.is_directory(os.path.join(data_path, 'train', i)):
                is_raw = False
                break
        if 'Images' and 'Annotations' in file_list:
            image_list_train, image_list_eval = get_image_train_eval(data_path)
        elif 'image_to_annotation.csv' in file_list:
            image_list_train, image_list_eval = get_image_csv(data_path)
        elif is_raw:
            image_list_train, image_list_eval = \
                get_image_train_eval_raw(data_path)
        else:
            image_list_train, image_list_eval, class_name = \
                get_image_classese_train_eval(data_path)

    elif donot_have_directory:
        image_list_train, image_list_eval, class_name = get_image_raw_txt(data_path, split_spec)
    else:
        image_list_train, image_list_eval, class_name = get_image_classese_raw(data_path, split_spec)
    return image_list_train, image_list_eval, class_name
예제 #4
0
def eval_model(FLAGS):
    if FLAGS.eval_weights_path != '':
        if file.is_directory(FLAGS.eval_weights_path):
            test_batch_h5(FLAGS)
        else:
            test_single_h5(FLAGS, FLAGS.eval_weights_path)
    elif FLAGS.eval_pb_path != '':
        test_single_model(FLAGS)
예제 #5
0
def check_args(FLAGS):
    if FLAGS.mode not in ['train', 'save_pb', 'eval']:
        raise Exception('FLAGS.mode error, should be train, save_pb or eval')
    if FLAGS.num_classes == 0:
        raise Exception(
            'FLAGS.num_classes error, '
            'should be a positive number associated with your classification task'
        )

    if FLAGS.mode == 'train':
        if FLAGS.data_url == '':
            raise Exception('you must specify FLAGS.data_url')
        if not file.exists(FLAGS.data_url):
            raise Exception('FLAGS.data_url: %s is not exist' % FLAGS.data_url)
        if FLAGS.restore_model_path != '' and (not file.exists(
                FLAGS.restore_model_path)):
            raise Exception('FLAGS.restore_model_path: %s is not exist' %
                            FLAGS.restore_model_path)
        if file.is_directory(FLAGS.restore_model_path):
            raise Exception(
                'FLAGS.restore_model_path must be a file path, not a directory, %s'
                % FLAGS.restore_model_path)
        if FLAGS.train_url == '':
            raise Exception('you must specify FLAGS.train_url')
        elif not file.exists(FLAGS.train_url):
            file.make_dirs(FLAGS.train_url)
        if FLAGS.deploy_script_path != '' and (not file.exists(
                FLAGS.deploy_script_path)):
            raise Exception('FLAGS.deploy_script_path: %s is not exist' %
                            FLAGS.deploy_script_path)
        if FLAGS.deploy_script_path != '' and file.exists(FLAGS.train_url +
                                                          '/model'):
            raise Exception(
                FLAGS.train_url +
                '/model is already exist, only one model directoty is allowed to exist'
            )
        if FLAGS.test_data_url != '' and (not file.exists(
                FLAGS.test_data_url)):
            raise Exception('FLAGS.test_data_url: %s is not exist' %
                            FLAGS.test_data_url)

    if FLAGS.mode == 'save_pb':
        if FLAGS.deploy_script_path == '' or FLAGS.freeze_weights_file_path == '':
            raise Exception(
                'you must specify FLAGS.deploy_script_path '
                'and FLAGS.freeze_weights_file_path when you want to save pb')
        if not file.exists(FLAGS.deploy_script_path):
            raise Exception('FLAGS.deploy_script_path: %s is not exist' %
                            FLAGS.deploy_script_path)
        if not file.is_directory(FLAGS.deploy_script_path):
            raise Exception(
                'FLAGS.deploy_script_path must be a directory, not a file path, %s'
                % FLAGS.deploy_script_path)
        if not file.exists(FLAGS.freeze_weights_file_path):
            raise Exception('FLAGS.freeze_weights_file_path: %s is not exist' %
                            FLAGS.freeze_weights_file_path)
        if file.is_directory(FLAGS.freeze_weights_file_path):
            raise Exception(
                'FLAGS.freeze_weights_file_path must be a file path, not a directory, %s '
                % FLAGS.freeze_weights_file_path)
        if file.exists(
                FLAGS.freeze_weights_file_path.rsplit('/', 1)[0] + '/model'):
            raise Exception('a model directory is already exist in ' +
                            FLAGS.freeze_weights_file_path.rsplit('/', 1)[0] +
                            ', please rename or remove the model directory ')

    if FLAGS.mode == 'eval':
        if FLAGS.eval_weights_path == '' and FLAGS.eval_pb_path == '':
            raise Exception(
                'you must specify FLAGS.eval_weights_path '
                'or FLAGS.eval_pb_path when you want to evaluate a model')
        if FLAGS.eval_weights_path != '' and FLAGS.eval_pb_path != '':
            raise Exception(
                'you must specify only one of FLAGS.eval_weights_path '
                'and FLAGS.eval_pb_path when you want to evaluate a model')
        if FLAGS.eval_weights_path != '' and (not file.exists(
                FLAGS.eval_weights_path)):
            raise Exception('FLAGS.eval_weights_path: %s is not exist' %
                            FLAGS.eval_weights_path)
        if FLAGS.eval_pb_path != '' and (not file.exists(FLAGS.eval_pb_path)):
            raise Exception('FLAGS.eval_pb_path: %s is not exist' %
                            FLAGS.eval_pb_path)
        if not file.is_directory(FLAGS.eval_pb_path) or (
                not FLAGS.eval_pb_path.endswith('model')):
            raise Exception(
                'FLAGS.eval_pb_path must be a directory named model '
                'which contain saved_model.pb and variables, %s' %
                FLAGS.eval_pb_path)
        if FLAGS.test_data_url == '':
            raise Exception(
                'you must specify FLAGS.test_data_url when you want to evaluate a model'
            )
        if not file.exists(FLAGS.test_data_url):
            raise Exception('FLAGS.test_data_url: %s is not exist' %
                            FLAGS.test_data_url)