Пример #1
0
 def _get_label_names(self, index_file):
     classes = []
     if index_file is not None and file.exists(index_file):
         label_file = h5py.File(index_file, 'r')
         classes_name = label_file['labels_list'][:]
         label_file.close()
         classes = [name.decode('utf-8') for name in classes_name]
     else:
         for data_file in self.img_list:
             annotation_file = data_file[1]
             tree = ET.ElementTree()
             parser = ET.XMLParser(target=ET.TreeBuilder())
             parser.feed(file.read(annotation_file, binary=True))
             tree._root = parser.close()
             objs = tree.findall('object')
             non_diff_objs = [
                 obj for obj in objs if int(obj.find('difficult').text) == 0
             ]
             objs = non_diff_objs
             for obj in objs:
                 class_name = obj.find('name').text.lower().strip()
                 if class_name not in classes:
                     classes.append(class_name)
         if index_file is not None:
             index_file_path = index_file
         else:
             index_file_path = INDEX_FILE_NAME
         if file.exists(index_file_path):
             file.remove(index_file_path)
         label_file = h5py.File(index_file, 'w')
         label_file.create_dataset('labels_list',
                                   data=[item.encode() for item in classes])
         label_file.close()
     return classes
Пример #2
0
def get_image_train_eval(data_path):
    """get image list when data struct is
    {
    |-- data_url
        |-- train
            |-- Images
                |-- a.jpg
                ...
            |-- Annotations
                |-- a.txt (or a.xml)
            |-- label_map_dict (optional)
        |-- eval
            |-- Images
                |-- b.jpg
                ...
            |-- Annotations
                |-- b.txt (or b.xml)
                ...
            |-- label_map_dict (optional)
        |-- label_map_dict (optional)
    }
    :param data_path: data store url
    Returns:
      train_data_list,
      eval_data_list,
    """
    image_list_train = []
    # get all labeled train data
    image_list_set = file.list_directory(os.path.join(data_path, 'train', 'Images'))
    assert not image_list_set == [], 'there is no file in data url'
    for i in image_list_set:
        if file.exists(os.path.join(data_path, 'train', 'Annotations', os.path.splitext(i)[0] + '.xml')):
            image_list_train.append([os.path.join(data_path, 'train', 'Images', i),
                                     os.path.join(data_path, 'train', 'Annotations', os.path.splitext(i)[0] + '.xml')])
        elif file.exists(os.path.join(data_path, 'train', 'Annotations', os.path.splitext(i)[0] + '.txt')):
            image_list_train.append([os.path.join(data_path, 'train', 'Images', i),
                                     file.read(os.path.join(data_path, 'train',
                                                            'Annotations',
                                                            os.path.splitext(i)[0] + '.txt'))])
    # get all labeled eval data
    image_list_eval = []
    image_list_set = []
    image_list_set = file.list_directory(os.path.join(data_path, 'eval', 'Images'))
    assert not image_list_set == [], 'there is no file in data url'
    for i in image_list_set:
        if file.exists(os.path.join(data_path, 'eval', 'Annotations', os.path.splitext(i)[0] + '.xml')):
            image_list_eval.append([os.path.join(data_path, 'eval', 'Images', i),
                                    os.path.join(data_path, 'eval', 'Annotations', os.path.splitext(i)[0] + '.xml')])
        elif file.exists(os.path.join(data_path, 'eval', 'Annotations', os.path.splitext(i)[0] + '.txt')):
            image_list_eval.append([os.path.join(data_path, 'eval', 'Images', i),
                                    file.read(os.path.join(data_path, 'eval',
                                                           'Annotations',
                                                           os.path.splitext(i)[0] + '.txt'))])

    return image_list_train, image_list_eval
Пример #3
0
def get_image_images_annotation(data_path, split_spec):
    """get image list when data struct is
   {
   |-- data_url
       |-- Images
           |-- a.jpg
           |-- b.jpg
           ...
       |-- Annotations
           |-- a.txt (or a.xml)
           |-- b.txt (or b.xml)
           ...
       |-- label_map_dict (optional)
   }
   :param data_path: data store url
   :param split_spec: split train percent if data doesn't have evaluation data
   Returns:
       train_data_list,
       eval_data_list,
   """
    image_set = []
    label_dict = {}
    label_num = 0
    class_name = []
    # get all labeled data
    image_list_set = file.list_directory(os.path.join(data_path, 'Images'))
    assert not image_list_set == [], 'there is no file in data url'
    for i in image_list_set:
        if file.exists(os.path.join(data_path, 'Annotations', os.path.splitext(i)[0] + '.xml')):
            image_set.append([os.path.join(data_path, 'Images', i),
                              os.path.join(data_path, 'Annotations', os.path.splitext(i)[0] + '.xml')])
        elif file.exists(os.path.join(data_path, 'Annotations', os.path.splitext(i)[0] + '.txt')):
            label_name = file.read(os.path.join(data_path, 'Annotations',
                                                os.path.splitext(i)[0] + '.txt'))
            if label_name not in label_dict.keys():
                label_dict[label_name] = label_num
                class_name.append(label_name)
                label_num = label_num + 1
            image_set.append([os.path.join(data_path, 'Images', i),
                             label_dict[label_name]])

    # split data to train and eval
    num_examples = len(image_set)
    train_num = int(num_examples * split_spec)
    shuffle_list = list(range(num_examples))
    random.shuffle(shuffle_list)
    image_list_train = []
    image_list_eval = []
    for idx, item in enumerate(shuffle_list):
        if idx < train_num:
            image_list_train.append(image_set[item])
        else:
            image_list_eval.append(image_set[item])
    return image_list_train, image_list_eval, class_name
Пример #4
0
def train_model(FLAGS):
    # data flow generator
    train_sequence, validation_sequence = data_flow(FLAGS.data_local, FLAGS.batch_size,
                                                    FLAGS.num_classes, FLAGS.input_size)

    optimizer = adam(lr=FLAGS.learning_rate, clipnorm=0.001)
    objective = 'binary_crossentropy'
    metrics = ['accuracy']
    model = model_fn(FLAGS, objective, optimizer, metrics)
    if FLAGS.restore_model_path != '' and file.exists(FLAGS.restore_model_path):
        if FLAGS.restore_model_path.startswith('s3://'):
            restore_model_name = FLAGS.restore_model_path.rsplit('/', 1)[1]
            file.copy(FLAGS.restore_model_path, '/cache/tmp/' + restore_model_name)
            model.load_weights('/cache/tmp/' + restore_model_name,by_name=True)
            os.remove('/cache/tmp/' + restore_model_name)
        else:
            model.load_weights(FLAGS.restore_model_path,by_name=True)
    if not os.path.exists(FLAGS.train_local):
        os.makedirs(FLAGS.train_local)
    tensorBoard = TensorBoard(log_dir=FLAGS.train_local)
    history = LossHistory(FLAGS)
    model.fit_generator(
        train_sequence,
        steps_per_epoch=len(train_sequence),
        epochs=FLAGS.max_epochs,
        verbose=1,
        callbacks=[history, tensorBoard],
        validation_data=validation_sequence,
        max_queue_size=10,
        workers=int(multiprocessing.cpu_count() * 0.7),
        use_multiprocessing=True,
        shuffle=True
    )

    print('training done!')

    if FLAGS.deploy_script_path != '':
        from save_model import save_pb_model
        save_pb_model(FLAGS, model)

    if FLAGS.test_data_url != '':
        print('test dataset predicting...')
        from eval import load_test_data
        img_names, test_data, test_labels = load_test_data(FLAGS)
        predictions = model.predict(test_data, verbose=0)

        right_count = 0
        for index, pred in enumerate(predictions):
            predict_label = np.argmax(pred, axis=0)
            test_label = test_labels[index]
            if predict_label == test_label:
                right_count += 1
        accuracy = right_count / len(img_names)
        print('accuracy: %0.4f' % accuracy)
        metric_file_name = os.path.join(FLAGS.train_local, 'metric.json')
        metric_file_content = '{"total_metric": {"total_metric_values": {"accuracy": %0.4f}}}' % accuracy
        with open(metric_file_name, "w") as f:
            f.write(metric_file_content + '\n')
    print('end')
Пример #5
0
def save_pb_model(FLAGS, model):
    if FLAGS.mode == 'train':
        pb_save_dir_local = FLAGS.train_local
        pb_save_dir_obs = FLAGS.train_url
    elif FLAGS.mode == 'save_pb':
        freeze_weights_file_dir = FLAGS.freeze_weights_file_path.rsplit(
            '/', 1)[0]
        if freeze_weights_file_dir.startswith('s3://'):
            pb_save_dir_local = '/cache/tmp'
            pb_save_dir_obs = freeze_weights_file_dir
        else:
            pb_save_dir_local = freeze_weights_file_dir
            pb_save_dir_obs = pb_save_dir_local

    signature = tf.saved_model.signature_def_utils.predict_signature_def(
        inputs={'input_img': model.input},
        outputs={'output_score': model.output})
    builder = tf.saved_model.builder.SavedModelBuilder(
        os.path.join(pb_save_dir_local, 'model'))
    legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
    builder.add_meta_graph_and_variables(
        sess=backend.get_session(),
        tags=[tf.saved_model.tag_constants.SERVING],
        signature_def_map={
            'predict_images': signature,
        },
        legacy_init_op=legacy_init_op)
    builder.save()
    print('save pb to local path success')

    if pb_save_dir_obs.startswith('s3://'):
        file.copy_parallel(os.path.join(pb_save_dir_local, 'model'),
                           os.path.join(pb_save_dir_obs, 'model'))
        print('copy pb to %s success' % pb_save_dir_obs)

    file.copy(os.path.join(FLAGS.deploy_script_path, 'config.json'),
              os.path.join(pb_save_dir_obs, 'model/config.json'))
    file.copy(os.path.join(FLAGS.deploy_script_path, 'customize_service.py'),
              os.path.join(pb_save_dir_obs, 'model/customize_service.py'))
    if file.exists(os.path.join(pb_save_dir_obs, 'model/config.json')) and \
            file.exists(os.path.join(pb_save_dir_obs, 'model/customize_service.py')):
        print('copy config.json and customize_service.py success')
    else:
        print('copy config.json and customize_service.py failed')
Пример #6
0
def check_args(FLAGS):
    if FLAGS.mode not in ['train', 'save_pb', 'eval']:
        raise Exception('FLAGS.mode error, should be train, save_pb or eval')
    if FLAGS.num_classes == 0:
        raise Exception(
            'FLAGS.num_classes error, '
            'should be a positive number associated with your classification task'
        )

    if FLAGS.mode == 'train':
        if FLAGS.data_url == '':
            raise Exception('you must specify FLAGS.data_url')
        if not file.exists(FLAGS.data_url):
            raise Exception('FLAGS.data_url: %s is not exist' % FLAGS.data_url)
        if FLAGS.restore_model_path != '' and (not file.exists(
                FLAGS.restore_model_path)):
            raise Exception('FLAGS.restore_model_path: %s is not exist' %
                            FLAGS.restore_model_path)
        if file.is_directory(FLAGS.restore_model_path):
            raise Exception(
                'FLAGS.restore_model_path must be a file path, not a directory, %s'
                % FLAGS.restore_model_path)
        if FLAGS.train_url == '':
            raise Exception('you must specify FLAGS.train_url')
        elif not file.exists(FLAGS.train_url):
            file.make_dirs(FLAGS.train_url)
        if FLAGS.deploy_script_path != '' and (not file.exists(
                FLAGS.deploy_script_path)):
            raise Exception('FLAGS.deploy_script_path: %s is not exist' %
                            FLAGS.deploy_script_path)
        if FLAGS.deploy_script_path != '' and file.exists(FLAGS.train_url +
                                                          '/model'):
            raise Exception(
                FLAGS.train_url +
                '/model is already exist, only one model directoty is allowed to exist'
            )
        if FLAGS.test_data_url != '' and (not file.exists(
                FLAGS.test_data_url)):
            raise Exception('FLAGS.test_data_url: %s is not exist' %
                            FLAGS.test_data_url)

    if FLAGS.mode == 'save_pb':
        if FLAGS.deploy_script_path == '' or FLAGS.freeze_weights_file_path == '':
            raise Exception(
                'you must specify FLAGS.deploy_script_path '
                'and FLAGS.freeze_weights_file_path when you want to save pb')
        if not file.exists(FLAGS.deploy_script_path):
            raise Exception('FLAGS.deploy_script_path: %s is not exist' %
                            FLAGS.deploy_script_path)
        if not file.is_directory(FLAGS.deploy_script_path):
            raise Exception(
                'FLAGS.deploy_script_path must be a directory, not a file path, %s'
                % FLAGS.deploy_script_path)
        if not file.exists(FLAGS.freeze_weights_file_path):
            raise Exception('FLAGS.freeze_weights_file_path: %s is not exist' %
                            FLAGS.freeze_weights_file_path)
        if file.is_directory(FLAGS.freeze_weights_file_path):
            raise Exception(
                'FLAGS.freeze_weights_file_path must be a file path, not a directory, %s '
                % FLAGS.freeze_weights_file_path)
        if file.exists(
                FLAGS.freeze_weights_file_path.rsplit('/', 1)[0] + '/model'):
            raise Exception('a model directory is already exist in ' +
                            FLAGS.freeze_weights_file_path.rsplit('/', 1)[0] +
                            ', please rename or remove the model directory ')

    if FLAGS.mode == 'eval':
        if FLAGS.eval_weights_path == '' and FLAGS.eval_pb_path == '':
            raise Exception(
                'you must specify FLAGS.eval_weights_path '
                'or FLAGS.eval_pb_path when you want to evaluate a model')
        if FLAGS.eval_weights_path != '' and FLAGS.eval_pb_path != '':
            raise Exception(
                'you must specify only one of FLAGS.eval_weights_path '
                'and FLAGS.eval_pb_path when you want to evaluate a model')
        if FLAGS.eval_weights_path != '' and (not file.exists(
                FLAGS.eval_weights_path)):
            raise Exception('FLAGS.eval_weights_path: %s is not exist' %
                            FLAGS.eval_weights_path)
        if FLAGS.eval_pb_path != '' and (not file.exists(FLAGS.eval_pb_path)):
            raise Exception('FLAGS.eval_pb_path: %s is not exist' %
                            FLAGS.eval_pb_path)
        if not file.is_directory(FLAGS.eval_pb_path) or (
                not FLAGS.eval_pb_path.endswith('model')):
            raise Exception(
                'FLAGS.eval_pb_path must be a directory named model '
                'which contain saved_model.pb and variables, %s' %
                FLAGS.eval_pb_path)
        if FLAGS.test_data_url == '':
            raise Exception(
                'you must specify FLAGS.test_data_url when you want to evaluate a model'
            )
        if not file.exists(FLAGS.test_data_url):
            raise Exception('FLAGS.test_data_url: %s is not exist' %
                            FLAGS.test_data_url)
Пример #7
0
def get_data_iter(data_path, train_file=None, val_file=None, split_spec=1,
                  hyper_train={}, hyper_val={}, **kwargs):
    train_set = None
    val_set = None
    train_list = None
    val_list = None
    if train_file is not None:
        assert file.exists(train_file), 'not found train file'
        train_path = file.read(train_file).split("\n")[0:-1]
        train_list = [path.replace('\r', '').split(' ') for path in train_path]
        train_list = [[os.path.join(data_path, path[0]),
                       os.path.join(data_path, path[1])] for path in train_list]
    if val_file is not None:
        assert file.exists(val_file), 'not found val file'
        val_path = file.read(val_file).split("\n")[0:-1]
        val_list = [path.replace('\r', '').split(' ') for path in val_path]
        val_list = [[os.path.join(data_path, path[0]),
                     os.path.join(data_path, path[1])] for path in val_list]
    if train_file is None and val_file is None:
        train_list, val_list, _ = get_image_list(data_path, split_spec)
    if 'anchors' not in kwargs:
        kwargs['anchors'] = [[116, 90, 156, 198, 373, 326],
                             [30, 61, 62, 45, 59, 119],
                             [10, 13, 16, 30, 33, 23]]
    if 'offsets' not in kwargs:
        kwargs['offsets'] = [(13, 13), (26, 26), (52, 52)]
    if train_list is not None and len(train_list) > 0:
        dataset = Detection_dataset(img_list=train_list,
                                    index_file=hyper_train.get(
                                        'index_file', None),
                                    width=hyper_train.get('width', 416),
                                    height=hyper_train.get('height', 416),
                                    is_train=True,
                                    ** kwargs)
        max_gt_box_number = max([len(item) for item in dataset.label_cache])
        batch_size = hyper_train.get('batch_size', 32)
        train_set = gluon.data.DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=hyper_train.get('shuffle', True),
            batchify_fn=_train_batchify_fn(max_gt_box_number),
            last_batch='rollover',
            num_workers=hyper_train.get('preprocess_threads', 4))
        next_data_batch = next(iter(train_set))
        setattr(train_set, 'reset', _reset)
        setattr(train_set, 'provide_data', _get_provide_data(next_data_batch))
        setattr(train_set, 'provide_label', _get_provide_label(
            next_data_batch, (batch_size, max_gt_box_number, 4), is_train=True))
    if val_list is not None and len(val_list) > 0:
        assert 'index_file' in hyper_val and file.exists(
            hyper_val['index_file']), 'not found label name file'
        dataset = Detection_dataset(img_list=val_list,
                                    index_file=hyper_val.get(
                                        'index_file'),
                                    width=hyper_val.get('width', 416),
                                    height=hyper_val.get('height', 416),
                                    is_train=False,
                                    ** kwargs)
        max_gt_box_number = max([len(item) for item in dataset.label_cache])
        batch_size = hyper_val.get('batch_size', 32)
        val_set = gluon.data.DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=hyper_val.get('shuffle', True),
            batchify_fn=_val_batchify_fn(max_gt_box_number),
            last_batch='keep',
            num_workers=hyper_val.get('preprocess_threads', 4))
        next_data_batch = next(iter(val_set))
        setattr(val_set, 'reset', _reset)
        setattr(val_set, 'provide_data', _get_provide_data(next_data_batch))
        setattr(val_set, 'provide_label', _get_provide_label(
            next_data_batch, is_train=False))
    return train_set, val_set
Пример #8
0
def train_model(FLAGS):
    # data flow generator
    train_sequence, validation_sequence = data_flow(FLAGS.data_local,
                                                    FLAGS.batch_size,
                                                    FLAGS.num_classes,
                                                    FLAGS.input_size)

    # optimizer = adam(lr=FLAGS.learning_rate, clipnorm=0.001)
    optimizer = Nadam(lr=FLAGS.learning_rate,
                      beta_1=0.9,
                      beta_2=0.999,
                      epsilon=1e-08,
                      schedule_decay=0.004)
    # optimizer = SGD(lr=FLAGS.learning_rate, momentum=0.9)
    objective = 'categorical_crossentropy'
    metrics = ['accuracy']
    model = model_fn(FLAGS, objective, optimizer, metrics)
    if FLAGS.restore_model_path != '' and file.exists(
            FLAGS.restore_model_path):
        if FLAGS.restore_model_path.startswith('s3://'):
            restore_model_name = FLAGS.restore_model_path.rsplit('/', 1)[1]
            file.copy(FLAGS.restore_model_path,
                      '/cache/tmp/' + restore_model_name)
            model.load_weights('/cache/tmp/' + restore_model_name)
            os.remove('/cache/tmp/' + restore_model_name)
        else:
            model.load_weights(FLAGS.restore_model_path)
    if not os.path.exists(FLAGS.train_local):
        os.makedirs(FLAGS.train_local)
    tensorBoard = TensorBoard(log_dir=FLAGS.train_local)
    # reduce_lr = ks.callbacks.ReduceLROnPlateau(monitor='val_acc', factor=0.5, verbose=1, patience=1,
    #                                            min_lr=1e-7)
    # 余弦退火学习率
    sample_count = len(train_sequence) * FLAGS.batch_size
    epochs = FLAGS.max_epochs
    warmup_epoch = 5
    batch_size = FLAGS.batch_size
    learning_rate_base = FLAGS.learning_rate
    total_steps = int(epochs * sample_count / batch_size)
    warmup_steps = int(warmup_epoch * sample_count / batch_size)

    warm_up_lr = WarmUpCosineDecayScheduler(
        learning_rate_base=learning_rate_base,
        total_steps=total_steps,
        warmup_learning_rate=0,
        warmup_steps=warmup_steps,
        hold_base_rate_steps=0,
    )
    history = LossHistory(FLAGS)
    model.fit_generator(train_sequence,
                        steps_per_epoch=len(train_sequence),
                        epochs=FLAGS.max_epochs,
                        verbose=1,
                        callbacks=[history, tensorBoard, warm_up_lr],
                        validation_data=validation_sequence,
                        max_queue_size=10,
                        workers=int(multiprocessing.cpu_count() * 0.7),
                        use_multiprocessing=True,
                        shuffle=True)

    print('training done!')

    if FLAGS.deploy_script_path != '':
        from save_model import save_pb_model
        save_pb_model(FLAGS, model)

    if FLAGS.test_data_url != '':
        print('test dataset predicting...')
        from eval import load_test_data
        img_names, test_data, test_labels = load_test_data(FLAGS)
        predictions = model.predict(test_data, verbose=0)

        right_count = 0
        for index, pred in enumerate(predictions):
            predict_label = np.argmax(pred, axis=0)
            test_label = test_labels[index]
            if predict_label == test_label:
                right_count += 1
        accuracy = right_count / len(img_names)
        print('accuracy: %0.4f' % accuracy)
        metric_file_name = os.path.join(FLAGS.train_local, 'metric.json')
        metric_file_content = '{"total_metric": {"total_metric_values": {"accuracy": %0.4f}}}' % accuracy
        with open(metric_file_name, "w") as f:
            f.write(metric_file_content + '\n')
    print('end')