def main(_): model_class = models.get_model_class(FLAGS.model) # Look up the model configuration. assert (FLAGS.config_name is None) != (FLAGS.config_json is None), ( "Exactly one of --config_name or --config_json is required.") config = (models.get_model_config(FLAGS.model, FLAGS.config_name) if FLAGS.config_name else config_util.parse_json(FLAGS.config_json)) config = configdict.ConfigDict(config) config_util.log_and_save_config(config, FLAGS.model_dir) # Create the estimator. run_config = tf.estimator.RunConfig(keep_checkpoint_max=1) estimator = estimator_util.create_estimator(model_class, config.hparams, run_config, FLAGS.model_dir) # Create an input function that reads the training dataset. We iterate through # the dataset once at a time if we are alternating with evaluation, otherwise # we iterate infinitely. train_input_fn = estimator_util.create_input_fn( file_pattern=FLAGS.train_files, input_config=config.inputs, mode=tf.estimator.ModeKeys.TRAIN, shuffle_values_buffer=FLAGS.shuffle_buffer_size, repeat=1 if FLAGS.eval_files else None) if not FLAGS.eval_files: estimator.train(train_input_fn, max_steps=FLAGS.train_steps) else: eval_input_fn = estimator_util.create_input_fn( file_pattern=FLAGS.eval_files, input_config=config.inputs, mode=tf.estimator.ModeKeys.EVAL) eval_args = { "val": (eval_input_fn, None) # eval_name: (input_fn, eval_steps) } for _ in estimator_runner.continuous_train_and_eval( estimator=estimator, train_input_fn=train_input_fn, eval_args=eval_args, train_steps=FLAGS.train_steps): # continuous_train_and_eval() yields evaluation metrics after each # training epoch. We don't do anything here. pass
def train(model, config): if FLAGS.model_dir: dir_name = "{}/{}_{}_{}".format( FLAGS.model_dir, FLAGS.model, FLAGS.config_name, datetime.datetime.now().strftime('%Y%m%d_%H%M%S')) config_util.log_and_save_config(config, dir_name) ds = input_ds.build_dataset( file_pattern=FLAGS.train_files, input_config=config.inputs, batch_size=config.hparams.batch_size, include_labels=True, shuffle_filenames=True, shuffle_values_buffer=FLAGS.shuffle_buffer_size, repeat=None) if FLAGS.eval_files: eval_ds = input_ds.build_dataset(file_pattern=FLAGS.eval_files, input_config=config.inputs, batch_size=config.hparams.batch_size, include_labels=True, shuffle_filenames=False, repeat=1) else: eval_ds = None assert config.hparams.optimizer == 'adam' lr = config.hparams.learning_rate beta_1 = 1.0 - config.hparams.one_minus_adam_beta_1 beta_2 = 1.0 - config.hparams.one_minus_adam_beta_2 epsilon = config.hparams.adam_epsilon optimizer = tf.keras.optimizers.Adam(learning_rate=lr, beta_1=beta_1, beta_2=beta_2, epsilon=epsilon) loss = tf.keras.losses.BinaryCrossentropy() metrics = [ tf.keras.metrics.Recall( name='r', class_id=config.inputs.primary_class, thresholds=config.hparams.prediction_threshold, ), tf.keras.metrics.Precision( name='p', class_id=config.inputs.primary_class, thresholds=config.hparams.prediction_threshold, ), ] model.compile(optimizer=optimizer, loss=loss, metrics=metrics) if getattr(config.hparams, 'decreasing_lr', False): def scheduler(epoch, lr): if epoch > 1: return lr / 10 else: return lr callbacks = [tf.keras.callbacks.LearningRateScheduler(scheduler)] else: callbacks = [] history = model.fit(ds, epochs=FLAGS.train_epochs, steps_per_epoch=FLAGS.train_steps, validation_data=eval_ds) if FLAGS.model_dir: model.save(dir_name) return history