parser.add_argument('--img_channels', type=int, default=1, help='0: Use the number of channels in the image, ' '1: Grayscale image, 3: RGB image') parser.add_argument('--ignore_case', action='store_true', help='Whether ignore case.(default false)') parser.add_argument('--restore', type=str, help='The model for restore, even if the number of ' 'characters is different') args = parser.parse_args() localtime = time.asctime() dataset_builder = DatasetBuilder(args.table_path, args.img_width, args.img_channels, args.ignore_case) train_ds = dataset_builder.build(args.train_ann_paths, args.batch_size, True) saved_model_prefix = '{epoch}_{word_accuracy:.4f}' if args.val_ann_paths: val_ds = dataset_builder.build(args.val_ann_paths, args.batch_size, False) saved_model_prefix = saved_model_prefix + '_{val_word_accuracy:.4f}' else: val_ds = None saved_model_path = f'saved_models/{localtime}/{saved_model_prefix}.h5' Path('saved_models', localtime).mkdir() print('Training start at {}'.format(localtime)) model = build_model(dataset_builder.num_classes, img_channels=args.img_channels) model.compile(optimizer=keras.optimizers.Adam(args.learning_rate), loss=CTCLoss(),
args = parser.parse_args() with args.config.open() as f: config = yaml.load(f, Loader=yaml.Loader)['train'] print(config) args.save_dir.mkdir(exist_ok=True) if list(args.save_dir.iterdir()): raise ValueError(f'{args.save_dir} is not a empty folder') shutil.copy(args.config, args.save_dir / args.config.name) model_prefix = '{epoch}_{sequence_accuracy:.4f}_{val_sequence_accuracy:.4f}' model_path = f'{args.save_dir}/{model_prefix}.h5' strategy = tf.distribute.MirroredStrategy() batch_size = config['batch_size_per_replica'] * strategy.num_replicas_in_sync dataset_builder = DatasetBuilder(**config['dataset_builder']) train_ds = dataset_builder.build(config['train_ann_paths'], batch_size, True) val_ds = dataset_builder.build(config['val_ann_paths'], batch_size, False) with strategy.scope(): model = build_model(dataset_builder.num_classes, config['dataset_builder']['img_shape']) model.compile(optimizer=keras.optimizers.Adam(config['learning_rate']), loss=CTCLoss(), metrics=[SequenceAccuracy()]) if config['restore']: model.load_weights(config['restore'], by_name=True, skip_mismatch=True) model.summary()
import argparse import yaml from tensorflow import keras from dataset_factory import DatasetBuilder parser = argparse.ArgumentParser() parser.add_argument('--config', type=str, required=True, help='The config file path.') parser.add_argument('--model', type=str, required=True, help='The saved model path.') args = parser.parse_args() with open(args.config) as f: config = yaml.load(f, Loader=yaml.Loader)['predict'] dataset_builder = DatasetBuilder(**config['dataset_builder']) ds = dataset_builder.build(config['ann_paths'], config['batch_size'], False) model = keras.models.load_model(args.model, compile=False) print(model.predict(ds).shape)