def main(_): configure_environment(FLAGS.fp16_run) hparams = { HP_TOKEN_TYPE: HP_TOKEN_TYPE.domain.values[1], # Preprocessing HP_MEL_BINS: HP_MEL_BINS.domain.values[0], HP_FRAME_LENGTH: HP_FRAME_LENGTH.domain.values[0], HP_FRAME_STEP: HP_FRAME_STEP.domain.values[0], HP_HERTZ_LOW: HP_HERTZ_LOW.domain.values[0], HP_HERTZ_HIGH: HP_HERTZ_HIGH.domain.values[0], # Model HP_EMBEDDING_SIZE: HP_EMBEDDING_SIZE.domain.values[0], HP_ENCODER_LAYERS: HP_ENCODER_LAYERS.domain.values[0], HP_ENCODER_SIZE: HP_ENCODER_SIZE.domain.values[0], HP_PROJECTION_SIZE: HP_PROJECTION_SIZE.domain.values[0], HP_TIME_REDUCT_INDEX: HP_TIME_REDUCT_INDEX.domain.values[0], HP_TIME_REDUCT_FACTOR: HP_TIME_REDUCT_FACTOR.domain.values[0], HP_PRED_NET_LAYERS: HP_PRED_NET_LAYERS.domain.values[0], HP_PRED_NET_SIZE: HP_PRED_NET_SIZE.domain.values[0], HP_JOINT_NET_SIZE: HP_JOINT_NET_SIZE.domain.values[0], HP_LEARNING_RATE: HP_LEARNING_RATE.domain.values[0] } with tf.summary.create_file_writer( os.path.join(FLAGS.tb_log_dir, 'hparams_tuning')).as_default(): hp.hparams_config( hparams=[ HP_TOKEN_TYPE, HP_VOCAB_SIZE, HP_EMBEDDING_SIZE, HP_ENCODER_LAYERS, HP_ENCODER_SIZE, HP_PROJECTION_SIZE, HP_TIME_REDUCT_INDEX, HP_TIME_REDUCT_FACTOR, HP_PRED_NET_LAYERS, HP_PRED_NET_SIZE, HP_JOINT_NET_SIZE ], metrics=[ hp.Metric(METRIC_ACCURACY, display_name='Accuracy'), hp.Metric(METRIC_CER, display_name='CER'), hp.Metric(METRIC_WER, display_name='WER'), ], ) _hparams = {k.name: v for k, v in hparams.items()} if len(FLAGS) == 0: gpus = [ x.name.strip('/physical_device:') for x in tf.config.experimental.list_physical_devices('GPU') ] else: gpus = ['GPU:' + str(gpu_id) for gpu_id in FLAGS.gpus] strategy = tf.distribute.MirroredStrategy(devices=gpus) # strategy = None dtype = tf.float32 if FLAGS.fp16_run: dtype = tf.float16 # initializer = tf.keras.initializers.RandomUniform( # minval=-0.1, maxval=0.1) initializer = None if FLAGS.checkpoint is not None: checkpoint_dir = os.path.dirname(os.path.realpath(FLAGS.checkpoint)) _hparams = model_utils.load_hparams(checkpoint_dir) encoder_fn, idx_to_text, vocab_size = encoding.load_encoder( checkpoint_dir, hparams=_hparams) if strategy is not None: with strategy.scope(): model = build_keras_model(_hparams, initializer=initializer, dtype=dtype) model.load_weights(FLAGS.checkpoint) else: model = build_keras_model(_hparams, initializer=initializer, dtype=dtype) model.load_weights(FLAGS.checkpoint) logging.info('Restored weights from {}.'.format(FLAGS.checkpoint)) else: os.makedirs(FLAGS.output_dir, exist_ok=True) shutil.copy(os.path.join(FLAGS.data_dir, 'encoder.subwords'), os.path.join(FLAGS.output_dir, 'encoder.subwords')) encoder_fn, idx_to_text, vocab_size = encoding.load_encoder( FLAGS.output_dir, hparams=_hparams) _hparams[HP_VOCAB_SIZE.name] = vocab_size if strategy is not None: with strategy.scope(): model = build_keras_model(_hparams, initializer=initializer, dtype=dtype) else: model = build_keras_model(_hparams, initializer=initializer, dtype=dtype) model_utils.save_hparams(_hparams, FLAGS.output_dir) logging.info('Using {} encoder with vocab size: {}'.format( _hparams[HP_TOKEN_TYPE.name], vocab_size)) loss_fn = get_loss_fn( reduction_factor=_hparams[HP_TIME_REDUCT_FACTOR.name]) start_token = encoder_fn('')[0] decode_fn = decoding.greedy_decode_fn(model, start_token=start_token) accuracy_fn = metrics.build_accuracy_fn(decode_fn) cer_fn = metrics.build_cer_fn(decode_fn, idx_to_text) wer_fn = metrics.build_wer_fn(decode_fn, idx_to_text) optimizer = tf.keras.optimizers.SGD(_hparams[HP_LEARNING_RATE.name], momentum=0.9) if FLAGS.fp16_run: optimizer = mixed_precision.LossScaleOptimizer(optimizer, loss_scale='dynamic') encoder = model.layers[2] prediction_network = model.layers[3] encoder.summary() prediction_network.summary() model.summary() dev_dataset, _ = get_dataset(FLAGS.data_dir, 'dev', batch_size=FLAGS.batch_size, n_epochs=FLAGS.n_epochs, strategy=strategy, max_size=FLAGS.eval_size) # dev_steps = dev_specs['size'] // FLAGS.batch_size log_dir = os.path.join(FLAGS.tb_log_dir, datetime.now().strftime('%Y%m%d-%H%M%S')) with tf.summary.create_file_writer(log_dir).as_default(): hp.hparams(hparams) if FLAGS.mode == 'train': train_dataset, _ = get_dataset(FLAGS.data_dir, 'train', batch_size=FLAGS.batch_size, n_epochs=FLAGS.n_epochs, strategy=strategy) # train_steps = train_specs['size'] // FLAGS.batch_size os.makedirs(FLAGS.output_dir, exist_ok=True) checkpoint_template = os.path.join( FLAGS.output_dir, 'checkpoint_{step}_{val_loss:.4f}.hdf5') run_training(model=model, optimizer=optimizer, loss_fn=loss_fn, train_dataset=train_dataset, batch_size=FLAGS.batch_size, n_epochs=FLAGS.n_epochs, checkpoint_template=checkpoint_template, strategy=strategy, steps_per_log=FLAGS.steps_per_log, steps_per_checkpoint=FLAGS.steps_per_checkpoint, eval_dataset=dev_dataset, train_metrics=[], eval_metrics=[accuracy_fn, cer_fn, wer_fn], gpus=gpus) elif FLAGS.mode == 'eval' or FLAGS.mode == 'test': if FLAGS.checkpoint is None: raise Exception( 'You must provide a checkpoint to perform eval.') if FLAGS.mode == 'test': dataset, test_specs = get_dataset(FLAGS.data_dir, 'test', batch_size=FLAGS.batch_size, n_epochs=FLAGS.n_epochs) else: dataset = dev_dataset eval_start_time = time.time() eval_loss, eval_metrics_results = run_evaluate( model=model, optimizer=optimizer, loss_fn=loss_fn, eval_dataset=dataset, batch_size=FLAGS.batch_size, strategy=strategy, metrics=[accuracy_fn, cer_fn, wer_fn], gpus=gpus) validation_log_str = 'VALIDATION RESULTS: Time: {:.4f}, Loss: {:.4f}'.format( time.time() - eval_start_time, eval_loss) for metric_name, metric_result in eval_metrics_results.items(): validation_log_str += ', {}: {:.4f}'.format( metric_name, metric_result) print(validation_log_str)
def train(): hparams = { HP_TOKEN_TYPE: HP_TOKEN_TYPE.domain.values[0], HP_VOCAB_SIZE: HP_VOCAB_SIZE.domain.values[0], # Preprocessing HP_MEL_BINS: HP_MEL_BINS.domain.values[0], HP_FRAME_LENGTH: HP_FRAME_LENGTH.domain.values[0], HP_FRAME_STEP: HP_FRAME_STEP.domain.values[0], HP_HERTZ_LOW: HP_HERTZ_LOW.domain.values[0], HP_HERTZ_HIGH: HP_HERTZ_HIGH.domain.values[0], # Model HP_EMBEDDING_SIZE: HP_EMBEDDING_SIZE.domain.values[0], HP_ENCODER_LAYERS: HP_ENCODER_LAYERS.domain.values[0], HP_ENCODER_SIZE: HP_ENCODER_SIZE.domain.values[0], HP_TIME_REDUCT_INDEX: HP_TIME_REDUCT_INDEX.domain.values[0], HP_TIME_REDUCT_FACTOR: HP_TIME_REDUCT_FACTOR.domain.values[0], HP_PRED_NET_LAYERS: HP_PRED_NET_LAYERS.domain.values[0], HP_JOINT_NET_SIZE: HP_JOINT_NET_SIZE.domain.values[0], HP_SOFTMAX_SIZE: HP_SOFTMAX_SIZE.domain.values[0], HP_LEARNING_RATE: HP_LEARNING_RATE.domain.values[0] } if os.path.exists(os.path.join(FLAGS.model_dir, 'hparams.json')): _hparams = model_utils.load_hparams(FLAGS.model_dir) encoder_fn, vocab_size = encoding.load_encoder(FLAGS.model_dir, hparams=_hparams) model, loss_fn = model_utils.load_model(FLAGS.model_dir, vocab_size=vocab_size, hparams=_hparams) else: _hparams = {k.name: v for k, v in hparams.items()} texts_gen = common_voice.texts_generator(FLAGS.data_dir) encoder_fn, vocab_size = encoding.build_encoder( texts_gen, model_dir=FLAGS.model_dir, hparams=_hparams) model, loss_fn = build_keras_model(vocab_size, _hparams) logging.info('Using {} encoder with vocab size: {}'.format( _hparams[HP_TOKEN_TYPE.name], vocab_size)) dataset_fn = get_dataset_fn(FLAGS.data_dir, encoder_fn=encoder_fn, batch_size=FLAGS.batch_size, hparams=hparams) train_dataset, train_steps = dataset_fn('train') dev_dataset, dev_steps = dataset_fn('dev') optimizer = tf.keras.optimizers.Adam(_hparams[HP_LEARNING_RATE.name]) model.compile(loss=loss_fn, optimizer=optimizer, experimental_run_tf_function=False) os.makedirs(FLAGS.model_dir, exist_ok=True) checkpoint_fp = os.path.join(FLAGS.model_dir, 'model.{epoch:03d}-{val_loss:.4f}.hdf5') model_utils.save_hparams(_hparams, FLAGS.model_dir) model.fit(train_dataset, epochs=FLAGS.n_epochs, steps_per_epoch=train_steps, validation_data=dev_dataset, validation_steps=dev_steps, callbacks=[ tf.keras.callbacks.TensorBoard(FLAGS.tb_log_dir), tf.keras.callbacks.ModelCheckpoint(checkpoint_fp, save_weights_only=True) ])
def main(_): strategy, dtype = configure_environment( gpu_names=FLAGS.gpus, fp16_run=FLAGS.fp16_run) hparams, tb_hparams = setup_hparams( log_dir=FLAGS.tb_log_dir, checkpoint=FLAGS.checkpoint) os.makedirs(FLAGS.output_dir, exist_ok=True) if FLAGS.checkpoint is None: encoder_dir = FLAGS.data_dir else: encoder_dir = os.path.dirname(os.path.realpath(FLAGS.checkpoint)) shutil.copy( os.path.join(encoder_dir, 'encoder.subwords'), os.path.join(FLAGS.output_dir, 'encoder.subwords')) encoder_fn, idx_to_text, vocab_size = encoding.get_encoder( encoder_dir=FLAGS.output_dir, hparams=hparams) if HP_VOCAB_SIZE.name not in hparams: hparams[HP_VOCAB_SIZE.name] = vocab_size with strategy.scope(): model = build_keras_model(hparams, dtype=dtype) if FLAGS.checkpoint is not None: model.load_weights(FLAGS.checkpoint) logging.info('Restored weights from {}.'.format(FLAGS.checkpoint)) model_utils.save_hparams(hparams, FLAGS.output_dir) optimizer = tf.keras.optimizers.SGD(hparams[HP_LEARNING_RATE.name], momentum=0.9) if FLAGS.fp16_run: optimizer = mixed_precision.LossScaleOptimizer(optimizer, loss_scale='dynamic') logging.info('Using {} encoder with vocab size: {}'.format( hparams[HP_TOKEN_TYPE.name], vocab_size)) loss_fn = get_loss_fn( reduction_factor=hparams[HP_TIME_REDUCT_FACTOR.name]) decode_fn = decoding.greedy_decode_fn(model, hparams) accuracy_fn = metrics.build_accuracy_fn(decode_fn) wer_fn = metrics.build_wer_fn(decode_fn, idx_to_text) encoder = model.layers[2] prediction_network = model.layers[3] encoder.summary() prediction_network.summary() model.summary() dev_dataset = None if FLAGS.eval_size != 0: dev_dataset = get_dataset(FLAGS.data_dir, 'dev', batch_size=FLAGS.batch_size, n_epochs=FLAGS.n_epochs, strategy=strategy, max_size=FLAGS.eval_size) log_dir = os.path.join(FLAGS.tb_log_dir, datetime.now().strftime('%Y%m%d-%H%M%S')) with tf.summary.create_file_writer(log_dir).as_default(): hp.hparams(tb_hparams) if FLAGS.mode == 'train': train_dataset = get_dataset(FLAGS.data_dir, 'train', batch_size=FLAGS.batch_size, n_epochs=FLAGS.n_epochs, strategy=strategy) os.makedirs(FLAGS.output_dir, exist_ok=True) checkpoint_template = os.path.join(FLAGS.output_dir, 'checkpoint_{step}_{val_loss:.4f}.hdf5') run_training( model=model, optimizer=optimizer, loss_fn=loss_fn, train_dataset=train_dataset, batch_size=FLAGS.batch_size, n_epochs=FLAGS.n_epochs, checkpoint_template=checkpoint_template, hparams=hparams, strategy=strategy, steps_per_log=FLAGS.steps_per_log, steps_per_checkpoint=FLAGS.steps_per_checkpoint, eval_dataset=dev_dataset, train_metrics=[], eval_metrics=[accuracy_fn, wer_fn]) elif FLAGS.mode == 'eval' or FLAGS.mode == 'test': if FLAGS.checkpoint is None: raise Exception('You must provide a checkpoint to perform eval.') if FLAGS.mode == 'test': dataset = get_dataset(FLAGS.data_dir, 'test', batch_size=FLAGS.batch_size, n_epochs=FLAGS.n_epochs) else: dataset = dev_dataset eval_start_time = time.time() eval_loss, eval_metrics_results = run_evaluate( model=model, optimizer=optimizer, loss_fn=loss_fn, eval_dataset=dataset, batch_size=FLAGS.batch_size, hparams=hparams, strategy=strategy, metrics=[accuracy_fn, wer_fn], gpus=gpus) validation_log_str = 'VALIDATION RESULTS: Time: {:.4f}, Loss: {:.4f}'.format( time.time() - eval_start_time, eval_loss) for metric_name, metric_result in eval_metrics_results.items(): validation_log_str += ', {}: {:.4f}'.format(metric_name, metric_result) print(validation_log_str)