예제 #1
0
파일: train.py 프로젝트: zzmcdc/CRNN.tf2
    print('Num of val samples: {}'.format(val_size))
    saved_model_prefix = saved_model_prefix + '_{val_word_accuracy:.4f}'
else:
    val_ds = None
saved_model_path = ('saved_models/{}/'.format(localtime) + saved_model_prefix +
                    '.h5')
os.makedirs('saved_models/{}'.format(localtime))
print('Training start at {}'.format(localtime))

model = build_model(dataset_builder.num_classes,
                    args.img_width,
                    channels=args.img_channels)
model.compile(optimizer=keras.optimizers.SGD(args.learning_rate,
                                             momentum=0.9,
                                             clipnorm=1.0),
              loss=CTCLoss(),
              metrics=[WordAccuracy()])

if args.restore:
    model.load_weights(args.restore, by_name=True, skip_mismatch=True)

epoch_batch = 975000 / args.batch_size

warm_up_lr = WarmUpCosineDecayScheduler(learning_rate_base=args.learning_rate,
                                        total_steps=args.epochs * epoch_batch,
                                        warmup_learning_rate=0.0,
                                        warmup_steps=epoch_batch,
                                        hold_base_rate_steps=4 * epoch_batch)

callbacks = [
    warm_up_lr,
예제 #2
0
파일: train.py 프로젝트: xs06974/CRNN.tf2
localtime = time.asctime()
dataset_builder = DatasetBuilder(args.table_path, args.img_width, 
                  args.img_channels, args.ignore_case)
train_ds, train_size = dataset_builder.build(args.train_ann_paths, True, 
                                             args.batch_size)
print('Num of training samples: {}'.format(train_size))
saved_model_prefix = '{epoch:03d}_{word_accuracy:.4f}'
if args.val_ann_paths:
    val_ds, val_size = dataset_builder.build(args.val_ann_paths, False,
                                             args.batch_size)
    print('Num of val samples: {}'.format(val_size))
    saved_model_prefix = saved_model_prefix + '_{val_word_accuracy:.4f}'
else:
    val_ds = None
saved_model_path = ('saved_models/{}/'.format(localtime) + 
                    saved_model_prefix + '.h5')
os.makedirs('saved_models/{}'.format(localtime))
print('Training start at {}'.format(localtime))

model = build_model(dataset_builder.num_classes, channels=args.img_channels)
model.compile(optimizer=keras.optimizers.Adam(args.learning_rate),
              loss=CTCLoss(), metrics=[WordAccuracy()])

if args.restore:
    model.load_weights(args.restore, by_name=True, skip_mismatch=True)

callbacks = [keras.callbacks.ModelCheckpoint(saved_model_path),
             keras.callbacks.TensorBoard(log_dir='logs/{}'.format(localtime),
                                         profile_batch=0)]
model.fit(train_ds, epochs=args.epochs, callbacks=callbacks,
          validation_data=val_ds)
예제 #3
0
parser.add_argument("-t", "--table_path", type=str, required=True, 
                    help="The path of table file.")
parser.add_argument("-w", "--image_width", type=int, default=100, 
                    help="Image width, this parameter will affect the output "
                         "shape of the model, default is 100, so this model "
                         "can only predict up to 24 characters.")
parser.add_argument("-b", "--batch_size", type=int, default=256, 
                    help="Batch size.")
parser.add_argument("-m", "--model", type=str, required=True, 
                    help="The saved model.")
parser.add_argument("--channels", type=int, default=1, help="Image channels, "
                    "0: Use the number of channels in the image, "
                    "1: Grayscale image, "
                    "3: RGB image")
parser.add_argument("--ignore_case", action="store_true", 
                    help="Whether ignore case.(default false)")
args = parser.parse_args()

eval_ds, size, num_classes = build_dataset(
    args.annotation_paths,
    args.table_path,
    args.image_width,
    args.channels,
    args.ignore_case,
    batch_size=args.batch_size)
print("Num of eval samples: {}".format(size))

model = keras.models.load_model(args.model, compile=False)
model.summary()
model.compile(loss=CTCLoss(), metrics=[WordAccuracy()])
model.evaluate(eval_ds)
예제 #4
0
    print("Num of val samples: {}".format(len(val_dl)))
    # saved_model_path = ("saved_models/{}/".format(localtime) +
    #     "{epoch:03d}_{word_accuracy:.4f}_{val_word_accuracy:.4f}.h5")
else:
    val_dl = lambda: None

print("Start at {}".format(localtime))
# os.makedirs("saved_models/{}".format(localtime))

print('train_dl.num_classes', train_dl.num_classes)
model = crnn(train_dl.num_classes)
# model.build(input_shape=())
# print('model.summary={}'.format(model.summary()))

print('start compile')
custom_loss = CTCLoss()
print('custom_loss={}'.format(custom_loss))
# compute_accuracy=WordAccuracy()

start_learning_rate = args.learning_rate
learning_rate = tf.Variable(start_learning_rate, dtype=tf.float32)

optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

# model.compile(
#               optimizer=keras.optimizers.Adam(lr=args.learning_rate),
#               loss=custom_loss,
#               metrics=[WordAccuracy()]
# )
#
#
예제 #5
0
import argparse

import yaml
from tensorflow import keras

from dataset_factory import DatasetBuilder
from losses import CTCLoss
from metrics import SequenceAccuracy, EditDistance

parser = argparse.ArgumentParser()
parser.add_argument('--config',
                    type=str,
                    required=True,
                    help='The config file path.')
parser.add_argument('--model',
                    type=str,
                    required=True,
                    help='The saved model path.')
args = parser.parse_args()

with open(args.config) as f:
    config = yaml.load(f, Loader=yaml.Loader)['eval']

dataset_builder = DatasetBuilder(**config['dataset_builder'])
ds = dataset_builder.build(config['ann_paths'], config['batch_size'], False)
model = keras.models.load_model(args.model, compile=False)
model.compile(loss=CTCLoss(), metrics=[SequenceAccuracy(), EditDistance()])
model.evaluate(ds)
예제 #6
0
파일: train.py 프로젝트: modemlxsg/docs
config_file.close()
config = yaml.full_load(config)

# data
dataset = Mj_Dataset('train')
train_ds = dataset.getDS().batch(16)

val_ds = Mj_Dataset('val').getDS().batch(16)

nclass = config['crnn']['nClass']
model = CRNN(nclass)



optimizer = keras.optimizers.Adam(learning_rate=0.003)
criterion = CTCLoss(logits_time_major=False)


epochs = 10
for epoch in range(epochs):
    print(f'Start of epoch {epoch}')

    # train
    for step, (imgs, labels) in enumerate(train_ds):
        y_true = dataset.encode(labels)  # sparse_tensor

        with tf.GradientTape() as tape:
            y_pred = model(imgs)
            loss = criterion(y_true, y_pred)
        
        grads = tape.gradient(loss, model.trainable_weights)