示例#1
0
def main(unused_argv):
    # 输出路径不存在就创建
    if not os.path.exists(FLAGS.output_model_dir):
        os.makedirs(FLAGS.output_model_dir)

    if FLAGS.params_file:
        dict_params = import_params_from_json(json_filename=FLAGS.params_file)
        parameters = Params(**dict_params)
    else:
        parameters = Params(
            train_batch_size=128,
            eval_batch_size=59,
            learning_rate=0.001,  # 1e-3 recommended
            learning_decay_rate=0.5,
            learning_decay_steps=23438 * 2,
            evaluate_every_epoch=1,
            save_interval=5e3,
            input_shape=(32, 304),
            image_channels=3,
            optimizer='adam',
            digits_only=False,
            alphabet=Alphabet.CHINESECHAR_LETTERS_DIGITS_EXTENDED,
            alphabet_decoding='same',
            csv_delimiter=' ',
            csv_files_train=FLAGS.csv_files_train,
            csv_files_eval=FLAGS.csv_files_eval,
            output_model_dir=FLAGS.output_model_dir,
            n_epochs=FLAGS.nb_epochs,
            gpu=FLAGS.gpu)

    model_params = {
        'Params': parameters,
    }
    # 保存配置
    parameters.export_experiment_params()

    os.environ['CUDA_VISIBLE_DEVICES'] = parameters.gpu
    config_sess = tf.ConfigProto()
    config_sess.gpu_options.per_process_gpu_memory_fraction = 0.4

    # Count number of image filenames in csv
    n_samples = 0
    for file in parameters.csv_files_train:
        with open(file, mode='r', encoding='utf8') as csvfile:
            n_samples += len(csvfile.readlines())

    save_checkpoints_steps = int(
        np.ceil(n_samples / parameters.train_batch_size))
    keep_checkpoint_max = parameters.n_epochs
    print(n_samples, 'save_checkpoints_steps', save_checkpoints_steps,
          ' keep_checkpoint_max', keep_checkpoint_max)
    # Config estimator

    est_config = tf.estimator.RunConfig()
    est_config = est_config.replace(
        keep_checkpoint_max=keep_checkpoint_max,
        save_checkpoints_steps=save_checkpoints_steps,
        session_config=config_sess,
        save_summary_steps=100,
        model_dir=parameters.output_model_dir)

    estimator = tf.estimator.Estimator(model_fn=crnn_fn,
                                       params=model_params,
                                       model_dir=parameters.output_model_dir,
                                       config=est_config)
    try:
        tensors_to_log = {'train_accuracy': 'train_accuracy'}
        logging_hook = tf.train.LoggingTensorHook(tensors=tensors_to_log,
                                                  every_n_iter=100)
        for e in range(0, parameters.n_epochs,
                       parameters.evaluate_every_epoch):
            estimator.train(input_fn=data_loader(
                csv_filename=parameters.csv_files_train,
                params=parameters,
                batch_size=parameters.train_batch_size,
                num_epochs=parameters.evaluate_every_epoch,
                data_augmentation=True,
                image_summaries=True),
                            hooks=[logging_hook])
            eval_results = estimator.evaluate(
                input_fn=data_loader(csv_filename=parameters.csv_files_eval,
                                     params=parameters,
                                     batch_size=parameters.eval_batch_size,
                                     num_epochs=1),
                steps=np.floor(n_samples / parameters.eval_batch_size),
            )
            print('Evaluation results: %s' % (str(eval_results)))
        # for tensorflow1.4
        # estimator.train(input_fn=input_fn(filename=parameters.csv_files_train,
        #                                   is_training=True
        #                                   params=parameters,
        #                                   batch_size=parameters.train_batch_size,
        #                                   num_epochs=parameters.n_epochs),
        #                 hooks=[logging_hook])
    except KeyboardInterrupt:
        print('Interrupted')
        estimator.export_savedmodel(
            os.path.join(parameters.output_model_dir, 'export'),
            preprocess_image_for_prediction(min_width=10))
        print('Exported model to {}'.format(
            os.path.join(parameters.output_model_dir, 'export')))

    estimator.export_savedmodel(
        os.path.join(parameters.output_model_dir, 'export'),
        preprocess_image_for_prediction(min_width=10))
    print('Exported model to {}'.format(
        os.path.join(parameters.output_model_dir, 'export')))
                                       config=est_config)

    # Count number of image filenames in csv
    n_samples = 0
    for file in parameters.csv_files_eval:
        with open(file, 'r', encoding='utf-8') as csvfile:
            reader = csv.reader(csvfile, delimiter=parameters.csv_delimiter)
            n_samples += len(list(reader))

    try:
        for e in trange(0, parameters.n_epochs,
                        parameters.evaluate_every_epoch):
            estimator.train(input_fn=data_loader(
                csv_filename=parameters.csv_files_train,
                params=parameters,
                batch_size=parameters.train_batch_size,
                num_epochs=parameters.evaluate_every_epoch,
                data_augmentation=True,
                image_summaries=True))
            estimator.evaluate(
                input_fn=data_loader(csv_filename=parameters.csv_files_eval,
                                     params=parameters,
                                     batch_size=parameters.eval_batch_size,
                                     num_epochs=1),
                steps=np.floor(n_samples / parameters.eval_batch_size))

    except KeyboardInterrupt:
        print('Interrupted')
        estimator.export_savedmodel(
            os.path.join(parameters.output_model_dir, 'export'),
            preprocess_image_for_prediction(min_width=10))
示例#3
0
                       save_checkpoints_secs=None,
                       save_summary_steps=1000)

    model_params = {
        'Params': params,
    }

    estimator = tf.estimator.Estimator(
        model_fn=crnn_fn,
        params=model_params,
        model_dir=args.get('model_dir'),
        config=est_config,
    )
    estimator.train(input_fn=data_loader(csv_filename=params.csv_files_train,
                                         params=params,
                                         batch_size=1,
                                         num_epochs=1,
                                         data_augmentation=True,
                                         image_summaries=True))

    estimator.export_savedmodel(
        args.get('export_dir'),
        serving_input_receiver_fn=preprocess_image_for_prediction(
            min_width=10))

#
# def _signature_def_to_tensors(signature_def):
#     g = tf.get_default_graph()
#     return {k: g.get_tensor_by_name(v.name) for k,v in signature_def.inputs.items()}, \
#            {k: g.get_tensor_by_name(v.name) for k,v in signature_def.outputs.items()}
#
# with tf.Session(graph=tf.Graph()) as sess:
save_checkpoints_steps=parameters.save_interval,
session_config=config_sess,
save_checkpoints_secs=None,
save_summary_steps=1000,
model_dir=parameters.output_model_dir)

estimator = tf.estimator.Estimator(model_fn=crnn_fn,
params=model_params,
model_dir=parameters.output_model_dir,
config=est_config
)

predictResults=estimator.predict(input_fn=data_loader(csv_filename='/Users/samueltin/Projects/sf/sf-image-generator/output/Test/sample.csv',
                                            params=parameters,
                                            batch_size=1,
                                            num_epochs=1,
                                            data_augmentation=False,
                                            image_summaries=False
                                            ))
is_vis=False

ans_dict = {}
ans_file = open('/Users/samueltin/Projects/sf/sf-image-generator/output/Test/sample.csv', mode='r', encoding='utf-8')
content = ans_file.read()
ans_file.close()
lines = content.split('\n')
for line in lines:
    items = line.split('\t')
    if len(items)>1:
        chats = items[1].split('{')
        label = ''.join(chats)